In [None]:

%load_ext autoreload
%autoreload 2

In [None]:
# Login to wandb
# !wandb login 1d46416e290617f0005c9b98c3592a0350c5fa01

In [None]:
import random
import numpy as np
import torch

# Import our custom modules
from config import Config
from dataset import create_data_loaders
from dataset import load_and_prepare_data
from model_helpers import create_model, setup_model_for_training
from trainer import Trainer
from utils import compute_class_frequency


def set_seed(seed: int) -> None:
    """Set random seeds for reproducibility."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


# Set random seed for reproducibility
set_seed(42)

# Check if CUDA is available
device = torch.device(
    'cuda' if torch.cuda.is_available() else 'cpu'
)
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

print(f"CUDA Version: {torch.version.cuda}")
print(f"PyTorch Version: {torch.__version__}")

In [None]:
# Create configuration object
config = Config()
config.info()

In [None]:
# Enable wandb logging
config.use_wandb = False
config.wandb_tags = [config.model_name]
config.wandb_config()

In [None]:
print("Loading and preparing dataset with distribution data...")

# Create data loaders with stratified splitting (updated signature)
df, image_paths, labels = load_and_prepare_data(config=config)
class_frequency = compute_class_frequency(df.drop(['file_name'], axis=1))
(
    train_loader,
    val_loader,
    label_columns,
    original_labels,
    train_labels,
    test_labels
) = create_data_loaders(df, image_paths, labels, config)

# Store the label data for later use
print(f"\nLabel distribution data available:")
print(f"  Original labels shape: {original_labels.shape}")
print(f"  Train labels shape: {train_labels.shape}")
print(f"  Test labels shape: {test_labels.shape}")


## Label Distribution Visualization

In [None]:
print("Creating label distribution visualization...")

# Import the visualization function
from dataset import plot_label_distribution

# Create the visualization
plot_label_distribution(
    original_labels=original_labels,
    train_labels=train_labels,
    test_labels=test_labels,
    label_columns=label_columns,
)

print("Label distribution analysis completed!")


## Class Imbalance Analysis and Configuration

In [None]:
print("Analyzing class imbalance and configuring class weights...")

# Show current configuration
print(f"\nClass Weight Configuration:")
print(f"  Class weight method: {config.class_weight_method}")
print(f"  Loss function: {config.loss_type}")

# Calculate and display class weights if enabled
if config.class_weight_method != 'none':
    from losses import print_class_weights

    # Get training labels
    train_labels = train_loader.dataset.labels

    # Print class weights analysis
    class_weights = print_class_weights(
        train_labels,
        label_columns,
        method=config.class_weight_method
    )

    # Store class weights for later use
    config.class_weights = class_weights

    print(f"\nClass weights calculated and stored in config.class_weights")
    print(f"Class weights shape: {class_weights.shape}")
    print(f"Class weights device: {class_weights.device}")


## Model Creation and Setup

In [None]:
print("Creating and setting up model...")

# Create model
model = create_model(config=config, num_classes=len(label_columns))

# Setup model for training (freeze/unfreeze based on training mode)
model = setup_model_for_training(model=model, device=device, config=config, class_freq=class_frequency)

# Test forward pass
print("\nTesting forward pass...")
with torch.no_grad():
    sample_images = torch.randn(2, 3, config.img_size, config.img_size).to(device)
    sample_output = model(sample_images)
    print(f"  Input shape: {sample_images.shape}")
    print(f"  Output shape: {sample_output.shape}")
    print(f"  Output range: [{sample_output.min():.3f}, {sample_output.max():.3f}]")

# Show class weight configuration
print(f"\nClass Weight Configuration:")
print(f"  Class weight method: {config.class_weight_method}")
print(f"  Loss function: {config.loss_type}")

if config.class_weights != 'none':
    print(f"  Class weights calculated: Yes")
    print(f"  Class weights shape: {config.class_weights.shape}")
else:
    print(f"  Class weights calculated: No")

## Training Setup

In [None]:
print("Setting up trainer...")

# Create trainer
trainer = Trainer(
    model=model,
    class_freq=class_frequency,
    config=config,
    train_loader=train_loader,
    val_loader=val_loader,
    label_columns=label_columns,
    device=str(device),
)

In [None]:
from torch_lr_finder import LRFinder

lr_finder = LRFinder(
    trainer.model, trainer.optimizer, trainer.criterion, device="cuda"
)
lr_finder.range_test(train_loader, end_lr=10, num_iter=100)
lr_finder.plot()  # Log scale LR curve

## Training Execution

In [None]:
print("Starting training...")
print("=" * 50)

# Start training
history = trainer.train()

print("=" * 50)
print("Training completed!")

# Model Evaluation and Analysis

In [None]:
from utils import find_best_thresholds

y_true, y_pred, y_prob, total_loss = trainer.validate_single_epoch()
dynamic_threshold = find_best_thresholds(y_true=y_true, y_prob=y_prob)

# Get per-class metrics
static_per_class_metrics = trainer.metrics_calculator.compute_per_class_metrics(
    y_true=y_true,
    y_pred=y_pred,
    class_names=label_columns,
    threshold=config.threshold,
)

# Get per-class metrics
dynamic_per_class_metrics = trainer.metrics_calculator.compute_per_class_metrics(
    y_true=y_true,
    y_pred=y_pred,
    class_names=label_columns,
    threshold=dynamic_threshold,
)

label_eval = {}
for i, label in enumerate(label_columns):
    static_result = static_per_class_metrics[label]
    dynamic_result = dynamic_per_class_metrics[label]
    static_result.update(dynamic_result)
    label_eval[label] = static_result