# MedNeXt Fine-Tuning (Differential Learning Rates)
This notebook fine-tunes MedNeXt on BraTS-style data with a differential learning rate strategy, then provides training plots and inference loading helpers.

In [None]:
!git clone https://github.com/MIC-DKFZ/MedNeXt.git mednext
!pip install -e ./mednext

Cloning into 'mednext'...
remote: Enumerating objects: 762, done.[K
remote: Counting objects: 100% (320/320), done.[K
remote: Compressing objects: 100% (76/76), done.[K
remote: Total 762 (delta 270), reused 244 (delta 244), pack-reused 442 (from 1)[K
Receiving objects: 100% (762/762), 522.43 KiB | 11.61 MiB/s, done.
Resolving deltas: 100% (459/459), done.
Obtaining file:///kaggle/working/mednext
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting dicom2nifti (from mednextv1==1.7.0)
  Downloading dicom2nifti-2.6.2-py3-none-any.whl.metadata (1.5 kB)
Collecting medpy (from mednextv1==1.7.0)
  Downloading medpy-0.5.2.tar.gz (156 kB)
[2K     [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m156.3/156.3 kB[0m [31m4.0 MB/s[0m eta [36m0:00:00[0m00:01[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting batchgenerators>=0.23 (from mednextv1==1.7.0)
  Downloading batchgene

In [None]:
import os
import sys
import math
import random
import glob
import numpy as np
import nibabel as nib
from tqdm import tqdm
from typing import List, Tuple, Dict
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import GradScaler, autocast

repo_path = os.path.abspath('mednext')
if repo_path not in sys.path:
    sys.path.insert(0, repo_path)

try:
    from nnunet_mednext import create_mednext_v1, MedNeXt
    print("MedNeXt library imported successfully.")
except ImportError as e:
    print(f"Error importing MedNeXt: {e}")

class TrainingConfig:
    PRETRAINED_PATH = "/kaggle/input/mednext/pytorch/default/1/best_model.pt"
    TRAIN_DIR = "/kaggle/input/instant-odc-ai-hackathon/Train"
    OUTPUT_MODEL_PATH = "/kaggle/working/best_finetuned_model.pt"

    MODEL_SIZE = 'B'
    KERNEL_SIZE = 3
    IN_CHANNELS = 4
    NUM_CLASSES = 4

    NUM_EPOCHS = 10
    BATCH_SIZE = 1
    PATCH_SIZE = (128, 128, 128)
    SAMPLES_PER_VOLUME = 4
    NUM_WORKERS = 4

    LEARNING_RATE = 1e-4
    WEIGHT_DECAY = 1e-5

    DICE_WEIGHT = 1.0
    CE_WEIGHT = 1.0

    WARMUP_EPOCHS = 2

    VAL_SPLIT = 0.15

    USE_AMP = True
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

config = TrainingConfig()
num_gpus = torch.cuda.device_count()

print("=" * 60)
print("MEDNEXT-B FINE-TUNING WITH DIFFERENTIAL LEARNING RATES")
print("=" * 60)
print(f"Pretrained model: {config.PRETRAINED_PATH}")
print(f"Training data:    {config.TRAIN_DIR}")
print(f"Output model:     {config.OUTPUT_MODEL_PATH}")
print("-" * 60)
print(f"Model: MedNeXt-{config.MODEL_SIZE} (kernel={config.KERNEL_SIZE})")
print("Strategy: Differential Learning Rates (all layers trainable)")
print(f"   Early layers (stem, enc_0/1): {config.LEARNING_RATE * 0.01:.2e} (100x smaller)")
print(f"   Later layers (enc_2+, decoder): {config.LEARNING_RATE:.2e}")
print("-" * 60)
print(f"Epochs: {config.NUM_EPOCHS}")
print(f"Batch size: {config.BATCH_SIZE} per GPU")
print(f"Patch size: {config.PATCH_SIZE}")
print(f"Samples per volume: {config.SAMPLES_PER_VOLUME}")
print("-" * 60)
print(f"Warmup: {config.WARMUP_EPOCHS} epochs")
print(f"Loss weights: Dice={config.DICE_WEIGHT}, CE={config.CE_WEIGHT}")
print("-" * 60)
print(f"Device: {config.DEVICE}")
print(f"GPUs available: {num_gpus}")
if num_gpus > 1:
    print("Multi-GPU training enabled with DataParallel")
    print(f"Effective batch size: {config.BATCH_SIZE * num_gpus}")
print("=" * 60)

‚úÖ MedNeXt library imported successfully!
üß† MEDNEXT-B FINE-TUNING WITH DIFFERENTIAL LEARNING RATES
üìÅ Pretrained model: /kaggle/input/mednext/pytorch/default/1/best_model.pt
üìÅ Training data:    /kaggle/input/instant-odc-ai-hackathon/Train
üìÅ Output model:     /kaggle/working/best_finetuned_model.pt
------------------------------------------------------------
üîß Model: MedNeXt-B (kernel=3)
üéØ Strategy: Differential Learning Rates (all layers trainable)
   ‚Ä¢ Early layers (stem, enc_0/1): 1.00e-06 (100x smaller)
   ‚Ä¢ Later layers (enc_2+, decoder): 1.00e-04
------------------------------------------------------------
üìä Epochs: 10
üìä Batch size: 1 per GPU
üìä Patch size: (128, 128, 128)
üìä Samples per volume: 4
------------------------------------------------------------
üìà Warmup: 2 epochs (protects early training)
üìà Loss weights: Dice=1.0, CE=1.0
------------------------------------------------------------
üñ•Ô∏è Device: cuda
üñ•Ô∏è GPUs available: 2
üö

In [None]:
class NestedFolderDataset(Dataset):

    def __init__(self, data_dir: str, subject_ids: List[str] = None,
                 patch_size: Tuple[int, int, int] = (128, 128, 128),
                 samples_per_volume: int = 2,
                 augment: bool = True):
        self.data_dir = data_dir
        self.patch_size = patch_size
        self.samples_per_volume = samples_per_volume
        self.augment = augment

        if subject_ids is None:
            if os.path.exists(data_dir):
                all_items = os.listdir(data_dir)
                self.subject_ids = sorted([
                    d for d in all_items
                    if os.path.isdir(os.path.join(data_dir, d))
                ])
            else:
                self.subject_ids = []
        else:
            self.subject_ids = subject_ids

        print(f"Found {len(self.subject_ids)} subjects in dataset")
        if len(self.subject_ids) > 0:
            print(f"Sample subjects: {self.subject_ids[:3]}")

    def __len__(self):
        return len(self.subject_ids) * self.samples_per_volume

    def _find_nifti_in_folder(self, folder_path):
        """Find any .nii or .nii.gz file in a folder."""
        if not os.path.isdir(folder_path):
            return None

        all_files = os.listdir(folder_path)
        for f in all_files:
            if f.endswith('.nii.gz'):
                return os.path.join(folder_path, f)
        for f in all_files:
            if f.endswith('.nii'):
                return os.path.join(folder_path, f)
        return None

    def _find_modality_file(self, patient_path, subject_id, modality):
        """Find modality file in nested folder structure."""
        items = os.listdir(patient_path)
        modality_lower = modality.lower()

        for item in items:
            item_path = os.path.join(patient_path, item)
            item_lower = item.lower()

            if os.path.isdir(item_path):
                if modality_lower == 't1':
                    if 't1' in item_lower and 't1ce' not in item_lower and 't1gd' not in item_lower:
                        nifti_file = self._find_nifti_in_folder(item_path)
                        if nifti_file:
                            return nifti_file
                elif modality_lower in item_lower:
                    nifti_file = self._find_nifti_in_folder(item_path)
                    if nifti_file:
                        return nifti_file

        direct_patterns = [
            os.path.join(patient_path, f"{subject_id}_{modality}.nii.gz"),
            os.path.join(patient_path, f"{subject_id}_{modality}.nii"),
        ]
        for pattern in direct_patterns:
            if os.path.exists(pattern):
                return pattern

        raise FileNotFoundError(f"Could not find {modality} for {subject_id}")

    def _load_nifti(self, filepath):
        """Load a NIfTI file."""
        img = nib.load(filepath)
        return img.get_fdata().astype(np.float32)

    def _normalize(self, data):
        """Robust Z-score normalization with percentile clipping."""
        mask = data > 0
        if mask.sum() == 0:
            return data
        pixels = data[mask]
        p_low, p_high = np.percentile(pixels, 0.5), np.percentile(pixels, 99.5)
        data = np.clip(data, p_low, p_high)
        pixels = data[mask]
        mean, std = pixels.mean(), pixels.std()
        data = (data - mean) / (std + 1e-8)
        data[~mask] = 0
        return data

    def _extract_random_patch(self, volume, seg):
        """Extract a random patch centered on tumor region (if possible)."""
        D, H, W = volume.shape[1:]
        pd, ph, pw = self.patch_size

        tumor_coords = np.where(seg > 0)

        if len(tumor_coords[0]) > 0 and random.random() > 0.2:
            idx = random.randint(0, len(tumor_coords[0]) - 1)
            center_d = tumor_coords[0][idx]
            center_h = tumor_coords[1][idx]
            center_w = tumor_coords[2][idx]

            d_start = max(0, min(D - pd, center_d - pd // 2 + random.randint(-20, 20)))
            h_start = max(0, min(H - ph, center_h - ph // 2 + random.randint(-20, 20)))
            w_start = max(0, min(W - pw, center_w - pw // 2 + random.randint(-20, 20)))
        else:
            d_start = random.randint(0, max(0, D - pd))
            h_start = random.randint(0, max(0, H - ph))
            w_start = random.randint(0, max(0, W - pw))

        vol_patch = volume[:, d_start:d_start+pd, h_start:h_start+ph, w_start:w_start+pw]
        seg_patch = seg[d_start:d_start+pd, h_start:h_start+ph, w_start:w_start+pw]

        if vol_patch.shape[1:] != self.patch_size:
            pad_d = pd - vol_patch.shape[1]
            pad_h = ph - vol_patch.shape[2]
            pad_w = pw - vol_patch.shape[3]
            vol_patch = np.pad(vol_patch, ((0, 0), (0, pad_d), (0, pad_h), (0, pad_w)))
            seg_patch = np.pad(seg_patch, ((0, pad_d), (0, pad_h), (0, pad_w)))

        return vol_patch, seg_patch

    def _augment(self, volume, seg):
        """Apply random augmentations."""
        if random.random() > 0.5:
            volume = np.flip(volume, axis=1).copy()
            seg = np.flip(seg, axis=0).copy()
        if random.random() > 0.5:
            volume = np.flip(volume, axis=2).copy()
            seg = np.flip(seg, axis=1).copy()
        if random.random() > 0.5:
            volume = np.flip(volume, axis=3).copy()
            seg = np.flip(seg, axis=2).copy()

        if random.random() > 0.5:
            for c in range(volume.shape[0]):
                scale = random.uniform(0.9, 1.1)
                volume[c] = volume[c] * scale

        return volume, seg

    def __getitem__(self, idx):
        subject_idx = idx // self.samples_per_volume
        subject_id = self.subject_ids[subject_idx]
        patient_path = os.path.join(self.data_dir, subject_id)

        modalities = ['t1', 't1ce', 't2', 'flair']
        modality_data = []

        for mod in modalities:
            try:
                filepath = self._find_modality_file(patient_path, subject_id, mod)
                data = self._load_nifti(filepath)
                data = self._normalize(data)
                modality_data.append(data)
            except Exception as e:
                print(f"Error loading {mod} for {subject_id}: {e}")
                if modality_data:
                    modality_data.append(np.zeros_like(modality_data[0]))
                else:
                    raise e

        volume = np.stack(modality_data, axis=0)

        try:
            seg_file = self._find_modality_file(patient_path, subject_id, 'seg')
            seg = self._load_nifti(seg_file)
            new_seg = np.zeros_like(seg)
            new_seg[seg == 1] = 1
            new_seg[seg == 2] = 2
            new_seg[seg == 4] = 3
            seg = new_seg
        except:
            seg = np.zeros(volume.shape[1:], dtype=np.float32)

        vol_patch, seg_patch = self._extract_random_patch(volume, seg)

        if self.augment:
            vol_patch, seg_patch = self._augment(vol_patch, seg_patch)

        return {
            'volume': torch.from_numpy(vol_patch.copy()).float(),
            'segmentation': torch.from_numpy(seg_patch.copy()).long(),
            'subject_id': subject_id
        }

print("NestedFolderDataset defined.")

‚úÖ NestedFolderDataset defined!


In [None]:
def diagnose_dataset_structure(data_dir: str, num_samples: int = 2):
    if not os.path.exists(data_dir):
        print(f"Directory does not exist: {data_dir}")
        return False

    all_items = os.listdir(data_dir)
    subject_folders = [d for d in all_items if os.path.isdir(os.path.join(data_dir, d))]

    print(f"\nData directory: {data_dir}")
    print(f"Total subject folders found: {len(subject_folders)}")

    if len(subject_folders) == 0:
        print("No subject folders found.")
        return False

    print(f"\nSample subjects: {subject_folders[:5]}")

    modalities_to_check = ['t1', 't1ce', 't2', 'flair', 'seg']
    all_ok = True

    for i, subject_id in enumerate(subject_folders[:num_samples]):
        print(f"\n{'‚îÄ' * 50}")
        print(f"Subject {i+1}: {subject_id}")
        patient_path = os.path.join(data_dir, subject_id)

        contents = os.listdir(patient_path)
        print(f"   Contents: {contents}")

        for mod in modalities_to_check:
            found = False
            found_path = None

            for item in contents:
                item_path = os.path.join(patient_path, item)
                item_lower = item.lower()

                if os.path.isdir(item_path):
                    if mod == 't1':
                        if 't1' in item_lower and 't1ce' not in item_lower and 't1gd' not in item_lower:
                            nii_files = [f for f in os.listdir(item_path) if f.endswith(('.nii.gz', '.nii'))]
                            if nii_files:
                                found = True
                                found_path = os.path.join(item_path, nii_files[0])
                                break
                    elif mod.lower() in item_lower:
                        nii_files = [f for f in os.listdir(item_path) if f.endswith(('.nii.gz', '.nii'))]
                        if nii_files:
                            found = True
                            found_path = os.path.join(item_path, nii_files[0])
                            break
                elif item.endswith(('.nii.gz', '.nii')):
                    if mod.lower() in item.lower():
                        found = True
                        found_path = item_path
                        break

            status = "OK" if found else "MISSING"
            if found:
                print(f"   {status:7s} {mod.upper():6s} -> {os.path.basename(found_path)}")
            else:
                print(f"   {status:7s} {mod.upper():6s} -> NOT FOUND")
                all_ok = False

    print(f"\n{'=' * 60}")
    if all_ok:
        print("Dataset structure looks correct.")
        print("All modalities found for sampled subjects.")
    else:
        print("Some modalities were not found.")
        print("Check the folder naming conventions.")
    print("=" * 60)

    return all_ok

diagnose_dataset_structure(TrainingConfig.TRAIN_DIR, num_samples=3)

üîç DATASET STRUCTURE DIAGNOSTIC

üìÅ Data directory: /kaggle/input/instant-odc-ai-hackathon/Train
üìä Total subject folders found: 917

üìã Sample subjects: ['BraTS2021_01030', 'BraTS2021_00656', 'BraTS2021_00466', 'BraTS2021_01070', 'BraTS2021_01057']

‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
üß™ Subject 1: BraTS2021_01030
   Contents: ['BraTS2021_01030_t1ce.nii', 'BraTS2021_01030_flair.nii', 'BraTS2021_01030_t2.nii', 'BraTS2021_01030_t1.nii', 'BraTS2021_01030_seg.nii']
   ‚úÖ T1     ‚Üí BraTS2021_01030_t1ce.nii
   ‚úÖ T1CE   ‚Üí BraTS2021_01030_t1ce.nii
   ‚úÖ T2     ‚Üí BraTS2021_01030_t2.nii
   ‚úÖ FLAIR  ‚Üí BraTS2021_01030_flair.nii
   ‚úÖ SEG    ‚Üí BraTS2021_01030_seg.nii

‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
üß™ Subject 2: BraTS2021_00656
   Contents

True

In [None]:
class DiceLoss(nn.Module):
    """Dice Loss for segmentation."""

    def __init__(self, smooth: float = 1e-5):
        super().__init__()
        self.smooth = smooth

    def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        """
        Args:
            pred: Softmax probabilities (B, C, D, H, W)
            target: One-hot encoded targets (B, C, D, H, W)
        """
        pred_flat = pred.view(pred.size(0), pred.size(1), -1)
        target_flat = target.view(target.size(0), target.size(1), -1)

        intersection = (pred_flat * target_flat).sum(-1)
        union = pred_flat.sum(-1) + target_flat.sum(-1)

        dice = (2 * intersection + self.smooth) / (union + self.smooth)
        return 1 - dice[:, 1:].mean()

class CombinedLoss(nn.Module):
    """Combined Dice + Cross Entropy loss with region-based Dice."""

    def __init__(self, dice_weight: float = 0.5, ce_weight: float = 0.5):
        super().__init__()
        self.dice_weight = dice_weight
        self.ce_weight = ce_weight
        self.dice_loss = DiceLoss()
        self.ce_loss = nn.CrossEntropyLoss()

    def forward(self, pred: torch.Tensor, target: torch.Tensor) -> dict:
        """
        Args:
            pred: Logits (B, C, D, H, W), C=4 classes
            target: Class indices (B, D, H, W)
        """
        ce = self.ce_loss(pred, target)

        pred_soft = F.softmax(pred, dim=1)

        target_onehot = F.one_hot(target, num_classes=pred.size(1))
        target_onehot = target_onehot.permute(0, 4, 1, 2, 3).float()

        dice = self.dice_loss(pred_soft, target_onehot)

        pred_wt = pred_soft[:, 1:].sum(dim=1, keepdim=True)
        target_wt = (target >= 1).float().unsqueeze(1)

        pred_tc = pred_soft[:, 1:2] + pred_soft[:, 3:4]
        target_tc = ((target == 1) | (target == 3)).float().unsqueeze(1)

        pred_et = pred_soft[:, 3:4]
        target_et = (target == 3).float().unsqueeze(1)

        def dice_score(p, t, smooth=1e-5):
            intersection = (p * t).sum()
            return (2 * intersection + smooth) / (p.sum() + t.sum() + smooth)

        dice_wt = 1 - dice_score(pred_wt, target_wt)
        dice_tc = 1 - dice_score(pred_tc, target_tc)
        dice_et = 1 - dice_score(pred_et, target_et)

        region_dice = (dice_wt + dice_tc + dice_et) / 3

        total = self.ce_weight * ce + self.dice_weight * (dice + region_dice) / 2

        return {
            'total': total,
            'ce': ce.item(),
            'dice': dice.item(),
            'dice_wt': 1 - dice_wt.item(),
            'dice_tc': 1 - dice_tc.item(),
            'dice_et': 1 - dice_et.item()
        }

print("Loss functions defined.")

‚úÖ Loss functions defined!


In [None]:
EARLY_LAYER_PATTERNS = ['stem', 'enc_block_0', 'enc_block_1', 'downsample_0', 'downsample_1']
EARLY_LR_FACTOR = 0.01

def load_pretrained_model(
    pretrained_path: str,
    model_size: str = 'B',
    in_channels: int = 4,
    num_classes: int = 4,
    kernel_size: int = 3,
 ):
    """Load pretrained MedNeXt model (all layers trainable)."""
    model = create_mednext_v1(
        num_input_channels=in_channels,
        num_classes=num_classes,
        model_id=model_size,
        kernel_size=kernel_size,
        deep_supervision=False
    )

    if os.path.exists(pretrained_path):
        print(f"Loading pretrained weights from: {pretrained_path}")
        try:
            checkpoint = torch.load(pretrained_path, map_location='cpu')
        except Exception as e:
            print(f"torch.load default failed: {e}")
            print("Retrying with weights_only=False")
            checkpoint = torch.load(pretrained_path, map_location='cpu', weights_only=False)

        if isinstance(checkpoint, dict):
            if 'state_dict' in checkpoint:
                state_dict = checkpoint['state_dict']
            elif 'model_state_dict' in checkpoint:
                state_dict = checkpoint['model_state_dict']
            else:
                state_dict = checkpoint
        else:
            state_dict = checkpoint

        new_state_dict = {}
        for k, v in state_dict.items():
            if k.startswith('module.'):
                new_state_dict[k[7:]] = v
            else:
                new_state_dict[k] = v

        model.load_state_dict(new_state_dict, strict=False)
        print("Pretrained weights loaded successfully.")
    else:
        print(f"Pretrained weights not found at {pretrained_path}")
        print("Training from scratch...")

    for param in model.parameters():
        param.requires_grad = True

    total_params = sum(p.numel() for p in model.parameters())
    print("\nModel Parameter Summary:")
    print(f"   Total parameters: {total_params:,} ({total_params/1e6:.2f}M)")
    print("   All parameters trainable (using differential learning rates)")

    return model


def get_parameter_groups(model, base_lr: float, early_lr_factor: float = 0.01):
    """Create parameter groups with differential learning rates."""
    early_lr = base_lr * early_lr_factor

    actual_model = model.module if isinstance(model, nn.DataParallel) else model

    early_params = []
    rest_params = []
    early_param_names = []
    rest_param_names = []

    for name, param in actual_model.named_parameters():
        if not param.requires_grad:
            continue

        is_early = any(pattern in name for pattern in EARLY_LAYER_PATTERNS)

        if is_early:
            early_params.append(param)
            early_param_names.append(name)
        else:
            rest_params.append(param)
            rest_param_names.append(name)

    early_count = sum(p.numel() for p in early_params)
    rest_count = sum(p.numel() for p in rest_params)

    print("\nDifferential Learning Rate Setup:")
    print(f"   Early layers (LR={early_lr:.2e}): {len(early_params)} tensors, {early_count:,} params ({early_count/1e6:.2f}M)")
    print(f"   Later layers (LR={base_lr:.2e}): {len(rest_params)} tensors, {rest_count:,} params ({rest_count/1e6:.2f}M)")
    print(f"   LR ratio: 1:{int(1/early_lr_factor)}")

    return [
        {'params': early_params, 'lr': early_lr, 'name': 'early_layers'},
        {'params': rest_params, 'lr': base_lr, 'name': 'later_layers'}
    ]


model = load_pretrained_model(
    pretrained_path=TrainingConfig.PRETRAINED_PATH,
    model_size=TrainingConfig.MODEL_SIZE,
    in_channels=TrainingConfig.IN_CHANNELS,
    num_classes=TrainingConfig.NUM_CLASSES,
    kernel_size=TrainingConfig.KERNEL_SIZE,
)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"\nDevice: {device}")

if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs with DataParallel")
    model = nn.DataParallel(model)

model = model.to(device)
print("Model ready for fine-tuning with differential learning rates.")

üì• Loading pretrained weights from: /kaggle/input/mednext/pytorch/default/1/best_model.pt
‚ö†Ô∏è torch.load default failed: Weights only load failed. This file can still be loaded, to do so you have two options, [1mdo those steps only if you trust the source of the checkpoint[0m. 
	(1) In PyTorch 2.6, we changed the default value of the `weights_only` argument in `torch.load` from `False` to `True`. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
	(2) Alternatively, to load with `weights_only=True` please check the recommended steps in the following error message.
	WeightsUnpickler error: Unsupported global: GLOBAL numpy._core.multiarray.scalar was not an allowed global by default. Please use `torch.serialization.add_safe_globals([numpy._core.multiarray.scalar])` or the `torch.serialization.safe_globals([numpy._core.multiarray.scalar])` context manage

In [None]:
_temp_dataset = NestedFolderDataset(
    data_dir=TrainingConfig.TRAIN_DIR,
    patch_size=TrainingConfig.PATCH_SIZE,
    samples_per_volume=1,
    augment=False
)

all_subjects = _temp_dataset.subject_ids.copy()
random.seed(42)
random.shuffle(all_subjects)

split_idx = int(len(all_subjects) * (1 - TrainingConfig.VAL_SPLIT))
train_subjects = all_subjects[:split_idx]
val_subjects = all_subjects[split_idx:]

print("\nTrain/Validation Split:")
print(f"   Total subjects: {len(all_subjects)}")
print(f"   Training subjects: {len(train_subjects)} ({100*(1-TrainingConfig.VAL_SPLIT):.0f}%)")
print(f"   Validation subjects: {len(val_subjects)} ({100*TrainingConfig.VAL_SPLIT:.0f}%)")

train_dataset = NestedFolderDataset(
    data_dir=TrainingConfig.TRAIN_DIR,
    subject_ids=train_subjects,
    patch_size=TrainingConfig.PATCH_SIZE,
    samples_per_volume=TrainingConfig.SAMPLES_PER_VOLUME,
    augment=True
)

val_dataset = NestedFolderDataset(
    data_dir=TrainingConfig.TRAIN_DIR,
    subject_ids=val_subjects,
    patch_size=TrainingConfig.PATCH_SIZE,
    samples_per_volume=1,
    augment=False
)

print("\nDataset Statistics:")
print(f"   Training samples per epoch: {len(train_dataset)}")
print(f"   Validation samples per epoch: {len(val_dataset)}")
print(f"   Patch size: {TrainingConfig.PATCH_SIZE}")
print(f"   Batch size: {TrainingConfig.BATCH_SIZE}")

train_loader = DataLoader(
    train_dataset,
    batch_size=TrainingConfig.BATCH_SIZE,
    shuffle=True,
    num_workers=TrainingConfig.NUM_WORKERS,
    pin_memory=True,
    drop_last=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=TrainingConfig.BATCH_SIZE,
    shuffle=False,
    num_workers=TrainingConfig.NUM_WORKERS,
    pin_memory=True,
    drop_last=False
)

print(f"   Training batches per epoch: {len(train_loader)}")
print(f"   Validation batches per epoch: {len(val_loader)}")

if len(train_dataset) > 0:
    print("\nTesting dataset loading...")
    sample = train_dataset[0]
    print(f"   Volume shape: {sample['volume'].shape}")
    print(f"   Segmentation shape: {sample['segmentation'].shape}")
    print(f"   Subject ID: {sample['subject_id']}")

    seg = sample['segmentation']
    unique, counts = torch.unique(seg, return_counts=True)
    print("   Label distribution:")
    for u, c in zip(unique.tolist(), counts.tolist()):
        pct = 100 * c / seg.numel()
        print(f"      Class {u}: {c:,} voxels ({pct:.2f}%)")

    print("Dataset loading test passed.")
else:
    print("No subjects found in training directory.")

del _temp_dataset

Found 917 subjects in dataset
Sample subjects: ['BraTS2021_00000', 'BraTS2021_00002', 'BraTS2021_00003']

üìä Train/Validation Split:
   Total subjects: 917
   Training subjects: 779 (85%)
   Validation subjects: 138 (15%)
Found 779 subjects in dataset
Sample subjects: ['BraTS2021_00044', 'BraTS2021_01061', 'BraTS2021_01328']
Found 138 subjects in dataset
Sample subjects: ['BraTS2021_01308', 'BraTS2021_01178', 'BraTS2021_00352']

üìä Dataset Statistics:
   Training samples per epoch: 3116
   Validation samples per epoch: 138
   Patch size: (128, 128, 128)
   Batch size: 1
   Training batches per epoch: 3116
   Validation batches per epoch: 138

üß™ Testing dataset loading...
   Volume shape: torch.Size([4, 128, 128, 128])
   Segmentation shape: torch.Size([128, 128, 128])
   Subject ID: BraTS2021_00044
   Label distribution:
      Class 0: 2,080,516 voxels (99.21%)
      Class 1: 601 voxels (0.03%)
      Class 2: 11,200 voxels (0.53%)
      Class 3: 4,835 voxels (0.23%)
‚úÖ Dataset 

In [None]:
param_groups = get_parameter_groups(
    model,
    base_lr=TrainingConfig.LEARNING_RATE,
    early_lr_factor=EARLY_LR_FACTOR
)

optimizer = optim.AdamW(
    param_groups,
    weight_decay=TrainingConfig.WEIGHT_DECAY
)

def get_lr_with_warmup(epoch: int, batch_idx: int, total_batches: int) -> float:
    """Calculate learning rate multiplier with linear warmup then cosine decay."""
    warmup_epochs = TrainingConfig.WARMUP_EPOCHS

    current_step = epoch * total_batches + batch_idx
    warmup_steps = warmup_epochs * total_batches

    if current_step < warmup_steps:
        return current_step / warmup_steps
    else:
        progress = (epoch - warmup_epochs) / max(1, TrainingConfig.NUM_EPOCHS - warmup_epochs)
        return (1 + math.cos(math.pi * progress)) / 2

def set_lr_with_differential(optimizer, lr_multiplier: float, base_lr: float, early_lr_factor: float):
    """Set learning rate for parameter groups, maintaining the differential ratio."""
    for param_group in optimizer.param_groups:
        if param_group.get('name') == 'early_layers':
            param_group['lr'] = base_lr * early_lr_factor * lr_multiplier
        else:
            param_group['lr'] = base_lr * lr_multiplier

criterion = CombinedLoss(
    dice_weight=TrainingConfig.DICE_WEIGHT,
    ce_weight=TrainingConfig.CE_WEIGHT
)

scaler = GradScaler()

print("Optimizer configured with differential learning rates.")
print(f"   Later layers LR: {TrainingConfig.LEARNING_RATE:.2e}")
print(f"   Early layers LR: {TrainingConfig.LEARNING_RATE * EARLY_LR_FACTOR:.2e}")
print(f"   Warmup epochs: {TrainingConfig.WARMUP_EPOCHS}")
print(f"   Total epochs: {TrainingConfig.NUM_EPOCHS}")
print(f"   Weight decay: {TrainingConfig.WEIGHT_DECAY}")


üéØ Differential Learning Rate Setup:
   Early layers (LR=1.00e-06): 34 tensors, 63,968 params (0.06M)
   Later layers (LR=1.00e-04): 195 tensors, 10,462,501 params (10.46M)
   LR ratio: 1:100 (later layers train 100x faster)

‚úÖ Optimizer configured with Differential Learning Rates!
   Later layers LR: 1.00e-04
   Early layers LR: 1.00e-06 (100x smaller)
   Warmup epochs: 2
   Total epochs: 10
   Weight decay: 1e-05


  scaler = GradScaler()


In [None]:
def train_one_epoch(
    model: nn.Module,
    train_loader: DataLoader,
    criterion: nn.Module,
    optimizer: optim.Optimizer,
    scaler: GradScaler,
    epoch: int,
    device: torch.device
) -> dict:
    """Train for one epoch with differential learning rates."""
    model.train()

    epoch_metrics = {
        'loss': [],
        'ce': [],
        'dice': [],
        'dice_wt': [],
        'dice_tc': [],
        'dice_et': []
    }

    total_batches = len(train_loader)
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{TrainingConfig.NUM_EPOCHS} [Train]")

    for batch_idx, batch in enumerate(pbar):
        lr_multiplier = get_lr_with_warmup(epoch, batch_idx, total_batches)
        set_lr_with_differential(
            optimizer,
            lr_multiplier,
            TrainingConfig.LEARNING_RATE,
            EARLY_LR_FACTOR
        )

        current_lr = TrainingConfig.LEARNING_RATE * lr_multiplier

        volumes = batch['volume'].to(device)
        targets = batch['segmentation'].to(device)

        optimizer.zero_grad()

        with autocast():
            outputs = model(volumes)
            losses = criterion(outputs, targets)
            loss = losses['total']

        scaler.scale(loss).backward()

        scaler.unscale_(optimizer)
        grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        scaler.step(optimizer)
        scaler.update()

        epoch_metrics['loss'].append(loss.item())
        epoch_metrics['ce'].append(losses['ce'])
        epoch_metrics['dice'].append(losses['dice'])
        epoch_metrics['dice_wt'].append(losses['dice_wt'])
        epoch_metrics['dice_tc'].append(losses['dice_tc'])
        epoch_metrics['dice_et'].append(losses['dice_et'])

        pbar.set_postfix({
            'loss': f"{loss.item():.4f}",
            'dice_wt': f"{losses['dice_wt']:.4f}",
            'lr': f"{current_lr:.2e}",
            'grad_norm': f"{grad_norm:.2f}"
        })

    avg_metrics = {k: np.mean(v) for k, v in epoch_metrics.items()}
    return avg_metrics

@torch.no_grad()
def validate_one_epoch(
    model: nn.Module,
    val_loader: DataLoader,
    criterion: nn.Module,
    device: torch.device,
    epoch: int
) -> dict:
    """Validate for one epoch (no gradient computation)."""
    model.eval()

    val_metrics = {
        'loss': [],
        'ce': [],
        'dice': [],
        'dice_wt': [],
        'dice_tc': [],
        'dice_et': []
    }

    pbar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{TrainingConfig.NUM_EPOCHS} [Val]")

    for batch in pbar:
        volumes = batch['volume'].to(device)
        targets = batch['segmentation'].to(device)

        with autocast():
            outputs = model(volumes)
            losses = criterion(outputs, targets)

        val_metrics['loss'].append(losses['total'].item())
        val_metrics['ce'].append(losses['ce'])
        val_metrics['dice'].append(losses['dice'])
        val_metrics['dice_wt'].append(losses['dice_wt'])
        val_metrics['dice_tc'].append(losses['dice_tc'])
        val_metrics['dice_et'].append(losses['dice_et'])

        pbar.set_postfix({
            'loss': f"{losses['total'].item():.4f}",
            'dice_wt': f"{losses['dice_wt']:.4f}",
            'dice_et': f"{losses['dice_et']:.4f}"
        })

    avg_metrics = {k: np.mean(v) for k, v in val_metrics.items()}
    return avg_metrics

def compute_dice_score(pred: torch.Tensor, target: torch.Tensor, smooth: float = 1e-5) -> float:
    """Compute Dice score between prediction and target."""
    pred_flat = pred.flatten()
    target_flat = target.flatten()
    intersection = (pred_flat * target_flat).sum()
    return float((2 * intersection + smooth) / (pred_flat.sum() + target_flat.sum() + smooth))

print("Training and validation loops defined.")

‚úÖ Training & Validation loops defined with Differential Learning Rates!


In [None]:
RESUME_CHECKPOINT = None

def load_checkpoint_for_resume(checkpoint_path: str, model, optimizer):
    """Load checkpoint to resume training after a timeout."""
    print(f"\nResuming from checkpoint: {checkpoint_path}")

    checkpoint = torch.load(checkpoint_path, map_location='cpu')

    state_dict = checkpoint.get('state_dict', checkpoint)
    if isinstance(model, nn.DataParallel):
        model.module.load_state_dict(state_dict)
    else:
        model.load_state_dict(state_dict)

    if 'optimizer_state_dict' in checkpoint:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        print("Optimizer state restored")

    start_epoch = checkpoint.get('epoch', 0)
    best_val_dice = checkpoint.get('best_val_dice', checkpoint.get('best_dice', 0.0))
    history = checkpoint.get('history', {
        'epoch': [],
        'train_loss': [], 'train_dice_wt': [], 'train_dice_tc': [], 'train_dice_et': [],
        'val_loss': [], 'val_dice_wt': [], 'val_dice_tc': [], 'val_dice_et': [],
        'lr': []
    })

    print(f"Resuming from epoch {start_epoch + 1}")
    print(f"Best validation Dice so far: {best_val_dice:.4f}")

    return start_epoch, best_val_dice, history


def main_training(resume_from: str = None):
    """Main training loop with validation-based model selection and checkpointing."""
    global model, optimizer

    print("=" * 60)
    print("STARTING FINE-TUNING WITH VALIDATION")
    print("=" * 60)
    print(f"\nTraining data: {TrainingConfig.TRAIN_DIR}")
    print(f"Pretrained weights: {TrainingConfig.PRETRAINED_PATH}")
    print(f"Output path: {TrainingConfig.OUTPUT_MODEL_PATH}")
    print(f"\nTraining Configuration:")
    print(f"   Epochs: {TrainingConfig.NUM_EPOCHS}")
    print(f"   Warmup epochs: {TrainingConfig.WARMUP_EPOCHS}")
    print(f"   Batch size: {TrainingConfig.BATCH_SIZE} (effective: {TrainingConfig.BATCH_SIZE * max(1, torch.cuda.device_count())})")
    print(f"   Later layers LR: {TrainingConfig.LEARNING_RATE:.2e}")
    print(f"   Early layers LR: {TrainingConfig.LEARNING_RATE * EARLY_LR_FACTOR:.2e} (100x smaller)")
    print("   Strategy: Differential Learning Rates (all layers trainable)")
    print(f"\nData Split:")
    print(f"   Training subjects: {len(train_dataset.subject_ids)}")
    print(f"   Validation subjects: {len(val_dataset.subject_ids)}")
    print(f"   Validation split: {TrainingConfig.VAL_SPLIT*100:.0f}%")
    print("=" * 60)

    start_epoch = 0
    best_val_dice = 0.0
    best_epoch = 0

    history = {
        'epoch': [],
        'train_loss': [],
        'train_dice_wt': [],
        'train_dice_tc': [],
        'train_dice_et': [],
        'val_loss': [],
        'val_dice_wt': [],
        'val_dice_tc': [],
        'val_dice_et': [],
        'lr': []
    }

    if resume_from and os.path.exists(resume_from):
        start_epoch, best_val_dice, history = load_checkpoint_for_resume(
            resume_from, model, optimizer
        )
        best_epoch = start_epoch
    elif resume_from:
        print(f"Resume checkpoint not found: {resume_from}")
        print("Starting from scratch...")

    for epoch in range(start_epoch, TrainingConfig.NUM_EPOCHS):
        print(f"\n{'='*60}")
        print(f"Epoch {epoch + 1}/{TrainingConfig.NUM_EPOCHS}")
        print(f"{'='*60}")

        train_metrics = train_one_epoch(
            model=model,
            train_loader=train_loader,
            criterion=criterion,
            optimizer=optimizer,
            scaler=scaler,
            epoch=epoch,
            device=device
        )

        val_metrics = validate_one_epoch(
            model=model,
            val_loader=val_loader,
            criterion=criterion,
            device=device,
            epoch=epoch
        )

        current_lr = get_lr_with_warmup(epoch, len(train_loader) - 1, len(train_loader))

        history['epoch'].append(epoch + 1)
        history['train_loss'].append(train_metrics['loss'])
        history['train_dice_wt'].append(train_metrics['dice_wt'])
        history['train_dice_tc'].append(train_metrics['dice_tc'])
        history['train_dice_et'].append(train_metrics['dice_et'])
        history['val_loss'].append(val_metrics['loss'])
        history['val_dice_wt'].append(val_metrics['dice_wt'])
        history['val_dice_tc'].append(val_metrics['dice_tc'])
        history['val_dice_et'].append(val_metrics['dice_et'])
        history['lr'].append(current_lr)

        val_avg_dice = (val_metrics['dice_wt'] + val_metrics['dice_tc'] + val_metrics['dice_et']) / 3
        train_avg_dice = (train_metrics['dice_wt'] + train_metrics['dice_tc'] + train_metrics['dice_et']) / 3

        print(f"\nEpoch {epoch + 1} Summary:")
        print(f"   {'Metric':<12} {'Train':>10} {'Val':>10}")
        print(f"   {'-'*34}")
        print(f"   {'Loss':<12} {train_metrics['loss']:>10.4f} {val_metrics['loss']:>10.4f}")
        print(f"   {'Dice WT':<12} {train_metrics['dice_wt']:>10.4f} {val_metrics['dice_wt']:>10.4f}")
        print(f"   {'Dice TC':<12} {train_metrics['dice_tc']:>10.4f} {val_metrics['dice_tc']:>10.4f}")
        print(f"   {'Dice ET':<12} {train_metrics['dice_et']:>10.4f} {val_metrics['dice_et']:>10.4f}")
        print(f"   {'Avg Dice':<12} {train_avg_dice:>10.4f} {val_avg_dice:>10.4f}")
        print(f"   LR: {current_lr:.2e}")

        if isinstance(model, nn.DataParallel):
            state_dict = model.module.state_dict()
        else:
            state_dict = model.state_dict()

        if val_avg_dice > best_val_dice:
            best_val_dice = val_avg_dice
            best_epoch = epoch + 1

            checkpoint = {
                'epoch': epoch + 1,
                'state_dict': state_dict,
                'optimizer_state_dict': optimizer.state_dict(),
                'scaler_state_dict': scaler.state_dict(),
                'best_val_dice': best_val_dice,
                'history': history,
                'config': {
                    'model_size': TrainingConfig.MODEL_SIZE,
                    'in_channels': TrainingConfig.IN_CHANNELS,
                    'num_classes': TrainingConfig.NUM_CLASSES,
                    'kernel_size': TrainingConfig.KERNEL_SIZE,
                },
            }

            torch.save(checkpoint, TrainingConfig.OUTPUT_MODEL_PATH)
            print(f"New best model saved (Val Dice: {best_val_dice:.4f})")
        else:
            print(f"Val Dice did not improve (best: {best_val_dice:.4f} at epoch {best_epoch})")

        checkpoint_path = TrainingConfig.OUTPUT_MODEL_PATH.replace('.pt', f'_epoch{epoch+1}.pt')
        epoch_checkpoint = {
            'epoch': epoch + 1,
            'state_dict': state_dict,
            'optimizer_state_dict': optimizer.state_dict(),
            'scaler_state_dict': scaler.state_dict(),
            'best_val_dice': best_val_dice,
            'history': history,
        }
        torch.save(epoch_checkpoint, checkpoint_path)
        print(f"Epoch checkpoint saved: {checkpoint_path}")

    print("\n" + "=" * 60)
    print("FINE-TUNING COMPLETE")
    print("=" * 60)
    print(f"   Best Epoch: {best_epoch}")
    print(f"   Best Validation Dice: {best_val_dice:.4f}")
    print(f"   Model saved to: {TrainingConfig.OUTPUT_MODEL_PATH}")

    import json
    history_path = TrainingConfig.OUTPUT_MODEL_PATH.replace('.pt', '_history.json')
    with open(history_path, 'w') as f:
        json.dump(history, f, indent=2)
    print(f"   History saved to: {history_path}")

    return history


history = main_training(resume_from=RESUME_CHECKPOINT)

üöÄ STARTING FINE-TUNING WITH VALIDATION

üìÅ Training data: /kaggle/input/instant-odc-ai-hackathon/Train
üìÅ Pretrained weights: /kaggle/input/mednext/pytorch/default/1/best_model.pt
üìÅ Output path: /kaggle/working/best_finetuned_model.pt

‚öôÔ∏è Training Configuration:
   Epochs: 10
   Warmup epochs: 2
   Batch size: 1 (effective: 2)
   Later layers LR: 1.00e-04
   Early layers LR: 1.00e-06 (100x smaller)
   Strategy: Differential Learning Rates (all layers trainable)

üìä Data Split:
   Training subjects: 779
   Validation subjects: 138
   Validation split: 15%

üìÖ Epoch 1/10


  with autocast():
Epoch 1/10 [Train]:   1%|          | 30/3116 [00:33<56:58,  1.11s/it, loss=0.0579, dice_wt=0.9763, lr=4.65e-07, grad_norm=0.31] 


KeyboardInterrupt: 

In [None]:
def plot_training_history(history: dict):
    """Plot training and validation metrics over epochs."""

    fig, axes = plt.subplots(2, 2, figsize=(14, 10))

    ax1 = axes[0, 0]
    ax1.plot(history['epoch'], history['train_loss'], 'b-', label='Train Loss', linewidth=2)
    ax1.plot(history['epoch'], history['val_loss'], 'r-', label='Val Loss', linewidth=2)
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.set_title('Training vs Validation Loss')
    ax1.legend()
    ax1.grid(True, alpha=0.3)

    ax2 = axes[0, 1]
    ax2.plot(history['epoch'], history['train_dice_wt'], 'r-', label='Train WT', linewidth=1.5, alpha=0.7)
    ax2.plot(history['epoch'], history['train_dice_tc'], 'g-', label='Train TC', linewidth=1.5, alpha=0.7)
    ax2.plot(history['epoch'], history['train_dice_et'], 'b-', label='Train ET', linewidth=1.5, alpha=0.7)
    ax2.plot(history['epoch'], history['val_dice_wt'], 'r--o', label='Val WT', linewidth=2, markersize=5)
    ax2.plot(history['epoch'], history['val_dice_tc'], 'g--s', label='Val TC', linewidth=2, markersize=5)
    ax2.plot(history['epoch'], history['val_dice_et'], 'b--^', label='Val ET', linewidth=2, markersize=5)

    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Dice Score')
    ax2.set_title('Region Dice Scores (Train vs Val)')
    ax2.legend(loc='lower right', fontsize=8)
    ax2.grid(True, alpha=0.3)
    ax2.set_ylim([0, 1])

    ax3 = axes[1, 0]
    train_avg_dice = [(wt + tc + et) / 3 for wt, tc, et in
                      zip(history['train_dice_wt'], history['train_dice_tc'], history['train_dice_et'])]
    val_avg_dice = [(wt + tc + et) / 3 for wt, tc, et in
                    zip(history['val_dice_wt'], history['val_dice_tc'], history['val_dice_et'])]

    ax3.plot(history['epoch'], train_avg_dice, 'b-o', label='Train Avg Dice', linewidth=2, markersize=5)
    ax3.plot(history['epoch'], val_avg_dice, 'r-s', label='Val Avg Dice', linewidth=2, markersize=5)

    best_val_idx = np.argmax(val_avg_dice)
    best_val_epoch = history['epoch'][best_val_idx]
    best_val_score = val_avg_dice[best_val_idx]
    ax3.axvline(x=best_val_epoch, color='green', linestyle='--', alpha=0.7, label=f'Best Val (Epoch {best_val_epoch})')
    ax3.scatter([best_val_epoch], [best_val_score], s=200, c='green', marker='*', zorder=5)

    ax3.set_xlabel('Epoch')
    ax3.set_ylabel('Average Dice Score')
    ax3.set_title('Average Dice Score (Train vs Val)')
    ax3.legend()
    ax3.grid(True, alpha=0.3)
    ax3.set_ylim([0, 1])

    ax3.fill_between(history['epoch'], train_avg_dice, val_avg_dice, alpha=0.2, color='gray')

    ax4 = axes[1, 1]
    ax4.axis('off')

    best_epoch = history['epoch'][best_val_idx]

    summary_text = f"""

    Model saved to:
    {TrainingConfig.OUTPUT_MODEL_PATH}
    """

    ax4.text(0.1, 0.5, summary_text, transform=ax4.transAxes, fontsize=11,
             verticalalignment='center', fontfamily='monospace',
             bbox=dict(boxstyle='round', facecolor='lightgray', alpha=0.5))

    plt.tight_layout()
    plt.savefig('training_history.png', dpi=150, bbox_inches='tight')
    plt.show()

    print("Training history plot saved to: training_history.png")

if 'history' in dir() and history is not None:
    plot_training_history(history)

In [None]:
def load_finetuned_model(checkpoint_path: str, device: torch.device):
    """Load the fine-tuned model for inference."""

    print(f"Loading fine-tuned model from: {checkpoint_path}")

    model = create_mednext_v1(
        num_input_channels=TrainingConfig.IN_CHANNELS,
        num_classes=TrainingConfig.NUM_CLASSES,
        model_id=TrainingConfig.MODEL_SIZE,
        kernel_size=TrainingConfig.KERNEL_SIZE,
        deep_supervision=False
    )

    checkpoint = torch.load(checkpoint_path, map_location='cpu')

    if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
        state_dict = checkpoint['state_dict']
        print(f"Loaded from epoch {checkpoint.get('epoch', 'unknown')}")
        print(f"Best Dice: {checkpoint.get('best_dice', 'unknown')}")
    else:
        state_dict = checkpoint

    new_state_dict = {}
    for k, v in state_dict.items():
        if k.startswith('module.'):
            new_state_dict[k[7:]] = v
        else:
            new_state_dict[k] = v

    model.load_state_dict(new_state_dict)
    model = model.to(device)
    model.eval()

    print("Fine-tuned model loaded and ready for inference.")
    return model

# Example:
# finetuned_model = load_finetuned_model(TrainingConfig.OUTPUT_MODEL_PATH, device)

print("Inference loading function defined.")
print("To load the fine-tuned model for inference, run:")
print(f"  finetuned_model = load_finetuned_model('{TrainingConfig.OUTPUT_MODEL_PATH}', device)")