In [1]:
import torch

print(f'PyTorch version: {torch.__version__}')
print(f'CUDA available: {torch.cuda.is_available()}')
if torch.cuda.is_available():
    print(f'CUDA version: {torch.version.cuda}')
    print(f'GPU device: {torch.cuda.get_device_name(0)}')
    print(f'Number of GPUs: {torch.cuda.device_count()}')
else:
    print('CUDA not available - GPU support not working')


PyTorch version: 2.5.1+cu121
CUDA available: True
CUDA version: 12.1
GPU device: NVIDIA GeForce RTX 3060 Laptop GPU
Number of GPUs: 1


# Tumor Detector (Completed)

Complete CNN/FasterRCNN/YOLOv8 training + validation notebook.


## 1. Environment
Install deps (PyTorch/torchvision, YAML, Pillow, optional ultralytics for YOLOv8).


In [2]:
# %pip install -q torch torchvision pyyaml pillow
%pip install -q pyyaml pillow
# Uncomment for YOLOv8
%pip install -q ultralytics


Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


## 2. Paths and config
Set dataset locations and training hyperparameters. Adjust to your Task01-processed PNGs and JSONs.


In [3]:
import sys
from pathlib import Path
import yaml
import torch

# Add project root so `models` imports work
PROJECT_ROOT = Path.cwd().resolve()
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

# Update these to your data
image_root = Path("data/processed")  # folder with PNG/JPG slices
train_annotations = Path("data/annotations/train.json")
val_annotations = Path("data/annotations/val.json")

default_labels = ["No Tumor", "Glioma", "Meningioma", "Pituitary"]

config = {
    "experiment_name": "detector_complete",
    "task": "detection",  # 'classification' or 'detection'
    "model": {
        "name": "fasterrcnn_resnet50_fpn",  # or resnet18 for classification
        "pretrained": True,
    },
    "data": {
        "image_root": str(image_root),
        "train_annotations": str(train_annotations),
        "val_annotations": str(val_annotations),
        "labels": default_labels,
    },
    "train": {
        "epochs": 5,
        "batch_size": 2,
        "lr": 2e-4,
        "weight_decay": 5e-4,
        "num_workers": 2,
        "grad_clip": 1.0,
        "checkpoint_dir": "runs/detector",
        "save_every": 1,
    },
}

# Save a copy of the config for reference
Path("configs").mkdir(parents=True, exist_ok=True)
with open("configs/detector_completed.yaml", "w", encoding="utf-8") as f:
    yaml.safe_dump(config, f)

config



{'experiment_name': 'detector_complete',
 'task': 'detection',
 'model': {'name': 'fasterrcnn_resnet50_fpn', 'pretrained': True},
 'data': {'image_root': 'data\\processed',
  'train_annotations': 'data\\annotations\\train.json',
  'val_annotations': 'data\\annotations\\val.json',
  'labels': ['No Tumor', 'Glioma', 'Meningioma', 'Pituitary']},
 'train': {'epochs': 5,
  'batch_size': 2,
  'lr': 0.0002,
  'weight_decay': 0.0005,
  'num_workers': 2,
  'grad_clip': 1.0,
  'checkpoint_dir': 'runs/detector',
  'save_every': 1}}

In [4]:
# Ensure repo root (with `models/`) is on sys.path
import sys
from pathlib import Path

# Try current directory, then walk up to find 'models'
root = Path.cwd().resolve()
if not (root / "models").exists():
    for parent in root.parents:
        if (parent / "models").exists():
            root = parent
            break

if str(root) not in sys.path:
    sys.path.insert(0, str(root))

# Override data paths to be absolute from project root
image_root = root / "data/processed"
train_annotations = root / "data/annotations/train.json"
val_annotations = root / "data/annotations/val.json"
config["data"].update(
    {
        "image_root": str(image_root),
        "train_annotations": str(train_annotations),
        "val_annotations": str(val_annotations),
    }
)

print("Using project root for imports:", root)
print("models exists:", (root / "models").exists())
print("Image root:", image_root)
print("Train annotations:", train_annotations)
print("Val annotations:", val_annotations)


Using project root for imports: C:\Users\FERRA\OneDrive\Documents\github\BT-pipe
models exists: True
Image root: C:\Users\FERRA\OneDrive\Documents\github\BT-pipe\data\processed
Train annotations: C:\Users\FERRA\OneDrive\Documents\github\BT-pipe\data\annotations\train.json
Val annotations: C:\Users\FERRA\OneDrive\Documents\github\BT-pipe\data\annotations\val.json


## Task01 converter (NIfTI â†’ PNG + JSON)
Run once to generate `data/processed/*.png` and `data/annotations/train.json` / `val.json` from BRATS Task01 volumes.


In [5]:
%pip install -q nibabel scikit-learn


Note: you may need to restart the kernel to use updated packages.


In [13]:
import json
import numpy as np
import nibabel as nib
from pathlib import Path
from PIL import Image
from sklearn.model_selection import train_test_split

# Paths to your Task01 BRATS data (use project root)
img_dir = root / "Data/Task01_BrainTumour/imagesTr"
lbl_dir = root / "Data/Task01_BrainTumour/labelsTr"

# Output locations for slices and annotations (absolute)
out_images = root / "data/processed"
out_ann = root / "data/annotations"
out_images.mkdir(parents=True, exist_ok=True)
out_ann.mkdir(parents=True, exist_ok=True)

# Label list for config; using single-class "Glioma" here
label_name = "Glioma"

# Quick sanity checks - filter out macOS resource fork files
img_files = [f for f in sorted(img_dir.glob("*.nii.gz")) if not f.name.startswith("._")]
lbl_files = [f for f in sorted(lbl_dir.glob("*.nii.gz")) if not f.name.startswith("._")]
# Also check for .nii files (without .gz)
img_files_nii = [f for f in sorted(img_dir.glob("*.nii")) if not f.name.startswith("._") and not f.name.endswith(".gz")]
lbl_files_nii = [f for f in sorted(lbl_dir.glob("*.nii")) if not f.name.startswith("._") and not f.name.endswith(".gz")]
img_files.extend(img_files_nii)
lbl_files.extend(lbl_files_nii)

print("Found images:", len(img_files), "labels:", len(lbl_files))
if not img_files:
    raise RuntimeError("No image volumes found at " + str(img_dir))
if not lbl_files:
    raise RuntimeError("No label volumes found at " + str(lbl_dir))

def get_case_name(path: Path) -> str:
    """Extract case name from BRATS file, handling both .nii and .nii.gz"""
    name = path.name
    if name.endswith(".nii.gz"):
        return name[:-7]  # Remove .nii.gz
    elif name.endswith(".nii"):
        return name[:-4]  # Remove .nii
    return path.stem

samples = []
for img_path in img_files:
    case = get_case_name(img_path)
    # Try both .nii.gz and .nii extensions
    lbl_path = lbl_dir / f"{case}.nii.gz"
    if not lbl_path.exists():
        lbl_path = lbl_dir / f"{case}.nii"
    if not lbl_path.exists():
        print(f"Skipping {case}, missing label file (tried {case}.nii.gz and {case}.nii)")
        continue
    img_vol = nib.load(img_path).get_fdata()
    lbl_vol = nib.load(lbl_path).get_fdata()

    # Use the first modality channel if 4D
    if img_vol.ndim == 4:
        img_vol = img_vol[..., 0]

    # Count non-empty slices for this case
    non_empty = int(np.sum(lbl_vol.max(axis=(0, 1)) > 0))
    if non_empty == 0:
        print(f"All-empty label volume: {case}")
        continue
    else:
        print(f"{case}: {non_empty} non-empty slices")

    for z in range(img_vol.shape[2]):
        lbl_slice = lbl_vol[..., z]
        if lbl_slice.max() < 1:  # skip empty slices
            continue
        ys, xs = np.where(lbl_slice > 0)
        x1, x2 = xs.min(), xs.max()
        y1, y2 = ys.min(), ys.max()

        # Validate bounding box: ensure it has positive width and height
        width = x2 - x1
        height = y2 - y1
        if width <= 0 or height <= 0:
            continue  # Skip invalid bounding boxes

        img_slice = img_vol[..., z]
        # normalize to 0-255
        arr = (img_slice - img_slice.min()) / (img_slice.max() - img_slice.min() + 1e-8)
        arr = (arr * 255).clip(0, 255).astype(np.uint8)

        png_name = f"{case}_z{z:03d}.png"
        Image.fromarray(arr).save(out_images / png_name)

        samples.append(
            {
                "image": png_name,
                "label": label_name,
                "bbox": [int(x1), int(y1), int(x2), int(y2)],
            }
        )

print("Total slices with tumors:", len(samples))

if len(samples) == 0:
    raise RuntimeError(
        "No labeled slices were found. Check that nibabel can read your label "
        "volumes and that `labelsTr` contains non-empty masks."
    )

# Train/val split
train, val = train_test_split(samples, test_size=0.2, random_state=42)
with open(out_ann / "train.json", "w", encoding="utf-8") as f:
    json.dump(train, f)
with open(out_ann / "val.json", "w", encoding="utf-8") as f:
    json.dump(val, f)

print("Wrote", len(train), "train and", len(val), "val samples")


Found images: 484 labels: 484
BRATS_001: 74 non-empty slices
BRATS_002: 61 non-empty slices
BRATS_003: 82 non-empty slices
BRATS_004: 73 non-empty slices
BRATS_005: 76 non-empty slices
BRATS_006: 84 non-empty slices
BRATS_007: 58 non-empty slices
BRATS_008: 76 non-empty slices
BRATS_009: 74 non-empty slices
BRATS_010: 41 non-empty slices
BRATS_011: 66 non-empty slices
BRATS_012: 44 non-empty slices
BRATS_013: 36 non-empty slices
BRATS_014: 72 non-empty slices
BRATS_015: 81 non-empty slices
BRATS_016: 57 non-empty slices
BRATS_017: 60 non-empty slices
BRATS_018: 76 non-empty slices
BRATS_019: 97 non-empty slices
BRATS_020: 80 non-empty slices
BRATS_021: 48 non-empty slices
BRATS_022: 80 non-empty slices
BRATS_023: 57 non-empty slices
BRATS_024: 41 non-empty slices
BRATS_025: 33 non-empty slices
BRATS_026: 56 non-empty slices
BRATS_027: 49 non-empty slices
BRATS_028: 53 non-empty slices
BRATS_029: 27 non-empty slices
BRATS_030: 44 non-empty slices
BRATS_031: 58 non-empty slices
BRATS_032

## 3. Data loaders
Uses the repo dataset utilities.


In [14]:
import json, os
from pathlib import Path
import nibabel as nib
import numpy as np
from PIL import Image
from sklearn.model_selection import train_test_split

# Use project root from cell 5 (ensure it's defined)
try:
    # root should be defined from cell 5
    pass
except NameError:
    # Fallback if cell 5 wasn't run
    root = Path.cwd().resolve()
    if not (root / "models").exists():
        for parent in root.parents:
            if (parent / "models").exists():
                root = parent
                break

img_dir = root / "Data/Task01_BrainTumour/imagesTr"
lbl_dir = root / "Data/Task01_BrainTumour/labelsTr"
out_images = root / "data/processed"
out_ann = root / "data/annotations"
out_images.mkdir(parents=True, exist_ok=True)
out_ann.mkdir(parents=True, exist_ok=True)

# Filter out macOS resource fork files and handle both .nii and .nii.gz
img_files = [f for f in sorted(img_dir.glob("*.nii.gz")) if not f.name.startswith("._")]
img_files_nii = [f for f in sorted(img_dir.glob("*.nii")) if not f.name.startswith("._") and not f.name.endswith(".gz")]
img_files.extend(img_files_nii)

# Get label files too
lbl_files = [f for f in sorted(lbl_dir.glob("*.nii.gz")) if not f.name.startswith("._")]
lbl_files_nii = [f for f in sorted(lbl_dir.glob("*.nii")) if not f.name.startswith("._") and not f.name.endswith(".gz")]
lbl_files.extend(lbl_files_nii)

print(f"Found {len(img_files)} image files and {len(lbl_files)} label files")
print(f"Image dir: {img_dir}")
print(f"Label dir: {lbl_dir}")
if img_files:
    print(f"First image file: {img_files[0].name}")
if lbl_files:
    print(f"First label file: {lbl_files[0].name}")

def get_case_name(path: Path) -> str:
    """Extract case name from BRATS file, handling both .nii and .nii.gz"""
    name = path.name
    if name.endswith(".nii.gz"):
        return name[:-7]  # Remove .nii.gz
    elif name.endswith(".nii"):
        return name[:-4]  # Remove .nii
    return path.stem

# Create a set of available label case names for faster lookup
available_label_cases = {get_case_name(f) for f in lbl_files}
print(f"Available label cases (first 10): {sorted(list(available_label_cases))[:10]}")

samples = []
matched = 0
skipped_no_label = 0
skipped_empty = 0

for img_path in img_files:
    case = get_case_name(img_path)
    # Try both .nii.gz and .nii extensions
    lbl_path = lbl_dir / f"{case}.nii.gz"
    if not lbl_path.exists():
        lbl_path = lbl_dir / f"{case}.nii"
    if not lbl_path.exists():
        skipped_no_label += 1
        if skipped_no_label <= 5:  # Show first 5 mismatches
            print(f"  No label found for {case} (tried {case}.nii.gz and {case}.nii)")
        continue
    matched += 1
    img_vol = nib.load(img_path).get_fdata()
    lbl_vol = nib.load(lbl_path).get_fdata()
    # use one modality (e.g., FLAIR channel 0); adjust if needed
    if img_vol.ndim == 4:
        img_vol = img_vol[..., 0]

    # Check if this volume has any non-empty slices
    has_tumors = lbl_vol.max() > 0
    if not has_tumors:
        skipped_empty += 1
        continue

    for z in range(img_vol.shape[2]):
        img_slice = img_vol[..., z]
        lbl_slice = lbl_vol[..., z]
        if lbl_slice.max() < 1:  # skip empty
            continue
        ys, xs = np.where(lbl_slice > 0)
        x1, x2 = xs.min(), xs.max()
        y1, y2 = ys.min(), ys.max()
        
        # Validate bounding box: ensure it has positive width and height
        width = x2 - x1
        height = y2 - y1
        if width <= 0 or height <= 0:
            continue  # Skip invalid bounding boxes
        
        # normalize to 0-255
        arr = img_slice
        arr = (arr - arr.min()) / (arr.max() - arr.min() + 1e-8)
        arr = (arr * 255).clip(0, 255).astype(np.uint8)
        png_name = f"{case}_z{z:03d}.png"
        Image.fromarray(arr).save(out_images / png_name)
        samples.append({
            "image": png_name,
            "label": "Glioma",           # set your class name
            "bbox": [int(x1), int(y1), int(x2), int(y2)],
        })

print(f"\nSummary:")
print(f"  Matched image-label pairs: {matched}")
print(f"  Skipped (no label file): {skipped_no_label}")
print(f"  Skipped (empty label volume): {skipped_empty}")
print(f"  Total slices with tumors: {len(samples)}")

if len(samples) == 0:
    print(f"\nDebugging info:")
    print(f"  Image files found: {len(img_files)}")
    print(f"  Label files found: {len(lbl_files)}")
    if img_files and lbl_files:
        img_case = get_case_name(img_files[0])
        lbl_case = get_case_name(lbl_files[0])
        print(f"  First image case: '{img_case}'")
        print(f"  First label case: '{lbl_case}'")
        print(f"  Do they match? {img_case == lbl_case}")
    raise RuntimeError(
        "No labeled slices were found. Check that nibabel can read your label "
        "volumes and that `labelsTr` contains non-empty masks."
    )

# split train/val
train, val = train_test_split(samples, test_size=0.2, random_state=42)
json.dump(train, open(out_ann / "train.json", "w"))
json.dump(val, open(out_ann / "val.json", "w"))
print("Wrote", len(train), "train and", len(val), "val samples")

Found 484 image files and 484 label files
Image dir: C:\Users\FERRA\OneDrive\Documents\github\BT-pipe\Data\Task01_BrainTumour\imagesTr
Label dir: C:\Users\FERRA\OneDrive\Documents\github\BT-pipe\Data\Task01_BrainTumour\labelsTr
First image file: BRATS_001.nii.gz
First label file: BRATS_001.nii.gz
Available label cases (first 10): ['BRATS_001', 'BRATS_002', 'BRATS_003', 'BRATS_004', 'BRATS_005', 'BRATS_006', 'BRATS_007', 'BRATS_008', 'BRATS_009', 'BRATS_010']

Summary:
  Matched image-label pairs: 484
  Skipped (no label file): 0
  Skipped (empty label volume): 0
  Total slices with tumors: 33438
Wrote 26750 train and 6688 val samples


In [17]:
from models.detector import TumorDataset, detection_collate_fn, build_model
from models.detector.dataset import _build_transforms
from torch.utils.data import DataLoader

# Resolve paths and verify that annotation files exist
from pathlib import Path

def _resolve_with_fallback(p: Path) -> Path:
    """Resolve a path; if missing, search parents for a matching relative path."""
    p = p.expanduser()
    if p.is_absolute():
        if p.exists():
            return p
        # If absolute but under notebooks/, try parent project
        candidates = [p, Path.cwd().resolve().parent / p.relative_to(p.anchor)] if "notebooks" in str(p) else [p]
    else:
        candidates = [Path.cwd() / p]
    # Walk up parents to find matching relative location
    rel = p.name if p.is_absolute() else p
    for base in [Path.cwd()] + list(Path.cwd().parents):
        candidate = base / rel
        candidates.append(candidate)
    for c in candidates:
        try:
            if c.exists():
                return c.resolve()
        except OSError:
            continue
    return p.resolve()

train_ann = _resolve_with_fallback(Path(config["data"]["train_annotations"]))
val_ann = _resolve_with_fallback(Path(config["data"]["val_annotations"]))
img_root = _resolve_with_fallback(Path(config["data"]["image_root"]))
config["data"].update(
    {
        "train_annotations": str(train_ann),
        "val_annotations": str(val_ann),
        "image_root": str(img_root),
    }
)

for p, label in [(train_ann, "train"), (val_ann, "val")]:
    if not p.exists():
        raise FileNotFoundError(
            f"Missing {label} annotations at {p}. Run the converter cell (Task01 -> PNG/JSON) to generate them."
        )
if not img_root.exists():
    raise FileNotFoundError(
        f"Image root not found at {img_root}. Ensure the converter wrote PNGs there."
    )

# Build datasets
train_ds = TumorDataset(
    annotation_path=str(train_ann),
    image_root=str(img_root),
    labels=config["data"]["labels"],
    task=config["task"],
    transforms=_build_transforms(config["task"], is_train=True),
)
val_ds = TumorDataset(
    annotation_path=str(val_ann),
    image_root=str(img_root),
    labels=config["data"]["labels"],
    task=config["task"],
    transforms=_build_transforms(config["task"], is_train=False),
)

collate_fn = detection_collate_fn if config["task"] == "detection" else None
train_loader = DataLoader(
    train_ds,
    batch_size=config["train"]["batch_size"],
    shuffle=True,
    num_workers=config["train"]["num_workers"],
    collate_fn=collate_fn,
)
val_loader = DataLoader(
    val_ds,
    batch_size=config["train"]["batch_size"],
    shuffle=False,
    num_workers=config["train"]["num_workers"],
    collate_fn=collate_fn,
)

print(f"Train samples: {len(train_ds)} | Val samples: {len(val_ds)}")
print("Sample[0] keys:", train_ds.samples[0].keys())

Train samples: 26750 | Val samples: 6688
Sample[0] keys: dict_keys(['image', 'label', 'bbox'])


## 4. Build model


In [18]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_classes = len(config["data"]["labels"])
model = build_model(
    task=config["task"],
    num_classes=num_classes,
    model_name=config["model"]["name"],
    pretrained=config["model"]["pretrained"],
).to(device)
print(f"Model: {config['model']['name']} | Task: {config['task']} | Device: {device}")


Model: fasterrcnn_resnet50_fpn | Task: detection | Device: cuda


## 5. Training + validation
Lightweight loop for both classification (acc/F1) and detection (loss proxy).


In [20]:
import torch.nn as nn
import torch.optim as optim
from torch.nn.utils import clip_grad_norm_
from sklearn.metrics import classification_report, confusion_matrix
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
from collections import defaultdict

# Set up matplotlib for inline plotting
%matplotlib inline
# Try different style options for compatibility
try:
    plt.style.use('seaborn-v0_8-darkgrid')
except OSError:
    try:
        plt.style.use('seaborn-darkgrid')
    except OSError:
        plt.style.use('ggplot')

def plot_training_curves(history, save_path=None):
    """Plot training curves for loss and accuracy/metrics"""
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    fig.suptitle('Training Progress', fontsize=16, fontweight='bold')
    
    epochs = range(1, len(history['train_loss']) + 1)
    
    # Plot 1: Loss curves
    ax1 = axes[0, 0]
    ax1.plot(epochs, history['train_loss'], 'b-', label='Train Loss', linewidth=2)
    if 'val_loss' in history:
        ax1.plot(epochs, history['val_loss'], 'r-', label='Val Loss', linewidth=2)
    ax1.set_xlabel('Epoch', fontsize=12)
    ax1.set_ylabel('Loss', fontsize=12)
    ax1.set_title('Loss Evolution', fontsize=14, fontweight='bold')
    ax1.legend(fontsize=11)
    ax1.grid(True, alpha=0.3)
    
    # Plot 2: Accuracy curves (for classification)
    ax2 = axes[0, 1]
    if 'train_acc' in history:
        ax2.plot(epochs, history['train_acc'], 'b-', label='Train Accuracy', linewidth=2)
        if 'val_acc' in history:
            ax2.plot(epochs, history['val_acc'], 'r-', label='Val Accuracy', linewidth=2)
        ax2.set_xlabel('Epoch', fontsize=12)
        ax2.set_ylabel('Accuracy', fontsize=12)
        ax2.set_title('Accuracy Evolution', fontsize=14, fontweight='bold')
        ax2.legend(fontsize=11)
        ax2.grid(True, alpha=0.3)
    else:
        ax2.text(0.5, 0.5, 'Accuracy metrics\nnot available\nfor detection task', 
                ha='center', va='center', fontsize=12, transform=ax2.transAxes)
        ax2.set_title('Accuracy Evolution', fontsize=14, fontweight='bold')
    
    # Plot 3: Detection loss components (for detection task)
    ax3 = axes[1, 0]
    if 'loss_components' in history and history['loss_components']:
        loss_comp = history['loss_components']
        for comp_name, values in loss_comp.items():
            if values:  # Check if list is not empty
                ax3.plot(epochs, values, label=comp_name.replace('loss_', '').title(), linewidth=2)
        ax3.set_xlabel('Epoch', fontsize=12)
        ax3.set_ylabel('Loss', fontsize=12)
        ax3.set_title('Detection Loss Components', fontsize=14, fontweight='bold')
        ax3.legend(fontsize=10)
        ax3.grid(True, alpha=0.3)
    else:
        ax3.text(0.5, 0.5, 'Loss components\nnot tracked', 
                ha='center', va='center', fontsize=12, transform=ax3.transAxes)
        ax3.set_title('Loss Components', fontsize=14, fontweight='bold')
    
    # Plot 4: Learning rate (if tracked)
    ax4 = axes[1, 1]
    if 'learning_rate' in history and history['learning_rate']:
        ax4.plot(epochs, history['learning_rate'], 'g-', linewidth=2)
        ax4.set_xlabel('Epoch', fontsize=12)
        ax4.set_ylabel('Learning Rate', fontsize=12)
        ax4.set_title('Learning Rate Schedule', fontsize=14, fontweight='bold')
        ax4.set_yscale('log')
        ax4.grid(True, alpha=0.3)
    else:
        ax4.text(0.5, 0.5, 'Learning rate\nnot tracked', 
                ha='center', va='center', fontsize=12, transform=ax4.transAxes)
        ax4.set_title('Learning Rate Schedule', fontsize=14, fontweight='bold')
    
    plt.tight_layout()
    
    if save_path:
        Path(save_path).parent.mkdir(parents=True, exist_ok=True)
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"Saved training curves to {save_path}")
    
    plt.show()


def train_and_validate(model, train_loader, val_loader, config, device):
    task = config['task']
    epochs = config['train']['epochs']
    lr = config['train']['lr']
    weight_decay = config['train']['weight_decay']
    grad_clip = config['train'].get('grad_clip', None)
    
    # Initialize history tracking
    history = {
        'train_loss': [],
        'val_loss': [],
        'train_acc': [],
        'val_acc': [],
        'loss_components': defaultdict(list),
        'learning_rate': []
    }
    
    # Create output directory for plots
    output_dir = Path(config['train'].get('checkpoint_dir', 'runs/detector'))
    output_dir.mkdir(parents=True, exist_ok=True)

    if task == 'classification':
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
        
        for epoch in range(1, epochs + 1):
            model.train()
            total_loss = correct = total = 0
            for images, labels in train_loader:
                images, labels = images.to(device), labels.to(device)
                optimizer.zero_grad()
                outputs = model(images)
                loss = criterion(outputs, labels)
                loss.backward()
                if grad_clip:
                    clip_grad_norm_(model.parameters(), grad_clip)
                optimizer.step()

                total_loss += loss.item() * images.size(0)
                preds = outputs.argmax(dim=1)
                correct += (preds == labels).sum().item()
                total += labels.size(0)

            train_loss = total_loss / total
            train_acc = correct / total
            history['train_loss'].append(train_loss)
            history['train_acc'].append(train_acc)
            history['learning_rate'].append(optimizer.param_groups[0]['lr'])

            # validation
            model.eval()
            v_loss = v_correct = v_total = 0
            y_true, y_pred = [], []
            with torch.no_grad():
                for images, labels in val_loader:
                    images, labels = images.to(device), labels.to(device)
                    outputs = model(images)
                    loss = criterion(outputs, labels)
                    v_loss += loss.item() * images.size(0)
                    preds = outputs.argmax(dim=1)
                    v_correct += (preds == labels).sum().item()
                    v_total += labels.size(0)
                    y_true.extend(labels.cpu().tolist())
                    y_pred.extend(preds.cpu().tolist())
            val_loss = v_loss / v_total
            val_acc = v_correct / v_total
            history['val_loss'].append(val_loss)
            history['val_acc'].append(val_acc)
            
            print(f"Epoch {epoch}/{epochs}: train_loss={train_loss:.4f} train_acc={train_acc:.4f} val_loss={val_loss:.4f} val_acc={val_acc:.4f}")
            
            # Update plots every epoch
            if epoch % 1 == 0:  # Update every epoch
                plot_training_curves(history, save_path=output_dir / f'training_curves_epoch_{epoch}.png')

        print("\n" + "="*60)
        print("Final Classification Report:")
        print("="*60)
        print(classification_report(y_true, y_pred, target_names=config['data']['labels']))
        print("\nConfusion Matrix:")
        print(confusion_matrix(y_true, y_pred))
        
        # Final plot
        plot_training_curves(history, save_path=output_dir / 'training_curves_final.png')

    else:  # detection
        params = [p for p in model.parameters() if p.requires_grad]
        optimizer = optim.AdamW(params, lr=lr, weight_decay=weight_decay)
        
        for epoch in range(1, epochs + 1):
            model.train()
            running_loss = 0.0
            epoch_loss_components = defaultdict(float)
            num_batches = 0
            
            for batch_idx, (images, targets) in enumerate(train_loader):
                images = [img.to(device) for img in images]
                targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
                optimizer.zero_grad()
                loss_dict = model(images, targets)
                loss = sum(loss for loss in loss_dict.values())
                loss.backward()
                if grad_clip:
                    clip_grad_norm_(model.parameters(), grad_clip)
                optimizer.step()
                
                running_loss += loss.item()
                for key, value in loss_dict.items():
                    epoch_loss_components[key] += value.item()
                num_batches += 1
                
                # Print progress every 100 batches
                if (batch_idx + 1) % 100 == 0:
                    print(f"  Batch {batch_idx + 1}/{len(train_loader)}, Loss: {loss.item():.4f}")

            avg_train_loss = running_loss / len(train_loader)
            history['train_loss'].append(avg_train_loss)
            
            # Store loss components
            for key, value in epoch_loss_components.items():
                history['loss_components'][key].append(value / num_batches)
            
            history['learning_rate'].append(optimizer.param_groups[0]['lr'])

            # validation
            val_loss = 0.0
            val_loss_components = defaultdict(float)
            val_batches = 0
            model.eval()
            with torch.no_grad():
                for images, targets in val_loader:
                    images = [img.to(device) for img in images]
                    targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
                    loss_dict = model(images, targets)
                    loss = sum(loss for loss in loss_dict.values())
                    val_loss += loss.item()
                    for key, value in loss_dict.items():
                        val_loss_components[key] += value.item()
                    val_batches += 1

            avg_val_loss = val_loss / len(val_loader)
            history['val_loss'].append(avg_val_loss)
            
            # Print detailed loss breakdown
            loss_str = ", ".join([f"{k}={v/val_batches:.4f}" for k, v in val_loss_components.items()])
            print(f"Epoch {epoch}/{epochs}: train_loss={avg_train_loss:.4f}, val_loss={avg_val_loss:.4f}")
            print(f"  Val loss components: {loss_str}")
            
            # Update plots every epoch
            if epoch % 1 == 0:  # Update every epoch
                plot_training_curves(history, save_path=output_dir / f'training_curves_epoch_{epoch}.png')
        
        # Final plot
        plot_training_curves(history, save_path=output_dir / 'training_curves_final.png')
        
        print("\n" + "="*60)
        print("Training completed!")
        print("="*60)
        print(f"Final train loss: {history['train_loss'][-1]:.4f}")
        print(f"Final val loss: {history['val_loss'][-1]:.4f}")
    
    return history

# Uncomment to run training once config and data are set
history = train_and_validate(model, train_loader, val_loader, config, device)



KeyboardInterrupt: 

## 5.1. Visualization Helpers

Helper functions to visualize training progress and model predictions.


In [None]:
def visualize_detections(model, dataset, config, device, num_samples=6, threshold=0.5):
    """Visualize model predictions on validation samples"""
    model.eval()
    
    # Get random samples from dataset
    indices = np.random.choice(len(dataset), min(num_samples, len(dataset)), replace=False)
    
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    fig.suptitle('Sample Detection Predictions', fontsize=16, fontweight='bold')
    axes = axes.flatten()
    
    with torch.no_grad():
        for idx, ax in zip(indices, axes):
            image, target = dataset[idx]
            
            # Get prediction
            img_tensor = image.unsqueeze(0).to(device)
            outputs = model([img_tensor])[0]
            
            # Convert image tensor to numpy for display
            img_np = image.permute(1, 2, 0).cpu().numpy()
            img_np = np.clip(img_np, 0, 1)
            
            ax.imshow(img_np, cmap='gray' if img_np.shape[2] == 1 else None)
            ax.set_title(f'Sample {idx}', fontsize=12, fontweight='bold')
            ax.axis('off')
            
            # Draw ground truth boxes
            gt_boxes = target['boxes'].cpu().numpy()
            for box in gt_boxes:
                x1, y1, x2, y2 = box
                rect = plt.Rectangle((x1, y1), x2-x1, y2-y1, 
                                    fill=False, color='green', linewidth=2, label='Ground Truth')
                ax.add_patch(rect)
            
            # Draw predicted boxes
            pred_boxes = outputs['boxes'].cpu().numpy()
            pred_scores = outputs['scores'].cpu().numpy()
            pred_labels = outputs['labels'].cpu().numpy()
            
            for box, score, label in zip(pred_boxes, pred_scores, pred_labels):
                if score >= threshold:
                    x1, y1, x2, y2 = box
                    label_name = config['data']['labels'][label]
                    rect = plt.Rectangle((x1, y1), x2-x1, y2-y1, 
                                        fill=False, color='red', linewidth=2, linestyle='--')
                    ax.add_patch(rect)
                    ax.text(x1, y1-5, f'{label_name}: {score:.2f}', 
                           color='red', fontsize=10, fontweight='bold',
                           bbox=dict(boxstyle='round', facecolor='white', alpha=0.7))
            
            # Add legend only to first subplot
            if idx == indices[0]:
                from matplotlib.patches import Patch
                legend_elements = [
                    Patch(facecolor='none', edgecolor='green', linewidth=2, label='Ground Truth'),
                    Patch(facecolor='none', edgecolor='red', linewidth=2, linestyle='--', label='Prediction')
                ]
                ax.legend(handles=legend_elements, loc='upper right', fontsize=9)
    
    plt.tight_layout()
    plt.show()
    
    model.train()  # Set back to training mode


def plot_confusion_matrix_heatmap(y_true, y_pred, labels, save_path=None):
    """Plot confusion matrix as a heatmap"""
    from sklearn.metrics import confusion_matrix
    
    try:
        import seaborn as sns
        use_seaborn = True
    except ImportError:
        use_seaborn = False
    
    cm = confusion_matrix(y_true, y_pred)
    
    plt.figure(figsize=(10, 8))
    if use_seaborn:
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                    xticklabels=labels, yticklabels=labels,
                    cbar_kws={'label': 'Count'})
    else:
        # Fallback to matplotlib if seaborn not available
        plt.imshow(cm, interpolation='nearest', cmap='Blues')
        plt.colorbar(label='Count')
        tick_marks = np.arange(len(labels))
        plt.xticks(tick_marks, labels, rotation=45)
        plt.yticks(tick_marks, labels)
        for i in range(len(labels)):
            for j in range(len(labels)):
                plt.text(j, i, str(cm[i, j]), ha='center', va='center', fontweight='bold')
    
    plt.title('Confusion Matrix', fontsize=16, fontweight='bold', pad=20)
    plt.ylabel('True Label', fontsize=12)
    plt.xlabel('Predicted Label', fontsize=12)
    plt.tight_layout()
    
    if save_path:
        Path(save_path).parent.mkdir(parents=True, exist_ok=True)
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"Saved confusion matrix to {save_path}")
    
    plt.show()

print("Visualization helpers loaded!")
print("\nUsage examples:")
print("  # Visualize detection predictions:")
print("  visualize_detections(model, val_ds, config, device, num_samples=6)")
print("  # Plot confusion matrix (for classification):")
print("  plot_confusion_matrix_heatmap(y_true, y_pred, config['data']['labels'])")


## 6. Save checkpoint


In [None]:
from pathlib import Path

def save_checkpoint(model, path="runs/detector/completed.pt"):
    Path(path).parent.mkdir(parents=True, exist_ok=True)
    torch.save(model.state_dict(), path)
    print(f"Saved {path}")

save_checkpoint(model)

## 7. YOLOv8 (optional detection)
Requires images/labels in YOLO format and `ultralytics` installed.


In [None]:
# Build YOLO data yaml (adjust paths/names)
import yaml

yolo_data = {
    "path": "data/yolo",      # root containing images/ and labels/
    "train": "images/train",
    "val": "images/val",
    "names": {0: "Tumor"},     # change if multiple classes
}

with open("configs/yolo_task01.yaml", "w", encoding="utf-8") as f:
    yaml.safe_dump(yolo_data, f)
print("Wrote configs/yolo_task01.yaml")

# Example training (uncomment when data ready)
from ultralytics import YOLO
yolo_model = YOLO("yolov8n.pt")
yolo_model.train(data="configs/yolo_task01.yaml", epochs=20, imgsz=640, batch=8, project="runs/yolo", name="task01")
yolo_model.val(data="configs/yolo_task01.yaml")




Wrote configs/yolo_task01.yaml


## 8. Quick inference helper


In [None]:
from PIL import Image
import torchvision.transforms as T

def infer_image(model, image_path, config, device):
    model.eval()
    img = Image.open(image_path).convert('RGB')
    if config['task'] == 'classification':
        tfm = T.Compose([
            T.Resize((224, 224)),
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
        tensor = tfm(img).unsqueeze(0).to(device)
        with torch.no_grad():
            logits = model(tensor)
            probs = torch.softmax(logits, dim=1)
            conf, pred = probs.max(dim=1)
        return {"label": config['data']['labels'][pred.item()], "confidence": conf.item()}
    else:
        tfm = T.Compose([T.ToTensor()])
        tensor = tfm(img).to(device)
        with torch.no_grad():
            outputs = model([tensor])[0]
        if len(outputs['boxes']) == 0:
            return {"detections": []}
        detections = []
        for b, s, l in zip(outputs['boxes'].cpu(), outputs['scores'].cpu(), outputs['labels'].cpu()):
            detections.append({
                "bbox": [round(x.item(), 2) for x in b],
                "score": round(float(s), 4),
                "label": config['data']['labels'][l],
            })
        return {"detections": detections}

# Example usage (set a real image path):
result = infer_image(model, "data/processed/sample.png", config, device)
print(result)

