In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
from PIL import Image, ImageFilter
from tqdm import tqdm
import os

# --- Device Setup ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- Hyperparameters ---
learning_rate = 0.01
momentum = 0.9
weight_decay = 1e-4
batch_size = 128
num_epochs_total = 50
epochs_per_blur = 10
blur_levels = [8, 4, 2, 1, 0]

# --- Gaussian Blur Wrapper ---
class GaussianBlur:
    def __init__(self, sigma):
        self.sigma = sigma

    def __call__(self, img):
        return img.filter(ImageFilter.GaussianBlur(self.sigma))

# --- DataLoader Generator ---
def get_dataloader(blur_level, batch_size):
    transform = transforms.Compose([
        transforms.Grayscale(num_output_channels=3),
        transforms.CenterCrop((256, 256)),
        transforms.Resize((224, 224)),
        transforms.Lambda(lambda img: GaussianBlur(blur_level)(img)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
    dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    return DataLoader(dataset, batch_size=batch_size, shuffle=True)

# --- Test Loader (no blur) ---
test_transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),
    transforms.CenterCrop((256, 256)),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=test_transform)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# --- Compute Validation Accuracy with BN-safe Evaluation ---
def evaluate_accuracy(model, val_loader, device):
    model.train()  # Avoid BatchNorm instability
    correct = 0
    total = 0
    val_pbar = tqdm(val_loader, desc="Evaluating", leave=False)
    with torch.no_grad():
        for inputs, labels in val_pbar:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            if torch.isnan(outputs).any():
                print("⚠️ NaNs detected, skipping batch.")
                continue
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            acc = 100 * correct / total if total > 0 else 0
            val_pbar.set_postfix(acc=f"{acc:.2f}%")
    return 100 * correct / total if total > 0 else 0

# --- Training Loop with single rolling checkpoint ---
def train_model(model, criterion, optimizer, train_loader, val_loader, device,
                num_epochs=10, blur_level=0, current_epoch_offset=0):
    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} (Blur {blur_level})", leave=False)
        for inputs, labels in pbar:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            pbar.set_postfix(loss=f"{loss.item():.4f}")

        acc = evaluate_accuracy(model, val_loader, device)
        print(f"✅ Epoch {current_epoch_offset + epoch + 1} done | Loss: {running_loss:.4f} | Accuracy: {acc:.2f}%")

        # Overwrite the same checkpoint file
        checkpoint = {
            'epoch': current_epoch_offset + epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'blur_level': blur_level
        }
        torch.save(checkpoint, "btc_checkpoint.pth")
        print(f"💾 Checkpoint saved: btc_checkpoint.pth (epoch {current_epoch_offset + epoch + 1})")

    return model

# --- BTC Trainer with resume ---
def adjust_blur_and_train(model):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=learning_rate,
                          momentum=momentum, weight_decay=weight_decay)

    # Resume if btc_checkpoint.pth exists
    start_epoch = 0
    start_blur = 0

    if os.path.exists("btc_checkpoint.pth"):
        print(f"🔄 Found checkpoint: btc_checkpoint.pth")
        checkpoint = torch.load("btc_checkpoint.pth")
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch']
        start_blur = checkpoint['blur_level']
        print(f"Resuming from blur {start_blur}, epoch {start_epoch}")

    current_epoch = start_epoch
    blur_index = blur_levels.index(start_blur) if start_epoch > 0 else 0

    for blur in blur_levels[blur_index:]:
        if current_epoch >= num_epochs_total:
            break
        print(f"\n--- Training with blur level {blur} ---")
        epochs_to_train = min(epochs_per_blur, num_epochs_total - current_epoch)
        train_loader = get_dataloader(blur, batch_size)
        model = train_model(model, criterion, optimizer, train_loader, test_loader, device,
                            num_epochs=epochs_to_train, blur_level=blur, current_epoch_offset=current_epoch)
        current_epoch += epochs_to_train

    return model

# --- Main ---
if __name__ == '__main__':
    resnet = models.resnet18(pretrained=False)
    resnet.fc = nn.Linear(resnet.fc.in_features, 10)
    resnet = resnet.to(device)

    print("🚀 Starting BTC Training for ResNet-18...")
    trained_model = adjust_blur_and_train(resnet)

    torch.save(trained_model.state_dict(), "resnet18_btc_mnist_final.pth")
    print("✅ Final model saved: resnet18_btc_mnist_final.pth")

100%|██████████| 9.91M/9.91M [00:00<00:00, 35.2MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 1.04MB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 9.21MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 7.80MB/s]


🚀 Starting BTC Training for ResNet-18...

--- Training with blur level 8 ---


                                                                                   

✅ Epoch 1 done | Loss: 513.0292 | Accuracy: 14.77%
💾 Checkpoint saved: btc_checkpoint.pth (epoch 1)


                                                                                   

✅ Epoch 2 done | Loss: 245.0312 | Accuracy: 7.88%
💾 Checkpoint saved: btc_checkpoint.pth (epoch 2)


                                                                                   

✅ Epoch 3 done | Loss: 183.5884 | Accuracy: 8.06%
💾 Checkpoint saved: btc_checkpoint.pth (epoch 3)


                                                                                   

✅ Epoch 4 done | Loss: 151.5682 | Accuracy: 8.44%
💾 Checkpoint saved: btc_checkpoint.pth (epoch 4)


                                                                                   

✅ Epoch 5 done | Loss: 133.2015 | Accuracy: 10.24%
💾 Checkpoint saved: btc_checkpoint.pth (epoch 5)


                                                                                   

✅ Epoch 6 done | Loss: 122.4183 | Accuracy: 9.31%
💾 Checkpoint saved: btc_checkpoint.pth (epoch 6)


                                                                                   

✅ Epoch 7 done | Loss: 111.1967 | Accuracy: 9.47%
💾 Checkpoint saved: btc_checkpoint.pth (epoch 7)


                                                                                   

✅ Epoch 8 done | Loss: 102.2539 | Accuracy: 10.03%
💾 Checkpoint saved: btc_checkpoint.pth (epoch 8)


                                                                                   

✅ Epoch 9 done | Loss: 98.2429 | Accuracy: 9.83%
💾 Checkpoint saved: btc_checkpoint.pth (epoch 9)


                                                                                    

✅ Epoch 10 done | Loss: 90.9857 | Accuracy: 9.53%
💾 Checkpoint saved: btc_checkpoint.pth (epoch 10)

--- Training with blur level 4 ---


                                                                                   

✅ Epoch 11 done | Loss: 160.1027 | Accuracy: 59.62%
💾 Checkpoint saved: btc_checkpoint.pth (epoch 11)


                                                                                   

✅ Epoch 12 done | Loss: 72.6438 | Accuracy: 55.47%
💾 Checkpoint saved: btc_checkpoint.pth (epoch 12)


                                                                                   

✅ Epoch 13 done | Loss: 54.1382 | Accuracy: 54.49%
💾 Checkpoint saved: btc_checkpoint.pth (epoch 13)


                                                                                   

✅ Epoch 14 done | Loss: 45.2932 | Accuracy: 54.54%
💾 Checkpoint saved: btc_checkpoint.pth (epoch 14)


                                                                                   

✅ Epoch 15 done | Loss: 39.5331 | Accuracy: 56.64%
💾 Checkpoint saved: btc_checkpoint.pth (epoch 15)


                                                                                   

✅ Epoch 16 done | Loss: 34.2580 | Accuracy: 56.88%
💾 Checkpoint saved: btc_checkpoint.pth (epoch 16)


                                                                                   

✅ Epoch 17 done | Loss: 31.7464 | Accuracy: 58.74%
💾 Checkpoint saved: btc_checkpoint.pth (epoch 17)


                                                                                   

✅ Epoch 18 done | Loss: 28.4596 | Accuracy: 58.50%
💾 Checkpoint saved: btc_checkpoint.pth (epoch 18)


                                                                                   

✅ Epoch 19 done | Loss: 24.2696 | Accuracy: 59.01%
💾 Checkpoint saved: btc_checkpoint.pth (epoch 19)


                                                                                    

✅ Epoch 20 done | Loss: 22.7817 | Accuracy: 59.40%
💾 Checkpoint saved: btc_checkpoint.pth (epoch 20)

--- Training with blur level 2 ---


                                                                                   

✅ Epoch 21 done | Loss: 40.6033 | Accuracy: 93.13%
💾 Checkpoint saved: btc_checkpoint.pth (epoch 21)


                                                                                   

✅ Epoch 22 done | Loss: 19.0708 | Accuracy: 92.73%
💾 Checkpoint saved: btc_checkpoint.pth (epoch 22)


                                                                                   

✅ Epoch 23 done | Loss: 15.0785 | Accuracy: 92.70%
💾 Checkpoint saved: btc_checkpoint.pth (epoch 23)


                                                                                   

✅ Epoch 24 done | Loss: 12.2748 | Accuracy: 92.50%
💾 Checkpoint saved: btc_checkpoint.pth (epoch 24)


                                                                                   

✅ Epoch 25 done | Loss: 9.6729 | Accuracy: 93.00%
💾 Checkpoint saved: btc_checkpoint.pth (epoch 25)


                                                                                   

✅ Epoch 26 done | Loss: 8.7599 | Accuracy: 92.50%
💾 Checkpoint saved: btc_checkpoint.pth (epoch 26)


                                                                                   

✅ Epoch 27 done | Loss: 7.0702 | Accuracy: 91.51%
💾 Checkpoint saved: btc_checkpoint.pth (epoch 27)


                                                                                   

✅ Epoch 28 done | Loss: 6.0791 | Accuracy: 92.35%
💾 Checkpoint saved: btc_checkpoint.pth (epoch 28)


                                                                                   

✅ Epoch 29 done | Loss: 4.6430 | Accuracy: 92.03%
💾 Checkpoint saved: btc_checkpoint.pth (epoch 29)


                                                                                    

✅ Epoch 30 done | Loss: 4.7250 | Accuracy: 91.99%
💾 Checkpoint saved: btc_checkpoint.pth (epoch 30)

--- Training with blur level 1 ---


                                                                                   

✅ Epoch 31 done | Loss: 14.6885 | Accuracy: 97.60%
💾 Checkpoint saved: btc_checkpoint.pth (epoch 31)


                                                                                   

✅ Epoch 32 done | Loss: 6.0830 | Accuracy: 97.74%
💾 Checkpoint saved: btc_checkpoint.pth (epoch 32)


                                                                                   

✅ Epoch 33 done | Loss: 4.1256 | Accuracy: 97.57%
💾 Checkpoint saved: btc_checkpoint.pth (epoch 33)


                                                                                   

✅ Epoch 34 done | Loss: 2.5385 | Accuracy: 98.38%
💾 Checkpoint saved: btc_checkpoint.pth (epoch 34)


                                                                                   

✅ Epoch 35 done | Loss: 2.6151 | Accuracy: 98.05%
💾 Checkpoint saved: btc_checkpoint.pth (epoch 35)


                                                                                   

✅ Epoch 36 done | Loss: 1.6030 | Accuracy: 98.11%
💾 Checkpoint saved: btc_checkpoint.pth (epoch 36)


                                                                                   

✅ Epoch 37 done | Loss: 1.7847 | Accuracy: 98.23%
💾 Checkpoint saved: btc_checkpoint.pth (epoch 37)


                                                                                   

✅ Epoch 38 done | Loss: 1.0738 | Accuracy: 98.35%
💾 Checkpoint saved: btc_checkpoint.pth (epoch 38)


                                                                                   

✅ Epoch 39 done | Loss: 0.4408 | Accuracy: 98.18%
💾 Checkpoint saved: btc_checkpoint.pth (epoch 39)


                                                                                    

✅ Epoch 40 done | Loss: 0.3065 | Accuracy: 98.32%
💾 Checkpoint saved: btc_checkpoint.pth (epoch 40)

--- Training with blur level 0 ---


                                                                                   

✅ Epoch 41 done | Loss: 7.7662 | Accuracy: 98.73%
💾 Checkpoint saved: btc_checkpoint.pth (epoch 41)


                                                                                   

✅ Epoch 42 done | Loss: 3.2238 | Accuracy: 98.89%
💾 Checkpoint saved: btc_checkpoint.pth (epoch 42)


                                                                                   

✅ Epoch 43 done | Loss: 1.9509 | Accuracy: 98.86%
💾 Checkpoint saved: btc_checkpoint.pth (epoch 43)


                                                                                   

✅ Epoch 44 done | Loss: 0.8664 | Accuracy: 98.96%
💾 Checkpoint saved: btc_checkpoint.pth (epoch 44)


                                                                                   

✅ Epoch 45 done | Loss: 0.7454 | Accuracy: 98.93%
💾 Checkpoint saved: btc_checkpoint.pth (epoch 45)


                                                                                   

✅ Epoch 46 done | Loss: 0.4349 | Accuracy: 98.94%
💾 Checkpoint saved: btc_checkpoint.pth (epoch 46)


                                                                                   

✅ Epoch 47 done | Loss: 0.1952 | Accuracy: 98.92%
💾 Checkpoint saved: btc_checkpoint.pth (epoch 47)


                                                                                   

✅ Epoch 48 done | Loss: 0.1370 | Accuracy: 98.97%
💾 Checkpoint saved: btc_checkpoint.pth (epoch 48)


                                                                                   

✅ Epoch 49 done | Loss: 0.1250 | Accuracy: 99.00%
💾 Checkpoint saved: btc_checkpoint.pth (epoch 49)


                                                                                    

✅ Epoch 50 done | Loss: 0.1156 | Accuracy: 98.96%
💾 Checkpoint saved: btc_checkpoint.pth (epoch 50)
✅ Final model saved: resnet18_btc_mnist_final.pth


In [3]:
!pip install torchattacks

Collecting torchattacks
  Downloading torchattacks-3.5.1-py3-none-any.whl.metadata (927 bytes)
Collecting requests~=2.25.1 (from torchattacks)
  Downloading requests-2.25.1-py2.py3-none-any.whl.metadata (4.2 kB)
Collecting chardet<5,>=3.0.2 (from requests~=2.25.1->torchattacks)
  Downloading chardet-4.0.0-py2.py3-none-any.whl.metadata (3.5 kB)
Collecting idna<3,>=2.5 (from requests~=2.25.1->torchattacks)
  Downloading idna-2.10-py2.py3-none-any.whl.metadata (9.1 kB)
Collecting urllib3<1.27,>=1.21.1 (from requests~=2.25.1->torchattacks)
  Downloading urllib3-1.26.20-py2.py3-none-any.whl.metadata (50 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m50.1/50.1 kB[0m [31m1.7 MB/s[0m eta [36m0:00:00[0m
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=1.7.1->torchattacks)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch>=1.7.1->torchattacks)
  Downloading nvidia_cublas_cu

In [8]:
# ---------------- Load BTC-trained ResNet-18 ----------------
from torchvision import models
import torch.nn as nn

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = models.resnet18(pretrained=False)
model.fc = nn.Linear(model.fc.in_features, 10)  # MNIST 10 classes
model.load_state_dict(torch.load("resnet18_btc_mnist_final.pth"))
model = model.to(device)
model.eval()



ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [None]:
import torchattacks
from tqdm import tqdm

# ---------------- PGD Attack Setup ----------------
pgd_eps = 0.01         # Changeable
pgd_alpha = 2/255
pgd_steps = 40

pgd_attack = torchattacks.PGD(model, eps=pgd_eps, alpha=pgd_alpha, steps=pgd_steps)

# ---------------- PGD Adversarial Accuracy Function ----------------
def adversarial_test_pgd(attack, loader):
    model.eval()
    correct = 0
    total = 0
    
    for inputs, labels in tqdm(loader, desc=f"Adversarial Test (PGD)"):
        inputs, labels = inputs.to(device), labels.to(device)
        adv_inputs = attack(inputs, labels)
        outputs = model(adv_inputs)
        _, predicted = outputs.max(1)
        correct += predicted.eq(labels).sum().item()
        total += labels.size(0)
    
    acc = correct / total
    print(f"📊 PGD Accuracy (ε={pgd_eps}): {acc:.4f}")
    return acc

# ---------------- Run PGD Evaluation ----------------
pgd_acc = adversarial_test_pgd(pgd_attack, test_loader)

Adversarial Test (PGD):  48%|████▊     | 38/79 [07:05<07:39, 11.22s/it]

In [None]:
import torchattacks
from tqdm import tqdm

# ---------------- CW Attack Setup ----------------
cw_c = 1e-3
cw_kappa = 0
cw_steps = 100
cw_lr = 0.01

cw_attack = torchattacks.CW(model, c=cw_c, kappa=cw_kappa, steps=cw_steps, lr=cw_lr)

# ---------------- CW Adversarial Accuracy Function ----------------
def adversarial_test_cw(attack, loader):
    model.eval()
    correct = 0
    total = 0

    for inputs, labels in tqdm(loader, desc=f"Adversarial Test (CW)"):
        inputs, labels = inputs.to(device), labels.to(device)
        adv_inputs = attack(inputs, labels)
        outputs = model(adv_inputs)
        _, predicted = outputs.max(1)
        correct += predicted.eq(labels).sum().item()
        total += labels.size(0)

    acc = correct / total
    print(f"📊 CW Accuracy (c={cw_c}, kappa={cw_kappa}): {acc:.4f}")
    return acc

# ---------------- Run CW Evaluation ----------------
cw_acc = adversarial_test_cw(cw_attack, test_loader)