In [None]:

%load_ext autoreload
%autoreload 2

In [None]:
# Login to wandb
# !wandb login 1d46416e290617f0005c9b98c3592a0350c5fa01

In [None]:
import random
import numpy as np
import torch

# Import our custom modules
from src.classifier.configs import TrainConfig
from src.classifier.training.dataset import create_data_loaders
from src.classifier.training.dataset import load_and_prepare_data
from src.classifier.training.model_helpers import create_model, setup_model_for_training
from src.classifier.training.trainer import Trainer
from src.classifier.training.utils import compute_class_frequency


def set_seed(seed: int) -> None:
    """Set random seeds for reproducibility."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


# Set random seed for reproducibility
set_seed(42)

# Check if CUDA is available
device = torch.device(
    'cuda' if torch.cuda.is_available() else 'cpu'
)
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

print(f"CUDA Version: {torch.version.cuda}")
print(f"PyTorch Version: {torch.__version__}")

In [None]:
# Create configuration object
config = TrainConfig()
config.num_epochs = 15
config.bce_power = 0.6788091730324309
config.tau_logit_adjust = 0.8612782621731778
config.use_wandb = False
config.wandb_tags = [config.model_name]
config.wandb_config()
config.info()

In [None]:
print("Loading and preparing dataset with distribution data...")

# Create data loaders with stratified splitting (updated signature)
df, image_paths, labels = load_and_prepare_data(config=config)
class_frequency = compute_class_frequency(df.drop(['file_name'], axis=1))
(
    train_loader,
    val_loader,
    label_columns,
    original_labels,
    train_labels,
    test_labels
) = create_data_loaders(df, image_paths, labels, config)

# Store the label data for later use
print(f"\nLabel distribution data available:")
print(f"  Original labels shape: {original_labels.shape}")
print(f"  Train labels shape: {train_labels.shape}")
print(f"  Test labels shape: {test_labels.shape}")


## Label Distribution Visualization

In [None]:
print("Creating label distribution visualization...")

# Import the visualization function
from src.classifier.training.dataset import plot_label_distribution

# Create the visualization
plot_label_distribution(
    original_labels=original_labels,
    train_labels=train_labels,
    test_labels=test_labels,
    label_columns=label_columns,
)

print("Label distribution analysis completed!")


## Model Creation and Setup

In [None]:
print("Creating and setting up model...")

# Create model
model = create_model(config=config, num_classes=len(label_columns))

# Setup model for training (freeze/unfreeze based on training mode)
model = setup_model_for_training(model=model, device=device, config=config, class_freq=class_frequency)

# Test forward pass
print("\nTesting forward pass...")
with torch.no_grad():
    sample_images = torch.randn(2, 3, config.img_size, config.img_size).to(device)
    sample_output = model(sample_images)
    print(f"  Input shape: {sample_images.shape}")
    print(f"  Output shape: {sample_output.shape}")
    print(f"  Output range: [{sample_output.min():.3f}, {sample_output.max():.3f}]")

## Training Setup

In [None]:
print("Setting up trainer...")

# Create trainer
trainer = Trainer(
    class_freq=class_frequency,
    config=config,
    train_loader=train_loader,
    val_loader=val_loader,
    label_columns=label_columns,
    device=str(device),
)

In [None]:
from torch_lr_finder import LRFinder

lr_finder = LRFinder(
    trainer.get_model(),
    trainer.get_optimizer(learning_rate=0.001), # dummy lr, will be overridden
    trainer.get_criterion(),
    device="cuda",
)
lr_finder.range_test(
    trainer.train_loader,
    start_lr=config.lr_range_start,
    end_lr=config.lr_range_end,
    num_iter=config.lr_range_steps,
    step_mode="exp",  # exponential increase
    smooth_f=0.05,  # light smoothing
    diverge_th=4,  # early stop if loss > 4x best
)
lr_finder.plot()  # Log scale LR curve