# Stage 1: Binary Classification Model (Healthy vs Wound)

This notebook trains the first stage of the cascade: **The Gatekeeper**.
Its goal is to filter out healthy/irrelevant images so the downstream models only see actual wounds.

**Strategy:**
- **Model**: EfficientNet-B0 (Pretrained on ImageNet)
- **Data**: All Wound Classes (Positives) vs Healthy (Negatives)
- **Loss**: BCEWithLogitsLoss
- **Validation**: 5-Fold Cross-Validation

In [None]:
import os
import sys
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import numpy as np
import timm
from pathlib import Path
from tqdm.auto import tqdm
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, classification_report
from torch.utils.data import DataLoader, Dataset
import albumentations as A
from albumentations.pytorch import ToTensorV2
from PIL import Image

# Add src to path for imports
sys.path.append("../src")
from dataset import WoundDataset

# Config
CONFIG = {
    "seed": 42,
    "img_size": 224,
    "batch_size": 32,
    "num_workers": 0, # Changed from 4 to 0 for Windows compatibility
    "epochs": 5,
    "lr": 1e-3,
    "model_name": "tf_efficientnet_b2",
    "model_dir": "../models/stage1_binary/",
    "data_csv": "../data/loaders/train_folds.csv",
    "root_dir": "../"
}

os.makedirs(CONFIG["model_dir"], exist_ok=True)

# Set Seed
def seed_everything(seed):
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

seed_everything(CONFIG['seed'])
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using Device: {device}")

## 1. Data Preparation
We use `Fold 0` for validation in this first run.

In [None]:
# Transforms
train_transforms = A.Compose([
    A.Resize(CONFIG['img_size'], CONFIG['img_size']),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.Rotate(limit=30, p=0.5),
    A.RandomBrightnessContrast(p=0.2),
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ToTensorV2(),
])

val_transforms = A.Compose([
    A.Resize(CONFIG['img_size'], CONFIG['img_size']),
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ToTensorV2(),
])

# Load DF
df = pd.read_csv(CONFIG['data_csv'])

# Split by Fold
FOLD = 0
train_df = df[df['fold'] != FOLD].reset_index(drop=True)
val_df = df[df['fold'] == FOLD].reset_index(drop=True)

# Create Datasets (Binary Mode = True)
train_dataset = WoundDataset(
    csv_file=CONFIG['data_csv'], 
    root_dir=CONFIG['root_dir'], 
    transform=train_transforms,
    binary_mode=True
)

print(f"Train Samples: {len(train_df)} | Val Samples: {len(val_df)}")

In [None]:
# Dataset Wrapper for DataFrame
class WoundDatasetDF(Dataset):
    def __init__(self, df, root_dir=None, transform=None, binary_mode=False):
        self.annotations = df
        self.root_dir = Path(root_dir) if root_dir else Path(".")
        self.transform = transform
        self.binary_mode = binary_mode
        
    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, index):
        row = self.annotations.iloc[index]
        rel_path = row['path']
        # Fix paths
        rel_path = str(rel_path).replace('\\', os.sep).replace('/', os.sep)
        if rel_path.startswith(".."):
            # handle relative paths if needed, but usually fine
            pass
        
        img_path = self.root_dir / rel_path
        
        try:
            image = Image.open(img_path).convert("RGB")
        except:
            # Fallback logic
            if rel_path.startswith(".."):
                 img_path = self.root_dir / rel_path[3:]
            try:
                image = Image.open(img_path).convert("RGB")
            except:
                # Create a black image as desperate fallback to not crash training
                print(f"Warning: Could not open {img_path}, using black image.")
                image = Image.new('RGB', (224, 224), color='black')
            
        label_str = row['label'] 
        
        if self.binary_mode:
            label = 0 if label_str.lower() == 'healthy' else 1
        else:
            pass
            
        if self.transform:
            image = np.array(image)
            augmented = self.transform(image=image)
            image = augmented['image']
            
        return image, torch.tensor(label, dtype=torch.float32) # Float for BCEWithLogits

# Init datasets
train_ds = WoundDatasetDF(train_df, root_dir=CONFIG['root_dir'], transform=train_transforms, binary_mode=True)
val_ds = WoundDatasetDF(val_df, root_dir=CONFIG['root_dir'], transform=val_transforms, binary_mode=True)

# No pin_memory for Windows stability
train_loader = DataLoader(train_ds, batch_size=CONFIG['batch_size'], shuffle=True, num_workers=CONFIG['num_workers'], pin_memory=False)
val_loader = DataLoader(val_ds, batch_size=CONFIG['batch_size'], shuffle=False, num_workers=CONFIG['num_workers'], pin_memory=False)

# Test Batch
img, lab = next(iter(train_loader))
print(f"Batch Shape: {img.shape}, Label Shape: {lab.shape}")

## 2. Model Factory & Training Loop

In [None]:
def get_model(model_name, num_classes=1, pretrained=True):
    model = timm.create_model(model_name, pretrained=pretrained, num_classes=num_classes)
    return model

def train_one_epoch(model, loader, criterion, optimizer, device, scaler=None):
    model.train()
    running_loss = 0.0
    preds_all = []
    targets_all = []
    
    for images, labels in tqdm(loader, desc="Training", leave=False):
        images, labels = images.to(device), labels.to(device).unsqueeze(1)
        
        optimizer.zero_grad()
        
        with torch.cuda.amp.autocast(enabled=(scaler is not None)):
            outputs = model(images)
            loss = criterion(outputs, labels)
            
        if scaler:
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            optimizer.step()
            
        running_loss += loss.item() * images.size(0)
        
        preds_all.extend(torch.sigmoid(outputs).detach().cpu().numpy())
        targets_all.extend(labels.detach().cpu().numpy())
        
    epoch_loss = running_loss / len(loader.dataset)
    # Binary metrics
    preds_binary = (np.array(preds_all) > 0.5).astype(int)
    acc = accuracy_score(targets_all, preds_binary)
    
    return epoch_loss, acc

def validate(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    preds_all = []
    targets_all = []
    
    with torch.no_grad():
        for images, labels in tqdm(loader, desc="Validating", leave=False):
            images, labels = images.to(device), labels.to(device).unsqueeze(1)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item() * images.size(0)
            preds_all.extend(torch.sigmoid(outputs).cpu().numpy())
            targets_all.extend(labels.cpu().numpy())
            
    epoch_loss = running_loss / len(loader.dataset)
    preds_binary = (np.array(preds_all) > 0.5).astype(int)
    acc = accuracy_score(targets_all, preds_binary)
    f1 = f1_score(targets_all, preds_binary)
    try:
        roc = roc_auc_score(targets_all, preds_all)
    except:
        roc = 0.5
    
    return epoch_loss, acc, f1, roc

In [None]:
# Initialize
model = get_model(CONFIG['model_name']).to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.AdamW(model.parameters(), lr=CONFIG['lr'])
scaler = torch.cuda.amp.GradScaler()

best_loss = float('inf')
save_path = f"{CONFIG['model_dir']}/best_model_fold_{FOLD}_b2.pth"

print(f"Starting Training for {CONFIG['epochs']} Epochs...")

for epoch in range(CONFIG['epochs']):
    train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, device, scaler)
    val_loss, val_acc, val_f1, val_roc = validate(model, val_loader, criterion, device)
    
    print(f"Epoch {epoch+1}/{CONFIG['epochs']}")
    print(f"Train Loss: {train_loss:.4f} | Acc: {train_acc:.4f}")
    print(f"Val   Loss: {val_loss:.4f}   | Acc: {val_acc:.4f} | F1: {val_f1:.4f} | ROC: {val_roc:.4f}")
    
    if val_loss < best_loss:
        print(f"ðŸ”¥ Loss Improved ({best_loss:.4f} -> {val_loss:.4f}). Saving Model...")
        best_loss = val_loss
        torch.save(model.state_dict(), save_path)
    print("-"*30)