In [None]:
import openslide
from pathlib import Path

# Use the first file found
wsi_dir = Path("dataset/training_dataset/training_image_data")
svs_path = next(wsi_dir.glob("*.svs"))

print(f"Inspecting: {svs_path}")
try:
    slide = openslide.OpenSlide(str(svs_path))
    print(f"Levels: {slide.level_count}")
    print(f"Dimensions: {slide.dimensions}")
    print(f"Level dimensions: {slide.level_dimensions}")
    print(f"Level downsamples: {slide.level_downsamples}")
    print(f"Objective power: {slide.properties.get('openslide.objective-power')}")
    print(f"Magnification: {slide.properties.get('aperio.AppMag')}")
except Exception as e:
    print(f"Error: {e}")

# Visualization of Stage 1 Results

This notebook visualizes the results of the Stage 1 preprocessing pipeline. It compares the original WSI thumbnails with the normalized images, blue ratio images, and generated masks.

In [None]:
from typing import Optional

import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
import openslide
import numpy as np
from pathlib import Path

# Set plot style
plt.style.use('ggplot')

In [None]:
# Load the report
report_path = Path("output/stage_one/reports/stage_one_summary.csv")
if not report_path.exists():
    print(f"Report not found at {report_path}. Please run stage 1 pipeline first.")
else:
    df = pd.read_csv(report_path)
    print(f"Loaded report with {len(df)} records.")
    display(df.head())

In [None]:
def load_image(path: str | Path, is_svs_thumbnail: bool = False, thumb_size: tuple[int, int] = (1024, 1024)) -> Optional[np.ndarray]:
    """Load an image, handling SVS files by extracting a thumbnail."""
    path = Path(path)
    if not path.exists():
        print(f"Warning: File not found: {path}")
        return None
        
    if path.suffix.lower() == '.svs':
        try:
            with openslide.OpenSlide(str(path)) as slide:
                img = slide.get_thumbnail(thumb_size)
                return np.array(img)
        except Exception as e:
            print(f"Error loading SVS {path}: {e}")
            return None
    else:
        try:
            return np.array(Image.open(path))
        except Exception as e:
            print(f"Error loading image {path}: {e}")
            return None

def visualize_row(row):
    """Visualize a single row of the report."""
    orig_path = row['image_path']
    norm_path = row['normalized_path']
    blue_path = row['blue_ratio_path']
    mask_path = row['mask_path']
    
    orig = load_image(orig_path, is_svs_thumbnail=True)
    norm = load_image(norm_path)
    blue = load_image(blue_path)
    mask = load_image(mask_path)
    
    fig, axes = plt.subplots(1, 4, figsize=(24, 6))
    
    # Original
    if orig is not None:
        axes[0].imshow(orig)
        axes[0].set_title(f"Original (Thumbnail)\n{Path(orig_path).name}")
    else:
        axes[0].text(0.5, 0.5, "Image Not Found", ha='center')
    axes[0].axis('off')
    
    # Normalized
    if norm is not None:
        axes[1].imshow(norm)
        axes[1].set_title("Macenko Normalized")
    else:
        axes[1].text(0.5, 0.5, "Image Not Found", ha='center')
    axes[1].axis('off')
    
    # Blue Ratio
    if blue is not None:
        axes[2].imshow(blue, cmap='jet')
        axes[2].set_title("Blue Ratio")
    else:
        axes[2].text(0.5, 0.5, "Image Not Found", ha='center')
    axes[2].axis('off')
    
    # Mask
    if mask is not None:
        axes[3].imshow(mask, cmap='gray')
        axes[3].set_title(f"Stroma Mask\nBlobs: {row['blob_count']}, Avg Area: {row['average_blob_area']:.1f}")
    else:
        axes[3].text(0.5, 0.5, "Image Not Found", ha='center')
    axes[3].axis('off')
    
    plt.tight_layout()
    plt.show()


In [None]:
# Visualize the first 5 processed images
if report_path.exists():
    for i, row in df.head(5).iterrows():
        print(f"Visualizing Index {i}")
        visualize_row(row)

In [None]:
# We will use this notebook cell to verify the code logic before editing the main file.
# Simulating the split logic for _train_det_paper

import random
from dataclasses import dataclass

@dataclass
class MockRow:
    case_id: str
    label: int

def split_rows(rows, seed):
    # Unique cases
    case_ids = sorted(list(set(r.case_id for r in rows)))
    rng = random.Random(seed)
    rng.shuffle(case_ids)
    
    n = len(case_ids)
    n_train = int(0.6 * n)
    n_val = int(0.2 * n)
    
    train_cases = set(case_ids[:n_train])
    val_cases = set(case_ids[n_train:n_train+n_val])
    test_cases = set(case_ids[n_train+n_val:])
    
    train_rows = [r for r in rows if r.case_id in train_cases]
    val_rows = [r for r in rows if r.case_id in val_cases]
    test_rows = [r for r in rows if r.case_id in test_cases]
    
    return train_rows, val_rows, test_rows

# Mock data
rows = [MockRow(f"case_{i%10}", i%2) for i in range(100)] 
# 10 cases, 10 rows each.
tr, va, te = split_rows(rows, 1337)
print(f"Total rows: {len(rows)}")
print(f"Train rows: {len(tr)}")
print(f"Val rows: {len(va)}")
print(f"Test rows: {len(te)}")


In [None]:
# We will write the new function content to a string to verify syntax before editing.
code_to_insert = """
@torch.no_grad()
def _eval_det_loader(
    *,
    model: nn.Module,
    loader: DataLoader,
    device: torch.device,
    max_batches: int = 0,
) -> dict[str, float]:
    criterion = nn.CrossEntropyLoss()
    model.eval()

    loss_sum = 0.0
    correct = 0
    total = 0
    tp = fp = fn = tn = 0
    kept = 0
    
    all_y_true = []
    all_y_score = []

    for images, labels in loader:
        if max_batches > 0 and kept >= max_batches:
            break
        images = images.to(device)
        labels = labels.to(device)
        images = _imagenet_normalize_batch(images)

        out = model(images)
        loss = criterion(out, labels)
        loss_sum += float(loss.detach().cpu())

        # For AUC (class 1 prob)
        probs = F.softmax(out, dim=1)[:, 1]
        all_y_score.append(probs.detach().cpu().numpy())
        all_y_true.append(labels.detach().cpu().numpy())

        preds = torch.argmax(out, dim=1)
        correct += int((preds == labels).sum().cpu())
        total += int(labels.numel())
        
        pred_pos = preds == 1
        gt_pos = labels == 1
        tp += int((pred_pos & gt_pos).sum().cpu())
        fp += int((pred_pos & (~gt_pos)).sum().cpu())
        fn += int(((~pred_pos) & gt_pos).sum().cpu())
        tn += int(((~pred_pos) & (~gt_pos)).sum().cpu())

        kept += 1

    avg_loss = float(loss_sum / max(1, kept))
    acc = float(correct / max(1, total))
    precision = float(tp / max(1, tp + fp))
    recall = float(tp / max(1, tp + fn))
    f1 = float((2 * precision * recall) / max(1e-12, precision + recall))
    
    auc = 0.0
    if all_y_true:
        y_true_flat = np.concatenate(all_y_true)
        y_score_flat = np.concatenate(all_y_score)
        if len(np.unique(y_true_flat)) > 1:
            auc = float(roc_auc_score(y_true_flat, y_score_flat))

    model.train()
    return {
        "loss": avg_loss,
        "acc": acc,
        "auc": auc,
        "precision": precision,
        "recall": recall,
        "f1": f1,
        "tp": float(tp),
        "fp": float(fp),
        "fn": float(fn),
        "tn": float(tn),
    }


def _train_det_paper(
    *,
    index_csv: Path,
    checkpoint_path: Path,
    metrics_csv: Path,
    batch_size: int,
    tune_epochs: int,
    final_epochs: int,
    device: str,
    split_seed: int,
    early_stop_patience: int = 5,
) -> None:
    \"\"\"Paper-aligned CNN_det training: 60/20/20 case split.

    - Reads 80x80 candidate patch index.
    - Splits by case_id.
    - Tuning phase: Train on 60%, Val on 20% (save best).
    - Final phase: Train on 80% (60+20), monitor on 20% test (early stop).
    - Model: AlexNet (pretrained), resized inputs to 227x227.
    - Hyperparameters: SGD lr=0.0001, momentum=0.9, weight_decay=0.0005.
    \"\"\"

    index_csv = Path(index_csv)
    checkpoint_path = Path(checkpoint_path)
    checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
    metrics_csv = Path(metrics_csv)
    _append_metrics_row(metrics_csv, {"event": "start", "device": device})

    # Load and split
    rows = read_det_patch_index(index_csv)
    case_ids = sorted(list(set(r.case_id for r in rows if r.case_id)))
    rng = random.Random(int(split_seed))
    rng.shuffle(case_ids)

    n_total = len(case_ids)
    n_train = int(0.6 * n_total)
    n_val = int(0.2 * n_total)
    
    train_cases = set(case_ids[:n_train])
    val_cases = set(case_ids[n_train:n_train+n_val])
    test_cases = set(case_ids[n_train+n_val:])
    
    train_rows = [r for r in rows if r.case_id in train_cases]
    val_rows = [r for r in rows if r.case_id in val_cases]
    test_rows = [r for r in rows if r.case_id in test_cases]
    
    logger.info(
        "CNN_det split (cases): total=%d train=%d val=%d test=%d seed=%d",
        n_total, len(train_cases), len(val_cases), len(test_cases), int(split_seed)
    )
    logger.info(
        "CNN_det split (patches): total=%d train=%d val=%d test=%d",
        len(rows), len(train_rows), len(val_rows), len(test_rows)
    )

    # config implies using config.alexnet_input_size, but we can hardcode 227 as per paper/user request
    # or load from config. User said "AlexNet requires 227x227".
    cfg = load_mfc_cnn_config()
    
    # Datasets with normalization=False because we normalize in loop (standard practice here)
    train_ds = PreparedMitosisDetectionPatchDataset(rows=train_rows, output_size=227, normalize_imagenet=False)
    val_ds = PreparedMitosisDetectionPatchDataset(rows=val_rows, output_size=227, normalize_imagenet=False)
    test_ds = PreparedMitosisDetectionPatchDataset(rows=test_rows, output_size=227, normalize_imagenet=False)
    
    # Paper-aligned detection training typically shuffles training data.
    train_loader = DataLoader(train_ds, batch_size=int(batch_size), shuffle=True, num_workers=0)
    val_loader = DataLoader(val_ds, batch_size=int(batch_size), shuffle=False, num_workers=0)
    test_loader = DataLoader(test_ds, batch_size=int(batch_size), shuffle=False, num_workers=0)

    device_t = _resolve_torch_device(device)
    model = CNNDet(num_classes=2, pretrained=True).to(device_t)

    # Paper hypers: LR=0.0001, Momentum=0.9, WD=0.0005
    optimizer = torch.optim.SGD(
        model.parameters(),
        lr=0.0001,
        momentum=0.9,
        weight_decay=0.0005,
    )
    criterion = nn.CrossEntropyLoss()

    best_score: float | None = None
    best_state: dict[str, torch.Tensor] | None = None
    global_step = 0
    no_improve = 0

    def _run_one_epoch(*, loader: DataLoader) -> dict[str, float]:
        nonlocal global_step
        running_loss = 0.0
        steps = 0
        correct = 0
        total = 0
        tp = fp = fn = tn = 0

        for images, labels in loader:
            images = images.to(device_t)
            labels = labels.to(device_t)
            images = _imagenet_normalize_batch(images)

            optimizer.zero_grad(set_to_none=True)
            logits = model(images)
            loss = criterion(logits, labels)
            loss.backward()
            optimizer.step()

            running_loss += float(loss.detach().cpu())
            steps += 1
            global_step += 1
            
            with torch.no_grad():
                preds = torch.argmax(logits, dim=1)
                correct += int((preds == labels).sum().cpu())
                total += int(labels.numel())
                pred_pos = preds == 1
                gt_pos = labels == 1
                tp += int((pred_pos & gt_pos).sum().cpu())
                fp += int((pred_pos & (~gt_pos)).sum().cpu())
                fn += int(((~pred_pos) & gt_pos).sum().cpu())
                tn += int(((~pred_pos) & (~gt_pos)).sum().cpu())
            
            if steps == 1 or steps % 50 == 0:
                logger.info("CNN_det train step=%d loss=%.4f acc=%.3f", steps, float(loss), float(correct/total))

        avg = running_loss / max(1, steps)
        acc = float(correct / max(1, total))
        dice = float((2 * tp) / max(1, 2 * tp + fp + fn))
        return {"loss": avg, "acc": acc, "dice": dice}

    logger.info("Starting Tuning Phase (Train on 60%, Val on 20%)")
    for epoch in range(int(tune_epochs)):
        model.train()
        train_m = _run_one_epoch(loader=train_loader)
        val_m = _eval_det_loader(model=model, loader=val_loader, device=device_t)
        
        score = val_m["acc"]  # Primary metric for detection is often simple Acc or F1. Paper mentions "monitoring".
        # Let's use F1 as it's more robust for imbalance. Or Acc. User prompt says "monitoring training...".
        # Let's log all.
        
        logger.info(
            "CNN_det tune epoch=%d train_loss=%.4f train_acc=%.3f val_loss=%.4f val_acc=%.3f val_f1=%.3f val_auc=%.3f",
            epoch + 1, train_m["loss"], train_m["acc"], val_m["loss"], val_m["acc"], val_m["f1"], val_m["auc"]
        )
        
        _append_metrics_row(metrics_csv, {
            "phase": "tune",
            "epoch": epoch+1,
            "train_loss": train_m["loss"],
            "val_loss": val_m["loss"],
            "val_acc": val_m["acc"],
            "val_auc": val_m["auc"]
        })

        if best_score is None or score > best_score:
            best_score = score
            best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
    
    if best_state is not None:
        model.load_state_dict(best_state)
        logger.info("Restored best model from tuning phase (acc=%.3f)", best_score)
    
    # Final Phase: Merge Train + Val
    logger.info("Starting Final Phase (Train on 80%, Monitor on 20% Test)")
    
    # Create merged dataset
    final_rows = train_rows + val_rows
    final_ds = PreparedMitosisDetectionPatchDataset(rows=final_rows, output_size=227, normalize_imagenet=False)
    final_loader = DataLoader(final_ds, batch_size=int(batch_size), shuffle=True, num_workers=0)
    
    no_improve = 0
    best_test_score = 0.0
    
    for epoch in range(int(final_epochs)):
        model.train()
        train_m = _run_one_epoch(loader=final_loader)
        test_m = _eval_det_loader(model=model, loader=test_loader, device=device_t)
        
        score = test_m["acc"] # Monitor metric
        
        logger.info(
            "CNN_det final epoch=%d train_loss=%.4f train_acc=%.3f test_loss=%.4f test_acc=%.3f test_f1=%.3f test_auc=%.3f",
            epoch + 1, train_m["loss"], train_m["acc"], test_m["loss"], test_m["acc"], test_m["f1"], test_m["auc"]
        )
        
        _append_metrics_row(metrics_csv, {
            "phase": "final",
            "epoch": epoch+1,
            "train_loss": train_m["loss"],
            "test_loss": test_m["loss"],
            "test_acc": test_m["acc"],
            "test_auc": test_m["auc"]
        })
        
        if score > best_test_score:
            best_test_score = score
            best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
            no_improve = 0
        else:
            no_improve += 1
            
        if early_stop_patience > 0 and no_improve >= early_stop_patience:
             logger.info("Early stopping triggered at final epoch %d", epoch+1)
             break
             
    if best_state is not None:
        model.load_state_dict(best_state)

    torch.save(
        {
            "model": "CNNDet", 
            "paper_aligned": True,
            "state_dict": model.state_dict(),
            "config": cfg.model_dump(mode="json")
        },
        checkpoint_path
    )
    logger.info("Saved CNN_det(paper) checkpoint: %s", str(checkpoint_path))

"""
print("Syntax OK")


In [None]:
# Prototyping _train_global_paper logic

def split_global_rows_paper(rows, seed):
    # Unique slides
    slide_ids = sorted(list(set(r.slide_id for r in rows)))
    # Labels for stratification? Paper says "split WSI cases". Usually stratified by class.
    # We can try to stratify if we have labels per slide (we do).
    
    slide_to_label = {}
    for r in rows:
        if r.slide_id not in slide_to_label:
            slide_to_label[r.slide_id] = r.label
            
    # Stratified split
    by_label = {}
    for s, l in slide_to_label.items():
        by_label.setdefault(l, []).append(s)
        
    train_slides = set()
    val_slides = set()
    test_slides = set()
    
    rng = random.Random(seed)
    
    for l, slides in by_label.items():
        rng.shuffle(slides)
        n = len(slides)
        n_train = int(0.6 * n)
        n_val = int(0.2 * n)
        
        train_slides.update(slides[:n_train])
        val_slides.update(slides[n_train:n_train+n_val])
        test_slides.update(slides[n_train+n_val:])
        
    train_rows = [r for r in rows if r.slide_id in train_slides]
    val_rows = [r for r in rows if r.slide_id in val_slides]
    test_rows = [r for r in rows if r.slide_id in test_slides]
    
    return train_rows, val_rows, test_rows

# Mock
@dataclass
class MockGlobalRow:
    slide_id: str
    label: int
    
grows = []
for i in range(100):
    slide_id = f"S{i}"
    label = i % 3 + 1
    # 10 patches per slide
    for _ in range(10):
        grows.append(MockGlobalRow(slide_id, label))
        
tr, va, te = split_global_rows_paper(grows, 1337)
print(f"Total: {len(grows)} (100 slides)")
print(f"Train: {len(tr)}")
print(f"Val: {len(va)}")
print(f"Test: {len(te)}")
