## **1. Google Drive**

In [1]:
LOCAL = False


## **2. Import Libraries**

In [2]:
# Set seed for reproducibility
SEED = 42

# Import necessary libraries
import os

# Set environment variables before importing modules
os.environ['PYTHONHASHSEED'] = str(SEED)
os.environ['MPLCONFIGDIR'] = os.getcwd() + '/configs/'


# Suppress warnings
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=Warning)

# Import necessary modules

import random
import numpy as np
from optuna.samplers import TPESampler
# Set seeds for random number generators in NumPy and Python
np.random.seed(SEED)
random.seed(SEED)

# Import PyTorch
import torch
torch.manual_seed(SEED)
from torch import nn
from torchvision.transforms import v2 as transforms
from torch.utils.data import  DataLoader

from PIL import Image
import torch.nn.functional as F



if torch.cuda.is_available():
    device = torch.device("cuda")
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.benchmark = True
else:
    device = torch.device("cpu")

print(f"PyTorch version: {torch.__version__}")
print(f"Device: {device}")

import optuna
from optuna.samplers import GridSampler
import torch.optim as optim

# Import other libraries
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import f1_score
from sklearn.model_selection import train_test_split
from PIL import Image

from tqdm import tqdm



from sklearn.preprocessing import LabelEncoder
import torch.nn.functional as F

# Configure plot display settings
sns.set(font_scale=1.4)
sns.set_style('white')
plt.rc('font', size=14)
%matplotlib inline

PyTorch version: 2.8.0+cu128
Device: cuda


## **3. Config**

In [3]:
USE_MASKED_PATCHES = False

In [4]:
datasets_path = os.path.join(os.path.pardir, "an2dl2526c2")

train_data_path = os.path.join(datasets_path, "train_data")
train_labels_path = os.path.join(datasets_path, "train_labels.csv")
test_data_path = os.path.join(datasets_path, "test_data")

CSV_PATH = train_labels_path                # Path to the CSV file with labels
SOURCE_FOLDER = train_data_path

if USE_MASKED_PATCHES:
  PATCHES_OUT = os.path.join(datasets_path, "preprocessing_results_masked","train_patches_masked")
  SUBMISSION_PATCHES_OUT = os.path.join(datasets_path, "preprocessing_results_masked","submission_patches_masked")
else:
  PATCHES_OUT = os.path.join(datasets_path, "preprocessing_results","train_patches")
  SUBMISSION_PATCHES_OUT = os.path.join(datasets_path, "preprocessing_results","submission_patches")

print(f"Dataset path: {datasets_path}")
print(f"Train data path: {train_data_path}")
print(f"Train labels path: {train_labels_path}")
print(f"Test data path: {test_data_path}")
print(f"Patches output path: {PATCHES_OUT}")
print(f"Submission patches output path: {SUBMISSION_PATCHES_OUT}")




TARGET_SIZE = (224, 224)                    # Target size for the resized images and masks

Dataset path: ../an2dl2526c2
Train data path: ../an2dl2526c2/train_data
Train labels path: ../an2dl2526c2/train_labels.csv
Test data path: ../an2dl2526c2/test_data
Patches output path: ../an2dl2526c2/preprocessing_results/train_patches
Submission patches output path: ../an2dl2526c2/preprocessing_results/submission_patches


## **4. Train/Val Split**

In [5]:
def create_metadata_dataframe(patches_dir, labels_csv_path):
    """
    Creates a DataFrame mapping patch filenames to their Bag IDs and Labels.
    """
    # 1. Load the labels CSV
    # Assuming CSV structure: [image_id, label] or similar
    df_labels = pd.read_csv(labels_csv_path)

    # Standardize column names for easier merging
    # We assume the first column is the ID and the second is the Label
    id_col = df_labels.columns[0]
    label_col = df_labels.columns[1]

    # Ensure IDs in CSV are strings (to match filenames)
    df_labels[id_col] = df_labels[id_col].astype(str)

    # If the CSV IDs contain extensions (e.g., 'img_001.png'), remove them
    # because our parsed Bag IDs won't have them.
    df_labels[id_col] = df_labels[id_col].apply(lambda x: os.path.splitext(x)[0])

    # 2. List all patch files
    patch_files = [f for f in os.listdir(patches_dir) if f.endswith('.png')]

    # 3. Parse filenames to get Bag IDs
    data = []
    print(f"Found {len(patch_files)} patches. Parsing metadata...")

    for filename in patch_files:
        # Expected format from your preprocessing: {base_name}_p{i}.png
        # Example: "img_0015_p12.png" -> Bag ID should be "img_0015"

        # Split from the right on '_p' to separate Bag ID from Patch Index
        # "img_0015_p12.png" -> ["img_0015", "12.png"]
        try:
            bag_id = filename.rsplit('_p', 1)[0]

            data.append({
                'filename': filename,
                'sample_id': bag_id,
                'path': os.path.join(patches_dir, filename)
            })
        except IndexError:
            print(f"Skipping malformed filename: {filename}")

    # Create temporary patches DataFrame
    df_patches = pd.DataFrame(data)

    # 4. Merge patches with labels
    # This assigns the correct Bag Label to every Patch in that Bag
    df = pd.merge(df_patches, df_labels, left_on='sample_id', right_on=id_col, how='inner')

    # 5. Clean up and Rename
    # Keep only required columns
    df = df[['filename', label_col, 'sample_id', 'path']]

    # Rename label column to standard 'label' if it isn't already
    df = df.rename(columns={label_col: 'label'})

    print(f"Successfully created DataFrame with {len(df)} rows.")
    return df

In [6]:
patches_metadata_df = create_metadata_dataframe(PATCHES_OUT, CSV_PATH)

# Verify the result
print("\nFirst 5 rows:")
print(patches_metadata_df.head().drop(columns=['path']))
print("\nPatches per Bag (Distribution):")
print(patches_metadata_df['sample_id'].value_counts().describe())

Found 3097 patches. Parsing metadata...
Successfully created DataFrame with 3097 rows.

First 5 rows:
          filename      label sample_id
0  img_0690_p2.png  Luminal A  img_0690
1  img_0690_p1.png  Luminal A  img_0690
2  img_0690_p0.png  Luminal A  img_0690
3  img_0689_p3.png  Luminal A  img_0689
4  img_0689_p2.png  Luminal A  img_0689

Patches per Bag (Distribution):
count    631.000000
mean       4.908082
std        2.913207
min        1.000000
25%        3.000000
50%        4.000000
75%        6.000000
max       23.000000
Name: count, dtype: float64


In [7]:
# Add Label Encoding
print("\n" + "="*50)
print("Label Encoding")
print("="*50)

label_encoder = LabelEncoder()
patches_metadata_df['label_encoded'] = label_encoder.fit_transform(patches_metadata_df['label'])

print(f"\nOriginal Labels: {label_encoder.classes_}")
print(f"Encoded as: {list(range(len(label_encoder.classes_)))}")
print(f"\nLabel Mapping:")
for orig, enc in zip(label_encoder.classes_, range(len(label_encoder.classes_))):
    print(f"  {orig} -> {enc}")


Label Encoding

Original Labels: ['HER2(+)' 'Luminal A' 'Luminal B' 'Triple negative']
Encoded as: [0, 1, 2, 3]

Label Mapping:
  HER2(+) -> 0
  Luminal A -> 1
  Luminal B -> 2
  Triple negative -> 3


In [8]:
# Train/Val Split on Original Images (not patches)
print("\n" + "="*50)
print("Train/Val Split on Original Images")
print("="*50)

# Get unique sample IDs
unique_samples = patches_metadata_df['sample_id'].unique()
print(f"\nTotal unique samples (original images): {len(unique_samples)}")

# Split samples into train (80%) and val (20%)
train_samples, val_samples = train_test_split(
    unique_samples,
    test_size=0.2,
    random_state=SEED,
    stratify=patches_metadata_df.drop_duplicates('sample_id').set_index('sample_id').loc[unique_samples, 'label_encoded'].values
)

print(f"Train samples: {len(train_samples)}")
print(f"Val samples: {len(val_samples)}")

# Create train and val DataFrames by filtering patches
df_train = patches_metadata_df[patches_metadata_df['sample_id'].isin(train_samples)].reset_index(drop=True)
df_val = patches_metadata_df[patches_metadata_df['sample_id'].isin(val_samples)].reset_index(drop=True)

print(f"\nTrain patches: {len(df_train)}")
print(f"Val patches: {len(df_val)}")
print(f"\nTrain label distribution:\n{df_train['label'].value_counts()}")
print(f"\nVal label distribution:\n{df_val['label'].value_counts()}")

# Print percentage distribution
print(f"\n" + "="*50)
print("Percentage Distribution")
print("="*50)
print(f"\nTrain label percentage:\n{df_train['label'].value_counts(normalize=True) * 100}")
print(f"\nVal label percentage:\n{df_val['label'].value_counts(normalize=True) * 100}")


Train/Val Split on Original Images

Total unique samples (original images): 631
Train samples: 504
Val samples: 127

Train patches: 2445
Val patches: 652

Train label distribution:
label
Luminal B          852
Luminal A          684
HER2(+)            676
Triple negative    233
Name: count, dtype: int64

Val label distribution:
label
Luminal B          238
HER2(+)            160
Luminal A          159
Triple negative     95
Name: count, dtype: int64

Percentage Distribution

Train label percentage:
label
Luminal B          34.846626
Luminal A          27.975460
HER2(+)            27.648262
Triple negative     9.529652
Name: proportion, dtype: float64

Val label percentage:
label
Luminal B          36.503067
HER2(+)            24.539877
Luminal A          24.386503
Triple negative    14.570552
Name: proportion, dtype: float64


## **5. Transformations & Augmentation**

In [9]:
# Define augmentation for training with enhanced transformations
train_augmentation = transforms.Compose([
    # Geometric transformations
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=15),  # Small rotations to handle orientation variations
    transforms.RandomAffine(
        degrees=0,
        translate=(0.1, 0.1),  # Reduced from 0.2 for more conservative shifts
        scale=None,  # Add scale variation
        shear=10  # Add shear transformation
    ),

    # Color/appearance transformations
    transforms.ColorJitter(
        brightness=0.2,  # Adjust brightness
        contrast=0.2,    # Adjust contrast
        saturation=0.2,  # Adjust saturation
        hue=0.1          # Slight hue variation
    ),
    #transforms.RandomGrayscale(p=0.1),  # Occasionally convert to grayscale to improve robustness

    # Occlusion simulation
    #transforms.RandomErasing(
    #    p=0.3,  # Reduced probability for more balanced augmentation
    #    scale=(0.02, 0.15),  # Reduced max scale
    #    ratio=(0.3, 3.3)  # Aspect ratio range
    #),

    # Optional: Add Gaussian blur for noise robustness
    # transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0))
])

## **6. Custom Dataset Class**

In [10]:
# ImageNet normalization statistics
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]

class TissueDataset(torch.utils.data.Dataset):
    def __init__(self, df, augmentation=None, normalize_imagenet=False, cache_images=True):
        self.augmentation = augmentation
        self.normalize_imagenet = normalize_imagenet
        self.df = df
        
        # CRITICAL OPTIMIZATION: Pre-convert to lists
        self.paths = df['path'].tolist()
        self.labels = df['label_encoded'].tolist()
        
        # Define transforms
        self.to_tensor = transforms.Compose([
            transforms.Resize(TARGET_SIZE),
            transforms.ToImage(),
            transforms.ToDtype(torch.float32, scale=True)
        ])
        
        if normalize_imagenet:
            self.normalize = transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
        else:
            self.normalize = None
        
        # For external use
        self.transform = transforms.Compose([
            transforms.Resize(TARGET_SIZE),
            transforms.ToImage(),
            transforms.ToDtype(torch.float32, scale=True),
            transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD) if normalize_imagenet else transforms.Identity()
        ])
        
        # --- IMAGE CACHING (Pre-load all images) ---
        self.image_cache = {}
        if cache_images:
            self._preload_images()
    
    def _preload_images(self):
        """Pre-load all images into memory for O(1) access during training."""
        total_images = len(self.paths)
        
        with tqdm(total=total_images, desc="Pre-loading images", unit="img") as pbar:
            for img_path in self.paths:
                if img_path not in self.image_cache:
                    try:
                        image = Image.open(img_path).convert("RGB")
                        image = self.to_tensor(image)
                        self.image_cache[img_path] = image
                    except Exception as e:
                        print(f"Error loading image {img_path}: {e}")
                        self.image_cache[img_path] = None
                
                pbar.update(1)
        
        print(f"Successfully cached {len(self.image_cache)} images.")

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        img_path = self.paths[idx]
        label = self.labels[idx]
        
        # Load image (from cache or disk)
        if img_path in self.image_cache:
            image = self.image_cache[img_path]
            if image is None:
                # Fallback for corrupted cached images
                image = torch.ones((3, TARGET_SIZE[0], TARGET_SIZE[1]), dtype=torch.float32) * 0.5
        else:
            # Fallback: Load on-the-fly if not cached
            try:
                img = Image.open(img_path).convert('RGB')
                image = self.to_tensor(img)
            except:
                image = torch.ones((3, TARGET_SIZE[0], TARGET_SIZE[1]), dtype=torch.float32) * 0.5
        
        # Apply augmentation
        if self.augmentation:
            image = self.augmentation(image)
        
        # Apply normalization
        if self.normalize:
            image = self.normalize(image)
        
        return image, label

## **7. Data Loaders**

In [11]:
if LOCAL: 
    num_workers = 0
    CACHE_IMAGES = False
else:
    num_workers = os.cpu_count()//2
    CACHE_IMAGES = True

# Instantiate Datasets
train_dataset = TissueDataset(
    df_train, 
    augmentation=train_augmentation, 
    normalize_imagenet=True,
    cache_images=CACHE_IMAGES  # Enable image pre-loading
)
val_dataset = TissueDataset(
    df_val, 
    augmentation=None, 
    normalize_imagenet=True,
    cache_images=CACHE_IMAGES  # Enable image pre-loading
)

Pre-loading images: 100%|██████████| 2445/2445 [00:10<00:00, 225.61img/s]


Successfully cached 2445 images.


Pre-loading images: 100%|██████████| 652/652 [00:02<00:00, 232.03img/s]

Successfully cached 652 images.





## **9. Model Definition (Transfer Learning - MobileNetV3)**

In [12]:
import torch
import torch.nn as nn

# IMPORTANT: this is RetCCL's ResNet implementation (copy ResNet.py from the RetCCL repo)
import ResNet as RetCCLResNet


def _clean_state_dict(sd: dict) -> dict:
    """Strip common prefixes (DataParallel, wrapper modules)."""
    out = {}
    for k, v in sd.items():
        for p in ("module.", "model.", "encoder.", "backbone."):
            if k.startswith(p):
                k = k[len(p):]
        out[k] = v
    return out


class RetCCLResNet50(nn.Module):
    """
    Drop-in replacement for your ResNet18 class, but using RetCCL (CNN) ResNet50 backbone.

    Args match your original:
      - num_classes
      - dropout_rate
      - freeze_backbone

    Extra:
      - ckpt_path: path to RetCCL checkpoint (e.g., best_ckpt.pth)
      - unfreeze_last_block: often helps on small datasets
    """
    def __init__(
        self,
        num_classes: int,
        dropout_rate: float = 0.2,
        freeze_backbone: bool = True,
        ckpt_path: str = "best_ckpt.pth",
        unfreeze_last_block: bool = True,
    ):
        super().__init__()

        # 1) Build RetCCL ResNet50 (their script uses num_classes=128 for the pretext head)
        self.backbone = RetCCLResNet.resnet50(
            num_classes=128, mlp=False, two_branch=False, normlinear=True
        )

        # 2) Load RetCCL pretrained weights
        sd = torch.load(ckpt_path, map_location="cpu")
        if isinstance(sd, dict) and "state_dict" in sd:
            sd = sd["state_dict"]
        sd = _clean_state_dict(sd)

        # Drop any fc keys from the checkpoint (we replace the head anyway)
        sd = {k: v for k, v in sd.items() if not k.startswith("fc.")}

        msg = self.backbone.load_state_dict(sd, strict=False)
        # Uncomment for debugging:
        # print("Missing keys:", msg.missing_keys)
        # print("Unexpected keys:", msg.unexpected_keys)

        # 3) Replace fc with your custom multi-layer head
        # RetCCL fc may not expose .in_features (e.g., NormLinear), so use weight shape
        if hasattr(self.backbone.fc, "in_features"):
            in_features = self.backbone.fc.in_features
        elif hasattr(self.backbone.fc, "weight"):
            in_features = self.backbone.fc.weight.shape[1]
        else:
            # ResNet50 default is 2048 if all else fails
            in_features = 2048

        self.backbone.fc = nn.Sequential(
            nn.Linear(in_features, 1024),
            nn.Hardswish(inplace=True),
            nn.Dropout(p=dropout_rate),
            nn.Linear(1024, 256),
            nn.ReLU(inplace=True),
            nn.Dropout(p=dropout_rate),
            nn.Linear(256, num_classes),
        )

        # 4) Freeze backbone (optional) + optionally unfreeze last block
        if freeze_backbone:
            for p in self.backbone.parameters():
                p.requires_grad = False

            # always train the new head
            for p in self.backbone.fc.parameters():
                p.requires_grad = True

            # often beneficial for small pathology datasets
            if unfreeze_last_block and hasattr(self.backbone, "layer4"):
                for p in self.backbone.layer4.parameters():
                    p.requires_grad = True

    def forward(self, x):
        return self.backbone(x)


In [13]:
import torch
import torch.nn as nn
import ResNet as RetCCLResNet # Ensure this matches your import

class RetCCLResNet50_Flexible(RetCCLResNet50):
    """
    Extends your original class to support dynamic MLP heads for Grid Search.
    """
    def __init__(
        self,
        num_classes: int,
        dropout_rate: float = 0.2,
        freeze_backbone: bool = True,
        ckpt_path: str = "best_ckpt.pth",
        unfreeze_last_block: bool = True,
        # New parameters for head search
        head_type: str = 'original', # 'linear', 'mlp_1_layer', 'original'
        head_hidden_dim: int = 1024
    ):
        # Initialize the base model (loads weights, sets up backbone)
        super().__init__(
            num_classes=num_classes,
            dropout_rate=dropout_rate,
            freeze_backbone=freeze_backbone,
            ckpt_path=ckpt_path,
            unfreeze_last_block=unfreeze_last_block
        )

        # Detect input features (ResNet50 usually 2048)
        # We look at the first layer of the head created by the parent class to find input size
        if isinstance(self.backbone.fc, nn.Sequential):
            in_features = self.backbone.fc[0].in_features
        elif isinstance(self.backbone.fc, nn.Linear):
            in_features = self.backbone.fc.in_features
        else:
            in_features = 2048 

        # --- Re-define the Head based on 'head_type' ---
        
        # Variation 1: Simple Linear (Logistic Regression equivalent)
        if head_type == 'linear':
            self.backbone.fc = nn.Linear(in_features, num_classes)
            
        # Variation 2: 1 Hidden Layer (Standard MLP)
        elif head_type == 'mlp_1_layer':
            self.backbone.fc = nn.Sequential(
                nn.Linear(in_features, head_hidden_dim),
                nn.ReLU(inplace=True),
                nn.Dropout(p=dropout_rate),
                nn.Linear(head_hidden_dim, num_classes),
            )
            
        # Variation 3: Your Original Complex Head (2 Hidden Layers + Hardswish)
        elif head_type == 'original':
            self.backbone.fc = nn.Sequential(
                nn.Linear(in_features, 1024),
                nn.Hardswish(inplace=True),
                nn.Dropout(p=dropout_rate),
                nn.Linear(1024, 256),
                nn.ReLU(inplace=True),
                nn.Dropout(p=dropout_rate),
                nn.Linear(256, num_classes),
            )
            
        # Ensure the new head is trainable
        for p in self.backbone.fc.parameters():
            p.requires_grad = True

## **10. Loss and Optimizer**

In [14]:
class FocalLoss(nn.Module):
    def __init__(self, alpha=None, gamma=2.0, reduction='mean'):
        """
        Args:
            alpha (Tensor, optional): Weights for each class. Shape [C].
            gamma (float): Focusing parameter. Higher value = more focus on hard examples.
                           Default is 2.0 (standard from the paper).
            reduction (str): 'mean', 'sum', or 'none'.
        """
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        # inputs: [Batch, C] (Logits)
        # targets: [Batch] (Class Indices)
        
        # 1. Standard Cross Entropy Loss (element-wise, no reduction yet)
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')
        
        # 2. Get the probability of the true class (pt)
        # pt = exp(-ce_loss) because ce_loss = -log(pt)
        pt = torch.exp(-ce_loss)
        
        # 3. Calculate Focal Component: (1 - pt)^gamma
        focal_loss = ((1 - pt) ** self.gamma) * ce_loss
        
        # 4. Apply Class Weights (alpha) if provided
        if self.alpha is not None:
            # Gather the alpha value corresponding to the target class for each sample
            if self.alpha.device != inputs.device:
                self.alpha = self.alpha.to(inputs.device)
            
            alpha_t = self.alpha[targets]
            focal_loss = alpha_t * focal_loss

        # 5. Reduction
        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

## **11. Function: Training & Validation Loop**

In [15]:
def train_one_epoch(model, loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0

    # Lists to store all predictions and labels for F1 calculation
    all_preds = []
    all_labels = []

    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Metrics accumulation
        running_loss += loss.item() * images.size(0)

        _, predicted = torch.max(outputs, 1)

        # Move to CPU and convert to numpy for sklearn metrics
        all_preds.extend(predicted.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())


    epoch_loss = running_loss / len(loader.dataset)
    # Calculate F1 Score (Macro for imbalanced data)
    epoch_f1 = f1_score(all_labels, all_preds, average='macro')

    return epoch_loss, epoch_f1

def validate(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0

    all_preds = []
    all_labels = []

    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)

            running_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs, 1)

            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    epoch_loss = running_loss / len(loader.dataset)
    epoch_f1 = f1_score(all_labels, all_preds, average='macro')

    return epoch_loss, epoch_f1

## **OPTUNA**

In [16]:
# Define explicit combinations to avoid redundant runs in Grid Search
LOSS_CONFIGS = {
    # --- Focal Loss Variations ---
    'focal_g2_weighted':   {'type': 'Focal', 'gamma': 2.0, 'use_weights': True,  'smoothing': 0.0},
    'focal_g5_weighted':   {'type': 'Focal', 'gamma': 5.0, 'use_weights': True,  'smoothing': 0.0},
    'focal_g2_no_weight':  {'type': 'Focal', 'gamma': 2.0, 'use_weights': False, 'smoothing': 0.0},
    
    # --- CrossEntropy Variations ---
    'ce_plain':            {'type': 'CE',    'gamma': None, 'use_weights': False, 'smoothing': 0.0},
    'ce_weighted':         {'type': 'CE',    'gamma': None, 'use_weights': True,  'smoothing': 0.0},
    'ce_smooth_0.1':       {'type': 'CE',    'gamma': None, 'use_weights': False, 'smoothing': 0.1},
    'ce_weighted_smooth':  {'type': 'CE',    'gamma': None, 'use_weights': True,  'smoothing': 0.1},
}

In [17]:
def objective(trial):
    # --- 1. Clean Hyperparameters (TPE supports ranges) ---
    lr = trial.suggest_float('lr', 1e-4, 1e-3, log=True)
    batch_size = trial.suggest_categorical('batch_size', [64])
    optimizer_name = trial.suggest_categorical('optimizer', ['AdamW', 'RAdam'])
    l2_reg = trial.suggest_float('l2_reg', 1e-5, 1e-2, log=True)
    
    # Corrected Dropout (removed the duplicate)
    dropout_rate = trial.suggest_float('dropout_rate', 0.3, 0.5) 
    
    # Head Params
    head_type = trial.suggest_categorical('head_type', ['linear', 'mlp_1_layer', 'original'])
    head_hidden_dim = trial.suggest_categorical('head_hidden_dim', [512, 1024])

    # --- 2. Smart Weight Logic ---
    # We select the Loss Config first
    loss_config_name = trial.suggest_categorical('loss_config', list(LOSS_CONFIGS.keys()))
    current_loss_params = LOSS_CONFIGS[loss_config_name]
    
    # Only suggest weights if the loss config actually USES them
    # This prevents TPE from wasting time tuning w1..w4 for unweighted losses
    if current_loss_params['use_weights']:
        w1 = trial.suggest_float('w1', 0.5, 1.0)
        w2 = trial.suggest_float('w2', 0.5, 1.0)
        w3 = trial.suggest_float('w3', 0.5, 1.0)
        w4 = trial.suggest_float('w4', 1.0, 1.5)
        final_weights = torch.tensor([w1, w2, w3, w4], dtype=torch.float32).to(device)
    else:
        final_weights = None

    # --- 3. Data Loaders ---
    opt_train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=os.cpu_count()//2, pin_memory=True)
    opt_val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=os.cpu_count()//2, pin_memory=True)

    # --- 4. Model Setup ---
    model = RetCCLResNet50_Flexible(
        num_classes=4, 
        dropout_rate=dropout_rate,
        freeze_backbone=True, 
        ckpt_path=os.path.join("models", "best_ckpt.pth"), 
        unfreeze_last_block=False,
        head_type=head_type,
        head_hidden_dim=head_hidden_dim
    ).to(device)

    # --- 5. Loss Setup ---
    if current_loss_params['type'] == 'Focal':
        criterion = FocalLoss(alpha=final_weights, gamma=current_loss_params['gamma'])
    elif current_loss_params['type'] == 'CE':
        criterion = nn.CrossEntropyLoss(weight=final_weights, label_smoothing=current_loss_params['smoothing'])
    
    # --- 6. Optimizer ---
    if optimizer_name == 'AdamW':
        optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=l2_reg)
    elif optimizer_name == 'SGD':
        optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, momentum=0.9, weight_decay=l2_reg)
    else:  # RAdam
        optimizer = optim.RAdam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=l2_reg, eps=1e-8, betas=(0.9, 0.999))

    # --- 7. Training Loop ---
    SEARCH_EPOCHS = 10 
    best_f1_in_trial = 0.0

    for epoch in range(SEARCH_EPOCHS):
        # 1. Run Train & Val (Silent functions, no tqdm)
        train_loss, train_f1 = train_one_epoch(model, opt_train_loader, criterion, optimizer, device)
        val_loss, val_f1 = validate(model, opt_val_loader, criterion, device)

        # 2. Update Best Score
        if val_f1 > best_f1_in_trial:
            best_f1_in_trial = val_f1
        
        # 3. PRINT PROGRESS (The Solution)
        # This prints ONE line per epoch. Clean and informative.
        print(f" Trial {trial.number} | Epoch {epoch+1}/{SEARCH_EPOCHS} | "
              f"Train Loss: {train_loss:.4f} | Train F1: {train_f1:.4f} | "
              f"Val Loss: {val_loss:.4f} | Val F1: {val_f1:.4f}")

        # 4. Pruning
        trial.report(val_f1, epoch)
        if trial.should_prune():
            print(f"Trial {trial.number} Pruned at Epoch {epoch+1}")
            raise optuna.TrialPruned()

    return best_f1_in_trial

In [19]:
# --- 3. Run the Expanded Study ---
sampler = TPESampler(seed=42)
pruner = optuna.pruners.MedianPruner(n_startup_trials=5, n_warmup_steps=3)

study = optuna.create_study(
    study_name="retccl_head_search", 
    direction="maximize", 
    sampler=sampler, 
    pruner=pruner
)
N_TRIALS = 100
print(f"Starting TPE Search with {N_TRIALS} trials...") # Fixed print statement

# Ensure n_trials is sufficient for TPE to converge (100 is a good start)
study.optimize(objective, n_trials=N_TRIALS, show_progress_bar=True, n_jobs=1) 

print("Best Params:", study.best_params)

[I 2025-12-13 12:25:59,286] A new study created in memory with name: retccl_head_search


Starting TPE Search with 100 trials...


  0%|          | 0/100 [00:00<?, ?it/s]

 Trial 0 | Epoch 1/10 | Train Loss: 0.2416 | Train F1: 0.2347 | Val Loss: 0.2544 | Val F1: 0.2853
 Trial 0 | Epoch 2/10 | Train Loss: 0.2233 | Train F1: 0.3310 | Val Loss: 0.2501 | Val F1: 0.2759
 Trial 0 | Epoch 3/10 | Train Loss: 0.2127 | Train F1: 0.3537 | Val Loss: 0.2430 | Val F1: 0.3117
 Trial 0 | Epoch 4/10 | Train Loss: 0.2010 | Train F1: 0.4308 | Val Loss: 0.2448 | Val F1: 0.3313
 Trial 0 | Epoch 5/10 | Train Loss: 0.1947 | Train F1: 0.4503 | Val Loss: 0.2358 | Val F1: 0.3584
 Trial 0 | Epoch 6/10 | Train Loss: 0.1883 | Train F1: 0.4517 | Val Loss: 0.2462 | Val F1: 0.3756
 Trial 0 | Epoch 7/10 | Train Loss: 0.1852 | Train F1: 0.4800 | Val Loss: 0.2515 | Val F1: 0.3835
 Trial 0 | Epoch 8/10 | Train Loss: 0.1811 | Train F1: 0.4776 | Val Loss: 0.2353 | Val F1: 0.3896
 Trial 0 | Epoch 9/10 | Train Loss: 0.1777 | Train F1: 0.4976 | Val Loss: 0.2672 | Val F1: 0.3916
 Trial 0 | Epoch 10/10 | Train Loss: 0.1753 | Train F1: 0.5006 | Val Loss: 0.2353 | Val F1: 0.4012
[I 2025-12-13 12:26

In [20]:
# --- Display Results ---
print("\n" + "="*50)
print("Study Statistics")
print("="*50)
print(f"Number of finished trials: {len(study.trials)}")
print(f"Best F1 Score: {study.best_value:.4f}")
print("Best Hyperparameters:")
for key, value in study.best_params.items():
    print(f"  {key}: {value}")


Study Statistics
Number of finished trials: 100
Best F1 Score: 0.4486
Best Hyperparameters:
  lr: 0.0008367433034670895
  batch_size: 64
  optimizer: AdamW
  l2_reg: 0.00023922618482545056
  dropout_rate: 0.39563913357721187
  head_type: original
  head_hidden_dim: 1024
  loss_config: focal_g2_no_weight


In [21]:
from optuna.visualization import plot_optimization_history, plot_param_importances, plot_parallel_coordinate


# 1. History (Did it improve over time?)
fig1 = plot_optimization_history(study)
fig1.show()

# 2. Importance (Which parameter affects the F1 score the most?)
fig2 = plot_param_importances(study)
fig2.show()

# 3. Parallel Coordinate (How do parameters interact?)
fig3 = plot_parallel_coordinate(study)
fig3.show()