# AI-Driven Lunar Soil Composition Analysis - Data Acquisition & Training

## Overview
This notebook handles:
1. **Data Acquisition**: Downloading Chang'e 3 PCAM and TCAM data.
2. **Data Preprocessing**: Creating a custom PyTorch Dataset.
3. **Model Training**: Training a ResNet-18 model for terrain classification.

## Setup
Ensure you are running this in an environment with internet access (e.g., Google Colab).

In [None]:
import os
import requests
from bs4 import BeautifulSoup
import pandas as pd
from urllib.parse import urljoin
from PIL import Image
from io import BytesIO
import matplotlib.pyplot as plt
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
import torch.nn as nn
import torch.optim as optim

# specific directories for data
DATA_DIR = '../data'
PCAM_DIR = os.path.join(DATA_DIR, 'pcam')
TCAM_DIR = os.path.join(DATA_DIR, 'tcam')

os.makedirs(PCAM_DIR, exist_ok=True)
os.makedirs(TCAM_DIR, exist_ok=True)

print(f"Data directories created: {PCAM_DIR}, {TCAM_DIR}")

## 1. Data Downloading Functions

In [None]:
def get_image_links(url):
    """Fetches image URLs from the Planetary Society index pages."""
    try:
        response = requests.get(url)
        response.raise_for_status()
        soup = BeautifulSoup(response.content, 'html.parser')
        links = []
        for a in soup.find_all('a', href=True):
            href = a['href']
            if href.lower().endswith(('.png', '.jpg')):
                full_url = urljoin(url, href)
                links.append(full_url)
        return links
    except Exception as e:
        print(f"Error fetching links from {url}: {e}")
        return []

def download_images(image_urls, save_dir, limit=20):
    """Downloads a limited number of images to the save directory."""
    count = 0
    for url in image_urls:
        if count >= limit:
            break
        filename = os.path.basename(url)
        save_path = os.path.join(save_dir, filename)
        if os.path.exists(save_path):
            # print(f"Skipping {filename} (already exists)")
            count += 1
            continue
        try:
            img_data = requests.get(url).content
            with open(save_path, 'wb') as f:
                f.write(img_data)
            print(f"Downloaded {filename}")
            count += 1
        except Exception as e:
            print(f"Failed to download {url}: {e}")
    print(f"Finished downloading {count} images to {save_dir}")

In [None]:
# URLs from the Planetary Society article
PCAM_INDEX_URL = 'http://planetary.s3.amazonaws.com/data/change3/pcam.html'
TCAM_INDEX_URL = 'http://planetary.s3.amazonaws.com/data/change3/tcam.html'

print("Fetching PCAM links...")
pcam_links = get_image_links(PCAM_INDEX_URL)
print(f"Found {len(pcam_links)} PCAM images.")

print("Fetching TCAM links...")
tcam_links = get_image_links(TCAM_INDEX_URL)
print(f"Found {len(tcam_links)} TCAM images.")

In [None]:
# Download a sample set (e.g., 50 images from each)
if pcam_links:
    print("Downloading sample PCAM images...")
    download_images(pcam_links, PCAM_DIR, limit=50)

if tcam_links:
    print("Downloading sample TCAM images...")
    download_images(tcam_links, TCAM_DIR, limit=50)

## 2. Data Loading & Preprocessing
We define a custom Dataset to load the images we just downloaded.

In [None]:
class LunarDataset(Dataset):
    def __init__(self, root_dirs, transform=None):
        self.image_paths = []
        self.transform = transform
        
        for d in root_dirs:
            if os.path.exists(d):
                for f in os.listdir(d):
                    if f.lower().endswith(('.png', '.jpg')):
                        self.image_paths.append(os.path.join(d, f))
                        
        print(f"Total images loaded: {len(self.image_paths)}")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        # Convert to RGB as some pngs might be RGBA
        image = Image.open(img_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
            
        # DUMMY LABEL GENERATION
        # In a real scenario, you would look up the label from a CSV based on filename
        # Classes: 0=Regolith, 1=Crater, 2=Boulder
        label = torch.randint(0, 3, (1,)).item() 
        
        return image, label

# Transforms for ResNet (224x224)
data_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

dataset = LunarDataset([PCAM_DIR, TCAM_DIR], transform=data_transform)

# Train/Val Split
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

print(f"Training samples: {len(train_dataset)}, Validation samples: {len(val_dataset)}")

## 3. Model Definition
Using ResNet-18 for transfer learning.

In [None]:
def get_model(num_classes=3, device='cpu'):
    # Load pretrained resnet
    model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
    num_ftrs = model.fc.in_features
    # Modify last layer
    model.fc = nn.Sequential(
        nn.Dropout(0.5),
        nn.Linear(num_ftrs, num_classes)
    )
    model.to(device)
    return model

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

model = get_model(device=device)

## 4. Training Loop

In [None]:
def train_model(model, train_loader, val_loader, num_epochs=5):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        
        for images, labels in train_loader:
            images = images.to(device)
            labels = labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            
        avg_loss = running_loss / len(train_loader) if len(train_loader) > 0 else 0
        print(f"Epoch {epoch+1}/{num_epochs} - Loss: {avg_loss:.4f}")
        
        # Validation
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for images, labels in val_loader:
                images = images.to(device)
                labels = labels.to(device)
                outputs = model(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        
        acc = 100 * correct / total if total > 0 else 0
        print(f"   Val Accuracy: {acc:.2f}%")

# Start Training
if len(dataset) > 0:
    train_model(model, train_loader, val_loader, num_epochs=5)
else:
    print("No data found. Please check download step.")

## 5. Save & Download Model
Since we are running in Colab, the file is saved on the remote server. We need to explicitly download it to your local machine.

In [None]:
os.makedirs('../models', exist_ok=True)
save_path = '../models/lunar_terrain_model.pth'

if len(dataset) > 0:
    torch.save(model.state_dict(), save_path)
    print(f"Model saved remotely to {save_path}")
    
    # ---- NEW: DOWNLOAD LOGIC ----
    print("Attempting to download file to your local machine...")
    try:
        from google.colab import files
        files.download(save_path)
        print("Download should start in your browser.")
    except ImportError:
        print("Could not import google.colab. If you are running locally, the file is already at:", os.path.abspath(save_path))
    except Exception as e:
        print(f"Download Error: {e}")