In [1]:
import torch
from torch.utils.data import DataLoader, random_split, Dataset
from torchvision import transforms, models
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import copy

In [2]:
import gdown

downloaded = True
Train_id = '1-iTQMYNtj4GlJW5k1q73OvDdsoemXSdr'

if not downloaded:
    gdown.download(f'https://drive.google.com/uc?id={Train_id}', 'Train.pt', quiet=False)

In [3]:
# Define the dataset class
class TaskDataset(Dataset):
    def __init__(self, imgs, labels, transform=None):
        self.imgs = imgs
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, idx):
        img = self.imgs[idx]
        label = self.labels[idx]
        if self.transform:
            img = self.transform(img)
        return img, label


In [4]:

transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),  
    transforms.ToTensor()])                      

In [5]:
data = torch.load('Train.pt')

full_dataset = TaskDataset(data.imgs, data.labels, transform=transform)

train_size = int(0.9 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

In [13]:
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)

model = models.resnet34(pretrained=True)
num_classes = 10  
model.fc = nn.Linear(model.fc.in_features, num_classes)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.05, momentum = 0.9)

In [14]:
def pgd(model, X, y, epsilon=0.01, alpha=0.001, num_iter=10, randomize=False):
    delta = torch.zeros_like(X, requires_grad=True)
    if randomize:
        delta.data.uniform_(-epsilon, epsilon)
    for t in range(num_iter):
        loss = criterion(model(X + delta), y)
        loss.backward()
        delta.data = (delta + alpha * delta.grad.detach().sign()).clamp(-epsilon, epsilon)
        delta.grad.zero_()
    return delta.detach()

In [15]:
def fgsm(model, X, y, epsilon=0.01):
    delta = torch.zeros_like(X, requires_grad=True)
    loss = nn.CrossEntropyLoss()(model(X + delta), y)
    loss.backward()
    return epsilon * delta.grad.detach().sign()

In [16]:
def evaluate_model(model, data_loader, device, attack=None, epsilon=0.01):
    model.eval()
    correct_predictions = 0
    total_predictions = 0
    total_loss = 0.0

    with torch.no_grad():
        with tqdm(total=len(data_loader), desc='Evaluating', unit='batch') as pbar:
            for imgs, labels in data_loader:
                imgs, labels = imgs.to(device), labels.to(device)
                
                if attack is not None:
                    with torch.enable_grad():  # Ensure gradients are enabled for attack
                        delta = attack(model, imgs, labels, epsilon)    
                    imgs = imgs + delta

                outputs = model(imgs)
                loss = criterion(outputs, labels)
                total_loss += loss.item()

                _, predicted = torch.max(outputs, 1)
                correct_predictions += (predicted == labels).sum().item()
                total_predictions += labels.size(0)

                pbar.update(1)

    avg_loss = total_loss / len(data_loader)
    accuracy = correct_predictions / total_predictions

    return avg_loss, accuracy

In [None]:
print(torch.cuda.is_available())

In [None]:
def train_model(model, train_loader, val_loader, device, num_epochs=10, patience=3):
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0
    best_loss = float('inf')
    epochs_no_improve = 0

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct_predictions = 0
        total_predictions = 0

        with tqdm(total=len(train_loader), desc=f'Epoch {epoch + 1}/{num_epochs}', unit='batch') as pbar:
            for imgs, labels in train_loader:
                imgs, labels = imgs.to(device), labels.to(device)

                delta = pgd(model, imgs, labels)

                optimizer.zero_grad()

                outputs = model(imgs + delta)
                loss = criterion(outputs, labels)

                loss.backward()
                optimizer.step()

                running_loss += loss.item()

                _, predicted = torch.max(outputs, 1)
                correct_predictions += (predicted == labels).sum().item()
                total_predictions += labels.size(0)

                pbar.update(1)
                pbar.set_postfix({'loss': running_loss / (total_predictions / labels.size(0)),
                                  'accuracy': correct_predictions / total_predictions})

        avg_train_loss = running_loss / len(train_loader)
        train_accuracy = correct_predictions / total_predictions

        val_loss, val_accuracy = evaluate_model(model, val_loader, device, attack=pgd)
        clean_loss, clean_accuracy = evaluate_model(model, val_loader, device)


        print(f'Epoch [{epoch + 1}/{num_epochs}], Train Loss: {avg_train_loss:.4f}, Train Accuracy: {train_accuracy:.4f}, Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.4f}, Clean Loss: {clean_loss:.4f}, Clean Accuracy: {clean_accuracy:.4f}')

        if val_accuracy > best_acc:
            best_acc = val_accuracy
            best_model_wts = copy.deepcopy(model.state_dict())
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= patience:
                print('Early stopping')
                break

    model.load_state_dict(best_model_wts)
    return model

# Train the model with early stopping
best_model = train_model(model, train_loader, val_loader, device, num_epochs=10, patience=3)


In [None]:
# Evaluate the best model
def evaluate_all(model, val_loader, device):
    print("Evaluating Clean Accuracy")
    clean_loss, clean_accuracy = evaluate_model(model, val_loader, device)

    print("Evaluating Robust Accuracy (FGSM)")
    fgsm_loss, fgsm_accuracy = evaluate_model(model, val_loader, device, attack=fgsm)

    print("Evaluating Robust Accuracy (PGD)")
    pgd_loss, pgd_accuracy = evaluate_model(model, val_loader, device, attack=pgd)

    print(f'Clean Accuracy: {clean_accuracy:.4f} \t Clean Loss: {clean_loss:.4f}')
    print(f'Robust Accuracy (FGSM): {fgsm_accuracy:.4f} \t FGSM Loss: {fgsm_loss:.4f}')
    print(f'Robust Accuracy (PGD): {pgd_accuracy:.4f} \t PGD Loss: {pgd_loss:.4f}')

# Evaluate on the validation set
evaluate_all(best_model, val_loader, device)

In [None]:
# Save the model's state_dict
torch.save(model.state_dict(), "out/models/models4.pt")

In [None]:
#### SUBMISSION ####

# # Create a dummy model
# model = models.resnet18(weights=None)
# model.fc = nn.Linear(model.fc.weight.shape[1], 10)
# torch.save(model.state_dict(), "out/models/dummy_submission.pt")

#### Tests ####
# (these are being ran on the eval endpoint for every submission)

allowed_models = {
    "resnet18": models.resnet18,
    "resnet34": models.resnet34,
    "resnet50": models.resnet50,
}
with open("out/models/models4.pt", "rb") as f:
    try:
        model: torch.nn.Module = allowed_models["resnet34"](weights=None)
        model.fc = torch.nn.Linear(model.fc.weight.shape[1], 10)
    except Exception as e:
        raise Exception(
            f"Invalid model class, {e=}, only {allowed_models.keys()} are allowed",
        )
    try:
        state_dict = torch.load(f, map_location=torch.device("cpu"))
        model.load_state_dict(state_dict, strict=True)
        model.eval()
        out = model(torch.randn(1, 3, 32, 32))
    except Exception as e:
        raise Exception(f"Invalid model, {e=}")

    assert out.shape == (1, 10), "Invalid output shape"




In [None]:
# Send the model to the server
import requests
response = requests.post("http://34.71.138.79:9090/robustness", files={"file": open("out/models/models7.pt", "rb")}, headers={"token": "76282151", "model-name": "resnet34"})

# Should be 400, the clean accuracy is too low
print(response.json())