In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

/kaggle/input/cornet-s_btc_mnist_50/pytorch/default/1/mnist_manual_backup_epoch50.pth


In [1]:
!git clone https://github.com/dicarlolab/CORnet.git
# Navigate to the cloned repository folder
import os
os.chdir('/kaggle/working/CORnet')

# Install the package if needed
!pip install .

Cloning into 'CORnet'...
remote: Enumerating objects: 155, done.[K
remote: Counting objects: 100% (20/20), done.[K
remote: Compressing objects: 100% (12/12), done.[K
remote: Total 155 (delta 13), reused 9 (delta 8), pack-reused 135 (from 1)[K
Receiving objects: 100% (155/155), 68.11 KiB | 8.51 MiB/s, done.
Resolving deltas: 100% (87/87), done.
Processing /kaggle/working/CORnet
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting fire (from CORnet==0.1.0)
  Downloading fire-0.7.0.tar.gz (87 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m87.2/87.2 kB[0m [31m4.1 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=0.4.0->CORnet==0.1.0)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch>=0.4.0->CORnet==0.1.0)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_

In [2]:
from cornet import cornet_s

model = cornet_s(pretrained=False)
print (model)

DataParallel(
  (module): Sequential(
    (V1): Sequential(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (nonlin1): ReLU(inplace=True)
      (pool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (norm2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (nonlin2): ReLU(inplace=True)
      (output): Identity()
    )
    (V2): CORblock_S(
      (conv_input): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (skip): Conv2d(128, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
      (norm_skip): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv1): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (no

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
import cv2
import numpy as np
from PIL import Image
from tqdm import tqdm
import os
from cornet import cornet_s

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Gaussian blur function
def apply_gaussian_blur(img, sigma):
    return cv2.GaussianBlur(img, (5, 5), sigmaX=sigma, sigmaY=sigma)

# Custom MNIST dataset with blur
class BlurryMNIST(Dataset):
    def __init__(self, train=True, sigma=0):
        self.dataset = torchvision.datasets.MNIST(root="./data", train=train, download=True)
        self.sigma = sigma
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x.repeat(3, 1, 1)),
            transforms.Normalize(mean=[0.4914, 0.4822, 0.446], std=[0.2023, 0.1994, 0.2010])
        ])

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        img, label = self.dataset[idx]
        img = np.array(img)
        if self.sigma > 0:
            img = apply_gaussian_blur(img, self.sigma)
        img = Image.fromarray(img)
        img = self.transform(img)
        return img, label

# CORnet-S model
model = cornet_s(pretrained=False).to(device)
model.module.decoder.linear = nn.Linear(model.module.decoder.linear.in_features, 10).to(device)

# Training settings
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
batch_size = 64
num_epochs = 10
# Updated sigma schedule: 3 epochs with σ=2, 3 epochs with σ=1, 4 epochs with σ=0
sigma_schedule = [2]*3 + [1]*3 + [0]*4

Using device: cuda


In [4]:
# --------------------- CHECKPOINT SETUP ---------------------
checkpoint_dir = "./mnist_checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)

checkpoint_path = "/kaggle/input/mnist_epoch40/pytorch/default/1/mnist_manual_backup_epoch40.pth"
start_epoch = 0
val_acc_saved = 0.0

# Load test set
test_dataset = BlurryMNIST(train=False, sigma=0)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# --------------------- LOAD CHECKPOINT ---------------------
if os.path.exists(checkpoint_path):
    print("Resuming from checkpoint...")
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch'] + 1  # ✅ epoch = 39 → start_epoch = 40
    val_acc_saved = checkpoint.get('val_acc', 0.0)
    print(f"Resumed from epoch {start_epoch}, saved val acc: {val_acc_saved:.4f}")

    # Run one validation pass to confirm
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for images, labels in tqdm(test_loader, desc="Quick Validation Check"):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)
    val_acc_check = correct / total
    print(f"🔍 Confirmed resumed model val accuracy: {val_acc_check:.4f}")
else:
    print("No checkpoint found. Starting fresh.")

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9.91M/9.91M [00:00<00:00, 37.3MB/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28.9k/28.9k [00:00<00:00, 1.12MB/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1.65M/1.65M [00:00<00:00, 9.11MB/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4.54k/4.54k [00:00<00:00, 5.87MB/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw

No checkpoint found. Starting fresh.





In [5]:
# --------------------- TRAINING LOOP ---------------------
for epoch in range(start_epoch, num_epochs):
    # Inside training loop (no need to change anything else)
    sigma = sigma_schedule[epoch]
    print(f"\nEpoch {epoch+1}/{num_epochs} - Blur σ={sigma}")

    train_dataset = BlurryMNIST(train=True, sigma=sigma)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    # Training phase
    model.train()
    running_loss, correct, total = 0.0, 0, 0
    train_pbar = tqdm(train_loader, desc="Training", leave=False)

    for images, labels in train_pbar:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        correct += (predicted == labels).sum().item()
        total += labels.size(0)
        train_pbar.set_postfix(loss=loss.item(), acc=100. * correct / total)

    train_acc = correct / total
    print(f"Train Loss: {running_loss / len(train_loader):.4f}, Train Acc: {train_acc:.4f}")

    # Validation phase
    model.eval()
    correct, total = 0, 0
    val_pbar = tqdm(test_loader, desc="Validation", leave=False)
    with torch.no_grad():
        for images, labels in val_pbar:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)
            val_pbar.set_postfix(acc=100. * correct / total)

    val_acc = correct / total
    print(f"Validation Top-1 Accuracy: {val_acc:.4f}")

    # Save checkpoint
    save_path = os.path.join(checkpoint_dir, "mnist_checkpoint.pth")
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'val_acc': val_acc
    }, save_path)
    print(f"📦 Checkpoint saved to {save_path}")


Epoch 1/10 - Blur σ=2


                                                                                   

Train Loss: 0.1651, Train Acc: 0.9504


                                                                       

Validation Top-1 Accuracy: 0.4977
📦 Checkpoint saved to ./mnist_checkpoints/mnist_checkpoint.pth

Epoch 2/10 - Blur σ=2


                                                                                   

Train Loss: 0.0634, Train Acc: 0.9804


                                                                       

Validation Top-1 Accuracy: 0.8705
📦 Checkpoint saved to ./mnist_checkpoints/mnist_checkpoint.pth

Epoch 3/10 - Blur σ=2


                                                                                    

Train Loss: 0.0517, Train Acc: 0.9843


                                                                       

Validation Top-1 Accuracy: 0.7146
📦 Checkpoint saved to ./mnist_checkpoints/mnist_checkpoint.pth

Epoch 4/10 - Blur σ=1


                                                                                    

Train Loss: 0.0372, Train Acc: 0.9882


                                                                       

Validation Top-1 Accuracy: 0.8308
📦 Checkpoint saved to ./mnist_checkpoints/mnist_checkpoint.pth

Epoch 5/10 - Blur σ=1


                                                                                    

Train Loss: 0.0330, Train Acc: 0.9895


                                                                       

Validation Top-1 Accuracy: 0.9545
📦 Checkpoint saved to ./mnist_checkpoints/mnist_checkpoint.pth

Epoch 6/10 - Blur σ=1


                                                                                    

Train Loss: 0.0311, Train Acc: 0.9903


                                                                       

Validation Top-1 Accuracy: 0.9783
📦 Checkpoint saved to ./mnist_checkpoints/mnist_checkpoint.pth

Epoch 7/10 - Blur σ=0


                                                                                    

Train Loss: 0.0273, Train Acc: 0.9914


                                                                       

Validation Top-1 Accuracy: 0.9791
📦 Checkpoint saved to ./mnist_checkpoints/mnist_checkpoint.pth

Epoch 8/10 - Blur σ=0


                                                                                    

Train Loss: 0.0228, Train Acc: 0.9928


                                                                       

Validation Top-1 Accuracy: 0.9882
📦 Checkpoint saved to ./mnist_checkpoints/mnist_checkpoint.pth

Epoch 9/10 - Blur σ=0


                                                                                    

Train Loss: 0.0219, Train Acc: 0.9930


                                                                       

Validation Top-1 Accuracy: 0.9941
📦 Checkpoint saved to ./mnist_checkpoints/mnist_checkpoint.pth

Epoch 10/10 - Blur σ=0


                                                                                    

Train Loss: 0.0189, Train Acc: 0.9937


                                                                       

Validation Top-1 Accuracy: 0.9936
📦 Checkpoint saved to ./mnist_checkpoints/mnist_checkpoint.pth


In [6]:
import torch
import os

# Ensure checkpoint directory exists
checkpoint_dir = "./mnist_checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)

# Path to save manual checkpoint for epoch 50
manual_save_path = os.path.join(checkpoint_dir, "mnist_manual_backup_epoch10.pth")

# Set epoch to 49 since it's 0-indexed
manual_epoch = 9

# Re-evaluate validation accuracy before saving
model.eval()
correct, total = 0, 0
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        correct += (predicted == labels).sum().item()
        total += labels.size(0)

val_acc = correct / total
print(f"✅ Final validation accuracy at epoch {manual_epoch+1}: {val_acc:.4f}")

# Save full checkpoint
torch.save({
    'epoch': manual_epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'val_acc': val_acc
}, manual_save_path)

print(f"📦 Manually saved checkpoint to: {manual_save_path}")

✅ Final validation accuracy at epoch 10: 0.9936
📦 Manually saved checkpoint to: ./mnist_checkpoints/mnist_manual_backup_epoch10.pth


# Perturbation Budgeting using CW and PGD

In [4]:
# Load the model
import torch
from tqdm import tqdm

# --------------------- MODEL LOADING ---------------------

# Set path to your saved model
checkpoint_path = "/kaggle/input/cornet-s_btc_mnist_50/pytorch/default/1/mnist_manual_backup_epoch50.pth"

# Initialize CORnet-S model for MNIST
model = cornet_s(pretrained=False).to(device)
model.module.decoder.linear = nn.Linear(model.module.decoder.linear.in_features, 10).to(device)

# Load model checkpoint
print(f"🔄 Loading model from: {checkpoint_path}")
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint['model_state_dict'])
print(f"✅ Model loaded successfully from epoch {checkpoint['epoch'] + 1}")

# --------------------- VALIDATE MODEL ---------------------

# Load clean test set
test_dataset = BlurryMNIST(train=False, sigma=0)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Run validation once to confirm model is good
model.eval()
correct, total = 0, 0
with torch.no_grad():
    for images, labels in tqdm(test_loader, desc="Validating Loaded Model"):
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = outputs.max(1)
        correct += predicted.eq(labels).sum().item()
        total += labels.size(0)

val_acc = correct / total
print(f"🎯 Validation Accuracy on Clean Test Set: {val_acc:.4f}")

🔄 Loading model from: /kaggle/input/cornet-s_btc_mnist_50/pytorch/default/1/mnist_manual_backup_epoch50.pth


  checkpoint = torch.load(checkpoint_path)


✅ Model loaded successfully from epoch 50


Validating Loaded Model: 100%|██████████| 157/157 [00:50<00:00,  3.14it/s]

🎯 Validation Accuracy on Clean Test Set: 0.9926





In [7]:
!pip install torchattacks

Collecting torchattacks
  Downloading torchattacks-3.5.1-py3-none-any.whl.metadata (927 bytes)
Collecting requests~=2.25.1 (from torchattacks)
  Downloading requests-2.25.1-py2.py3-none-any.whl.metadata (4.2 kB)
Collecting chardet<5,>=3.0.2 (from requests~=2.25.1->torchattacks)
  Downloading chardet-4.0.0-py2.py3-none-any.whl.metadata (3.5 kB)
Collecting idna<3,>=2.5 (from requests~=2.25.1->torchattacks)
  Downloading idna-2.10-py2.py3-none-any.whl.metadata (9.1 kB)
Collecting urllib3<1.27,>=1.21.1 (from requests~=2.25.1->torchattacks)
  Downloading urllib3-1.26.20-py2.py3-none-any.whl.metadata (50 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m50.1/50.1 kB[0m [31m1.2 MB/s[0m eta [36m0:00:00[0m
Downloading torchattacks-3.5.1-py3-none-any.whl (142 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m142.0/142.0 kB[0m [31m4.1 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hDownloading requests-2.25.1-py2.py3-none-any.whl (61 kB)
[2K   [90m

In [9]:
import torchattacks
from tqdm import tqdm

# ---------------- PGD Attack Setup ----------------
pgd_eps = 0.01         # Changeable
pgd_alpha = 2/255
pgd_steps = 40

pgd_attack = torchattacks.PGD(model, eps=pgd_eps, alpha=pgd_alpha, steps=pgd_steps)

# ---------------- PGD Adversarial Accuracy Function ----------------
def adversarial_test_pgd(attack, loader):
    model.eval()
    correct = 0
    total = 0
    
    for inputs, labels in tqdm(loader, desc=f"Adversarial Test (PGD)"):
        inputs, labels = inputs.to(device), labels.to(device)
        adv_inputs = attack(inputs, labels)
        outputs = model(adv_inputs)
        _, predicted = outputs.max(1)
        correct += predicted.eq(labels).sum().item()
        total += labels.size(0)
    
    acc = correct / total
    print(f"📊 PGD Accuracy (ε={pgd_eps}): {acc:.4f}")
    return acc

# ---------------- Run PGD Evaluation ----------------
pgd_acc = adversarial_test_pgd(pgd_attack, test_loader)

Adversarial Test (PGD): 100%|██████████| 157/157 [53:28<00:00, 20.44s/it]

📊 PGD Accuracy (ε=0.01): 0.0958





In [10]:
import torchattacks
from tqdm import tqdm

# ---------------- CW Attack Setup ----------------
cw_c = 1e-3
cw_kappa = 0
cw_steps = 100
cw_lr = 0.01

cw_attack = torchattacks.CW(model, c=cw_c, kappa=cw_kappa, steps=cw_steps, lr=cw_lr)

# ---------------- CW Adversarial Accuracy Function ----------------
def adversarial_test_cw(attack, loader):
    model.eval()
    correct = 0
    total = 0

    for inputs, labels in tqdm(loader, desc=f"Adversarial Test (CW)"):
        inputs, labels = inputs.to(device), labels.to(device)
        adv_inputs = attack(inputs, labels)
        outputs = model(adv_inputs)
        _, predicted = outputs.max(1)
        correct += predicted.eq(labels).sum().item()
        total += labels.size(0)

    acc = correct / total
    print(f"📊 CW Accuracy (c={cw_c}, kappa={cw_kappa}): {acc:.4f}")
    return acc

# ---------------- Run CW Evaluation ----------------
cw_acc = adversarial_test_cw(cw_attack, test_loader)

Adversarial Test (CW): 100%|██████████| 157/157 [1:37:54<00:00, 37.41s/it]

📊 CW Accuracy (c=0.001, kappa=0): 0.1167





In [None]:
!pip install -q torchattacks
