In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import models, transforms
import numpy as np
import random
import requests
from collections import defaultdict
import json
from tqdm import tqdm
import torch.nn.functional as F

torch.manual_seed(42)
np.random.seed(42)
random.seed(42)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


In [None]:
class TaskDataset(Dataset):
    def __init__(self, data_path, transform=None):
        # load raw object; could be dict, list, or saved TaskDataset
        data = torch.load(data_path, weights_only=False)

        # If it's already a TaskDataset instance, reuse its attributes
        if isinstance(data, TaskDataset):
            self.ids = data.ids
            self.imgs = data.imgs
            self.labels = data.labels

        # If it's a dict with keys 'ids','imgs','labels', unpack it:
        elif isinstance(data, dict) and all(k in data for k in ('ids','imgs','labels')):
            self.ids = list(data['ids'])
            self.imgs = list(data['imgs'])
            self.labels = list(data['labels'])

        # If it's a list of triplets, unzip them:
        elif isinstance(data, list) and len(data) > 0 and isinstance(data[0], (list, tuple)) and len(data[0]) == 3:
            self.ids, self.imgs, self.labels = map(list, zip(*data))

        else:
            raise RuntimeError(f"Unrecognized Train.pt format: got {type(data)}")

        self.transform = transform

    def __getitem__(self, index):
        id_ = self.ids[index]
        img = self.imgs[index]
        if self.transform is not None:
            img = self.transform(img)
        label = self.labels[index]
        return id_, img, label

    def __len__(self):
        return len(self.ids)


def fgsm_attack(model, loss_fn, images, labels, epsilon):
    """FGSM attack implementation"""
    images = images.clone().detach().requires_grad_(True)
    
    outputs = model(images)
    loss = loss_fn(outputs, labels)
    
    # Zero gradients
    model.zero_grad()
    
    # Calculate gradients
    loss.backward()
    
    # Create adversarial examples
    grad_sign = images.grad.data.sign()
    adv_images = images + epsilon * grad_sign
    
    return torch.clamp(adv_images, 0, 1).detach()


def pgd_attack(model, loss_fn, images, labels, epsilon, alpha, iters):
    """PGD attack implementation"""
    orig_images = images.clone().detach()
    
    # Start with random noise
    delta = torch.zeros_like(images).uniform_(-epsilon, epsilon)
    delta = torch.clamp(delta, 0-images, 1-images)
    
    for _ in range(iters):
        delta.requires_grad = True
        
        # Forward pass
        outputs = model(orig_images + delta)
        loss = loss_fn(outputs, labels)
        
        # Backward pass
        model.zero_grad()
        loss.backward()
        
        # Update delta
        grad = delta.grad.detach()
        delta = delta + alpha * grad.sign()
        
        # Project back to epsilon ball
        delta = torch.clamp(delta, -epsilon, epsilon)
        delta = torch.clamp(delta, 0-orig_images, 1-orig_images)
        delta = delta.detach()
    
    return (orig_images + delta).detach()


def train_epoch(model, device, train_loader, optimizer, loss_fn, epoch, total_epochs, 
                epsilon, alpha, pgd_iters):
    """Improved training function with much more conservative adversarial training"""
    model.train()
    total_loss = 0
    correct = 0
    total_samples = 0
    
    # Much more conservative curriculum to maintain clean accuracy
    if epoch <= 20:
        # First 20 epochs: only clean data to establish strong baseline
        clean_ratio = 1.0
        current_epsilon = 0.0
        use_adv = False
    elif epoch <= 50:
        # Next 30 epochs: very gradually introduce weak adversarial examples
        progress = (epoch - 20) / 30.0
        clean_ratio = 0.9 - 0.1 * progress  # 0.9 -> 0.8 (always favor clean)
        current_epsilon = epsilon * (0.1 + 0.3 * progress)  # 0.1*eps -> 0.4*eps
        use_adv = True
    elif epoch <= 80:
        # Next 30 epochs: gradually increase adversarial strength
        progress = (epoch - 50) / 30.0
        clean_ratio = 0.8 - 0.05 * progress  # 0.8 -> 0.75 (still favor clean)
        current_epsilon = epsilon * (0.4 + 0.4 * progress)  # 0.4*eps -> 0.8*eps
        use_adv = True
    else:
        # Final epochs: full strength but still heavily favor clean
        clean_ratio = 0.75  # Always keep 75% clean examples
        current_epsilon = epsilon
        use_adv = True
    
    train_bar = tqdm(train_loader, desc=f"Epoch {epoch}/{total_epochs}")
    
    for batch_idx, (_, imgs, labels) in enumerate(train_bar):
        imgs, labels = imgs.to(device), labels.to(device)
        batch_size = imgs.size(0)
        
        optimizer.zero_grad()
        
        if not use_adv:
            # Clean training only
            outputs = model(imgs)
            loss = loss_fn(outputs, labels)
        else:
            # Mixed training with heavy emphasis on clean examples
            clean_size = int(batch_size * clean_ratio)
            
            # Ensure we always have clean examples
            if clean_size == 0:
                clean_size = max(1, batch_size // 2)
            
            # Clean examples
            clean_imgs = imgs[:clean_size]
            clean_labels = labels[:clean_size]
            
            # Adversarial examples (only if we have remaining samples)
            adv_imgs = imgs[clean_size:]
            adv_labels = labels[clean_size:]
            
            all_imgs = [clean_imgs]
            all_labels = [clean_labels]
            
            # Add adversarial examples only if we have any
            if len(adv_imgs) > 0:
                # Use only FGSM for first part of adversarial training
                if epoch <= 60:
                    # Only FGSM to start
                    fgsm_imgs = fgsm_attack(model, loss_fn, adv_imgs, adv_labels, current_epsilon)
                    all_imgs.append(fgsm_imgs)
                    all_labels.append(adv_labels)
                else:
                    # Mix FGSM and PGD later
                    split_point = len(adv_imgs) // 2
                    
                    # FGSM examples
                    if split_point > 0:
                        fgsm_imgs = fgsm_attack(model, loss_fn, adv_imgs[:split_point], 
                                              adv_labels[:split_point], current_epsilon)
                        all_imgs.append(fgsm_imgs)
                        all_labels.append(adv_labels[:split_point])
                    
                    # PGD examples (fewer iterations to be less aggressive)
                    if len(adv_imgs) - split_point > 0:
                        pgd_imgs = pgd_attack(model, loss_fn, adv_imgs[split_point:], 
                                            adv_labels[split_point:], current_epsilon, 
                                            alpha, max(3, pgd_iters // 3))  # Reduced iterations
                        all_imgs.append(pgd_imgs)
                        all_labels.append(adv_labels[split_point:])
            
            # Combine all examples
            combined_imgs = torch.cat(all_imgs, dim=0)
            combined_labels = torch.cat(all_labels, dim=0)
            
            # Forward pass
            outputs = model(combined_imgs)
            loss = loss_fn(outputs, combined_labels)
        
        # Backward pass
        loss.backward()
        
        # More conservative gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)
        
        optimizer.step()
        
        # Statistics
        total_loss += loss.item()
        with torch.no_grad():
            if use_adv:
                _, predicted = outputs.max(1)
                correct += predicted.eq(combined_labels).sum().item()
                total_samples += combined_labels.size(0)
            else:
                _, predicted = outputs.max(1)
                correct += predicted.eq(labels).sum().item()
                total_samples += labels.size(0)
        
        # Update progress bar
        train_bar.set_postfix({
            'loss': f'{loss.item():.4f}',
            'acc': f'{correct/total_samples:.4f}',
            'eps': f'{current_epsilon:.4f}',
            'clean_ratio': f'{clean_ratio:.2f}'
        })
    
    avg_loss = total_loss / len(train_loader)
    accuracy = correct / total_samples
    return avg_loss, accuracy


def evaluate_model(model, device, data_loader, loss_fn, epsilon, alpha, pgd_iters):
    """Evaluate model on clean, FGSM, and PGD examples"""
    model.eval()
    
    clean_correct = 0
    fgsm_correct = 0
    pgd_correct = 0
    total = 0
    
    eval_bar = tqdm(data_loader, desc="Evaluating")
    
    with torch.no_grad():
        for _, imgs, labels in eval_bar:
            imgs, labels = imgs.to(device), labels.to(device)
            batch_size = imgs.size(0)
            
            # Clean accuracy
            clean_outputs = model(imgs)
            clean_preds = clean_outputs.argmax(dim=1)
            clean_correct += clean_preds.eq(labels).sum().item()
            
            total += batch_size
            
            eval_bar.set_postfix({
                'clean': f'{clean_correct/total:.4f}',
                'total': total
            })
    
    # Reset for adversarial evaluation
    model.eval()
    fgsm_correct = 0
    pgd_correct = 0
    total = 0
    
    eval_bar = tqdm(data_loader, desc="Evaluating Adversarial")
    
    for _, imgs, labels in eval_bar:
        imgs, labels = imgs.to(device), labels.to(device)
        batch_size = imgs.size(0)
        
        # FGSM attack
        fgsm_imgs = fgsm_attack(model, loss_fn, imgs, labels, epsilon)
        with torch.no_grad():
            fgsm_outputs = model(fgsm_imgs)
            fgsm_preds = fgsm_outputs.argmax(dim=1)
            fgsm_correct += fgsm_preds.eq(labels).sum().item()
        
        # PGD attack
        pgd_imgs = pgd_attack(model, loss_fn, imgs, labels, epsilon, alpha, pgd_iters)
        with torch.no_grad():
            pgd_outputs = model(pgd_imgs)
            pgd_preds = pgd_outputs.argmax(dim=1)
            pgd_correct += pgd_preds.eq(labels).sum().item()
        
        total += batch_size
        
        eval_bar.set_postfix({
            'fgsm': f'{fgsm_correct/total:.4f}',
            'pgd': f'{pgd_correct/total:.4f}',
            'total': total
        })
    
    clean_acc = clean_correct / len(data_loader.dataset)
    fgsm_acc = fgsm_correct / total
    pgd_acc = pgd_correct / total
    
    return clean_acc, fgsm_acc, pgd_acc


def submit_model(token, model_name, model_path):
    """Submit model to evaluation server"""
    try:
        response = requests.post(
            "http://34.122.51.94:9090/robustness",
            files={"file": open(model_path, "rb")},
            headers={"token": token, "model-name": model_name}
        )
        print("Submission response:", response.json())
        return response.json()
    except Exception as e:
        print(f"Submission failed: {e}")
        return None

In [None]:
if __name__ == '__main__':
    # Configuration
    data_path = '/kaggle/input/tml-t3/Train.pt'
    model_name = 'resnet34'
    batch_size = 128
    epochs = 100
    lr = 0.01  # Lower learning rate for more stable training
    epsilon = 8/255
    alpha = 2/255
    pgd_iters = 10
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    print(f"Using device: {device}")
    print(f"Model: {model_name}")
    print(f"Epsilon: {epsilon:.4f}, Alpha: {alpha:.4f}, PGD iters: {pgd_iters}")
    
    # More conservative data augmentation
    train_transform = transforms.Compose([
        transforms.Lambda(lambda img: img.convert('RGB') if hasattr(img, 'convert') else img),
        transforms.RandomHorizontalFlip(p=0.3),  # Reduced probability
        transforms.RandomCrop(32, padding=2),    # Reduced padding
        transforms.ToTensor(),
    ])
    
    val_transform = transforms.Compose([
        transforms.Lambda(lambda img: img.convert('RGB') if hasattr(img, 'convert') else img),
        transforms.ToTensor(),
    ])

    def collate_fn(batch):
        ids, imgs, labels = zip(*batch)
        imgs = torch.stack(imgs, dim=0)
        labels = torch.tensor(labels)
        return list(ids), imgs, labels

    # Create train/validation split
    full_dataset = TaskDataset(data_path, transform=train_transform)
    train_size = int(0.9 * len(full_dataset))
    val_size = len(full_dataset) - train_size
    
    train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])
    
    # Create separate dataset for validation with different transform
    val_dataset_clean = TaskDataset(data_path, transform=val_transform)
    val_indices = val_dataset.indices
    val_subset = torch.utils.data.Subset(val_dataset_clean, val_indices)
    
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=4,
        collate_fn=collate_fn,
        pin_memory=True
    )
    
    val_loader = DataLoader(
        val_subset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=4,
        collate_fn=collate_fn,
        pin_memory=True
    )
    
    # Model setup
    model = getattr(models, model_name)(weights='DEFAULT')
    
    # Replace final layer with better initialization
    if hasattr(model, 'fc'):
        model.fc = nn.Linear(model.fc.in_features, 10)
        # Xavier initialization for better stability
        nn.init.xavier_uniform_(model.fc.weight)
        nn.init.constant_(model.fc.bias, 0)
    
    model = model.to(device)

    # More conservative optimizer settings
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=1e-4)
    
    # More gradual learning rate schedule
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[30, 60, 85], gamma=0.2)
    
    loss_fn = nn.CrossEntropyLoss()

    # Training setup
    best_score = 0.0
    os.makedirs('out/models', exist_ok=True)
    
    history = {
        'train_loss': [],
        'train_acc': [],
        'val_clean_acc': [],
        'val_fgsm_acc': [],
        'val_pgd_acc': [],
        'combined_score': []
    }

    print("Starting conservative adversarial training...")
    
    for epoch in range(1, epochs + 1):
        print(f"\n=== Epoch {epoch}/{epochs} ===")
        
        # Training
        train_loss, train_acc = train_epoch(
            model, device, train_loader, optimizer, loss_fn, epoch, epochs,
            epsilon, alpha, pgd_iters
        )
        
        scheduler.step()
        
        # Log training progress
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        
        print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
        print(f"Learning Rate: {scheduler.get_last_lr()[0]:.6f}")
        
        # Evaluate on validation set
        if epoch % 5 == 0 or epoch == epochs or epoch <= 20:
            print("Evaluating on validation set...")
            clean_acc, fgsm_acc, pgd_acc = evaluate_model(
                model, device, val_loader, loss_fn, epsilon, alpha, pgd_iters
            )
            
            # Heavily weighted towards clean accuracy
            combined_score = 0.7 * clean_acc + 0.2 * fgsm_acc + 0.1 * pgd_acc
            
            history['val_clean_acc'].append(clean_acc)
            history['val_fgsm_acc'].append(fgsm_acc)
            history['val_pgd_acc'].append(pgd_acc)
            history['combined_score'].append(combined_score)
            
            print(f"Val Clean Acc: {clean_acc:.4f}")
            print(f"Val FGSM Acc:  {fgsm_acc:.4f}")
            print(f"Val PGD Acc:   {pgd_acc:.4f}")
            print(f"Combined Score: {combined_score:.4f}")
            
            # Save best model with strong preference for clean accuracy
            if combined_score > best_score and clean_acc > 0.60:  # Ensure clean acc is high
                best_score = combined_score
                torch.save(model.state_dict(), f"out/models/{model_name}_best.pt")
                print(f"🎯 New best model saved! Score: {combined_score:.4f}")

    # Final evaluation
    print("\n=== Final Evaluation ===")
    clean_acc, fgsm_acc, pgd_acc = evaluate_model(
        model, device, val_loader, loss_fn, epsilon, alpha, pgd_iters
    )
    final_combined_score = 0.7 * clean_acc + 0.2 * fgsm_acc + 0.1 * pgd_acc
    
    print(f"Final Clean Accuracy: {clean_acc:.4f}")
    print(f"Final FGSM Accuracy:  {fgsm_acc:.4f}")
    print(f"Final PGD Accuracy:   {pgd_acc:.4f}")
    print(f"Final Combined Score: {final_combined_score:.4f}")
    print(f"Best Combined Score:  {best_score:.4f}")
    
    # Save final model and history
    torch.save(model.state_dict(), f"out/models/{model_name}_final.pt")
    with open(f"out/models/training_history.json", "w") as f:
        json.dump(history, f, indent=2)
    
    print("\nTraining complete!")