# Experiment 07: Attention U-Net with Cyclical Learning Rate (CLR)

This notebook addresses the training instability observed in previous experiments. We will use a Cyclical Learning Rate (CLR) schedule instead of a constant learning rate. The goal is to help the optimizer escape poor local minima and find a more robust solution.

### **Methodology**

This is a two-part experiment:
1.  **LR Range Test**: We first run a short training process where we linearly increase the learning rate from a very small to a large value. We plot the loss vs. the learning rate to identify the optimal range where the loss decreases most rapidly.
2.  **Full Training**: We use the identified optimal range to train our full model for 50 epochs using a CLR scheduler.

### **Model Configuration**

*   **Objective**: Stabilize training and improve performance using a CLR schedule.
*   **Model Architecture**: Attention U-Net.
*   **Dataset**: D2_TCPW, eligible patients.
*   **Preprocessing**: RAovSeg custom preprocessing.
*   **Loss Function**: Focal Tversky Loss.
*   **Optimizer**: Adam.
*   **Learning Rate**: **Cyclical, varying between a min and max bound.**
*   **Epochs**: 50.
*   **Batch Size**: 16.

In [None]:
# --- Imports and Setup ---
import os
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Subset
from torch.optim import Adam
from torch.optim.lr_scheduler import LambdaLR
from tqdm import tqdm
import matplotlib.pyplot as plt
import sys
import numpy as np

project_root = os.path.abspath('..')
if project_root not in sys.path:
    sys.path.append(project_root)

from src.data_loader import UterusDatasetWithPreprocessing 
from src.models import AttentionUNet
from src.losses import FocalTverskyLoss

# --- Configuration for LR Range Test ---
manifest_path = '../data/d2_manifest_t2fs_ovary_eligible.csv'
image_size = 256
batch_size = 16
start_lr = 1e-7
end_lr = 1e-1
num_steps = 20 # Number of steps to increase the LR

# --- Data Loading ---
print("--- Loading Data for LR Range Test ---")
# Only need the training data for this test
train_full_dataset = UterusDatasetWithPreprocessing(manifest_path=manifest_path, image_size=image_size, augment=True)
patient_ids = train_full_dataset.manifest['patient_id'].unique()
split_idx = int(len(patient_ids) * 0.8)
train_ids = patient_ids[:split_idx]
train_indices = [i for i, sm in enumerate(train_full_dataset.slice_map) if train_full_dataset.manifest.loc[sm['patient_index'], 'patient_id'] in train_ids]
train_dataset = Subset(train_full_dataset, train_indices)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)

# --- LR Range Test Logic ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"\nUsing device: {device}")

model = AttentionUNet(n_channels=1, n_classes=1).to(device)
optimizer = Adam(model.parameters(), lr=start_lr) # Start with a very small LR
criterion = FocalTverskyLoss(alpha=0.7, beta=0.3, gamma=4/3)

# Linearly increase LR from start_lr to end_lr over num_steps
lr_lambda = lambda step: (end_lr / start_lr) ** (step / num_steps)
scheduler = LambdaLR(optimizer, lr_lambda)

learning_rates = []
losses = []

model.train()
iterator = iter(train_loader)
for step in tqdm(range(num_steps), desc="LR Range Test"):
    try:
        images, masks = next(iterator)
    except StopIteration:
        iterator = iter(train_loader)
        images, masks = next(iterator)

    images, masks = images.to(device), masks.to(device)
    
    optimizer.zero_grad()
    outputs = model(images)
    loss = criterion(outputs, masks)
    
    # Break if loss explodes
    if torch.isnan(loss) or loss > 4 * min(losses, default=1.0):
        print("Loss exploded, stopping test.")
        break
        
    loss.backward()
    optimizer.step()
    
    learning_rates.append(scheduler.get_last_lr()[0])
    losses.append(loss.item())
    
    scheduler.step()

# --- Plot the Results ---
print("Plotting LR Range Test Results...")
plt.figure(figsize=(10, 6))
plt.plot(learning_rates, losses)
plt.xscale("log")
plt.xlabel("Learning Rate (log scale)")
plt.ylabel("Loss")
plt.title("Learning Rate Range Test")
plt.grid(True)
plt.show()

--- Loading Data for LR Range Test ---
Loading manifest from ../data/d2_manifest_t2fs_ovary_eligible.csv and creating slice map...
Slice map created. Found 278 slices containing the ovary.

Using device: cuda


LR Range Test:   0%|          | 0/20 [00:00<?, ?it/s]