# **Install & Import Libraries**

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
from torch.serialization import add_safe_globals

# **Dataset Preparation**

In [None]:
# Required to load .pt file that was saved using TaskDataset
class TaskDataset(Dataset):
    def __init__(self, transform=None):
        self.ids = []
        self.imgs = []
        self.labels = []
        self.transform = transform

    def __getitem__(self, idx):
        img = self.imgs[idx]
        label = self.labels[idx]

        if img.mode != "RGB":
            img = img.convert("RGB")
        if self.transform:
            img = self.transform(img)
        return img, label

    def __len__(self):
        return len(self.imgs)

# Register class for deserialization
add_safe_globals({'TaskDataset': TaskDataset})

# Load dataset from .pt file
dataset = torch.load("Train.pt", weights_only=False)

# Add transform (resize and normalize)
transform_fn = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor()
])
dataset.transform = transform_fn

# Data loader
data_loader = DataLoader(dataset, batch_size=32, shuffle=True)

# **FGSM and PGD Attack**

In [None]:
# FGSM attack function
def fgsm(model, x, y, eps):
    x.requires_grad = True
    pred = model(x)
    loss = nn.CrossEntropyLoss()(pred, y)
    model.zero_grad()
    loss.backward()
    adv_x = torch.clamp(x + eps * x.grad.sign(), 0, 1)
    return adv_x

# PGD attack function
def pgd(model, x, y, eps=0.03, alpha=0.01, steps=3):
    x_orig = x.detach().clone()
    x_adv = x_orig + 0.001 * torch.randn_like(x_orig)

    for _ in range(steps):
        x_adv.requires_grad = True
        output = model(x_adv)
        loss = nn.CrossEntropyLoss()(output, y)
        model.zero_grad()
        loss.backward()
        step = alpha * x_adv.grad.sign()
        delta = torch.clamp(x_adv + step - x_orig, -eps, eps)
        x_adv = torch.clamp(x_orig + delta, 0, 1).detach()

    return x_adv

# **Model Training**

In [None]:
# Create ResNet34 model
model = models.resnet34(weights=None)
model.fc = nn.Linear(model.fc.in_features, 10)

# Device and optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
loss_fn = nn.CrossEntropyLoss()

# Training settings
epochs = 20
epsilon = 0.03
alpha = 0.01
pgd_steps = 3

# Training loop
for epoch in range(epochs):
    model.train()
    total_loss = 0

    for x_batch, y_batch in data_loader:
        x_batch, y_batch = x_batch.to(device), y_batch.to(device)
        optimizer.zero_grad()

        # Clean loss
        out_clean = model(x_batch)
        loss_clean = loss_fn(out_clean, y_batch)

        # FGSM loss
        x_fgsm = fgsm(model, x_batch, y_batch, epsilon)
        out_fgsm = model(x_fgsm)
        loss_fgsm = loss_fn(out_fgsm, y_batch)

        # PGD loss
        x_pgd = pgd(model, x_batch, y_batch, epsilon, alpha, pgd_steps)
        out_pgd = model(x_pgd)
        loss_pgd = loss_fn(out_pgd, y_batch)

        # Combine and backprop
        loss_total = (loss_clean + loss_fgsm + loss_pgd) / 3
        loss_total.backward()
        optimizer.step()

        total_loss += loss_total.item()

    avg_loss = total_loss / len(data_loader)
    print(f"Epoch {epoch+1:02d}/{epochs} | Avg Loss: {avg_loss:.4f}")

# Save model
torch.save(model.state_dict(), "robust_model.pt")
print("Model saved to 'robust_model.pt'")

# **Evaluation Script to Load and Validate Model**

In [None]:
##### Evaluation Script #####
allowed_models = {"resnet18": models.resnet18, "resnet34": models.resnet34,"resnet50": models.resnet50,}

with open("robust_model.pt", "rb") as f:
    try:
        eval_model = allowed_models["resnet34"](weights=None)
        eval_model.fc = nn.Linear(eval_model.fc.in_features, 10)
    except Exception as err:
        raise Exception(
            f"Model architecture not permitted. {err=}, allowed: {allowed_models.keys()}"
        )
    try:
        weights = torch.load(f, map_location=torch.device("cpu"))
        eval_model.load_state_dict(weights, strict=True)
        eval_model.eval()
        eval_model(torch.randn(1, 3, 32, 32))
    except Exception as err:
        raise Exception(f"Model failed to load or execute. {err=}")

# **Submission**

In [None]:
# --- Submission ---
import requests
response = requests.post(
    "http://34.122.51.94:9090/robustness",
    files={"file": open("robust_model.pt", "rb")},
    headers={"token": "93145372", "model-name": "resnet34"}
)
print("Submission response:", response.json())