In [1]:
import os
import pandas as pd
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, random_split, Subset
from torchvision import transforms
from tqdm import tqdm
from datetime import datetime
from sklearn.metrics import roc_auc_score, confusion_matrix

# ---------------------------
# Config
# ---------------------------
DATA_DIR   = "/Users/ramiab/Desktop/Mineral-Predictions-Local"
CSV_PATH   = os.path.join(DATA_DIR, "Training", "data", "preprocessed", "rock_features.csv")
IMG_DIR    = os.path.join(DATA_DIR, "Images", "Rock-Mag-Images")
SUBSET_SIZE = 0.1  # Use 20% of data for quick tuning
NUM_EPOCHS = 1
BATCH_SIZE = 64
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
LABEL_NAMES = ['AU','AG','CU','CO','NI']

# Configurations to test - including different class weight strategies
configs = [
    {
        'name': 'Medium_Balanced',
        'cnn_layers': [(32, 3), (64, 3), (128, 3)],
        'geo_layers': [256, 128],  # Base geological architecture
        'lr': 5e-4,
        'dropout': 0.3,
        'weights': {'AU': 10.0, 'AG': 10.0, 'CU': 5.0, 'CO': 5.0, 'NI': 5.0}
    },
    {
        'name': 'Medium_DeepGeo',
        'cnn_layers': [(32, 3), (64, 3), (128, 3)],
        'geo_layers': [512, 256, 128],  # Deeper geological branch
        'lr': 5e-4,
        'dropout': 0.3,
        'weights': {'AU': 12.0, 'AG': 12.0, 'CU': 6.0, 'CO': 6.0, 'NI': 6.0}
    },
    {
        'name': 'Medium_WideGeo',
        'cnn_layers': [(32, 3), (64, 3), (128, 3)],
        'geo_layers': [1024, 512],  # Wider geological branch
        'lr': 5e-4,
        'dropout': 0.4,
        'weights': {'AU': 8.0, 'AG': 8.0, 'CU': 4.0, 'CO': 4.0, 'NI': 4.0}
    },
    {
        'name': 'Medium_LowDropout',
        'cnn_layers': [(32, 3), (64, 3), (128, 3)],
        'geo_layers': [256, 128],
        'lr': 5e-4,
        'dropout': 0.2,  # Lower dropout
        'weights': {'AU': 10.0, 'AG': 10.0, 'CU': 5.0, 'CO': 5.0, 'NI': 5.0}
    },
    {
        'name': 'Medium_HighDropout',
        'cnn_layers': [(32, 3), (64, 3), (128, 3)],
        'geo_layers': [256, 128],
        'lr': 5e-4,
        'dropout': 0.5,  # Higher dropout
        'weights': {'AU': 10.0, 'AG': 10.0, 'CU': 5.0, 'CO': 5.0, 'NI': 5.0}
    },
    {
        'name': 'Medium_SlowLearning',
        'cnn_layers': [(32, 3), (64, 3), (128, 3)],
        'geo_layers': [256, 128],
        'lr': 1e-4,  # Lower learning rate
        'dropout': 0.3,
        'weights': {'AU': 10.0, 'AG': 10.0, 'CU': 5.0, 'CO': 5.0, 'NI': 5.0}
    },
    {
        'name': 'Medium_SmallGeo',
        'cnn_layers': [(32, 3), (64, 3), (128, 3)],
        'geo_layers': [128, 64],  # Smaller geological branch
        'lr': 5e-4,
        'dropout': 0.3,
        'weights': {'AU': 10.0, 'AG': 10.0, 'CU': 5.0, 'CO': 5.0, 'NI': 5.0}
    },
    {
        'name': 'Medium_ResidualGeo',
        'cnn_layers': [(32, 3), (64, 3), (128, 3)],
        'geo_layers': [256, 256],  # Same size layers for residual
        'lr': 5e-4,
        'dropout': 0.3,
        'weights': {'AU': 10.0, 'AG': 10.0, 'CU': 5.0, 'CO': 5.0, 'NI': 5.0},
        'use_residual': True  # New parameter for residual connections
    },
    {
        'name': 'Medium_BatchNorm',
        'cnn_layers': [(32, 3), (64, 3), (128, 3)],
        'geo_layers': [256, 128],
        'lr': 5e-4,
        'dropout': 0.3,
        'weights': {'AU': 10.0, 'AG': 10.0, 'CU': 5.0, 'CO': 5.0, 'NI': 5.0},
        'use_batchnorm': True  # New parameter for batch normalization
    },
    {
        'name': 'Medium_Combined',
        'cnn_layers': [(32, 3), (64, 3), (128, 3)],
        'geo_layers': [512, 256],
        'lr': 5e-4,
        'dropout': 0.4,
        'weights': {'AU': 12.0, 'AG': 12.0, 'CU': 6.0, 'CO': 6.0, 'NI': 6.0},
        'use_batchnorm': True,
        'use_residual': True
    }
]
# ---------------------------
# Dataset
# ---------------------------
class MineralDataset(Dataset):
   def __init__(self, data, img_dir, transform=None):
       # Accept either DataFrame or file path
       self.df = data if isinstance(data, pd.DataFrame) else pd.read_csv(data)
       self.feature_cols = ['CODE_LITH_encoded', 'STRAT_encoded', 
                          'dist_fault', 'dist_cont']
       self.label_cols = [col for col in self.df.columns if col.endswith('_target')]
       self.img_dir = img_dir
       self.transform = transform
       
       # Print class distributions
       print("\nClass distributions:")
       for col in self.label_cols:
           pos_count = self.df[col].sum()
           pos_ratio = pos_count / len(self.df) * 100
           print(f"{col}: {pos_count} positives ({pos_ratio:.2f}%)")
   
   def __len__(self):
       return len(self.df)
   
   def __getitem__(self, idx):
       row = self.df.iloc[idx]
       img_id = int(row['UNIQUE_ID'])
       img_path = os.path.join(self.img_dir, f"{img_id}.jpg")
       img = Image.open(img_path).convert('RGB')
       if self.transform:
           img = self.transform(img)
       geo_features = torch.tensor(row[self.feature_cols].values, dtype=torch.float32)
       targets = torch.tensor(row[self.label_cols].values, dtype=torch.float32)
       return img, geo_features, targets

# ---------------------------
# Model Architecture
# ---------------------------
def create_model(config, num_geological_features):
    class QuickModel(nn.Module):
        def __init__(self):
            super().__init__()
            
            # CNN Branch
            cnn_layers = []
            in_channels = 3
            for out_channels, kernel_size in config['cnn_layers']:
                if config.get('use_batchnorm', False):
                    cnn_layers.extend([
                        nn.Conv2d(in_channels, out_channels, kernel_size, padding=1),
                        nn.BatchNorm2d(out_channels),
                        nn.ReLU(),
                        nn.MaxPool2d(2)
                    ])
                else:
                    cnn_layers.extend([
                        nn.Conv2d(in_channels, out_channels, kernel_size, padding=1),
                        nn.ReLU(),
                        nn.MaxPool2d(2)
                    ])
                in_channels = out_channels
            self.cnn_branch = nn.Sequential(*cnn_layers)
            
            # Calculate CNN output size
            self.cnn_output_size = in_channels * (170 // (2**len(config['cnn_layers'])))**2
            
            # Geological Branch
            geo_layers = []
            in_features = num_geological_features
            
            for out_features in config['geo_layers']:
                if config.get('use_batchnorm', False):
                    geo_layers.extend([
                        nn.Linear(in_features, out_features),
                        nn.BatchNorm1d(out_features),
                        nn.ReLU(),
                        nn.Dropout(config['dropout'])
                    ])
                else:
                    geo_layers.extend([
                        nn.Linear(in_features, out_features),
                        nn.ReLU(),
                        nn.Dropout(config['dropout'])
                    ])
                
                # Add residual connection if layers are same size
                if config.get('use_residual', False) and in_features == out_features:
                    geo_layers.append(Lambda(lambda x: x + geo_layers[-4](x)))
                
                in_features = out_features
            
            self.geo_branch = nn.Sequential(*geo_layers)
            
            # Combiner
            final_geo_size = config['geo_layers'][-1]
            self.combiner = nn.Sequential(
                nn.Linear(self.cnn_output_size + final_geo_size, 512),
                nn.ReLU(),
                nn.Dropout(config['dropout']),
                nn.Linear(512, 256),
                nn.ReLU(),
                nn.Linear(256, 5)
            )
        
        def forward(self, img, geo_features):
            x_img = self.cnn_branch(img)
            x_img = x_img.view(x_img.size(0), -1)
            x_geo = self.geo_branch(geo_features)
            combined = torch.cat([x_img, x_geo], dim=1)
            return self.combiner(combined)
    
    return QuickModel()

# Helper class for residual connections
class Lambda(nn.Module):
    def __init__(self, func):
        super().__init__()
        self.func = func
    
    def forward(self, x):
        return self.func(x)

def evaluate_predictions(targets, predictions, prefix=""):
   """Detailed evaluation metrics for imbalanced data"""
   results = {}
   for i, label in enumerate(LABEL_NAMES):
       tn, fp, fn, tp = confusion_matrix(targets[:, i], predictions[:, i]).ravel()
       precision = tp / (tp + fp) if (tp + fp) > 0 else 0
       recall = tp / (tp + fn) if (tp + fn) > 0 else 0
       
       results[label] = {
           'true_positives': int(tp),
           'false_positives': int(fp),
           'true_negatives': int(tn),
           'false_negatives': int(fn),
           'precision': precision,
           'recall': recall,
           'positive_ratio': (tp + fp) / len(targets) * 100  # % predicted positive
       }
       
       print(f"\n{prefix}{label}:")
       print(f"  True Positives: {tp}")
       print(f"  False Positives: {fp}")
       print(f"  Precision: {precision:.4f}")
       print(f"  Recall: {recall:.4f}")
       print(f"  Predicted Positive Ratio: {results[label]['positive_ratio']:.2f}%")
   
   return results

def validate(model, val_loader):
   model.eval()
   all_targets = []
   all_predictions = []
   
   with torch.no_grad():
       for images, geo_features, targets in val_loader:
           images = images.to(DEVICE)
           geo_features = geo_features.to(DEVICE)
           outputs = model(images, geo_features)
           predictions = (torch.sigmoid(outputs) >= 0.5).float()
           
           all_targets.append(targets.cpu().numpy())
           all_predictions.append(predictions.cpu().numpy())
   
   all_targets = np.concatenate(all_targets)
   all_predictions = np.concatenate(all_predictions)
   
   return evaluate_predictions(all_targets, all_predictions, prefix="Validation ")

def quick_train_and_evaluate(config, train_loader, val_loader, num_geological_features):
   print(f"\n=== Testing Configuration: {config['name']} ===")
   print("Class weights:", config['weights'])
   
   model = create_model(config, num_geological_features).to(DEVICE)
   optimizer = optim.Adam(model.parameters(), lr=config['lr'])
   
   # Create weight tensor from config
   class_weights = torch.tensor([config['weights'][label] for label in LABEL_NAMES]).to(DEVICE)
   criterion = nn.BCEWithLogitsLoss(pos_weight=class_weights)
   
   for epoch in range(NUM_EPOCHS):
       # Training
       model.train()
       train_loss = 0
       for images, geo_features, targets in tqdm(train_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}"):
           images = images.to(DEVICE)
           geo_features = geo_features.to(DEVICE)
           targets = targets.to(DEVICE)
           
           optimizer.zero_grad()
           outputs = model(images, geo_features)
           loss = criterion(outputs, targets)
           loss.backward()
           optimizer.step()
           train_loss += loss.item()
       
       avg_loss = train_loss / len(train_loader)
       print(f"\nEpoch {epoch+1} Loss: {avg_loss:.4f}")
       
       # Validation
       if epoch == NUM_EPOCHS - 1:  # Only validate on last epoch
           results = validate(model, val_loader)
           return results

def main():
   # Set up transforms
   transform = transforms.Compose([
       transforms.Resize((170, 170)),
       transforms.ToTensor(),
       transforms.Normalize(mean=[0.485, 0.456, 0.406],
                          std=[0.229, 0.224, 0.225])
   ])
   
   # Load full dataset
   full_dataset = MineralDataset(CSV_PATH, IMG_DIR, transform)
   
   # Create smaller subset by subsetting the DataFrame first
   subset_size = int(len(full_dataset) * SUBSET_SIZE)
   subset_indices = torch.randperm(len(full_dataset))[:subset_size].numpy()
   
   # Create new dataset with subset of DataFrame
   subset_df = full_dataset.df.iloc[subset_indices].reset_index(drop=True)
   dataset = MineralDataset(subset_df, IMG_DIR, transform)
   
   # Split into train/val
   val_size = int(len(dataset) * 0.2)
   train_size = len(dataset) - val_size
   train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
   
   train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
   val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)
   
   print(f"\nQuick tuning with {subset_size} samples")
   print(f"Train: {train_size}, Val: {val_size}")
   
   # Test all configurations
   results = {}
   for config in configs:
       results[config['name']] = quick_train_and_evaluate(
           config, train_loader, val_loader, 
           len(dataset.feature_cols)
       )
   
   # Print final comparison
   print("\n=== Final Comparison ===")
   for config_name, config_results in results.items():
       print(f"\n{config_name}:")
       for label in LABEL_NAMES:
           mineral_results = config_results[label]
           print(f"\n{label}:")
           print(f"  Precision: {mineral_results['precision']:.4f}")
           print(f"  Recall: {mineral_results['recall']:.4f}")
           print(f"  Positive Predictions: {mineral_results['true_positives'] + mineral_results['false_positives']}")
           print(f"  True Positives: {mineral_results['true_positives']}")
           print(f"  Predicted Positive Ratio: {mineral_results['positive_ratio']:.2f}%")

if __name__ == "__main__":
   main()


Class distributions:
AU_target: 11700 positives (2.69%)
AG_target: 15123 positives (3.48%)
CU_target: 10662 positives (2.45%)
CO_target: 8549 positives (1.97%)
NI_target: 17449 positives (4.01%)

Class distributions:
AU_target: 1187 positives (2.73%)
AG_target: 1505 positives (3.46%)
CU_target: 991 positives (2.28%)
CO_target: 816 positives (1.88%)
NI_target: 1741 positives (4.00%)

Quick tuning with 43485 samples
Train: 34788, Val: 8697

=== Testing Configuration: Medium_Balanced ===
Class weights: {'AU': 10.0, 'AG': 10.0, 'CU': 5.0, 'CO': 5.0, 'NI': 5.0}


Epoch 1/1: 100%|██████████████████████████████| 544/544 [25:10<00:00,  2.78s/it]



Epoch 1 Loss: 0.5517

Validation AU:
  True Positives: 74
  False Positives: 140
  Precision: 0.3458
  Recall: 0.3289
  Predicted Positive Ratio: 2.46%

Validation AG:
  True Positives: 63
  False Positives: 336
  Precision: 0.1579
  Recall: 0.2150
  Predicted Positive Ratio: 4.59%

Validation CU:
  True Positives: 0
  False Positives: 0
  Precision: 0.0000
  Recall: 0.0000
  Predicted Positive Ratio: 0.00%

Validation CO:
  True Positives: 27
  False Positives: 22
  Precision: 0.5510
  Recall: 0.1849
  Predicted Positive Ratio: 0.56%

Validation NI:
  True Positives: 81
  False Positives: 70
  Precision: 0.5364
  Recall: 0.2341
  Predicted Positive Ratio: 1.74%

=== Testing Configuration: Medium_DeepGeo ===
Class weights: {'AU': 12.0, 'AG': 12.0, 'CU': 6.0, 'CO': 6.0, 'NI': 6.0}


Epoch 1/1: 100%|██████████████████████████████| 544/544 [24:35<00:00,  2.71s/it]



Epoch 1 Loss: 0.5843

Validation AU:
  True Positives: 80
  False Positives: 183
  Precision: 0.3042
  Recall: 0.3556
  Predicted Positive Ratio: 3.02%

Validation AG:
  True Positives: 59
  False Positives: 353
  Precision: 0.1432
  Recall: 0.2014
  Predicted Positive Ratio: 4.74%

Validation CU:
  True Positives: 11
  False Positives: 20
  Precision: 0.3548
  Recall: 0.0591
  Predicted Positive Ratio: 0.36%

Validation CO:
  True Positives: 62
  False Positives: 112
  Precision: 0.3563
  Recall: 0.4247
  Predicted Positive Ratio: 2.00%

Validation NI:
  True Positives: 146
  False Positives: 205
  Precision: 0.4160
  Recall: 0.4220
  Predicted Positive Ratio: 4.04%

=== Testing Configuration: Medium_WideGeo ===
Class weights: {'AU': 8.0, 'AG': 8.0, 'CU': 4.0, 'CO': 4.0, 'NI': 4.0}


Epoch 1/1: 100%|██████████████████████████████| 544/544 [19:05<00:00,  2.11s/it]



Epoch 1 Loss: 0.5331

Validation AU:
  True Positives: 62
  False Positives: 114
  Precision: 0.3523
  Recall: 0.2756
  Predicted Positive Ratio: 2.02%

Validation AG:
  True Positives: 26
  False Positives: 30
  Precision: 0.4643
  Recall: 0.0887
  Predicted Positive Ratio: 0.64%

Validation CU:
  True Positives: 0
  False Positives: 0
  Precision: 0.0000
  Recall: 0.0000
  Predicted Positive Ratio: 0.00%

Validation CO:
  True Positives: 19
  False Positives: 6
  Precision: 0.7600
  Recall: 0.1301
  Predicted Positive Ratio: 0.29%

Validation NI:
  True Positives: 41
  False Positives: 12
  Precision: 0.7736
  Recall: 0.1185
  Predicted Positive Ratio: 0.61%

=== Testing Configuration: Medium_LowDropout ===
Class weights: {'AU': 10.0, 'AG': 10.0, 'CU': 5.0, 'CO': 5.0, 'NI': 5.0}


Epoch 1/1: 100%|██████████████████████████████| 544/544 [17:02<00:00,  1.88s/it]



Epoch 1 Loss: 0.5295

Validation AU:
  True Positives: 93
  False Positives: 439
  Precision: 0.1748
  Recall: 0.4133
  Predicted Positive Ratio: 6.12%

Validation AG:
  True Positives: 32
  False Positives: 67
  Precision: 0.3232
  Recall: 0.1092
  Predicted Positive Ratio: 1.14%

Validation CU:
  True Positives: 5
  False Positives: 22
  Precision: 0.1852
  Recall: 0.0269
  Predicted Positive Ratio: 0.31%

Validation CO:
  True Positives: 43
  False Positives: 71
  Precision: 0.3772
  Recall: 0.2945
  Predicted Positive Ratio: 1.31%

Validation NI:
  True Positives: 117
  False Positives: 139
  Precision: 0.4570
  Recall: 0.3382
  Predicted Positive Ratio: 2.94%

=== Testing Configuration: Medium_HighDropout ===
Class weights: {'AU': 10.0, 'AG': 10.0, 'CU': 5.0, 'CO': 5.0, 'NI': 5.0}


Epoch 1/1: 100%|██████████████████████████████| 544/544 [16:37<00:00,  1.83s/it]



Epoch 1 Loss: 0.6219

Validation AU:
  True Positives: 0
  False Positives: 1
  Precision: 0.0000
  Recall: 0.0000
  Predicted Positive Ratio: 0.01%

Validation AG:
  True Positives: 0
  False Positives: 0
  Precision: 0.0000
  Recall: 0.0000
  Predicted Positive Ratio: 0.00%

Validation CU:
  True Positives: 0
  False Positives: 0
  Precision: 0.0000
  Recall: 0.0000
  Predicted Positive Ratio: 0.00%

Validation CO:
  True Positives: 0
  False Positives: 0
  Precision: 0.0000
  Recall: 0.0000
  Predicted Positive Ratio: 0.00%

Validation NI:
  True Positives: 0
  False Positives: 0
  Precision: 0.0000
  Recall: 0.0000
  Predicted Positive Ratio: 0.00%

=== Testing Configuration: Medium_SlowLearning ===
Class weights: {'AU': 10.0, 'AG': 10.0, 'CU': 5.0, 'CO': 5.0, 'NI': 5.0}


Epoch 1/1: 100%|██████████████████████████████| 544/544 [17:02<00:00,  1.88s/it]



Epoch 1 Loss: 0.5522

Validation AU:
  True Positives: 28
  False Positives: 37
  Precision: 0.4308
  Recall: 0.1244
  Predicted Positive Ratio: 0.75%

Validation AG:
  True Positives: 11
  False Positives: 5
  Precision: 0.6875
  Recall: 0.0375
  Predicted Positive Ratio: 0.18%

Validation CU:
  True Positives: 0
  False Positives: 0
  Precision: 0.0000
  Recall: 0.0000
  Predicted Positive Ratio: 0.00%

Validation CO:
  True Positives: 19
  False Positives: 6
  Precision: 0.7600
  Recall: 0.1301
  Predicted Positive Ratio: 0.29%

Validation NI:
  True Positives: 95
  False Positives: 115
  Precision: 0.4524
  Recall: 0.2746
  Predicted Positive Ratio: 2.41%

=== Testing Configuration: Medium_SmallGeo ===
Class weights: {'AU': 10.0, 'AG': 10.0, 'CU': 5.0, 'CO': 5.0, 'NI': 5.0}


Epoch 1/1: 100%|██████████████████████████████| 544/544 [16:17<00:00,  1.80s/it]



Epoch 1 Loss: 0.5450

Validation AU:
  True Positives: 68
  False Positives: 150
  Precision: 0.3119
  Recall: 0.3022
  Predicted Positive Ratio: 2.51%

Validation AG:
  True Positives: 28
  False Positives: 50
  Precision: 0.3590
  Recall: 0.0956
  Predicted Positive Ratio: 0.90%

Validation CU:
  True Positives: 0
  False Positives: 0
  Precision: 0.0000
  Recall: 0.0000
  Predicted Positive Ratio: 0.00%

Validation CO:
  True Positives: 22
  False Positives: 15
  Precision: 0.5946
  Recall: 0.1507
  Predicted Positive Ratio: 0.43%

Validation NI:
  True Positives: 37
  False Positives: 17
  Precision: 0.6852
  Recall: 0.1069
  Predicted Positive Ratio: 0.62%

=== Testing Configuration: Medium_ResidualGeo ===
Class weights: {'AU': 10.0, 'AG': 10.0, 'CU': 5.0, 'CO': 5.0, 'NI': 5.0}


Epoch 1/1: 100%|██████████████████████████████| 544/544 [16:20<00:00,  1.80s/it]



Epoch 1 Loss: 0.6133

Validation AU:
  True Positives: 0
  False Positives: 0
  Precision: 0.0000
  Recall: 0.0000
  Predicted Positive Ratio: 0.00%

Validation AG:
  True Positives: 38
  False Positives: 269
  Precision: 0.1238
  Recall: 0.1297
  Predicted Positive Ratio: 3.53%

Validation CU:
  True Positives: 0
  False Positives: 0
  Precision: 0.0000
  Recall: 0.0000
  Predicted Positive Ratio: 0.00%

Validation CO:
  True Positives: 18
  False Positives: 6
  Precision: 0.7500
  Recall: 0.1233
  Predicted Positive Ratio: 0.28%

Validation NI:
  True Positives: 26
  False Positives: 5
  Precision: 0.8387
  Recall: 0.0751
  Predicted Positive Ratio: 0.36%

=== Testing Configuration: Medium_BatchNorm ===
Class weights: {'AU': 10.0, 'AG': 10.0, 'CU': 5.0, 'CO': 5.0, 'NI': 5.0}


Epoch 1/1: 100%|██████████████████████████████| 544/544 [19:23<00:00,  2.14s/it]



Epoch 1 Loss: 0.5332

Validation AU:
  True Positives: 90
  False Positives: 389
  Precision: 0.1879
  Recall: 0.4000
  Predicted Positive Ratio: 5.51%

Validation AG:
  True Positives: 29
  False Positives: 41
  Precision: 0.4143
  Recall: 0.0990
  Predicted Positive Ratio: 0.80%

Validation CU:
  True Positives: 0
  False Positives: 0
  Precision: 0.0000
  Recall: 0.0000
  Predicted Positive Ratio: 0.00%

Validation CO:
  True Positives: 46
  False Positives: 109
  Precision: 0.2968
  Recall: 0.3151
  Predicted Positive Ratio: 1.78%

Validation NI:
  True Positives: 95
  False Positives: 115
  Precision: 0.4524
  Recall: 0.2746
  Predicted Positive Ratio: 2.41%

=== Testing Configuration: Medium_Combined ===
Class weights: {'AU': 12.0, 'AG': 12.0, 'CU': 6.0, 'CO': 6.0, 'NI': 6.0}


Epoch 1/1: 100%|██████████████████████████████| 544/544 [18:57<00:00,  2.09s/it]



Epoch 1 Loss: 0.6268

Validation AU:
  True Positives: 89
  False Positives: 492
  Precision: 0.1532
  Recall: 0.3956
  Predicted Positive Ratio: 6.68%

Validation AG:
  True Positives: 14
  False Positives: 24
  Precision: 0.3684
  Recall: 0.0478
  Predicted Positive Ratio: 0.44%

Validation CU:
  True Positives: 0
  False Positives: 0
  Precision: 0.0000
  Recall: 0.0000
  Predicted Positive Ratio: 0.00%

Validation CO:
  True Positives: 21
  False Positives: 17
  Precision: 0.5526
  Recall: 0.1438
  Predicted Positive Ratio: 0.44%

Validation NI:
  True Positives: 48
  False Positives: 21
  Precision: 0.6957
  Recall: 0.1387
  Predicted Positive Ratio: 0.79%

=== Final Comparison ===

Medium_Balanced:

AU:
  Precision: 0.3458
  Recall: 0.3289
  Positive Predictions: 214
  True Positives: 74
  Predicted Positive Ratio: 2.46%

AG:
  Precision: 0.1579
  Recall: 0.2150
  Positive Predictions: 399
  True Positives: 63
  Predicted Positive Ratio: 4.59%

CU:
  Precision: 0.0000
  Recall: 