In [None]:
import os
import pandas as pd
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import transforms
from tqdm import tqdm
from datetime import datetime
from sklearn.metrics import roc_auc_score

# ---------------------------
# Config
# ---------------------------
DATA_DIR   = "/Users/ramiab/Desktop/Mineral-Predictions-Local"
CSV_PATH   = os.path.join(DATA_DIR, "Training", "data", "preprocessed", "rock_features.csv")
IMG_DIR    = os.path.join(DATA_DIR, "Images", "Rock-Mag-Images")
BATCH_SIZE = 64
NUM_EPOCHS = 10
LR         = 5e-4  # Changed to optimal learning rate
VAL_SPLIT  = 0.2
DEVICE     = "cuda" if torch.cuda.is_available() else "cpu"
LABEL_NAMES = ['AU','AG','CU','CO','NI']

# Make sure models directory exists
os.makedirs(os.path.join(DATA_DIR, "Training", "models"), exist_ok=True)

# ---------------------------
# Model Architecture
# ---------------------------
class MineralMultiModalNet(nn.Module):
   def __init__(self, num_classes=5, num_geological_features=None):
       super().__init__()
       
       # 1. CNN Branch (for magnetic images) - Optimized 3-layer architecture
       self.cnn_branch = nn.Sequential(
           # First CNN Block
           nn.Conv2d(3, 32, kernel_size=3, padding=1),
           nn.BatchNorm2d(32),
           nn.ReLU(),
           nn.MaxPool2d(2),
           
           # Second CNN Block
           nn.Conv2d(32, 64, kernel_size=3, padding=1),
           nn.BatchNorm2d(64),
           nn.ReLU(),
           nn.MaxPool2d(2),
           
           # Third CNN Block
           nn.Conv2d(64, 128, kernel_size=3, padding=1),
           nn.BatchNorm2d(128),
           nn.ReLU(),
           nn.MaxPool2d(2),
       )
       
       # 2. Geological Features Branch
       self.geo_branch = nn.Sequential(
           nn.Linear(num_geological_features, 256),
           nn.BatchNorm1d(256),
           nn.ReLU(),
           nn.Dropout(0.3),
           nn.Linear(256, 128),
           nn.BatchNorm1d(128),
           nn.ReLU(),
           nn.Dropout(0.2)
       )
       
       # Calculate CNN output size (170 -> 85 -> 42 -> 21)
       cnn_output_size = 128 * 21 * 21
       
       # 3. Combination Layer
       self.combiner = nn.Sequential(
           nn.Linear(cnn_output_size + 128, 512),
           nn.BatchNorm1d(512),
           nn.ReLU(),
           nn.Dropout(0.3),
           nn.Linear(512, 256),
           nn.BatchNorm1d(256),
           nn.ReLU(),
           nn.Linear(256, num_classes)
       )
   
   def forward(self, img, geo_features):
       # Process image through CNN
       x_img = self.cnn_branch(img)
       x_img = x_img.view(x_img.size(0), -1)
       
       # Process geological features
       x_geo = self.geo_branch(geo_features)
       
       # Combine both feature sets
       combined = torch.cat([x_img, x_geo], dim=1)
       
       # Final prediction
       return self.combiner(combined)

# ---------------------------
# Dataset
# ---------------------------
class MineralDataset(Dataset):
   def __init__(self, csv_file, img_dir, transform=None):
       self.df = pd.read_csv(csv_file)
       
       # Get feature columns (excluding targets and ID)
       self.feature_cols = ['CODE_LITH_encoded', 'STRAT_encoded', 
                          'dist_fault', 'dist_cont']
       self.label_cols = [col for col in self.df.columns if col.endswith('_target')]
       
       self.img_dir = img_dir
       self.transform = transform
       
       print(f"Dataset initialized with {len(self.df)} samples")
       print(f"Features: {self.feature_cols}")
       print(f"Labels: {self.label_cols}")
       
   def __len__(self):
       return len(self.df)
   
   def __getitem__(self, idx):
       row = self.df.iloc[idx]
       
       # Get image
       img_id = int(row['UNIQUE_ID'])
       img_path = os.path.join(self.img_dir, f"{img_id}.jpg")
       img = Image.open(img_path).convert('RGB')
       if self.transform:
           img = self.transform(img)
           
       # Get geological features
       geo_features = torch.tensor(row[self.feature_cols].values, 
                                 dtype=torch.float32)
       
       # Get targets
       targets = torch.tensor(row[self.label_cols].values, 
                            dtype=torch.float32)
       
       return img, geo_features, targets

# ---------------------------
# Training Functions
# ---------------------------
def compute_class_weights(df, target_cols):
   weights = []
   for col in target_cols:
       base_weight = (1 - df[col].mean()) / df[col].mean()
       
       # Optimized weights based on hyperparameter tuning
       if col == 'AU_target':
           adjusted_weight = 12.0  # Sweet spot from testing
       elif col == 'AG_target':
           adjusted_weight = 12.0  # Sweet spot from testing
       elif col == 'CU_target':
           adjusted_weight = 7.0   # Sweet spot from testing
       elif col == 'CO_target':
           adjusted_weight = 7.0   # Sweet spot from testing
       elif col == 'NI_target':
           adjusted_weight = 5.0   # Sweet spot from testing
           
       weights.append(adjusted_weight)
       print(f"{col} - Positive samples: {df[col].sum()}, Weight: {adjusted_weight:.2f}")
   return torch.tensor(weights)

def train_one_epoch(model, train_loader, criterion, optimizer):
   model.train()
   running_loss = 0.0
   pbar = tqdm(train_loader, desc=f"Training", ncols=100)
   
   for images, geo_features, targets in pbar:
       images = images.to(DEVICE)
       geo_features = geo_features.to(DEVICE)
       targets = targets.to(DEVICE)
       
       optimizer.zero_grad()
       outputs = model(images, geo_features)
       loss = criterion(outputs, targets)
       loss.backward()
       optimizer.step()
       
       running_loss += loss.item()
       pbar.set_postfix({'loss': f"{running_loss / (pbar.n + 1):.4f}"})
   
   return running_loss / len(train_loader)

def validate(model, val_loader, criterion):
   model.eval()
   val_loss = 0
   all_preds = []
   all_targets = []
   pred_counts = {label: 0 for label in LABEL_NAMES}
   true_counts = {label: 0 for label in LABEL_NAMES}
   
   with torch.no_grad():
       for images, geo_features, targets in tqdm(val_loader, desc="Validating", ncols=100):
           images = images.to(DEVICE)
           geo_features = geo_features.to(DEVICE)
           targets = targets.to(DEVICE)
           
           outputs = model(images, geo_features)
           loss = criterion(outputs, targets)
           val_loss += loss.item()
           
           # Convert to probabilities and predictions
           probs = torch.sigmoid(outputs)
           preds = (probs >= 0.5).float()
           
           # Store predictions and targets
           all_preds.append(preds.cpu())
           all_targets.append(targets.cpu())
           
           # Count predictions and true values
           for i, label in enumerate(LABEL_NAMES):
               pred_counts[label] += preds[:, i].sum().item()
               true_counts[label] += targets[:, i].sum().item()
   
   # Combine all predictions and targets
   all_preds = torch.cat(all_preds, dim=0)
   all_targets = torch.cat(all_targets, dim=0)
   
   # Calculate AUC for each mineral
   aucs = {}
   for i, label in enumerate(LABEL_NAMES):
       if len(torch.unique(all_targets[:, i])) > 1:
           aucs[label] = roc_auc_score(all_targets[:, i], all_preds[:, i])
       else:
           aucs[label] = float('nan')
   
   return {
       'val_loss': val_loss / len(val_loader),
       'aucs': aucs,
       'pred_counts': pred_counts,
       'true_counts': true_counts
   }

# ---------------------------
# Main Training Loop
# ---------------------------
def main():
   # Set up transforms
   transform = transforms.Compose([
       transforms.Resize((170, 170)),
       transforms.ToTensor(),
       transforms.Normalize(mean=[0.485, 0.456, 0.406],
                          std=[0.229, 0.224, 0.225])
   ])
   
   # Create dataset
   dataset = MineralDataset(CSV_PATH, IMG_DIR, transform=transform)
   
   # Calculate number of geological features
   num_geo_features = len(dataset.feature_cols)
   
   # Create train/val split
   val_size = int(len(dataset) * VAL_SPLIT)
   train_size = len(dataset) - val_size
   train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
   
   train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
   val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
   
   # Initialize model
   model = MineralMultiModalNet(num_classes=len(LABEL_NAMES), 
                               num_geological_features=num_geo_features).to(DEVICE)
   
   # Calculate class weights
   target_cols = [col for col in dataset.df.columns if col.endswith('_target')]
   class_weights = compute_class_weights(dataset.df, target_cols).to(DEVICE)
   
   # Loss and optimizer
   criterion = nn.BCEWithLogitsLoss(pos_weight=class_weights)
   optimizer = optim.Adam(model.parameters(), lr=LR)
   
   print(f"\nStarting training with:")
   print(f"Device: {DEVICE}")
   print(f"Train samples: {train_size}")
   print(f"Validation samples: {val_size}")
   print(f"Batch size: {BATCH_SIZE}")
   print(f"Learning rate: {LR}")
   
   # Training loop
   for epoch in range(NUM_EPOCHS):
       print(f"\nEpoch [{epoch+1}/{NUM_EPOCHS}]")
       
       # Train
       train_loss = train_one_epoch(model, train_loader, criterion, optimizer)
       
       # Validate
       val_metrics = validate(model, val_loader, criterion)
       
       # Print epoch summary
       print(f"\nTraining Loss: {train_loss:.4f}")
       print(f"Validation Loss: {val_metrics['val_loss']:.4f}")
       print("\nPrediction Summary:")
       for label in LABEL_NAMES:
           print(f"{label}:")
           print(f"  Predicted positive: {val_metrics['pred_counts'][label]}")
           print(f"  Actually positive: {val_metrics['true_counts'][label]}")
           print(f"  AUC Score: {val_metrics['aucs'][label]:.4f}")
       
       # Save model
       timestamp = datetime.now().strftime('%Y%m%d')
       model_path = os.path.join(DATA_DIR, "Training", "models", 
                                f"cnn_model_{timestamp}_epoch{epoch+1}.pt")
       
       torch.save({
           'epoch': epoch + 1,
           'model_state_dict': model.state_dict(),
           'optimizer_state_dict': optimizer.state_dict(),
           'train_loss': train_loss,
           'val_metrics': val_metrics,
           'model': model  # Save complete model
       }, model_path)
       
       print(f"\nSaved model to: {model_path}")

if __name__ == "__main__":
   main()

Dataset initialized with 434850 samples
Features: ['CODE_LITH_encoded', 'STRAT_encoded', 'dist_fault', 'dist_cont']
Labels: ['AU_target', 'AG_target', 'CU_target', 'CO_target', 'NI_target']
AU_target - Positive samples: 11700, Weight: 12.00
AG_target - Positive samples: 15123, Weight: 12.00
CU_target - Positive samples: 10662, Weight: 7.00
CO_target - Positive samples: 8549, Weight: 7.00
NI_target - Positive samples: 17449, Weight: 5.00

Starting training with:
Device: cpu
Train samples: 347880
Validation samples: 86970
Batch size: 64
Learning rate: 0.0005

Epoch [1/10]


Training: 100%|██████████████████████████████████| 5436/5436 [2:51:08<00:00,  1.89s/it, loss=0.4294]
Validating: 100%|███████████████████████████████████████████████| 1359/1359 [19:17<00:00,  1.17it/s]



Training Loss: 0.4294
Validation Loss: 0.3686

Prediction Summary:
AU:
  Predicted positive: 6299.0
  Actually positive: 2324.0
  AUC Score: 0.8118
AG:
  Predicted positive: 6397.0
  Actually positive: 3041.0
  AUC Score: 0.7052
CU:
  Predicted positive: 2889.0
  Actually positive: 2163.0
  AUC Score: 0.6871
CO:
  Predicted positive: 3212.0
  Actually positive: 1663.0
  AUC Score: 0.7736
NI:
  Predicted positive: 4841.0
  Actually positive: 3426.0
  AUC Score: 0.8039

Saved model to: /Users/ramiab/Desktop/Mineral-Predictions-Local/Training/models/cnn_model_20250125_epoch1.pt

Epoch [2/10]


Training:   0%|                                     | 8/5436 [00:22<4:10:51,  2.77s/it, loss=0.3118]