In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# Login to wandb
# !wandb login 1d46416e290617f0005c9b98c3592a0350c5fa01

In [None]:
import os
import random
import numpy as np
import torch

# Import our custom modules
from config import Config
from dataset import create_data_loaders
from metrics import plot_roc_curves, plot_precision_recall_curves
from model import create_model, setup_model_for_training
from trainer import Trainer
from utils import compute_class_frequency
from dataset import load_and_prepare_data


def set_seed(seed: int) -> None:
    """Set random seeds for reproducibility."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


# Set random seed for reproducibility
set_seed(42)

# Check if CUDA is available
device = torch.device(
    'cuda' if torch.cuda.is_available() else 'cpu'
)
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

print(f"CUDA Version: {torch.version.cuda}")
print(f"PyTorch Version: {torch.__version__}")

In [None]:
# Create configuration object
config = Config()
config.info()

In [None]:
# Enable wandb logging
config.use_wandb = False
config.wandb_tags = [config.model_name]
config.wandb_config()

## Updated Dataset Loading with Label Distribution Data

In [None]:
print("Loading and preparing dataset with distribution data...")

# Create data loaders with stratified splitting (updated signature)
df, image_paths, labels = load_and_prepare_data(config=config)
class_frequency = compute_class_frequency(df.drop(['file_name'], axis=1))
train_loader, val_loader, label_columns, original_labels, train_labels, test_labels = create_data_loaders(df, image_paths, labels, config)

# Store the label data for later use
print(f"\nLabel distribution data available:")
print(f"  Original labels shape: {original_labels.shape}")
print(f"  Train labels shape: {train_labels.shape}")
print(f"  Test labels shape: {test_labels.shape}")


## Label Distribution Visualization

In [None]:
print("Creating label distribution visualization...")

# Import the visualization function
from dataset import plot_label_distribution

# Create the visualization
plot_label_distribution(
    original_labels=original_labels,
    train_labels=train_labels,
    test_labels=test_labels,
    label_columns=label_columns,
)

print("Label distribution analysis completed!")


## Class Imbalance Analysis and Configuration

In [None]:
print("Analyzing class imbalance and configuring class weights...")

# Show current configuration
print(f"\nClass Weight Configuration:")
print(f"  Class weight method: {config.class_weight_method}")
print(f"  Loss function: {config.loss_type}")

# Calculate and display class weights if enabled
if config.class_weight_method != 'none':
    from losses import print_class_weights

    # Get training labels
    train_labels = train_loader.dataset.labels

    # Print class weights analysis
    class_weights = print_class_weights(
        train_labels,
        label_columns,
        method=config.class_weight_method
    )

    # Store class weights for later use
    config.class_weights = class_weights

    print(f"\nClass weights calculated and stored in config.class_weights")
    print(f"Class weights shape: {class_weights.shape}")
    print(f"Class weights device: {class_weights.device}")


## Model Creation and Setup

In [None]:
print("Creating and setting up model...")

# Create model
model = create_model(config=config, num_classes=len(label_columns))

# Setup model for training (freeze/unfreeze based on training mode)
model = setup_model_for_training(model=model, device=device, config=config)

# Test forward pass
print("\nTesting forward pass...")
with torch.no_grad():
    sample_images = torch.randn(2, 3, config.img_size, config.img_size).to(device)
    sample_output = model(sample_images)
    print(f"  Input shape: {sample_images.shape}")
    print(f"  Output shape: {sample_output.shape}")
    print(f"  Output range: [{sample_output.min():.3f}, {sample_output.max():.3f}]")

# Show class weight configuration
print(f"\nClass Weight Configuration:")
print(f"  Class weight method: {config.class_weight_method}")
print(f"  Loss function: {config.loss_type}")

if config.class_weights != 'none':
    print(f"  Class weights calculated: Yes")
    print(f"  Class weights shape: {config.class_weights.shape}")
else:
    print(f"  Class weights calculated: No")

## Training Setup

In [None]:
print("Setting up trainer...")

# Create trainer
trainer = Trainer(
    model=model,
    class_freq=class_frequency,
    config=config,
    train_loader=train_loader,
    val_loader=val_loader,
    label_columns=label_columns,
    device=str(device),
)

## Training Execution

In [None]:
print("Starting training...")
print("=" * 50)

# Start training
history = trainer.train()

print("=" * 50)
print("Training completed!")
print(f"Best validation F1 Micro: {trainer.best_val_f1:.4f}")

# Plot training history
print("\nPlotting training history...")
trainer.plot_training_history(save_path=os.path.join(config.output_dir, 'training_history.png'))


# Model Evaluation and Analysis

In [None]:
val_loss, val_metrics = trainer.validate_epoch()

print("\nFinal Validation Metrics:")
print(f"  Loss: {val_loss:.4f}")
print(f"  F1 Micro: {val_metrics['f1_micro']:.4f}")
print(f"  F1 Macro: {val_metrics['f1_macro']:.4f}")
print(f"  F1 Samples: {val_metrics['f1_samples']:.4f}")
print(f"  Precision Micro: {val_metrics['precision_micro']:.4f}")
print(f"  Precision Macro: {val_metrics['precision_macro']:.4f}")
print(f"  Recall Micro: {val_metrics['recall_micro']:.4f}")
print(f"  Recall Macro: {val_metrics['recall_macro']:.4f}")

if 'roc_auc_micro' in val_metrics:
    print(f"  ROC AUC Micro: {val_metrics['roc_auc_micro']:.4f}")
    print(f"  ROC AUC Macro: {val_metrics['roc_auc_macro']:.4f}")
    print(f"  PR AUC Micro: {val_metrics['pr_auc_micro']:.4f}")
    print(f"  PR AUC Macro: {val_metrics['pr_auc_macro']:.4f}")

# Get per-class metrics
print("\nComputing per-class metrics...")
per_class_metrics = trainer.metrics_calculator.compute_per_class_metrics(
    y_true=val_loader.dataset.labels,
    y_pred=np.concatenate([trainer.model(torch.tensor(batch[0]).to(device)).cpu().detach().numpy()
                           for batch in val_loader]),
    class_names=label_columns
)

print("\nPer-class F1 Scores:")
for class_name, metrics in per_class_metrics.items():
    print(f"  {class_name}: {metrics['f1']:.4f} (support: {metrics['support']})")


## Visualization and Analysis

In [None]:
print("Generating visualizations...")

# Get predictions for visualization
trainer.model.eval()
all_predictions = []
all_labels = []
all_probabilities = []

with torch.no_grad():
    for images, labels in val_loader:
        images = images.to(device)
        labels = labels.to(device)

        outputs = trainer.model(images)
        probabilities = torch.sigmoid(outputs).cpu().numpy()
        predictions = outputs.cpu().numpy()
        labels_np = labels.cpu().numpy()

        all_predictions.append(predictions)
        all_labels.append(labels_np)
        all_probabilities.append(probabilities)

all_predictions = np.concatenate(all_predictions, axis=0)
all_labels = np.concatenate(all_labels, axis=0)
all_probabilities = np.concatenate(all_probabilities, axis=0)

# Plot ROC curves
print("Plotting ROC curves...")
plot_roc_curves(
    all_labels,
    all_probabilities,
    label_columns,
    save_path=os.path.join(config.output_dir, 'roc_curves.png')
)

# Plot Precision-Recall curves
print("Plotting Precision-Recall curves...")
plot_precision_recall_curves(
    all_labels,
    all_probabilities,
    label_columns,
    save_path=os.path.join(config.output_dir, 'pr_curves.png')
)


In [None]:
from utils import visualize_predictions

SAMPLES = 5
dataset = iter(val_loader.dataset)

images, labels = [], []
for s in range(SAMPLES):
    img, label = next(dataset)
    images.append(img)
    labels.append(label)

test_images = torch.stack(images, dim=0)
test_labels = torch.stack(labels, dim=0)

visualize_predictions(
    model=trainer.model,
    test_images=test_images,
    test_labels=test_labels,
    label_columns=label_columns,
    threshold=config.threshold,
)