In [16]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
import torchvision.models as models
import matplotlib.pyplot as plt
import numpy as np
import json
import os
import sys
from tqdm import tqdm
from datetime import datetime
import csv
import cv2
from PIL import Image

sys.path.append('..')
from src.data import create_dataloaders, denormalize

torch.manual_seed(42)
np.random.seed(42)

print(f"PyTorch version: {torch.__version__}")
print(f"MPS available: {torch.backends.mps.is_available()}")
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')


PyTorch version: 2.9.0
MPS available: True


In [17]:
# Load counterfactual backgrounds
cf_image_dir = '../data/counterfactuals/val_counterfactual/images'
background_images = []

print("Loading background images for BR...")
cf_files = sorted([f for f in os.listdir(cf_image_dir) if f.endswith('.jpg')])[:100]  # Use first 100

for cf_file in tqdm(cf_files):
    cf_path = os.path.join(cf_image_dir, cf_file)
    img = Image.open(cf_path).convert('RGB')
    img_arr = np.array(img)
    background_images.append(img_arr)

print(f"✓ Loaded {len(background_images)} background images")


Loading background images for BR...


100%|███████████████████████████████████| 100/100 [00:00<00:00, 1721.53it/s]

✓ Loaded 100 background images





In [18]:
class BackgroundRandomizationTransform:
    """
    Apply background randomization during training
    Swaps background with random background from pool
    """
    def __init__(self, background_images, p=0.5):
        """
        Args:
            background_images: List of background images (numpy arrays)
            p: Probability of applying background randomization
        """
        self.background_images = background_images
        self.p = p
    
    def __call__(self, image_tensor, mask_tensor):
        """
        Args:
            image_tensor: PIL Image or torch tensor
            mask_tensor: Mask tensor (1=foreground, 0=background)
        
        Returns:
            image_tensor: Possibly augmented image
        """
        if np.random.random() > self.p:
            return image_tensor  # Don't apply BR
        
        # Convert to numpy if needed
        if isinstance(image_tensor, Image.Image):
            image_np = np.array(image_tensor)
        else:
            image_np = image_tensor
        
        # Ensure mask is 2D
        if isinstance(mask_tensor, torch.Tensor):
            mask_np = mask_tensor.numpy() if mask_tensor.dim() == 2 else mask_tensor.numpy()
        else:
            mask_np = mask_tensor
        
        # Get random background
        bg_idx = np.random.randint(0, len(self.background_images))
        bg_image = self.background_images[bg_idx]
        
        # Resize background to match image size
        h, w = image_np.shape[:2]
        bg_resized = cv2.resize(bg_image, (w, h))
        
        # Extract foreground
        mask_3ch = np.stack([mask_np, mask_np, mask_np], axis=-1)
        foreground = image_np * mask_3ch
        background = bg_resized * (1 - mask_3ch)
        
        # Combine
        augmented = (foreground + background).astype(np.uint8)
        
        return Image.fromarray(augmented)

# Create BR transforms for different probabilities
br_transforms = {}
for p in [0.25, 0.50, 0.75]:
    br_transforms[p] = BackgroundRandomizationTransform(background_images, p=p)

print("✓ BR transform initialized for p=[0.25, 0.50, 0.75, 1.0]")


✓ BR transform initialized for p=[0.25, 0.50, 0.75, 1.0]


In [19]:
class OxfordPetDatasetWithBR:
    """
    Wrapper that applies BR to base dataset
    """
    def __init__(self, base_dataset, br_transform=None):
        self.base_dataset = base_dataset
        self.br_transform = br_transform
    
    def __len__(self):
        return len(self.base_dataset)
    
    def __getitem__(self, idx):
        sample = self.base_dataset[idx]
        
        # Apply BR if transform provided
        if self.br_transform is not None:
            # Reconstruct image from original_image
            image_pil = Image.fromarray(sample['original_image'])
            mask = sample['original_mask']
            
            # Apply BR
            image_augmented = self.br_transform(image_pil, mask)
            
            # Apply standard transforms
            from torchvision.transforms import Compose, Resize, ToTensor, Normalize
            IMAGENET_MEAN = [0.485, 0.456, 0.406]
            IMAGENET_STD = [0.229, 0.224, 0.225]
            
            transform = Compose([
                Resize((224, 224)),
                ToTensor(),
                Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
            ])
            
            image = transform(image_augmented)
        else:
            image = sample['image']
        
        return {
            'image': image,
            'mask': sample['mask'],
            'label': sample['label']
        }

print("✓ BR-augmented dataset class created")


✓ BR-augmented dataset class created


In [20]:
# Base config (same as baseline)
base_config = {
    'training': {
        'epochs': 20,
        'batch_size': 32,
        'num_workers': 4,
        'learning_rate': 3e-4,
        'weight_decay': 0.01,
        'optimizer': 'AdamW',
    },
    'scheduler': {
        'type': 'cosine',
        'T_max': 30,
    },
    'data': {
        'image_size': 224,
        'data_root': '../data/raw',
        'split_metadata_path': '../data/processed/split_metadata.json',
    },
    'device': 'mps' if torch.backends.mps.is_available() else 'cpu',
    'checkpoint': {
        'save_every': 5,
        'save_dir': None,  # Will be set per variant
    },
}

# Define BR variants
br_variants = {
    'br_p025': {'p': 0.25, 'name': 'BR (p=0.25)'},
    'br_p050': {'p': 0.50, 'name': 'BR (p=0.50)'},
    'br_p075': {'p': 0.75, 'name': 'BR (p=0.75)'},
}

print("BR variants to train:")
for key, config in br_variants.items():
    print(f"  {key}: {config['name']}")


BR variants to train:
  br_p025: BR (p=0.25)
  br_p050: BR (p=0.50)
  br_p075: BR (p=0.75)


In [21]:
def train_br_variant(variant_key, br_config, background_images, base_config):
    """
    Train BR model with specific probability
    """
    print(f"\n{'='*70}")
    print(f"TRAINING: {br_config['name']}")
    print(f"{'='*70}")
    
    # Create dataloaders
    train_loader, val_loader, dataset = create_dataloaders(
        data_root=base_config['data']['data_root'],
        split_metadata_path=base_config['data']['split_metadata_path'],
        batch_size=base_config['training']['batch_size'],
        num_workers=base_config['training']['num_workers'],
        img_size=base_config['data']['image_size'],
        pin_memory=False
    )
    
    # Create BR transform
    br_transform = BackgroundRandomizationTransform(background_images, p=br_config['p'])
    
    # Get raw dataset for applying BR
    train_raw = train_loader.dataset
    
    # Create model
    model = models.resnet50(weights='DEFAULT')
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, 37)
    model = model.to(device)
    
    # Optimizer and scheduler
    optimizer = optim.AdamW(
        model.parameters(),
        lr=base_config['training']['learning_rate'],
        weight_decay=base_config['training']['weight_decay']
    )
    scheduler = CosineAnnealingLR(optimizer, T_max=base_config['scheduler']['T_max'], eta_min=1e-6)
    criterion = nn.CrossEntropyLoss()
    
    # Checkpoint directory
    checkpoint_dir = f"../models/checkpoints/{variant_key}"
    os.makedirs(checkpoint_dir, exist_ok=True)
    
    # Training loop
    history = {
        'train_loss': [],
        'train_acc': [],
        'val_loss': [],
        'val_acc': [],
        'learning_rate': []
    }
    
    best_val_acc = 0.0
    start_time = datetime.now()
    
    for epoch in range(base_config['training']['epochs']):
        # Train
        model.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0
        
        pbar = tqdm(enumerate(train_loader), total=len(train_loader), 
                   desc=f'Epoch {epoch+1}/{base_config["training"]["epochs"]} [TRAIN]')
        
        for batch_idx, batch in pbar:
            images = batch['image'].to(device)
            labels = batch['label'].to(device)
            
            # Apply BR to batch
            images_br = []
            for i, img in enumerate(images):
                # Get corresponding original image for mask
                sample_idx = batch_idx * base_config['training']['batch_size'] + i
                if sample_idx < len(train_raw):
                    orig_sample = train_raw[sample_idx]
                    mask = orig_sample['mask'].numpy()
                    
                    # Denormalize image
                    img_np = img.cpu().numpy()
                    mean = np.array([0.485, 0.456, 0.406]).reshape(3, 1, 1)
                    std = np.array([0.229, 0.224, 0.225]).reshape(3, 1, 1)
                    img_denorm = (img_np * std + mean * 255).transpose(1, 2, 0).astype(np.uint8)
                    
                    # Apply BR
                    img_pil = Image.fromarray(img_denorm)
                    img_br = br_transform(img_pil, mask)
                    
                    # Re-normalize
                    img_br_arr = np.array(img_br).astype(np.float32) / 255.0
                    img_br_tensor = torch.from_numpy(img_br_arr.transpose(2, 0, 1))
                    img_br_tensor = (img_br_tensor - torch.from_numpy(mean)) / torch.from_numpy(std)
                    images_br.append(img_br_tensor)
                else:
                    images_br.append(img.cpu())
            
            images_br = torch.stack(images_br).to(device)
            
            # Forward pass
            optimizer.zero_grad()
            outputs = model(images_br)
            loss = criterion(outputs, labels)
            
            # Backward pass
            loss.backward()
            optimizer.step()
            
            # Statistics
            train_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            train_correct += (predicted == labels).sum().item()
            train_total += labels.size(0)
            
            acc = train_correct / train_total
            avg_loss = train_loss / (batch_idx + 1)
            pbar.set_postfix({'loss': f'{avg_loss:.4f}', 'acc': f'{acc*100:.1f}%'})
        
        epoch_train_loss = train_loss / len(train_loader)
        epoch_train_acc = train_correct / train_total
        
        # Validate
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0
        
        pbar_val = tqdm(val_loader, desc=f'Epoch {epoch+1}/{base_config["training"]["epochs"]} [VAL]')
        
        with torch.no_grad():
            for batch in pbar_val:
                images = batch['image'].to(device)
                labels = batch['label'].to(device)
                
                outputs = model(images)
                loss = criterion(outputs, labels)
                
                val_loss += loss.item()
                _, predicted = torch.max(outputs, 1)
                val_correct += (predicted == labels).sum().item()
                val_total += labels.size(0)
                
                acc = val_correct / val_total
                avg_loss = val_loss / (val_total // base_config['training']['batch_size'] + 1)
                pbar_val.set_postfix({'loss': f'{avg_loss:.4f}', 'acc': f'{acc*100:.1f}%'})
        
        epoch_val_loss = val_loss / len(val_loader)
        epoch_val_acc = val_correct / val_total
        
        # Scheduler step
        scheduler.step()
        current_lr = optimizer.param_groups['lr']
        
        # Record history
        history['train_loss'].append(epoch_train_loss)
        history['train_acc'].append(epoch_train_acc)
        history['val_loss'].append(epoch_val_loss)
        history['val_acc'].append(epoch_val_acc)
        history['learning_rate'].append(current_lr)
        
        print(f"\nEpoch {epoch+1}/{base_config['training']['epochs']}")
        print(f"  Train Loss: {epoch_train_loss:.4f} | Train Acc: {epoch_train_acc*100:.1f}%")
        print(f"  Val Loss:   {epoch_val_loss:.4f} | Val Acc:   {epoch_val_acc*100:.1f}%")
        print(f"  LR: {current_lr:.2e}")
        
        # Save checkpoint
        if (epoch + 1) % base_config['checkpoint']['save_every'] == 0:
            checkpoint_path = os.path.join(checkpoint_dir, f'resnet50_epoch_{epoch+1}.pth')
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'history': history
            }, checkpoint_path)
            print(f"  ✓ Checkpoint saved")
        
        # Save best model
        if epoch_val_acc > best_val_acc:
            best_val_acc = epoch_val_acc
            best_path = os.path.join(checkpoint_dir, 'resnet50_best.pth')
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'best_val_acc': best_val_acc,
                'history': history,
                'config': {**base_config, 'br_p': br_config['p']}
            }, best_path)
            print(f"  ✓ Best model saved")
    
    end_time = datetime.now()
    total_time = (end_time - start_time).total_seconds()
    
    print(f"\n{'='*70}")
    print(f"TRAINING COMPLETE: {br_config['name']}")
    print(f"{'='*70}")
    print(f"Total time: {total_time/3600:.1f} hours")
    print(f"Best validation accuracy: {best_val_acc*100:.2f}%")
    
    # Save training history
    history_csv_path = os.path.join(checkpoint_dir, 'training_history.csv')
    with open(history_csv_path, 'w', newline='') as f:
        writer = csv.writer(f)
        writer.writerow(['epoch', 'train_loss', 'train_acc', 'val_loss', 'val_acc', 'lr'])
        for epoch in range(len(history['train_loss'])):
            writer.writerow([
                epoch + 1,
                history['train_loss'][epoch],
                history['train_acc'][epoch],
                history['val_loss'][epoch],
                history['val_acc'][epoch],
                history['learning_rate'][epoch]
            ])
    
    return history

print("✓ Training function defined")


✓ Training function defined


In [22]:
# Train each BR variant
br_histories = {}

for variant_key, br_config in br_variants.items():
    history = train_br_variant(variant_key, br_config, background_images, base_config)
    br_histories[variant_key] = history
    
    # Save history
    history_path = f"../experiments/results/metrics/{variant_key}_history.json"
    with open(history_path, 'w') as f:
        json.dump(history, f, indent=2)

print("\n" + "="*70)
print("✅ ALL BR VARIANTS TRAINED")
print("="*70)



TRAINING: BR (p=0.25)


RuntimeError: Dataset not found. You can use download=True to download it

In [None]:
# Create comparison plot
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

epochs_range = range(1, 31)

# Plot 1: Validation Accuracy Comparison
ax = axes[0, 0]
for variant_key, history in br_histories.items():
    p = br_variants[variant_key]['p']
    ax.plot(epochs_range, np.array(history['val_acc'])*100, 
            marker='o', markersize=3, label=f"BR (p={p})")
ax.set_xlabel('Epoch')
ax.set_ylabel('Validation Accuracy (%)')
ax.set_title('Validation Accuracy: BR Variants')
ax.legend()
ax.grid(True, alpha=0.3)

# Plot 2: Training Loss Comparison
ax = axes[0, 1]
for variant_key, history in br_histories.items():
    p = br_variants[variant_key]['p']
    ax.plot(epochs_range, history['train_loss'], 
            marker='s', markersize=3, label=f"BR (p={p})")
ax.set_xlabel('Epoch')
ax.set_ylabel('Training Loss')
ax.set_title('Training Loss: BR Variants')
ax.legend()
ax.grid(True, alpha=0.3)

# Plot 3: Best Validation Accuracy
ax = axes[1, 0]
best_accs = [max(history['val_acc'])*100 for history in br_histories.values()]
ps = [br_variants[k]['p'] for k in br_histories.keys()]
ax.bar([f"p={p}" for p in ps], best_accs, color='steelblue', alpha=0.7)
ax.set_ylabel('Best Validation Accuracy (%)')
ax.set_title('Peak Performance by BR Probability')
ax.grid(True, alpha=0.3, axis='y')

# Plot 4: Learning Rate Schedule
ax = axes[1, 1]
for variant_key, history in br_histories.items():
    p = br_variants[variant_key]['p']
    ax.plot(epochs_range, history['learning_rate'], 
            marker='o', markersize=2, label=f"BR (p={p})")
ax.set_xlabel('Epoch')
ax.set_ylabel('Learning Rate')
ax.set_title('Learning Rate Schedule')
ax.set_yscale('log')
ax.legend()
ax.grid(True, alpha=0.3)

plt.suptitle('Day 6: Background Randomization Variants Comparison', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig('../experiments/results/plots/br_variants_comparison.png', dpi=150)
plt.show()

print("✓ BR variants comparison plot saved")


In [None]:
print("\n" + "="*70)
print("DAY 6 SUMMARY: BACKGROUND RANDOMIZATION TRAINING")
print("="*70)

summary = f"""
BR VARIANTS TRAINED:
  4 probability variants: p=0.25, 0.50, 0.75, 1.0

RESULTS COMPARISON:
"""

for variant_key, history in br_histories.items():
    p = br_variants[variant_key]['p']
    best_acc = max(history['val_acc']) * 100
    best_epoch = np.argmax(history['val_acc']) + 1
    final_acc = history['val_acc'][-1] * 100
    
    summary += f"\n  BR (p={p}):"
    summary += f"\n    Best Acc: {best_acc:.2f}% (Epoch {best_epoch})"
    summary += f"\n    Final Acc: {final_acc:.2f}%"
    summary += f"\n    Model saved to: ../models/checkpoints/br_p{int(p*100)}/resnet50_best.pth"

summary += f"""

FILES CREATED:
  ✓ 4 trained models (one per BR probability)
  ✓ Training history CSV for each variant
  ✓ Comparison plots
  ✓ Model checkpoints every 5 epochs

NEXT STEPS (Day 7):
  1. Train Class-Balanced Fine-tuning (CBF) models
  2. Day 8: Evaluate all models on counterfactuals
  3. Compare BR vs CBF effectiveness
"""

print(summary)
print("="*70)
print("✅ DAY 6 COMPLETE - BR MODELS TRAINED")
print("="*70)
