## Step 1: Setup & Imports

In [32]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
import torchvision.transforms as T
from torchvision.models import efficientnet_v2_s, EfficientNet_V2_S_Weights
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, classification_report, confusion_matrix
from PIL import Image
import cv2
import seaborn as sns

# Try to import Lion optimizer
try:
    from lion_pytorch import Lion
    print("‚úÖ Lion optimizer available")
except ImportError:
    print("‚ö†Ô∏è Installing Lion optimizer...")
    !pip install lion-pytorch
    from lion_pytorch import Lion
    print("‚úÖ Lion optimizer installed")

# Device setup
if torch.cuda.is_available():
    DEVICE = "cuda"
elif hasattr(torch, 'mps') and torch.mps.is_available():
    DEVICE = "mps"
else:
    DEVICE = "cpu"

print(f"\nüöÄ Device: {DEVICE}")
print(f"PyTorch version: {torch.__version__}")

# Seed for reproducibility
def set_seed(seed=42):
    torch.manual_seed(seed)
    np.random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(42)

‚úÖ Lion optimizer available

üöÄ Device: mps
PyTorch version: 2.9.1


## Step 2: Data Preprocessing with Artifact Filtering

**Advice 05/12**: "A single drop of poison, the purest well corrupts."

We preprocess images to:
- Apply masks (remove background)
- Detect and remove artifacts (green/orange/brown markers)
- Filter out heavily contaminated images
- Crop to ROI (tissue region)

In [33]:
# Configuration
INPUT_DIR = '../data/train_data'
OUTPUT_DIR = '../data/train_data_cleaned'
CSV_PATH = '../data/train_labels.csv'
MAX_ARTIFACT_RATIO = 0.005  # Reject images with >10% artifacts

# Create output directory
os.makedirs(OUTPUT_DIR, exist_ok=True)

print(f"Input: {INPUT_DIR}")
print(f"Output: {OUTPUT_DIR}")
print(f"Max artifact ratio: {MAX_ARTIFACT_RATIO:.1%}")

Input: ../data/train_data
Output: ../data/train_data_cleaned
Max artifact ratio: 0.5%


In [34]:
def get_artifact_masks(hsv):
    """Detect artifacts in HSV color space."""
    # Green markers
    LOWER_GREEN = np.array([25, 40, 40])
    UPPER_GREEN = np.array([95, 255, 255])
    
    # Orange markers
    LOWER_ORANGE = np.array([5, 50, 50])
    UPPER_ORANGE = np.array([35, 255, 255])
    
    # White centers
    LOWER_WHITE = np.array([0, 0, 200])
    UPPER_WHITE = np.array([180, 30, 255])
    
    # Brown markers
    LOWER_BROWN = np.array([0, 60, 20])
    UPPER_BROWN = np.array([25, 255, 150])
    
    m_green = cv2.inRange(hsv, LOWER_GREEN, UPPER_GREEN)
    m_orange = cv2.inRange(hsv, LOWER_ORANGE, UPPER_ORANGE)
    m_white = cv2.inRange(hsv, LOWER_WHITE, UPPER_WHITE)
    m_brown = cv2.inRange(hsv, LOWER_BROWN, UPPER_BROWN)
    
    bad_pixels = cv2.bitwise_or(m_green, m_orange)
    bad_pixels = cv2.bitwise_or(bad_pixels, m_white)
    bad_pixels = cv2.bitwise_or(bad_pixels, m_brown)
    
    return bad_pixels


def preprocess_image(img_path, mask_path, max_artifact_ratio=0.10):
    """Preprocess a single image with artifact filtering."""
    # Load image and mask
    img = cv2.imread(img_path)
    if img is None:
        return None, None, False, "Failed to load image"
    
    if os.path.exists(mask_path):
        mask = cv2.imread(mask_path, 0)
    else:
        mask = np.full(img.shape[:2], 255, dtype=np.uint8)
    
    # Detect artifacts BEFORE dilation
    hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
    bad_pixels = get_artifact_masks(hsv)
    
    # Count artifact ratio
    total_pixels = img.shape[0] * img.shape[1]
    artifact_count = cv2.countNonZero(bad_pixels)
    artifact_ratio = artifact_count / total_pixels
    
    # Reject if too contaminated (Advice 05/12)
    if artifact_ratio > max_artifact_ratio:
        return None, None, False, f"Too many artifacts ({artifact_ratio:.2%})"
    
    # Dilate artifact mask
    kernel = np.ones((5, 5), np.uint8)
    bad_pixels_expanded = cv2.dilate(bad_pixels, kernel, iterations=3)
    
    # Clean mask
    clean_mask = cv2.bitwise_and(mask, cv2.bitwise_not(bad_pixels_expanded))
    
    # Find ROI
    y, x = np.where(clean_mask > 0)
    
    if len(y) == 0:
        return None, None, False, "No tissue remaining"
    
    # Crop with padding
    pad = 20
    y_min = max(0, y.min() - pad)
    y_max = min(img.shape[0], y.max() + pad)
    x_min = max(0, x.min() - pad)
    x_max = min(img.shape[1], x.max() + pad)
    
    # Apply mask and crop
    masked_img = cv2.bitwise_and(img, img, mask=clean_mask)
    final_img = masked_img[y_min:y_max, x_min:x_max]
    
    # Crop mask too
    final_mask = clean_mask[y_min:y_max, x_min:x_max]
    
    # Convert to RGB
    final_img = cv2.cvtColor(final_img, cv2.COLOR_BGR2RGB)
    
    return final_img, final_mask, True, f"Success ({artifact_ratio:.2%} artifacts removed)"

print("Preprocessing functions defined (now returns both image and mask).")

Preprocessing functions defined (now returns both image and mask).


In [35]:
# Load CSV and process images
df = pd.read_csv(CSV_PATH)
print(f"Found {len(df)} images to process\n")

successful = 0
failed = 0
failed_images = []
rejection_reasons = {'high_artifacts': 0, 'no_tissue': 0, 'corrupted': 0}

for idx, row in tqdm(df.iterrows(), total=len(df), desc="Preprocessing"):
    img_name = row['sample_index']
    img_path = os.path.join(INPUT_DIR, img_name)
    mask_name = img_name.replace("img_", "mask_")
    mask_path = os.path.join(INPUT_DIR, mask_name)
    
    try:
        processed_img, processed_mask, success, message = preprocess_image(
            img_path, mask_path, max_artifact_ratio=MAX_ARTIFACT_RATIO
        )
        
        if success and processed_img is not None:
            # Save both image and mask
            img_output_path = os.path.join(OUTPUT_DIR, img_name)
            mask_output_path = os.path.join(OUTPUT_DIR, mask_name)
            cv2.imwrite(img_output_path, cv2.cvtColor(processed_img, cv2.COLOR_RGB2BGR))
            cv2.imwrite(mask_output_path, processed_mask)
            successful += 1
        else:
            failed += 1
            failed_images.append(img_name)
            
            if 'artifacts' in message:
                rejection_reasons['high_artifacts'] += 1
            elif 'tissue' in message:
                rejection_reasons['no_tissue'] += 1
    except Exception as e:
        failed += 1
        failed_images.append(img_name)
        rejection_reasons['corrupted'] += 1

print(f"\n{'='*60}")
print(f"Preprocessing Complete!")
print(f"{'='*60}")
print(f"‚úÖ Successful: {successful}/{len(df)} ({successful/len(df)*100:.1f}%)")
print(f"‚ùå Failed:     {failed}/{len(df)} ({failed/len(df)*100:.1f}%)")

if failed > 0:
    print(f"\nRejection breakdown:")
    print(f"  üé® Too many artifacts: {rejection_reasons['high_artifacts']}")
    print(f"  üö´ No tissue:          {rejection_reasons['no_tissue']}")
    print(f"  üí• Corrupted:          {rejection_reasons['corrupted']}")

Found 691 images to process



Preprocessing: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 691/691 [00:19<00:00, 35.91it/s]


Preprocessing Complete!
‚úÖ Successful: 547/691 (79.2%)
‚ùå Failed:     144/691 (20.8%)

Rejection breakdown:
  üé® Too many artifacts: 144
  üö´ No tissue:          0
  üí• Corrupted:          0





In [36]:
# Create cleaned CSV
df_clean = df[~df['sample_index'].isin(failed_images)].reset_index(drop=True)
clean_csv_path = '../data/train_labels_cleaned_dual_branch.csv'
df_clean.to_csv(clean_csv_path, index=False)

print(f"\n{'='*60}")
print(f"CSV Created:")
print(f"  Original samples: {len(df)}")
print(f"  Cleaned samples:  {len(df_clean)}")
print(f"  Removed:          {len(df) - len(df_clean)}")
print(f"  Saved to:         {clean_csv_path}")
print(f"{'='*60}")

# Class distribution
print("\nClass distribution after cleaning:")
print(df_clean['label'].value_counts())


CSV Created:
  Original samples: 691
  Cleaned samples:  547
  Removed:          144
  Saved to:         ../data/train_labels_cleaned_dual_branch.csv

Class distribution after cleaning:
label
Luminal B          191
Luminal A          149
HER2(+)            145
Triple negative     62
Name: count, dtype: int64


## Step 3: Dataset & DataLoader Setup

**Advice 04/12**: "Know the dimension of your stream" - Using larger batches (32) for BatchNorm stability

**Advice 06/12**: "Let the policy emerge from the struggle" - RandAugment automated augmentation

**Advice 10/12**: "Parallel Paths" - Dual-branch dataset loads both RGB image and binary mask

In [37]:
class DualBranchDataset(Dataset):
    """Dataset for dual-branch model: loads both RGB image and binary mask."""
    def __init__(self, df, img_dir, transform=None):
        self.df = df
        self.img_dir = img_dir
        self.transform = transform
        self.label_map = {
            "Luminal A": 0,
            "Luminal B": 1,
            "HER2(+)": 2,
            "Triple negative": 3
        }

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_name = row['sample_index']
        label = self.label_map[row['label']]
        
        # Load RGB image
        img_path = os.path.join(self.img_dir, img_name)
        image = Image.open(img_path).convert('RGB')
        
        # Load binary mask
        mask_name = img_name.replace("img_", "mask_")
        mask_path = os.path.join(self.img_dir, mask_name)
        mask = Image.open(mask_path).convert('L')  # Grayscale
        
        if self.transform:
            # Apply transform to RGB image (with normalization)
            image = self.transform(image)
            
            # Apply geometric transforms to mask but skip normalization
            # Create mask transform without Normalize
            mask_transform = T.Compose([
                t for t in self.transform.transforms 
                if not isinstance(t, T.Normalize)
            ])
            mask = mask_transform(mask)
            
        return image, mask, label, img_name

print("Dual-branch dataset class defined.")

Dual-branch dataset class defined.


In [38]:
# Configuration - EfficientNet uses 384√ó384
IMG_SIZE = 384  # EfficientNetV2-S standard input size
BATCH_SIZE = 32  # Larger for BatchNorm stability (Advice 04/12)
NUM_WORKERS = 0  # Must be 0 for Jupyter notebooks on macOS

# Training transforms with RandAugment (Advice 06/12)
train_transform = T.Compose([
    T.Resize((IMG_SIZE, IMG_SIZE)),
    T.RandomHorizontalFlip(p=0.5),
    T.RandomVerticalFlip(p=0.5),
    T.RandomRotation(degrees=15),
    T.RandAugment(num_ops=2, magnitude=7),  # Automated augmentation
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Validation transforms (no augmentation)
val_transform = T.Compose([
    T.Resize((IMG_SIZE, IMG_SIZE)),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

print(f"Image size: {IMG_SIZE}√ó{IMG_SIZE} (EfficientNetV2-S standard)")
print(f"Batch size: {BATCH_SIZE} (stable for BatchNorm)")
print(f"Transforms configured with RandAugment")
print(f"Dual-branch: RGB + Mask inputs")

Image size: 384√ó384 (EfficientNetV2-S standard)
Batch size: 32 (stable for BatchNorm)
Transforms configured with RandAugment
Dual-branch: RGB + Mask inputs


In [39]:
class DualBranchDataset(Dataset):
    """Dataset for dual-branch model: loads both RGB image and binary mask."""
    def __init__(self, df, img_dir, transform=None):
        self.df = df
        self.img_dir = img_dir
        self.transform = transform
        self.label_map = {
            "Luminal A": 0,
            "Luminal B": 1,
            "HER2(+)": 2,
            "Triple negative": 3
        }

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_name = row['sample_index']
        label = self.label_map[row['label']]
        
        # Load RGB image
        img_path = os.path.join(self.img_dir, img_name)
        image = Image.open(img_path).convert('RGB')
        
        # Load binary mask
        mask_name = img_name.replace("img_", "mask_")
        mask_path = os.path.join(self.img_dir, mask_name)
        mask = Image.open(mask_path).convert('L')  # Grayscale
        
        if self.transform:
            # Apply transform to RGB image (with normalization)
            image = self.transform(image)
            
            # Apply geometric transforms to mask but skip normalization
            # Create mask transform without Normalize
            mask_transform = T.Compose([
                t for t in self.transform.transforms 
                if not isinstance(t, T.Normalize)
            ])
            mask = mask_transform(mask)
            if mask.shape[0] == 1:
                mask = mask.repeat(3, 1, 1)
            
        return image, mask, label, img_name

print("Dual-branch dataset class defined.")
# Load cleaned data and split
df_clean = pd.read_csv(clean_csv_path)

train_df, val_df = train_test_split(
    df_clean, 
    test_size=0.2, 
    stratify=df_clean['label'], 
    random_state=42
)

train_df = train_df.reset_index(drop=True)
val_df = val_df.reset_index(drop=True)

print(f"Training samples: {len(train_df)}")
print(f"Validation samples: {len(val_df)}")

# Calculate weights for class imbalance
class_counts = train_df['label'].value_counts()
weight_per_class = {cls: 1.0/count for cls, count in class_counts.items()}
sample_weights = [weight_per_class[row['label']] for _, row in train_df.iterrows()]
sampler = WeightedRandomSampler(sample_weights, num_samples=len(train_df), replacement=True)

# Create datasets and loaders
train_ds = DualBranchDataset(train_df, OUTPUT_DIR, transform=train_transform)
val_ds = DualBranchDataset(val_df, OUTPUT_DIR, transform=val_transform)

train_loader = DataLoader(
    train_ds, 
    batch_size=BATCH_SIZE, 
    sampler=sampler, 
    num_workers=NUM_WORKERS,
    pin_memory=True if DEVICE == 'cuda' else False
)

val_loader = DataLoader(
    val_ds, 
    batch_size=BATCH_SIZE, 
    shuffle=False, 
    num_workers=NUM_WORKERS,
    pin_memory=True if DEVICE == 'cuda' else False
)

print("\nDual-branch DataLoaders ready!")

Dual-branch dataset class defined.
Training samples: 437
Validation samples: 110

Dual-branch DataLoaders ready!


In [40]:
train_ds[5][1].shape

torch.Size([3, 384, 384])

## Step 4: Dual-Branch EfficientNetV2-S Model Setup

**Advice 07/12**: "The Lion, with instinct fierce and memory sparse, the prey tracks."

**Advice 10/12**: "Parallel Paths" - Two branches run separately, meet at fusion

**Dual-Branch EfficientNetV2-S**: 
- **RGB Branch**: Processes color, texture, cellular morphology (pretrained)
- **Mask Branch**: Processes shape, boundaries, spatial structure (from scratch)
- **Fusion**: Concatenate features before classification
- Total parameters: ~42M (21M + 21M)

In [41]:
class DualBranchEfficientNet(nn.Module):
    """Dual-branch EfficientNetV2-S: RGB + Mask branches with late fusion."""
    def __init__(self):
        super().__init__()
        
        # RGB Branch - pretrained on ImageNet
        self.rgb_branch = efficientnet_v2_s(weights=EfficientNet_V2_S_Weights.IMAGENET1K_V1)
        rgb_features = self.rgb_branch.classifier[1].in_features
        self.rgb_branch.classifier = nn.Identity()  # Remove classifier
        
        # Mask Branch - train from scratch (binary masks)
        self.mask_branch = efficientnet_v2_s(weights=None)
        mask_features = self.mask_branch.classifier[1].in_features
        self.mask_branch.classifier = nn.Identity()  # Remove classifier
        
        # Fusion and classification
        self.classifier = nn.Sequential(
            nn.Dropout(p=0.3),
            nn.Linear(rgb_features + mask_features, 4)
        )
        
        print(f"RGB branch features: {rgb_features}")
        print(f"Mask branch features: {mask_features}")
        print(f"Total fused features: {rgb_features + mask_features}")
    
    def forward(self, rgb, mask):
        # RGB branch: color, texture, morphology
        rgb_features = self.rgb_branch(rgb)
        
        # Mask branch: shape, boundaries, geometry
        mask_features = self.mask_branch(mask)
        
        # Fusion: parallel paths meet at summit
        fused_features = torch.cat([rgb_features, mask_features], dim=1)
        
        # Classification
        return self.classifier(fused_features)


# Create model
model = DualBranchEfficientNet()
model = model.to(DEVICE)

# Loss with label smoothing
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)

# Lion optimizer (Advice 07/12)
# Lion needs 3-10x smaller LR than Adam
optimizer = Lion(model.parameters(), lr=3e-4, weight_decay=1e-2)

# Cosine annealing scheduler
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50, eta_min=1e-6)

print("\n‚úÖ Model: Dual-Branch EfficientNetV2-S")
print("‚úÖ RGB Branch: Pretrained on ImageNet")
print("‚úÖ Mask Branch: Train from scratch")
print("‚úÖ Fusion: Late concatenation")
print("‚úÖ Optimizer: Lion (lr=3e-4)")
print("‚úÖ Scheduler: Cosine Annealing")
print(f"‚úÖ Device: {DEVICE}")

RGB branch features: 1280
Mask branch features: 1280
Total fused features: 2560

‚úÖ Model: Dual-Branch EfficientNetV2-S
‚úÖ RGB Branch: Pretrained on ImageNet
‚úÖ Mask Branch: Train from scratch
‚úÖ Fusion: Late concatenation
‚úÖ Optimizer: Lion (lr=3e-4)
‚úÖ Scheduler: Cosine Annealing
‚úÖ Device: mps


## Step 5: Training Loop

In [42]:
def train_epoch(model, loader, criterion, optimizer, device):
    """Train for one epoch with dual-branch inputs."""
    model.train()
    running_loss = 0.0
    all_preds = []
    all_targets = []
    
    pbar = tqdm(loader, desc="Training", leave=False)
    for images, masks, labels, _ in pbar:
        images = images.to(device)
        masks = masks.to(device)
        labels = labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images, masks)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        preds = outputs.argmax(dim=1)
        all_preds.extend(preds.cpu().numpy())
        all_targets.extend(labels.cpu().numpy())
        
        pbar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    epoch_loss = running_loss / len(loader)
    epoch_f1 = f1_score(all_targets, all_preds, average='macro')
    epoch_acc = (np.array(all_preds) == np.array(all_targets)).mean()
    
    return epoch_loss, epoch_f1, epoch_acc


def validate_epoch(model, loader, criterion, device):
    """Validate for one epoch with dual-branch inputs."""
    model.eval()
    running_loss = 0.0
    all_preds = []
    all_targets = []
    
    with torch.no_grad():
        pbar = tqdm(loader, desc="Validation", leave=False)
        for images, masks, labels, _ in pbar:
            images = images.to(device)
            masks = masks.to(device)
            labels = labels.to(device)
            
            outputs = model(images, masks)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            preds = outputs.argmax(dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_targets.extend(labels.cpu().numpy())
            
            pbar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    epoch_loss = running_loss / len(loader)
    epoch_f1 = f1_score(all_targets, all_preds, average='macro')
    epoch_acc = (np.array(all_preds) == np.array(all_targets)).mean()
    
    return epoch_loss, epoch_f1, epoch_acc, all_preds, all_targets

print("Training functions defined (dual-branch).")

Training functions defined (dual-branch).


In [43]:
# Training configuration
NUM_EPOCHS = 50
FREEZE_EPOCHS = 5  # Freeze RGB branch for first 5 epochs
PATIENCE = 10

# Freeze RGB branch initially (let mask branch learn)
for param in model.rgb_branch.parameters():
    param.requires_grad = False

print(f"Configuration:")
print(f"  Total epochs: {NUM_EPOCHS}")
print(f"  Freeze RGB branch epochs: {FREEZE_EPOCHS}")
print(f"  Early stopping patience: {PATIENCE}")
print(f"  RGB branch frozen for first {FREEZE_EPOCHS} epochs (mask branch learns first)")

Configuration:
  Total epochs: 50
  Freeze RGB branch epochs: 5
  Early stopping patience: 10
  RGB branch frozen for first 5 epochs (mask branch learns first)


In [None]:
# Training loop
history = {
    'train_loss': [], 'train_f1': [], 'train_acc': [],
    'val_loss': [], 'val_f1': [], 'val_acc': []
}

best_f1 = 0.0
patience_counter = 0

print("\n" + "="*60)
print("TRAINING STARTED - Dual-Branch EfficientNetV2-S")
print("="*60 + "\n")

for epoch in range(NUM_EPOCHS):
    # Unfreeze RGB branch after warmup
    if epoch == FREEZE_EPOCHS:
        print(f"\nüîì Unfreezing RGB branch at epoch {epoch}\n")
        for param in model.rgb_branch.parameters():
            param.requires_grad = True
    
    # Train
    train_loss, train_f1, train_acc = train_epoch(
        model, train_loader, criterion, optimizer, DEVICE
    )
    
    # Validate
    val_loss, val_f1, val_acc, val_preds, val_targets = validate_epoch(
        model, val_loader, criterion, DEVICE
    )
    
    # Step scheduler
    scheduler.step()
    
    # Store history
    history['train_loss'].append(train_loss)
    history['train_f1'].append(train_f1)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_f1'].append(val_f1)
    history['val_acc'].append(val_acc)
    
    # Print progress
    print(f"Epoch {epoch+1}/{NUM_EPOCHS}:")
    print(f"  Train - Loss: {train_loss:.4f}, F1: {train_f1:.4f}, Acc: {train_acc:.4f}")
    print(f"  Val   - Loss: {val_loss:.4f}, F1: {val_f1:.4f}, Acc: {val_acc:.4f}")
    print(f"  LR: {scheduler.get_last_lr()[0]:.2e}")
    
    # Save best model
    if val_f1 > best_f1:
        best_f1 = val_f1
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_f1': val_f1,
            'val_acc': val_acc,
        }, '../best_model_dual_branch.pth')
        print(f"  üíæ Saved best model (F1: {val_f1:.4f})")
        patience_counter = 0
    else:
        patience_counter += 1
    
    # Early stopping
    if patience_counter >= PATIENCE:
        print(f"\n‚ö†Ô∏è Early stopping triggered (no improvement for {PATIENCE} epochs)")
        break
    
    print()

print("="*60)
print("TRAINING COMPLETE")
print(f"Best Validation F1: {best_f1:.4f}")
print("="*60)


TRAINING STARTED - Dual-Branch EfficientNetV2-S



Training:   0%|          | 0/14 [00:00<?, ?it/s]

## Step 6: Visualize Training Results

In [None]:
# Plot training curves
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Loss
axes[0].plot(history['train_loss'], label='Train Loss', marker='o')
axes[0].plot(history['val_loss'], label='Val Loss', marker='s')
axes[0].axvline(x=FREEZE_EPOCHS-1, color='r', linestyle='--', alpha=0.5, label='Unfreeze RGB')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training & Validation Loss - Dual-Branch')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# F1 Score
axes[1].plot(history['train_f1'], label='Train F1', marker='o', color='green')
axes[1].plot(history['val_f1'], label='Val F1', marker='s', color='orange')
axes[1].axvline(x=FREEZE_EPOCHS-1, color='r', linestyle='--', alpha=0.5, label='Unfreeze RGB')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('F1 Score (Macro)')
axes[1].set_title(f'F1 Score (Best: {best_f1:.4f})')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

# Accuracy
axes[2].plot(history['train_acc'], label='Train Acc', marker='o', color='blue')
axes[2].plot(history['val_acc'], label='Val Acc', marker='s', color='red')
axes[2].axvline(x=FREEZE_EPOCHS-1, color='r', linestyle='--', alpha=0.5, label='Unfreeze RGB')
axes[2].set_xlabel('Epoch')
axes[2].set_ylabel('Accuracy')
axes[2].set_title('Training & Validation Accuracy')
axes[2].legend()
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('../training_curves_dual_branch.png', dpi=150, bbox_inches='tight')
plt.show()

print("Training curves saved to '../training_curves_dual_branch.png'")

In [None]:
# Load best model for evaluation
checkpoint = torch.load('../best_model_dual_branch.pth')
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

# Get final predictions
_, _, _, final_preds, final_targets = validate_epoch(
    model, val_loader, criterion, DEVICE
)

# Classification report
class_names = ["Luminal A", "Luminal B", "HER2(+)", "Triple Negative"]
print("\nClassification Report - Dual-Branch EfficientNetV2-S:")
print("="*60)
print(classification_report(final_targets, final_preds, target_names=class_names))

# Confusion matrix
cm = confusion_matrix(final_targets, final_preds)
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
plt.title('Confusion Matrix - Dual-Branch EfficientNetV2-S')
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.tight_layout()
plt.savefig('../confusion_matrix_dual_branch.png', dpi=150, bbox_inches='tight')
plt.show()

print("Confusion matrix saved to '../confusion_matrix_dual_branch.png'")

## Step 7: Patch-Based Inference for Test Set

**Advice 08/12**: "Let the model walk the landscape step by step, preserving the original resolution."

We train on resized images (stable), but infer using patches at full resolution (preserves details).

In [None]:
def extract_patches_inference(img, mask, patch_size=384, stride=192):
    """Extract overlapping patches for dual-branch inference."""
    w, h = img.size
    patches_img = []
    patches_mask = []
    
    # Handle small images
    if w < patch_size or h < patch_size:
        resized_img = img.resize((patch_size, patch_size), Image.BILINEAR)
        resized_mask = mask.resize((patch_size, patch_size), Image.NEAREST)
        return [resized_img], [resized_mask]
    
    # Sliding window
    for top in range(0, h - patch_size + 1, stride):
        for left in range(0, w - patch_size + 1, stride):
            patch_img = img.crop((left, top, left + patch_size, top + patch_size))
            patch_mask = mask.crop((left, top, left + patch_size, top + patch_size))
            patches_img.append(patch_img)
            patches_mask.append(patch_mask)
    
    # Handle edges
    if (w - patch_size) % stride != 0:
        left = w - patch_size
        for top in range(0, h - patch_size + 1, stride):
            patch_img = img.crop((left, top, w, top + patch_size))
            patch_mask = mask.crop((left, top, w, top + patch_size))
            patches_img.append(patch_img)
            patches_mask.append(patch_mask)
    
    if (h - patch_size) % stride != 0:
        top = h - patch_size
        for left in range(0, w - patch_size + 1, stride):
            patch_img = img.crop((left, top, left + patch_size, h))
            patch_mask = mask.crop((left, top, left + patch_size, h))
            patches_img.append(patch_img)
            patches_mask.append(patch_mask)
    
    # Corner
    if (w - patch_size) % stride != 0 and (h - patch_size) % stride != 0:
        patch_img = img.crop((w - patch_size, h - patch_size, w, h))
        patch_mask = mask.crop((w - patch_size, h - patch_size, w, h))
        patches_img.append(patch_img)
        patches_mask.append(patch_mask)
    
    return patches_img, patches_mask


def predict_with_patches(model, img_path, mask_path, transform, patch_size=384, stride=192, device='cpu'):
    """Predict using patch-based inference with dual-branch model."""
    img = Image.open(img_path).convert('RGB')
    mask = Image.open(mask_path).convert('L')
    
    patches_img, patches_mask = extract_patches_inference(img, mask, patch_size=patch_size, stride=stride)
    
    all_probs = []
    model.eval()
    
    with torch.no_grad():
        for patch_img, patch_mask in zip(patches_img, patches_mask):
            patch_img_tensor = transform(patch_img).unsqueeze(0).to(device)
            patch_mask_tensor = transform(patch_mask).unsqueeze(0).to(device)
            logits = model(patch_img_tensor, patch_mask_tensor)
            probs = torch.softmax(logits, dim=1).squeeze(0).cpu().numpy()
            all_probs.append(probs)
    
    # Average probabilities (soft voting)
    avg_probs = np.array(all_probs).mean(axis=0)
    prediction = avg_probs.argmax()
    confidence = avg_probs[prediction]
    
    return prediction, confidence

print("Dual-branch patch-based inference functions defined.")

In [None]:
# Preprocess Test Set with Masks
TEST_INPUT_DIR = '../data/test_data'
TEST_OUTPUT_DIR = '../data/test_data_preprocessed'

os.makedirs(TEST_OUTPUT_DIR, exist_ok=True)

if os.path.exists(TEST_INPUT_DIR):
    print(f"Preprocessing test set (clean data - permissive filtering)...")
    print(f"Input: {TEST_INPUT_DIR}")
    print(f"Output: {TEST_OUTPUT_DIR}")
    print(f"Note: Using 50% artifact threshold since test data is clean\n")
    
    test_files = sorted([f for f in os.listdir(TEST_INPUT_DIR) if f.startswith('img_')])
    print(f"Found {len(test_files)} test images\n")
    
    successful = 0
    failed = 0
    failed_test_images = []
    test_rejection_reasons = {'high_artifacts': 0, 'no_tissue': 0, 'corrupted': 0}
    
    for img_name in tqdm(test_files, desc="Preprocessing Test Set"):
        img_path = os.path.join(TEST_INPUT_DIR, img_name)
        mask_name = img_name.replace("img_", "mask_")
        mask_path = os.path.join(TEST_INPUT_DIR, mask_name)
        
        try:
            # Use 0.5 (50%) threshold for test - much more permissive since data is clean
            processed_img, processed_mask, success, message = preprocess_image(
                img_path, mask_path, max_artifact_ratio=0.7
            )
            
            if success and processed_img is not None:
                img_output_path = os.path.join(TEST_OUTPUT_DIR, img_name)
                mask_output_path = os.path.join(TEST_OUTPUT_DIR, mask_name)
                cv2.imwrite(img_output_path, cv2.cvtColor(processed_img, cv2.COLOR_RGB2BGR))
                cv2.imwrite(mask_output_path, processed_mask)
                successful += 1
            else:
                failed += 1
                failed_test_images.append(img_name)
                
                if 'artifacts' in message:
                    test_rejection_reasons['high_artifacts'] += 1
                elif 'tissue' in message:
                    test_rejection_reasons['no_tissue'] += 1
        except Exception as e:
            failed += 1
            failed_test_images.append(img_name)
            test_rejection_reasons['corrupted'] += 1
    
    print(f"\n{'='*60}")
    print(f"Test Set Preprocessing Complete!")
    print(f"{'='*60}")
    print(f"‚úÖ Successful: {successful}/{len(test_files)} ({successful/len(test_files)*100:.1f}%)")
    print(f"‚ùå Failed:     {failed}/{len(test_files)} ({failed/len(test_files)*100:.1f}%)")
    
    if failed > 0:
        print(f"\nRejection breakdown:")
        print(f"  üé® Too many artifacts: {test_rejection_reasons['high_artifacts']}")
        print(f"  üö´ No tissue:          {test_rejection_reasons['no_tissue']}")
        print(f"  üí• Corrupted:          {test_rejection_reasons['corrupted']}")
else:
    print(f"Test directory not found: {TEST_INPUT_DIR}")
    print("Skipping test set preprocessing")

In [None]:
# Inference on test set
TEST_DIR = '../data/test_data_preprocessed'

if os.path.exists(TEST_DIR):
    test_files = sorted([f for f in os.listdir(TEST_DIR) if f.startswith('img_')])
    print(f"Found {len(test_files)} test images")
    print(f"Using dual-branch patch-based inference with 384√ó384 patches, stride=192 (50% overlap)\n")
    
    results = []
    
    for img_name in tqdm(test_files, desc="Inference"):
        img_path = os.path.join(TEST_DIR, img_name)
        mask_name = img_name.replace("img_", "mask_")
        mask_path = os.path.join(TEST_DIR, mask_name)
        
        pred, conf = predict_with_patches(
            model, img_path, mask_path, val_transform, 
            patch_size=384, stride=192, device=DEVICE
        )
        
        results.append({
            'sample_index': img_name,
            'label': class_names[pred],
            # 'confidence': conf
        })
    
    # Save submission
    submission_df = pd.DataFrame(results)
    submission_df.to_csv('../submission_dual_branch_patches.csv', index=False)
    
    print(f"\n‚úÖ Inference complete!")
    print(f"‚úÖ Saved to: ../submission_dual_branch_patches.csv")
    print(f"\nPrediction distribution:")
    print(submission_df['label'].value_counts())
    # print(f"\nAverage confidence: {submission_df['confidence'].mean():.3f}")
else:
    print(f"Test directory not found: {TEST_DIR}")
    print("Please run test set preprocessing first")

## Summary - Dual-Branch EfficientNetV2-S Pipeline

### ‚úÖ All 5 Pieces of Advice Implemented:

1. **Advice 05/12** (Outliers): Filtered images with >10% artifacts
2. **Advice 04/12** (Normalization): Used batch_size=32 for BatchNorm stability
3. **Advice 07/12** (Modern Optimizers): Lion optimizer with lr=3e-4
4. **Advice 06/12** (Auto Augmentation): RandAugment with num_ops=2, magnitude=7
5. **Advice 08/12** (Full Resolution): Patch-based inference at native resolution
6. **Advice 09/12** (Masks as Focus Filters): Background removed during preprocessing
7. **Advice 10/12** (Parallel Paths): Dual-branch architecture with late fusion

### üéØ Dual-Branch Architecture:
- **RGB Branch**: Processes color, texture, cellular morphology (21M params, pretrained)
- **Mask Branch**: Processes shape, boundaries, spatial structure (21M params, from scratch)
- **Fusion**: Late concatenation of 1280 + 1280 = 2560 features
- **Total**: ~42M parameters (double single-branch)

### üåâ Why Dual-Branch?
- **Appearance vs Geometry**: RGB learns "what" (cell types), Mask learns "where" (tissue structure)
- **Complementary Features**: Color + shape provide richer representations
- **Robustness**: Less sensitive to staining variations (mask branch helps)
- **Medical Imaging**: Proven effective in histopathology competitions

### üìä Results:
- Check `training_curves_dual_branch.png` for loss/F1 progression
- Check `confusion_matrix_dual_branch.png` for per-class performance
- Best model saved to `best_model_dual_branch.pth`
- Test predictions saved to `submission_dual_branch_patches.csv`

### üîÑ Comparison with Single-Branch:
Run all notebooks and compare:
- **Single EfficientNetV2-S** (21M params)
- **Dual-Branch EfficientNetV2-S** (42M params)
- **DenseNet-121** (8M params)
- **Swin Transformer** (28M params)

Compare:
- Training stability and speed
- Validation F1 scores
- Inference confidence
- Overfitting behavior

### üí° Training Strategy:
- **Phase 1**: Freeze RGB branch, train mask branch (first 5 epochs)
- **Phase 2**: Unfreeze RGB branch, fine-tune both branches together
- **Fusion**: Parallel paths meet at latent space for classification

This architecture should give you the best performance - appearance + geometry = superior classification!