### Instruction
- Start the mlflow server before running this code.

- Change the IP address (tracking uri) accordingly.

In [None]:
from torchinfo import summary
import torch
import mlflow
import mlflow.pytorch
import os
# connect the experiment to the tracking server 
# change the IP address accordingly
mlflow.set_tracking_uri("http://0.0.0.0:5001")

# set the experiment name
mlflow.set_experiment("post_baseline_EfficientNet")

# enable autologging
mlflow.pytorch.autolog()

In [2]:
# Set up logging to a file
import logging

# Get current working directory
log_file_path = os.path.join(os.getcwd(), 'training.log')

# Set up the logger manually
logger = logging.getLogger()
logger.setLevel(logging.INFO)

# Remove all previous handlers (important in Jupyter)
logger.handlers.clear()

# Create and add FileHandler
file_handler = logging.FileHandler(log_file_path, mode='w')
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)

# Optional: also log to console
console_handler = logging.StreamHandler()
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)

In [3]:
from utils.seed import set_seed
# Set a global seed for reproducibility
global_seed = 42
set_seed(global_seed)

2025-08-19 22:29:22,693 - INFO - MPS backend is available. It uses the global PyTorch seed.
2025-08-19 22:29:22,694 - INFO - Deterministic algorithms set to True.
2025-08-19 22:29:22,694 - INFO - Random seed set to: 42


### Instruction 

- Only the following cell requires changes inorder to swap models being trained.

In [None]:
# Number of classes
num_classes = 5  # change according to your dataset
dropout = 0.5  # Set dropout rate
# learning rates
base_lr = 0.001          # head + classifier
encoder_lr = base_lr*0.1  # encoder when unfrozen

# ================================================================================
# Model selection - change this section to swap models
# ================================================================================
from models_SICC.EfficientNet import EfficientNetB0WithCosineHead
# Initialize the model
model = EfficientNetB0WithCosineHead(num_classes, dropout)
# ================================================================================


# Check device
if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.backends.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
logging.info(f"Using device: {device.type}")

# Move model to device
model.to(device)

2025-08-19 22:29:22,804 - INFO - Using device: mps


EfficientNetB0WithCosineHead(
  (head): CNNPatchDownscaleHead(
    (features): Sequential(
      (0): Conv2d(3, 32, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(32, 3, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
    )
  )
  (encoder): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): SiLU(inplace=True)
    )
    (1): Sequential(
      (0): MBConv(
        (block): Sequential(
          (0): Conv2dNormActivation(
            (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
            (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): SiLU(inplace=True)
          

### Parameters used

In [None]:
from torch.optim.lr_scheduler import ReduceLROnPlateau
from utils.regularization import EarlyStopping
from utils.train_validate import train, validate
from utils.transfer_learning import freeze_encoder, unfreeze_last_layers, get_optimizer

# batch size
dataloader_eval_batch = 8
dataloader_train_batch = 8

# Initialize optimizer
starting_lr = base_lr
optimizer = optimizer = get_optimizer(
    model,
    base_lr=starting_lr, 
    encoder_lr=0.0001, 
    weight_decay=0.0001
)


# Initialize learning rate scheduler
scheduler = ReduceLROnPlateau(
    optimizer,
    mode='max',           # maximize validation accuracy or F1
    factor=0.5,
    patience=2,
    cooldown=2,
    min_lr=1e-6,
    threshold=0.0001,
    threshold_mode='abs',
    # verbose=True
)

# Initialize early stopping
early_stopping_patience = scheduler.cooldown + scheduler.patience + 5
early_stopping_delta = scheduler.threshold / 2
early_stopping = EarlyStopping(
    patience=early_stopping_patience,
    verbose=True,
    min_delta=early_stopping_delta
)

unfreeze_encoder_at = 10
max_num_epochs = 300



# Log params to MLflow
params = {
    "global_seed": global_seed,
    "train_batch_size": dataloader_train_batch,
    "eval_batch_size": dataloader_eval_batch,
    "loss_function": "CrossEntropyLoss",
    "learning_rate": starting_lr,
    "optimizer": "Adam",
    "scheduler_mode": "max",
    "scheduler_factor": 0.5,
    "scheduler_patience": 2,
    "scheduler_cooldown": 0,
    "scheduler_min_lr": 1e-6,
    "scheduler_threshold": 0.0001,
    "scheduler_threshold_mode": "abs",
    "scheduler": "ReduceLROnPlateau",
    "early_stopping_patience": early_stopping_patience,
    "early_stopping_delta": early_stopping_delta,
    "max_epochs": max_num_epochs,
    "device": device.type,
    "num_classes": num_classes,
    "dropout": dropout
}


In [None]:
from torch.utils.data import DataLoader, WeightedRandomSampler
from torchvision import transforms
from utils.datasets import AugmentedDataset
from utils.custom_sampler_and_loss import effective_num_weights, ClassBalancedFocalLoss
import numpy as np
import torch.nn as nn
import logging
import mlflow
import os

def main(num_workers=4):
    input_transform = transforms.Compose([
        transforms.Resize((720, 1270)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
    ])

    g = torch.Generator()
    g.manual_seed(42)

    # Load datasets
    full_dataset = AugmentedDataset("augmented_data/train", transform=input_transform)
    test_dataset = AugmentedDataset("augmented_data/val", transform=input_transform)


    # ===== NEW: Extract labels directly =====
    if hasattr(full_dataset, "labels"):
        train_targets = np.array(full_dataset.labels)
    elif hasattr(full_dataset, "targets"):
        train_targets = np.array(full_dataset.targets)
    else:
        raise AttributeError("Your AugmentedDataset must store labels in .labels or .targets")

    if hasattr(test_dataset, "labels"):
        val_targets = np.array(test_dataset.labels)
    elif hasattr(test_dataset, "targets"):
        val_targets = np.array(test_dataset.targets)
    else:
        raise AttributeError("Your AugmentedDataset must store labels in .labels or .targets")

    num_classes = len(np.unique(train_targets))

    # ===== Compute class counts separately for train and val =====
    train_class_counts = np.bincount(train_targets, minlength=num_classes)
    val_class_counts   = np.bincount(val_targets, minlength=num_classes)

    # ===== Compute ENS class weights =====
    train_class_weights = effective_num_weights(train_class_counts, beta=0.9999, device=device)
    val_class_weights   = effective_num_weights(val_class_counts, beta=0.9999, device=device)

    logging.info(f"Train class counts: {train_class_counts.tolist()}")
    logging.info(f"Val   class counts: {val_class_counts.tolist()}")
    logging.info(f"ENS Train class weights: {train_class_weights.tolist()}")
    logging.info(f"ENS Val   class weights: {val_class_weights.tolist()}")

    # ===== Class-balanced Focal yLoss for train and val =====
    train_criterion = ClassBalancedFocalLoss(train_class_counts, beta=0.9999, gamma=2.0, device=device)
    val_criterion   = ClassBalancedFocalLoss(val_class_counts, beta=0.9999, gamma=2.0, device=device)


    # Compute sample weights for WeightedRandomSampler (train only)
    train_sample_weights = train_class_weights[train_targets]

    # Compute sample weights for WeightedRandomSampler (val only)
    val_sample_weights = val_class_weights[val_targets]

    # IMPORTANT: move to CPU + float32 (MPS cannot handle float64)
    train_sampler = WeightedRandomSampler(
        weights     = train_sample_weights.cpu().to(torch.float32),
        num_samples = len(train_sample_weights),
        replacement = True
    )

    # Create the sampler for the validation set
    val_sampler = WeightedRandomSampler(
        weights     = val_sample_weights.cpu().to(torch.float32),
        num_samples = len(val_sample_weights),
        replacement = True
    )
    
    # DataLoaders
    train_loader = DataLoader(
        full_dataset,
        batch_size=dataloader_train_batch,
        sampler=train_sampler,
        drop_last=True,
        shuffle=False,  
        num_workers=num_workers
    )
    val_loader = DataLoader(
        test_dataset,
        batch_size=dataloader_eval_batch,
        sampler = val_sampler,
        drop_last=True, 
        shuffle  = False,  # no need to shuffle validation data
        num_workers=num_workers
    )

    # ===== TRAINING LOOP =====
    with mlflow.start_run() as run:
        run_id = run.info.run_id
        logging.info(f"Started MLflow run with ID: {run_id}")
        mlflow.log_params(params)
        
        try:
            model_summary = summary(model, input_size=(1, 3, 720, 1270), device=device.type)

            with open("model_report.txt", "w") as f:
                f.write("### MODEL ARCHITECTURE ###\n")
                f.write(str(model))
                f.write("\n\n### MODEL SUMMARY ###\n")
                f.write(str(model_summary))

            mlflow.log_artifact("model_report.txt")
            logging.info("Model report logged to MLflow")
            os.remove("model_report.txt")
        except Exception as e:
            logging.error(f"Failed to log model report: {e}")
            raise
    
        logging.info("Starting training...")

        best_f1 = 0.0
        best_save_epoch = 0
        train_dir = "augmented_data/train"
        class_names = sorted([d for d in os.listdir(train_dir) if os.path.isdir(os.path.join(train_dir, d))])

        for epoch in range(max_num_epochs):
            # --- Freezing / unfreezing logic ---
            if epoch == 0:
                freeze_encoder(model)

            if epoch == unfreeze_encoder_at:  # after 20 epochs, unfreeze last few encoder layers
                unfreeze_last_layers(model, num_blocks=2)
                # lower LR for encoder param group
                for param_group in optimizer.param_groups:
                    if param_group["lr"] < base_lr:  # encoder group
                        param_group["lr"] = base_lr * 0.1
                logging.info("Adjusted LR for encoder fine-tuning")
            
            head_lr    = optimizer.param_groups[0]['lr']
            encoder_lr = optimizer.param_groups[1]['lr']

            mlflow.log_metric("head_lr", head_lr, step=epoch + 1)
            mlflow.log_metric("encoder_lr", encoder_lr, step=epoch + 1)

            logging.info(f"Epoch {epoch+1}, Head LR: {head_lr}, Encoder LR: {encoder_lr}")


            train(model, train_loader, train_criterion, optimizer, device, epoch, class_names)
            val_f1 = validate(model, val_loader, val_criterion, device, epoch, class_names)

            scheduler.step(val_f1)

            if val_f1 > best_f1:
                best_f1 = val_f1
                best_save_epoch = epoch
                mlflow.pytorch.log_model(model, artifact_path="best_model")
                logging.info(f"Improved F1 (weighted): {val_f1:.6f}. Saved best model.")

            mlflow.pytorch.log_model(model, artifact_path="latest_model")

            early_stopping(val_f1, model)
            if early_stopping.early_stop:
                mlflow.log_param("best_model_saved_at", best_save_epoch)
                mlflow.log_param("early_stopping_triggered_at", epoch)
                logging.warning("Early stopping triggered. Stopping training.")
                model.load_state_dict(early_stopping.best_model_state)
                break


if __name__ == "__main__":
    main(num_workers=4)  # safer for Jupyter

2025-08-19 22:29:23,048 - INFO - Train class counts: [320, 320, 801, 8000, 480]
2025-08-19 22:29:23,048 - INFO - Val   class counts: [80, 80, 200, 2000, 120]
2025-08-19 22:29:23,049 - INFO - ENS Train class weights: [1.5932106971740723, 1.5932106971740723, 0.6518386602401733, 0.0911238044500351, 1.070616602897644]
2025-08-19 22:29:23,049 - INFO - ENS Val   class weights: [1.6054655313491821, 1.6054655313491821, 0.6460433602333069, 0.07057320326566696, 1.0724523067474365]
2025-08-19 22:29:23,248 - INFO - Started MLflow run with ID: 27e44f7ad4874ddfbb9a6c7e07947f04
2025-08-19 22:29:24,203 - INFO - Model report logged to MLflow
2025-08-19 22:29:24,204 - INFO - Starting training...
2025-08-19 22:29:24,205 - INFO -  Encoder frozen (training head + classifier only)
2025-08-19 22:29:24,205 - INFO - Unfroze last 2 residual blocks of encoder.layer4
2025-08-19 22:29:24,206 - INFO - Adjusted LR for encoder fine-tuning
2025-08-19 22:29:24,212 - INFO - Epoch 1, Head LR: 0.001, Encoder LR: 0.0001
Ep