In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
from cleverhans.torch.attacks.carlini_wagner_l2 import carlini_wagner_l2


In [6]:
# Load FashionMNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
train_loader = DataLoader(
    datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform),
    batch_size=64, shuffle=True
)
test_loader = DataLoader(
    datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform),
    batch_size=1000, shuffle=False
)

In [7]:

# Define a ResNet50 model adapted for grayscale images and 10 classes
class ResNet50(nn.Module):
    def __init__(self):
        super(ResNet50, self).__init__()
        self.model = models.resnet50(weights=None)
        # Modify the first conv layer to accept 1-channel input
        self.model.conv1 = nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1, bias=False)
        # Remove maxpool layer to reduce spatial downsampling
        self.model.maxpool = nn.Identity()
        # Modify the fully connected layer to output 10 classes
        num_ftrs = self.model.fc.in_features
        self.model.fc = nn.Linear(num_ftrs, 10)

    def forward(self, x):
        return self.model(x)

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
model = ResNet50().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

# Training function with CW-L2 adversarial examples
def train_with_cw(model, train_loader, optimizer, criterion, device, epoch):
    model.train()
    total_loss, total_adv_loss = 0, 0
    correct, correct_adv = 0, 0
    total = 0

    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)

        # Standard forward pass
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        total_loss += loss.item()

        # Compute accuracy for clean data
        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()

        # Generate adversarial examples using CW-L2 from CleverHans
        adv_data = carlini_wagner_l2(model, data, n_classes=10, targeted=False, confidence=0, clip_min=0.0, clip_max=1.0, max_iterations=10)
        
        # Forward pass on adversarial examples
        adv_output = model(adv_data)
        adv_loss = criterion(adv_output, target)
        total_adv_loss += adv_loss.item()

        # Compute accuracy for adversarial data
        adv_pred = adv_output.argmax(dim=1, keepdim=True)
        correct_adv += adv_pred.eq(target.view_as(adv_pred)).sum().item()

        # Combined loss
        total_batch_loss = (loss + adv_loss) / 2

        # Backward pass and optimization
        total_batch_loss.backward()
        optimizer.step()

        total += target.size(0)

        # Log information for each batch
        if batch_idx % 100 == 0:
            print(f'Epoch: {epoch+1} [{batch_idx * len(data)}/{len(train_loader.dataset)}] '
                  f'Loss: {loss.item():.4f} | Adv Loss: {adv_loss.item():.4f} | '
                  f'Acc: {100. * correct / total:.2f}% | Adv Acc: {100. * correct_adv / total:.2f}%')

    # Log average loss and accuracy for epoch
    avg_loss = total_loss / len(train_loader)
    avg_adv_loss = total_adv_loss / len(train_loader)
    avg_acc = 100. * correct / total
    avg_adv_acc = 100. * correct_adv / total
    print(f'==> Epoch: {epoch+1} | Avg Loss: {avg_loss:.4f} | Avg Adv Loss: {avg_adv_loss:.4f} | '
          f'Avg Acc: {avg_acc:.2f}% | Avg Adv Acc: {avg_adv_acc:.2f}%')


cuda


In [8]:

# Training loop with CW-L2 adversarial examples
epochs = 5
for epoch in range(epochs):
    train_with_cw(model, train_loader, optimizer, criterion, device, epoch)
    print(f"Epoch {epoch+1}/{epochs} completed\n")

print("Adversarial training completed.")

Epoch: 1 [0/60000] Loss: 2.3446 | Adv Loss: 2.3619 | Acc: 15.62% | Adv Acc: 6.25%
Epoch: 1 [6400/60000] Loss: 12.2217 | Adv Loss: 12.1206 | Acc: 9.03% | Adv Acc: 9.50%
Epoch: 1 [12800/60000] Loss: 11.7812 | Adv Loss: 11.8770 | Acc: 9.28% | Adv Acc: 9.72%
Epoch: 1 [19200/60000] Loss: 10.5922 | Adv Loss: 10.8280 | Acc: 9.92% | Adv Acc: 10.10%
Epoch: 1 [25600/60000] Loss: 11.9562 | Adv Loss: 11.9950 | Acc: 10.27% | Adv Acc: 9.91%
Epoch: 1 [32000/60000] Loss: 19.1488 | Adv Loss: 19.0997 | Acc: 10.15% | Adv Acc: 10.08%
Epoch: 1 [38400/60000] Loss: 16.1204 | Adv Loss: 16.4417 | Acc: 9.98% | Adv Acc: 9.95%
Epoch: 1 [44800/60000] Loss: 21.0035 | Adv Loss: 20.8643 | Acc: 9.92% | Adv Acc: 9.94%
Epoch: 1 [51200/60000] Loss: 23.5379 | Adv Loss: 20.4228 | Acc: 9.94% | Adv Acc: 9.98%
Epoch: 1 [57600/60000] Loss: 48.4336 | Adv Loss: 49.6084 | Acc: 9.75% | Adv Acc: 9.87%
==> Epoch: 1 | Avg Loss: 19.3728 | Avg Adv Loss: 19.2713 | Avg Acc: 9.76% | Avg Adv Acc: 9.84%
Epoch 1/5 completed

Epoch: 2 [0/6000