In [None]:
!pip install torchattacks

Collecting torchattacks
  Downloading torchattacks-3.5.1-py3-none-any.whl.metadata (927 bytes)
Collecting requests~=2.25.1 (from torchattacks)
  Downloading requests-2.25.1-py2.py3-none-any.whl.metadata (4.2 kB)
Collecting chardet<5,>=3.0.2 (from requests~=2.25.1->torchattacks)
  Downloading chardet-4.0.0-py2.py3-none-any.whl.metadata (3.5 kB)
Collecting idna<3,>=2.5 (from requests~=2.25.1->torchattacks)
  Downloading idna-2.10-py2.py3-none-any.whl.metadata (9.1 kB)
Collecting urllib3<1.27,>=1.21.1 (from requests~=2.25.1->torchattacks)
  Downloading urllib3-1.26.20-py2.py3-none-any.whl.metadata (50 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m50.1/50.1 kB[0m [31m1.9 MB/s[0m eta [36m0:00:00[0m
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=1.7.1->torchattacks)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=1.7.1->torchattacks)
  Downloading 

####Libraries

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torchattacks import PGD
import time

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

Using device: cpu


####Load CIFAR-10 dataset

In [None]:
transform = transforms.Compose([transforms.ToTensor()])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,
                                          shuffle=True)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=128,
                                         shuffle=False)

100%|██████████| 170M/170M [00:03<00:00, 49.1MB/s]


####Define simple CNN model

In [None]:
class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(64*8*8, 256),
            nn.ReLU(),
            nn.Linear(256, 10)
        )
    def forward(self, x):
        return self.net(x)

model = SimpleCNN().to(device)

####Input preprocessing (simple quantization + noise)

In [None]:
def input_preprocess(x):
    noise = torch.randn_like(x) * 0.01
    x_noisy = torch.clamp(x + noise, 0, 1)
    x_quant = torch.round(x_noisy * 255) / 255
    return x_quant

####Setup loss, optimizer, adversarial attacker

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

pgd = PGD(model, eps=0.03, alpha=0.01, steps=40)

####Training loop with adversarial training

In [None]:
def train(epoch):
    model.train()
    running_loss = 0
    for batch_idx, (inputs, labels) in enumerate(trainloader):
        inputs, labels = inputs.to(device), labels.to(device)
        inputs_prep = input_preprocess(inputs)
        adv_inputs = pgd(inputs_prep, labels)

        optimizer.zero_grad()
        outputs_clean = model(inputs_prep)
        outputs_adv = model(adv_inputs)
        loss = (criterion(outputs_clean, labels) + criterion(outputs_adv, labels)) / 2
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if batch_idx % 100 == 99:
            print(f'[Epoch {epoch+1}, Batch {batch_idx+1}] loss: {running_loss/100:.4f}')
            running_loss = 0

####Evaluation function

In [None]:
def evaluate():
    model.eval()
    correct_clean = 0
    total = 0
    correct_adv = 0

    for inputs, labels in testloader:
        inputs, labels = inputs.to(device), labels.to(device)
        inputs_prep = input_preprocess(inputs)

        with torch.no_grad():
            outputs_clean = model(inputs_prep)
            _, predicted_clean = outputs_clean.max(1)

        adv_inputs = pgd(inputs_prep, labels)
        with torch.no_grad():
            outputs_adv = model(adv_inputs)
            _, predicted_adv = outputs_adv.max(1)

        total += labels.size(0)
        correct_clean += (predicted_clean == labels).sum().item()
        correct_adv += (predicted_adv == labels).sum().item()

    clean_acc = 100 * correct_clean / total
    adv_acc = 100 * correct_adv / total
    print(f'Clean accuracy: {clean_acc:.2f}%')
    print(f'Adversarial accuracy: {adv_acc:.2f}%')
    return clean_acc, adv_acc

####Measure latency per batch

In [None]:
def measure_latency():
    model.eval()
    start = time.time()
    with torch.no_grad():
        for inputs, _ in testloader:
            inputs = inputs.to(device)
            inputs_prep = input_preprocess(inputs)
            _ = model(inputs_prep)
            break
    end = time.time()
    print(f'Latency per batch of {inputs.size(0)} samples: {(end - start):.4f} seconds')

####Run training and evaluation

In [None]:
clean_acc_list = []
adv_acc_list = []
best_adv_acc = 0

for epoch in range(3):
    train(epoch)
    clean_acc, adv_acc = evaluate()
    measure_latency()

    clean_acc_list.append(clean_acc)
    adv_acc_list.append(adv_acc)

    if adv_acc > best_adv_acc:
        best_adv_acc = adv_acc
        torch.save(model.state_dict(), 'best_model_adv_acc.pth')
        print(f" Best model saved at epoch {epoch+1} with {adv_acc:.2f}% adversarial accuracy")

torch.save(model.state_dict(), 'hybrid_defense_model.pth')
print(" Final model saved as 'hybrid_defense_model.pth'")

with open('accuracy_log.csv', 'w', newline='') as file:
    writer = csv.writer(file)
    writer.writerow(['Epoch', 'Clean Accuracy (%)', 'Adversarial Accuracy (%)'])
    for i, (clean, adv) in enumerate(zip(clean_acc_list, adv_acc_list), 1):
        writer.writerow([i, clean, adv])
print(" Saved 'accuracy_log.csv'")

epochs = range(1, len(clean_acc_list)+1)
plt.figure(figsize=(8, 5))
plt.plot(epochs, clean_acc_list, label='Clean Accuracy', marker='o')
plt.plot(epochs, adv_acc_list, label='Adversarial Accuracy', marker='s')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.title('Clean vs Adversarial Accuracy Over Epochs')
plt.legend()
plt.grid(True)
plt.show()

[Epoch 1, Batch 100] loss: 2.0896
