In [4]:
from IPython.display import clear_output
!pip install cleverhans
clear_output()

In [5]:
import numpy as np
from tqdm import tqdm
import pandas as pd

import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split

from cleverhans.torch.attacks.fast_gradient_method import fast_gradient_method
from torch.optim.lr_scheduler import StepLR

import torchvision.transforms as transforms
import torchvision.datasets as datasets

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [6]:
train_transform = transforms.Compose([transforms.ToTensor()])

train_set = datasets.CIFAR10(root='../data/', train=True, download=True, transform=train_transform)
print(train_set.data.shape)
print(train_set.data.mean(axis=(0, 1, 2)) / 255)
print(train_set.data.std(axis=(0, 1, 2)) / 255)

Files already downloaded and verified
(50000, 32, 32, 3)
[0.49139968 0.48215841 0.44653091]
[0.24703223 0.24348513 0.26158784]


In [7]:
batch_size = 128
mean       = [0.49139968, 0.48215841, 0.44653091]
std        = [0.24703223, 0.24348513, 0.26158784]

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])

CLASSES = [
    'airplane', 'automobile', 'bird', 'cat', 'deer',
    'dog', 'frog', 'horse', 'ship', 'truck'
]

In [8]:
def denormalize(img, mean, std):
    img = img.numpy().transpose((1, 2, 0))
    img = img * std + mean
    return np.clip(img, 0, 1)

def show_samples(data_loader):
    data_iter = iter(data_loader)
    images, labels = next(data_iter)
    images = images[:5]
    labels = labels[:5]
    fig, axes = plt.subplots(1, 5, figsize=(15, 3))
    for i, ax in enumerate(axes):
        img = denormalize(images[i], mean, std)
        ax.imshow(img)
        ax.set_title(f"{CLASSES[labels[i].item()]}", fontsize=20)
        ax.axis("off")
    plt.show()

def show_adv_images_with_labels(adv_examples, true_labels, pred_labels_orig, pred_labels_adv):
    orig_images, adv_images = adv_examples[0]  # Example batch
    fig, axes = plt.subplots(3, 5, figsize=(15, 8))
    for i in range(5):
        orig = denormalize(orig_images[i].detach().cpu(), mean, std)
        adv = denormalize(adv_images[i].detach().cpu(), mean, std)
        diff = np.abs(adv - orig)

        true_label = CLASSES[true_labels[i]]
        orig_pred_label = CLASSES[pred_labels_orig[i]]
        adv_pred_label = CLASSES[pred_labels_adv[i]]

        axes[0, i].imshow(orig)
        axes[0, i].set_title(f"True: {true_label}\nPred: {orig_pred_label}")
        axes[0, i].axis("off")
        axes[1, i].imshow(adv)
        axes[1, i].set_title(f"Adversarial\nPred: {adv_pred_label}")
        axes[1, i].axis("off")
        axes[2, i].imshow(diff)
        axes[2, i].set_title("Difference")
        axes[2, i].axis("off")
    plt.tight_layout()
    plt.show()

def visualize_adversarial_samples(generator, test_loader):
    generator.eval()

    images, labels = next(iter(test_loader))
    images, labels = images[:5].to(device), labels[:5].to(device)
    with torch.no_grad():
        perturbations = generator(images)
        adv_images = torch.clamp(images + perturbations, -1, 1)

    images = images.cpu()
    adv_images = adv_images.cpu()
    perturbations = perturbations.cpu()

    fig, axes = plt.subplots(3, 5, figsize=(15, 8))
    for i in range(5):
        axes[0, i].imshow(denormalize(images[i], mean, std))
        axes[0, i].set_title(f"Original: {CLASSES[labels[i]]}")
        axes[0, i].axis("off")
        axes[1, i].imshow(denormalize(adv_images[i], mean, std))
        axes[1, i].set_title("Adversarial")
        axes[1, i].axis("off")
        axes[2, i].imshow(denormalize(perturbations[i], mean, std), cmap="seismic")
        axes[2, i].set_title("Perturbation")
        axes[2, i].axis("off")
    plt.tight_layout()
    plt.show()

def plot_losses(epochs, g_losses, d_losses, attack_success_rates, val_g_losses, val_d_losses, val_attack_success_rates):
# val_g_accuracies, val_model_accuracies, g_accuracies, model_accuracies, 
    actual_epochs = len(g_losses)    
    plt.figure(figsize=(18, 10))
    plt.subplot(2, 3, 1)
    plt.plot(range(1, actual_epochs + 1), g_losses, label='Generator Loss', color='blue')
    plt.plot(range(1, actual_epochs + 1), d_losses, label='Discriminator Loss', color='red')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Generator and Discriminator Losses')
    plt.legend()
    plt.subplot(2, 3, 2)
    plt.plot(range(1, actual_epochs + 1), attack_success_rates, label='Attack Success Rate', color='orange')
    plt.plot(range(1, actual_epochs + 1), val_attack_success_rates, label='Validation Attack Success Rate', color='green')
    plt.xlabel('Epochs')
    plt.ylabel('Success Rate')
    plt.title('Attack Success Rates')
    plt.legend()
    plt.subplot(2, 3, 3)
    plt.plot(range(1, actual_epochs + 1), val_g_losses, label='Validation Generator Loss', color='blue')
    plt.plot(range(1, actual_epochs + 1), val_d_losses, label='Validation Discriminator Loss', color='red')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Validation Generator and Discriminator Losses')
    plt.legend()
    plt.tight_layout()
    plt.show()

def plot_confidence_histograms(target_model, test_loader, generator):
    normal_confidences = []
    adversarial_confidences = []

    generator.eval()
    target_model.eval()
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            perturbations = generator(images)
            adv_images = torch.clamp(images + perturbations, -1, 1)
            # Get softmax confidence for normal and adversarial images
            normal_outputs = torch.softmax(target_model(images), dim=1)
            adv_outputs = torch.softmax(target_model(adv_images), dim=1)

            normal_confidences.extend(normal_outputs.max(1)[0].cpu().numpy())
            adversarial_confidences.extend(adv_outputs.max(1)[0].cpu().numpy())
    plt.figure(figsize=(10, 5))
    plt.hist(adversarial_confidences, bins=20, alpha=0.7, label="Adversarial Samples", color="red")
    plt.hist(normal_confidences, bins=20, alpha=0.7, label="Normal Samples", color="blue")
    plt.xlabel("Confidence")
    plt.ylabel("Frequency")
    plt.title("Confidence Histogram: Normal vs Adversarial")
    plt.legend()
    plt.show()

In [9]:
def attack_success_rate(model, adv_examples):
    total, success = 0, 0
    model.eval()
    for orig_images, adv_images in adv_examples:
        with torch.no_grad():
            orig_preds = model(orig_images).argmax(dim=1)
            adv_preds = model(adv_images).argmax(dim=1)
            success += (orig_preds != adv_preds).sum().item()
            total += orig_images.size(0)
    return success / total

def calculate_attack_success(target_model, test_loader, generator):
    overall_success = 0
    per_class_success = np.zeros(10)
    per_class_count = np.zeros(10)

    generator.eval()
    target_model.eval()
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            perturbations = generator(images)
            adv_images = torch.clamp(images + perturbations, -1, 1)
            # Get predictions on adversarial samples
            adv_outputs = target_model(adv_images)
            _, adv_preds = adv_outputs.max(1)
            fooled = (adv_preds != labels).cpu().numpy()
            overall_success += fooled.sum()
            for label, is_fooled in zip(labels.cpu().numpy(), fooled):
                per_class_count[label] += 1
                per_class_success[label] += is_fooled
    overall_rate = overall_success / sum(per_class_count)
    per_class_rate = per_class_success / per_class_count
    return overall_rate, per_class_rate

def calculate_targeted_attack_success(target_model, test_loader, generator, target_class):
    overall_success = 0
    per_class_success = np.zeros(10)
    per_class_count = np.zeros(10)

    generator.eval()
    target_model.eval()

    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            perturbations = generator(images)
            adv_images = torch.clamp(images + perturbations, -1, 1)

            # Get predictions on adversarial samples
            adv_outputs = target_model(adv_images)
            _, adv_preds = adv_outputs.max(1)

            # Check if target model was successfully targeted
            targeted = (adv_preds == target_class).cpu().numpy()
            overall_success += targeted.sum()

            # Update per-class metrics
            for label, is_targeted in zip(labels.cpu().numpy(), targeted):
                per_class_count[label] += 1
                per_class_success[label] += is_targeted

    overall_rate = overall_success / sum(per_class_count)
    per_class_rate = per_class_success / per_class_count

    return overall_rate, per_class_rate

In [10]:
dataset = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
print(dataset.data.shape)
print(dataset.data.mean(axis=(0, 1, 2)) / 255)
print(dataset.data.std(axis=(0, 1, 2)) / 255)
test_dataset = datasets.CIFAR10(root="./data", train=False, download=True, transform=transform)

# Split training dataset into training and validation
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader   = DataLoader(val_dataset,   batch_size=batch_size, shuffle=False)
test_loader  = DataLoader(test_dataset,  batch_size=batch_size, shuffle=False)

Files already downloaded and verified
(50000, 32, 32, 3)
[0.49139968 0.48215841 0.44653091]
[0.24703223 0.24348513 0.26158784]
Files already downloaded and verified


In [11]:
# show_samples(train_loader)
# show_samples(val_loader)
# show_samples(test_loader)

In [12]:
target_model = torch.hub.load(
    "chenyaofo/pytorch-cifar-models", "cifar10_resnet20", pretrained=True
).to(device)
target_model.eval()

Using cache found in /root/.cache/torch/hub/chenyaofo_pytorch-cifar-models_master


CifarResNet(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias

In [13]:
# # Evaluate accuracy on test set
# correct = 0
# total = 0
# with torch.no_grad():
#     for images, labels in test_loader:
#         images, labels = images.to(device), labels.to(device)
#         outputs = target_model(images)
#         _, predicted = torch.max(outputs.data, 1)
#         total += labels.size(0)
#         correct += (predicted == labels).sum().item()
# print(f"Test Accuracy: {100 * correct / total}%")

# ## ................................................ ##

# # Generate adversarial images using FGSM
# epsilon = 0.01

# adv_examples = []
# true_labels = []
# pred_labels_orig = []
# pred_labels_adv = []

# target_model.eval()
# for images, labels in test_loader:
#     images, labels = images.to(device), labels.to(device)

#     adv_images = fast_gradient_method(target_model, images, epsilon, norm=np.inf)
#     adv_examples.append((images, adv_images))
#     true_labels.extend(labels.cpu().numpy())
    
#     with torch.no_grad():
#         orig_preds = target_model(images).argmax(dim=1)
#         adv_preds = target_model(adv_images).argmax(dim=1)
#         pred_labels_orig.extend(orig_preds.cpu().numpy())
#         pred_labels_adv.extend(adv_preds.cpu().numpy())

# success_rate = attack_success_rate(target_model, adv_examples)
# print(f"Attack Success Rate: {success_rate * 100:.2f}%")

# ## ................................................ ##

# show_samples(train_loader)
# show_samples(val_loader)
# show_samples(test_loader)

# show_adv_images_with_labels(adv_examples, true_labels, pred_labels_orig, pred_labels_adv)

In [14]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1),  # 32x32 -> 16x16
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),  # 16x16 -> 8x8
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),  # 8x8 -> 4x4
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1),  # 4x4 -> 2x2
            nn.BatchNorm2d(512)
        )
        self.decoder = nn.Sequential(
            nn.ReLU(True),
            nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, output_padding=1),  # 2x2 -> 4x4
            nn.BatchNorm2d(256),
            nn.Dropout(0.5),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1),  # 4x4 -> 8x8
            nn.BatchNorm2d(128),
            nn.Dropout(0.5),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),  # 8x8 -> 16x16
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 3, kernel_size=3, stride=2, padding=1, output_padding=1),  # 16x16 -> 32x32
            nn.Tanh()
        )
    
    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 8, kernel_size=4, stride=2, padding=1),  # 32x32 -> 16x16
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(8, 16, kernel_size=4, stride=2, padding=1),  # 16x16 -> 8x8
            nn.BatchNorm2d(16),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(16, 32, kernel_size=4, stride=2, padding=1),  # 8x8 -> 4x4
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(32, 1, kernel_size=4),  # 4x4 -> 1x1
            nn.Sigmoid()  # Normalize output to [0, 1]
            # nn.Flatten(),  # Flatten the feature map: 4x4x32 -> 512
            # nn.Linear(4 * 4 * 32, 1)  # Linear layer for final output
        )
    def forward(self, x):
        return self.model(x).view(-1)

In [15]:
# def gan_loss(discriminator, original_images, perturbations):
#     mse_loss = nn.MSELoss()
#     real_preds = discriminator(original_images)
#     real_targets = torch.ones_like(real_preds)  # Targets for real images are 1
#     real_loss = mse_loss(real_preds, real_targets)
#     fake_images = original_images + perturbations
#     fake_preds = discriminator(fake_images)
#     fake_targets = torch.zeros_like(fake_preds)  # Targets for fake images are 0
#     fake_loss = mse_loss(fake_preds, fake_targets)
#     total_gan_loss = real_loss + fake_loss
#     return total_gan_loss
# def adversarial_loss(target_model, adv_images, target_labels):
#     outputs = target_model(adv_images)
#     return F.cross_entropy(outputs, target_labels)
# def hinge_loss(perturbations, c):
#     norm = torch.norm(perturbations.view(perturbations.size(0), -1), dim=1)  # L2 norm
#     hinge = F.relu(norm - c)
#     return torch.mean(hinge)

class LossFunctions:
    @staticmethod
    def gan_loss():
        return nn.MSELoss()

    @staticmethod
    def adv_loss(logits, target, num_classes=10, kappa=0):
        target_one_hot = torch.eye(num_classes).type(logits.type())[target.long()]
        real = torch.sum(target_one_hot * logits, 1)
        other = torch.max((1 - target_one_hot) * logits - (target_one_hot * 10000), 1)[0]
        kappa = torch.zeros_like(other).fill_(kappa)
        return torch.sum(torch.max(real - other, kappa))

    @staticmethod
    def hinge_loss(perturb, c):
        norm = torch.norm(perturb.view(perturb.size(0), -1), dim=1)
        clamped = torch.clamp(norm - c, min=0)
        return torch.mean(clamped)

    @staticmethod
    def total_loss(adv_loss_value, gan_loss_value, hinge_loss_value, alpha, beta):
        return alpha * gan_loss_value + beta * hinge_loss_value + adv_loss_value

    @staticmethod
    def total_gan_loss(real_loss, fake_loss):
        return real_loss + fake_loss
        real_loss = LossFunctions.gan_loss()(D(images), torch.ones(batch_size, device=device))
        fake_loss = LossFunctions.gan_loss()(D(adversarial_images.detach()), torch.zeros(batch_size, device=device))

In [16]:
def attack_success(model, adversarial_images, labels):
    with torch.no_grad():
        adv_preds = model(adversarial_images).argmax(dim=1)
        success = (labels != adv_preds).sum().item()
    return success

def train_epoch(epoch, G, D, train_loader, optimizer_G, optimizer_D, target_model, alpha, beta, epochs):
    G.train(), D.train()
    train_g_loss, train_d_loss, train_attack_success_counts = 0, 0, 0
    tepoch = tqdm(train_loader, unit="batch", desc=f"Training Epoch {epoch+1}/{epochs}", ncols=100, leave=False)
    for batch_idx, (images, labels) in enumerate(tepoch):
        images, labels = images.to(device), labels.to(device)
        batch_size = images.size(0)
        perturbations = G(images)
        adversarial_images = images + perturbations

        optimizer_D.zero_grad()
        d_loss = LossFunctions.total_gan_loss(
            real_loss=LossFunctions.gan_loss()(D(images), torch.ones(batch_size, device=device)), 
            fake_loss=LossFunctions.gan_loss()(D(adversarial_images.detach()), torch.zeros(batch_size, device=device))
        )
        d_loss.backward()
        optimizer_D.step()

        optimizer_G.zero_grad()
        g_loss = LossFunctions.total_loss(
            adv_loss_value=LossFunctions.adv_loss(target_model(adversarial_images), labels),
            gan_loss_value=LossFunctions.gan_loss()(D(adversarial_images), torch.ones(batch_size, device=device)), 
            hinge_loss_value=LossFunctions.hinge_loss(perturbations, c), 
            alpha=alpha, beta=beta
        )
        g_loss.backward()
        optimizer_G.step()

        train_g_loss += g_loss.item()
        train_d_loss += d_loss.item()
        train_attack_success_counts += attack_success(target_model, adversarial_images, labels)
        tepoch.set_postfix({'Batch': batch_idx + 1, 'G_Loss': g_loss.item(), 'D_Loss': d_loss.item()})
    return (train_g_loss/len(train_loader)), (train_d_loss/len(train_loader)), (train_attack_success_counts/len(train_dataset))

def validate_epoch(G, D, val_loader, target_model, alpha, beta):
    G.eval(), D.eval()
    val_g_loss, val_d_loss, val_attack_success_counts = 0, 0, 0
    with torch.no_grad():
        for val_images, val_labels in val_loader:
            val_images, val_labels = val_images.to(device), val_labels.to(device)
            batch_size = val_images.size(0)
            val_perturbations = G(val_images)
            val_adversarial_images = val_images + val_perturbations

            val_d_loss_batch = LossFunctions.total_gan_loss(
                real_loss = LossFunctions.gan_loss()(D(val_images), torch.ones(batch_size, device=device)),
                fake_loss = LossFunctions.gan_loss()(D(val_adversarial_images), torch.zeros(batch_size, device=device))
            )
            val_g_loss_batch = LossFunctions.total_loss(
                adv_loss_value=LossFunctions.adv_loss(target_model(val_adversarial_images), val_labels),
                gan_loss_value=LossFunctions.gan_loss()(D(val_adversarial_images), torch.ones(batch_size, device=device)), 
                hinge_loss_value=LossFunctions.hinge_loss(val_perturbations, c), 
                alpha=alpha, beta=beta
            )
            
            val_g_loss += val_g_loss_batch.item()
            val_d_loss += val_d_loss_batch.item()
            val_attack_success_counts += attack_success(target_model, val_adversarial_images, val_labels)
    return (val_g_loss/len(val_loader)), (val_d_loss/len(val_loader)), (val_attack_success_counts/len(val_dataset))

In [17]:
G = Generator().to(device)
D = Discriminator().to(device)

epochs = 50 
lr = 0.001
alpha = 1.0 # param alpha: Weight for adversarial loss (LGAN) 
beta = 10.0 # param beta: Weight for hinge loss (Lhinge)
c = 8/255   # param c: Hinge loss bound
patience = 12
8 / (255 * torch.tensor(std).max().item())
optimizer_G = optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999)) # , betas=(0.5, 0.999)
optimizer_D = optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999)) # , betas=(0.5, 0.999)

g_scheduler = StepLR(optimizer_G, step_size=10, gamma=0.5)
d_scheduler = StepLR(optimizer_D, step_size=10, gamma=0.5)

In [None]:
best_val_loss = float('inf')
patience_counter = 0
g_losses, d_losses, g_accuracies = [], [], []
val_g_losses, val_d_losses = [], []
attack_success_rates, val_attack_success_rates = [], []

for epoch in range(epochs):
    g_loss, d_loss, attack_success_rate = train_epoch(
        epoch, G, D, train_loader, optimizer_G, optimizer_D, target_model, alpha, beta, epochs
    )
    g_losses.append(g_loss)
    d_losses.append(d_loss)
    attack_success_rates.append(attack_success_rate)
    
    val_g_loss, val_d_loss, val_attack_success_rate = validate_epoch(
        G, D, val_loader, target_model, alpha, beta
    )
    val_g_losses.append(val_g_loss)
    val_d_losses.append(val_d_loss)
    val_attack_success_rates.append(val_attack_success_rate)
    
    print(f"Epoch {epoch+1}/{epochs} | g_loss: {g_loss:.4f}, d_loss: {d_loss:.4f}, val_g_loss: {val_g_loss:.4f}, val_d_loss: {val_d_loss:.4f}, attack_success_rate: {attack_success_rate:.4f}, val_attack_success_rate: {val_attack_success_rate:.4f}")        
    g_scheduler.step(), d_scheduler.step()
    print(f"Learning rate for generator: {g_scheduler.get_last_lr()[0]}, Learning rate for discriminator: {d_scheduler.get_last_lr()[0]}") 

    if val_g_loss < best_val_loss:
        best_val_loss = val_g_loss
        patience_counter = 0
        torch.save(G.state_dict(), 'best_generator.pth')
        torch.save(D.state_dict(), 'best_discriminator.pth')
    else:
        print(f"Patience counter: {patience_counter + 1}/{patience}")
        patience_counter += 1
    if patience_counter >= patience:
        print(f"Early stopping triggered after {epoch + 1} epochs.")
        break

                                                                                                    

Epoch 1/50 | g_loss: 169.3494, d_loss: 0.0154, val_g_loss: 127.7969, val_d_loss: 0.2079, attack_success_rate: 0.8641, val_attack_success_rate: 0.8797
Learning rate for generator: 0.001, Learning rate for discriminator: 0.001


                                                                                                    

Epoch 2/50 | g_loss: 129.4487, d_loss: 0.0000, val_g_loss: 118.6016, val_d_loss: 0.1348, attack_success_rate: 0.8860, val_attack_success_rate: 0.8984
Learning rate for generator: 0.001, Learning rate for discriminator: 0.001


                                                                                                    

Epoch 3/50 | g_loss: 118.6614, d_loss: 0.0000, val_g_loss: 110.9631, val_d_loss: 0.2069, attack_success_rate: 0.8967, val_attack_success_rate: 0.8882
Learning rate for generator: 0.001, Learning rate for discriminator: 0.001


                                                                                                    

Epoch 4/50 | g_loss: 112.9163, d_loss: 0.0000, val_g_loss: 111.0995, val_d_loss: 0.4258, attack_success_rate: 0.9020, val_attack_success_rate: 0.8950
Learning rate for generator: 0.001, Learning rate for discriminator: 0.001
Patience counter: 1/12


                                                                                                    

Epoch 5/50 | g_loss: 108.1064, d_loss: 0.0000, val_g_loss: 104.4267, val_d_loss: 0.5058, attack_success_rate: 0.9091, val_attack_success_rate: 0.9002
Learning rate for generator: 0.001, Learning rate for discriminator: 0.001


                                                                                                    

Epoch 6/50 | g_loss: 106.1960, d_loss: 0.0000, val_g_loss: 102.7951, val_d_loss: 0.5144, attack_success_rate: 0.9112, val_attack_success_rate: 0.9244
Learning rate for generator: 0.001, Learning rate for discriminator: 0.001


                                                                                                    

Epoch 7/50 | g_loss: 103.2523, d_loss: 0.0000, val_g_loss: 107.6912, val_d_loss: 0.5481, attack_success_rate: 0.9121, val_attack_success_rate: 0.9028
Learning rate for generator: 0.001, Learning rate for discriminator: 0.001
Patience counter: 1/12


                                                                                                    

Epoch 8/50 | g_loss: 100.5570, d_loss: 0.0000, val_g_loss: 98.7040, val_d_loss: 0.4503, attack_success_rate: 0.9167, val_attack_success_rate: 0.9286
Learning rate for generator: 0.001, Learning rate for discriminator: 0.001


                                                                                                    

Epoch 9/50 | g_loss: 98.2213, d_loss: 0.0000, val_g_loss: 100.8930, val_d_loss: 0.4654, attack_success_rate: 0.9198, val_attack_success_rate: 0.8936
Learning rate for generator: 0.001, Learning rate for discriminator: 0.001
Patience counter: 1/12


                                                                                                    

Epoch 10/50 | g_loss: 97.5302, d_loss: 0.0000, val_g_loss: 100.0644, val_d_loss: 0.4982, attack_success_rate: 0.9218, val_attack_success_rate: 0.9152
Learning rate for generator: 0.0005, Learning rate for discriminator: 0.0005
Patience counter: 2/12


                                                                                                    

Epoch 11/50 | g_loss: 91.8798, d_loss: 0.0000, val_g_loss: 100.0839, val_d_loss: 0.6150, attack_success_rate: 0.9264, val_attack_success_rate: 0.8870
Learning rate for generator: 0.0005, Learning rate for discriminator: 0.0005
Patience counter: 3/12


                                                                                                    

Epoch 12/50 | g_loss: 89.0996, d_loss: 0.0000, val_g_loss: 91.8105, val_d_loss: 0.5694, attack_success_rate: 0.9297, val_attack_success_rate: 0.9075
Learning rate for generator: 0.0005, Learning rate for discriminator: 0.0005


                                                                                                    

Epoch 13/50 | g_loss: 87.7193, d_loss: 0.0000, val_g_loss: 94.3415, val_d_loss: 0.6295, attack_success_rate: 0.9332, val_attack_success_rate: 0.9308
Learning rate for generator: 0.0005, Learning rate for discriminator: 0.0005
Patience counter: 1/12


                                                                                                    

Epoch 14/50 | g_loss: 86.2813, d_loss: 0.0000, val_g_loss: 90.6174, val_d_loss: 0.7004, attack_success_rate: 0.9360, val_attack_success_rate: 0.9338
Learning rate for generator: 0.0005, Learning rate for discriminator: 0.0005


                                                                                                    

Epoch 15/50 | g_loss: 85.0467, d_loss: 0.0000, val_g_loss: 92.9536, val_d_loss: 0.5762, attack_success_rate: 0.9372, val_attack_success_rate: 0.9112
Learning rate for generator: 0.0005, Learning rate for discriminator: 0.0005
Patience counter: 1/12


                                                                                                    

Epoch 16/50 | g_loss: 83.4629, d_loss: 0.0000, val_g_loss: 91.4607, val_d_loss: 0.5922, attack_success_rate: 0.9388, val_attack_success_rate: 0.9103
Learning rate for generator: 0.0005, Learning rate for discriminator: 0.0005
Patience counter: 2/12


Training Epoch 17/50:  72%|▋| 226/313 [00:15<00:05, 14.60batch/s, Batch=228, G_Loss=85, D_Loss=2.38e

In [None]:
# best_val_loss = float('inf')
# patience_counter = 0
# g_losses, d_losses, g_accuracies = [], [], []
# val_g_losses, val_d_losses = [], []
# attack_success_rates, val_attack_success_rates = [], []
# # g_accuracies, val_g_accuracies = [], []
# # model_accuracies, val_model_accuracies = [], []
# for epoch in range(epochs):
#     g_loss, d_loss, attack_success_rate = train_epoch(
#         epoch, G, D, train_loader, optimizer_G, optimizer_D, target_model, alpha, beta
#     )
#     g_losses.append(g_loss)
#     d_losses.append(d_loss)
#     attack_success_rates.append(attack_success_rate)
    
#     val_g_loss, val_d_loss, val_attack_success_rate = validate_epoch(
#         G, D, val_loader, target_model, alpha, beta
#     )
#     val_g_losses.append(val_g_loss)
#     val_d_losses.append(val_d_loss)
#     val_attack_success_rates.append(val_attack_success_rate)
    
#     print(f"Epoch {epoch+1}/{epochs} | g_loss: {g_loss:.4f}, d_loss: {d_loss:.4f}, val_g_loss: {val_g_loss:.4f}, val_d_loss: {val_d_loss:.4f}, attack_success_rate: {attack_success_rate:.4f}, val_attack_success_rate: {val_attack_success_rate:.4f}")    
#     if val_g_loss < best_val_loss:
#         best_val_loss = val_g_loss
#         patience_counter = 0
#         torch.save(G.state_dict(), 'best_generator.pth')
#         torch.save(D.state_dict(), 'best_discriminator.pth')
#     else:
#         print(f"Patience counter: {patience_counter + 1}/{patience}")
#         patience_counter += 1

#     if patience_counter >= patience:
#         print(f"Early stopping triggered after {epoch + 1} epochs.")
#         print("Early stopping triggered.")
#         break

In [None]:
actual_epochs = len(g_losses)    
plt.figure(figsize=(18, 10))

plt.subplot(2, 3, 1)
plt.plot(range(1, actual_epochs + 1), g_losses, label='Generator Loss', color='blue')
plt.plot(range(1, actual_epochs + 1), d_losses, label='Discriminator Loss', color='red')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Generator and Discriminator Losses')
plt.legend()

plt.subplot(2, 3, 2)
plt.plot(range(1, actual_epochs + 1), attack_success_rates, label='Attack Success Rate', color='orange')
plt.plot(range(1, actual_epochs + 1), val_attack_success_rates, label='Validation Attack Success Rate', color='green')
plt.xlabel('Epochs')
plt.ylabel('Success Rate')
plt.title('Attack Success Rates')
plt.legend()

plt.subplot(2, 3, 3)
plt.plot(range(1, actual_epochs + 1), val_g_losses, label='Validation Generator Loss', color='blue')
plt.plot(range(1, actual_epochs + 1), val_d_losses, label='Validation Discriminator Loss', color='red')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Validation Generator and Discriminator Losses')
plt.legend()

plt.tight_layout()
plt.show()

In [None]:
# def train_gan(
#     G, D, target_model, train_loader, val_loader,
#     epochs, optimizer_G, optimizer_D, c, alpha, beta, 
#     g_scheduler, d_scheduler, patience
# ):
#     g_losses, d_losses, g_accuracies = [], [], []
#     val_g_losses, val_d_losses, val_g_accuracies = [], [], []
#     model_accuracies, attack_success_rates = [], []
#     val_model_accuracies, val_attack_success_rates = [], []
#     best_g_loss, best_d_loss = float('inf'), float('inf')
#     epochs_no_improve = 0

#     for epoch in range(epochs):
#         correct_adv, correct_benign, total_samples = 0, 0, 0
#         g_loss_epoch, d_loss_epoch = 0.0, 0.0

#         tepoch = tqdm(train_loader, unit="batch", desc=f"Epoch {epoch+1}/{epochs}", ncols=100, leave=False)

#         for batch_idx, (images, labels) in enumerate(tepoch):
#             images, labels = images.to(device), labels.to(device)
#             G.train()
#             D.train()
#             target_model.eval()
        
#             """
#             Generator Update
#             """
#             optimizer_G.zero_grad()
#             perturbations = G(images)  # Recompute perturbations for generator update
#             g_loss = LossFunctions.total_loss(
#                 D, target_model,
#                 alpha, beta, c, images, perturbations, labels
#             )
#             g_loss.backward()
#             optimizer_G.step()
#             """
#             Discriminator Update
#             """
#             optimizer_D.zero_grad()
#             with torch.no_grad():  # Disable gradient tracking for generator during discriminator update
#                 perturbations = G(images)
#                 adv_images = torch.clamp(images + perturbations, -1, 1)
        
#             d_loss = LossFunctions.gan_loss(D, images, perturbations)
#             d_loss.backward()
#             optimizer_D.step() 

            
#             # Track losses and accuracy
#             g_loss_epoch += g_loss.item()
#             d_loss_epoch += d_loss.item()
        
#             with torch.no_grad():
#                 benign_output = target_model(images)
#                 _, benign_predicted = benign_output.max(1)
#                 adv_output = target_model(adv_images)
#                 _, adv_predicted = adv_output.max(1)
        
#                 correct_benign += (benign_predicted == labels).sum().item()
#                 correct_adv += (adv_predicted != benign_predicted).sum().item()
#                 total_samples += labels.size(0)
        
#             tepoch.set_postfix({'Batch': batch_idx + 1, 'G_Loss': g_loss.item(), 'D_Loss': d_loss.item()})
        
#         g_losses.append(g_loss_epoch / len(train_loader))
#         d_losses.append(d_loss_epoch / len(train_loader))
#         # len(train_loader):313 O total_samples:40000
#         print(f"len(train_loader):{len(train_loader)} O total_samples:{total_samples}")
#         print(f"correct_adv:{correct_adv} O total_samples:{total_samples}")
#         g_accuracies.append(correct_adv / total_samples)
#         model_accuracies.append(correct_benign / total_samples)
#         attack_success_rates.append(correct_adv / total_samples)

#         G.eval()
#         D.eval()
#         target_model.eval()
#         val_g_loss_epoch, val_d_loss_epoch = 0.0, 0.0
#         val_correct_adv, val_correct_benign, val_samples = 0, 0, 0
#         with torch.no_grad():
#             for val_images, val_labels in val_loader:
#                 val_images, val_labels = val_images.to(device), val_labels.to(device)

#                 val_perturbations = G(val_images)
#                 val_adv_images = torch.clamp(val_images + val_perturbations, -1, 1)

#                 val_d_loss = LossFunctions.gan_loss(D, val_images, val_perturbations)
#                 val_g_loss = LossFunctions.total_loss(
#                     D, target_model,
#                     alpha, beta, c, val_images, val_perturbations, val_labels
#                 )

#                 val_g_loss_epoch += val_g_loss.item()
#                 val_d_loss_epoch += val_d_loss.item()

#                 val_benign_output = target_model(val_images)
#                 _, val_benign_predicted = val_benign_output.max(1)
#                 val_adv_output = target_model(val_adv_images)
#                 _, val_adv_predicted = val_adv_output.max(1)

#                 val_correct_benign += (val_benign_predicted == val_labels).sum().item()
#                 val_correct_adv += (val_adv_predicted != val_labels).sum().item()
#                 val_samples += val_labels.size(0)

#         g_scheduler.step()
#         d_scheduler.step()
        
#         val_g_losses.append(val_g_loss_epoch / len(val_loader))
#         val_d_losses.append(val_d_loss_epoch / len(val_loader))
#         val_g_accuracies.append(val_correct_adv / val_samples)
#         val_model_accuracies.append(val_correct_benign / val_samples)
#         val_attack_success_rates.append(val_correct_adv / val_samples)

#         print(f"Epoch [{epoch+1}/{epochs}], "
#               f"Train G_Loss: {g_losses[-1]:.4f}, D_Loss: {d_losses[-1]:.4f}, Model Accuracy: {model_accuracies[-1]:.4f}, "
#               f"Attack Success Rate: {attack_success_rates[-1]:.4f}, "
#               f"Val G_Loss: {val_g_losses[-1]:.4f}, Val D_Loss: {val_d_losses[-1]:.4f}, "
#               f"Val Model Accuracy: {val_model_accuracies[-1]:.4f}, Val Attack Success Rate: {val_attack_success_rates[-1]:.4f}")
#         if val_g_losses[-1] < best_g_loss or val_d_losses[-1] < best_d_loss:
#             best_g_loss = min(val_g_losses[-1], best_g_loss)
#             best_d_loss = min(val_d_losses[-1], best_d_loss)
#             epochs_no_improve = 0
#         else:
#             epochs_no_improve += 1
#             print(f"Epochs without improvement: {epochs_no_improve}/{patience}")
#             if epochs_no_improve >= patience:
#                 print(f"Early stopping triggered after {epoch + 1} epochs.")
#                 break
#     return (g_losses, d_losses, g_accuracies, model_accuracies, attack_success_rates,
#             val_g_losses, val_d_losses, val_g_accuracies, val_model_accuracies, val_attack_success_rates)

# (g_losses, d_losses, g_accuracies, model_accuracies, attack_success_rates, 
#  val_g_losses, val_d_losses, val_g_accuracies, val_model_accuracies, 
#  val_attack_success_rates) = train_gan(
#     G, D, target_model, train_loader, val_loader,
#     epochs, optimizer_G, optimizer_D, c, alpha, beta, 
#     g_scheduler, d_scheduler, patience
# )

# plot_losses(
#     epochs, g_losses, d_losses, g_accuracies, model_accuracies, 
#     attack_success_rates, val_g_losses, val_d_losses, 
#     val_g_accuracies, val_model_accuracies, val_attack_success_rates
# )

In [None]:
overall_rate, per_class_rate = calculate_attack_success(target_model, test_loader, generator)

print(f"Overall Attack Success Rate: {overall_rate * 100:.2f}%")
for i, class_name in enumerate(CLASSES):
    print(f"Class {class_name}: {per_class_rate[i] * 100:.2f}%")

In [None]:
visualize_adversarial_samples(generator, test_loader)

In [None]:
plot_confidence_histograms(target_model, test_loader, generator)

In [None]:
# def train_gan(
#     generator, discriminator, target_model, train_loader, 
#     epochs, g_optimizer, d_optimizer, c, alpha, beta, 
#     g_scheduler, d_scheduler, patience
# ):
#     g_losses, d_losses, g_accuracies = [], [], []
#     best_g_loss, best_d_loss = float('inf'), float('inf')
#     epochs_no_improve = 0

#     for epoch in range(epochs):
#         correct_adv = 0
#         g_loss_epoch, d_loss_epoch = 0.0, 0.0
#         tepoch = tqdm(train_loader, unit="batch", desc=f"Epoch {epoch+1}/{epochs}", ncols=100, leave=False)

#         generator.train()
#         discriminator.train()
#         for batch_idx, (images, labels) in enumerate(tepoch):
#             images, labels = images.to(device), labels.to(device)

#             perturbations = generator(images)
#             adv_images = torch.clamp(images + perturbations, -1, 1)

#             d_optimizer.zero_grad()
#             d_loss = LossFunctions.gan_loss(
#                 discriminator, images, perturbations
#             )
#             d_loss.backward()
#             d_optimizer.step()

#             g_optimizer.zero_grad()
#             g_loss = LossFunctions.total_loss(
#                 discriminator, target_model, 
#                 alpha, beta, c, images, perturbations
#             )
#             g_loss.backward()
#             g_optimizer.step()

#             g_loss_epoch += g_loss.item()
#             d_loss_epoch += d_loss.item()
#             adv_output = target_model(adv_images)
#             _, predicted = adv_output.max(1)
#             correct_adv += (predicted != labels).sum().item()
#             tepoch.set_postfix({'Batch': batch_idx + 1, 'G_Loss': g_loss.item(), 'D_Loss': d_loss.item()})

#         g_losses.append(g_loss_epoch / len(train_loader))
#         d_losses.append(d_loss_epoch / len(train_loader))
#         g_accuracies.append(correct_adv / len(train_loader))
        
#         print(f"Epoch [{epoch+1}/{epochs}], Generator Loss: {g_losses[-1]:.4f}, Discriminator Loss: {d_losses[-1]:.4f}, Fooling Accuracy: {g_accuracies[-1]:.4f}")
#         generator.eval()
#         g_scheduler.step()
        
#         discriminator.eval()
#         d_scheduler.step()

#         if g_losses[-1] < best_g_loss or d_losses[-1] < best_d_loss:
#             best_g_loss = min(g_losses[-1], best_g_loss)
#             best_d_loss = min(d_losses[-1], best_d_loss)
#             epochs_no_improve = 0
#         else:
#             print(f"Epochs without improvement: {epochs_no_improve + 1}/{patience}")
#             epochs_no_improve += 1
#             if epochs_no_improve >= patience:
#                 print(f"Early stopping triggered after {epoch + 1} epochs.")
#                 break

#     return g_losses, d_losses, g_accuracies

# g_losses, d_losses, g_accuracies = train_gan(
#     generator, discriminator, target_model, train_loader, 
#     epochs, g_optimizer, d_optimizer, c, alpha, beta, 
#     g_scheduler, d_scheduler, patience
# )

# plot_losses(epochs, g_losses, d_losses, g_accuracies)

In [None]:
# target_class = 0  # Example: Force adversarial images to be classified as "airplane" (class 0)
# fooling_loss = adv_loss_fn(adv_output, torch.full_like(labels, target_class))

# generator = Generator().to(device)
# discriminator = Discriminator().to(device)

# epochs = 50 
# lr = 0.001
# alpha = 1.0  # Weight for GAN loss
# beta = 10.0  # Weight for hinge loss
# # c = 0.1  # Perturbation bound
# c = 8/255 # Perturbation bound (c) = 8/255 for CIFAR-10

# g_optimizer = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
# d_optimizer = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))

# g_scheduler = StepLR(g_optimizer, step_size=10, gamma=0.5)
# d_scheduler = StepLR(d_optimizer, step_size=10, gamma=0.5)

# g_losses, d_losses, g_accuracies = train_gan(
#     generator, discriminator, target_model, train_loader, 
#     epochs, g_optimizer, d_optimizer, c, alpha, beta, 
#     g_scheduler, d_scheduler, patience
# )

# plot_losses(epochs, g_losses, d_losses, g_accuracies)
# overall_rate, per_class_rate = calculate_targeted_attack_success(target_model, test_loader, generator, target_class)

# print(f"Overall Targeted Attack Success Rate: {overall_rate * 100:.2f}%")
# for i, class_name in enumerate(CLASSES):
#     print(f"Class {class_name}: {per_class_rate[i] * 100:.2f}%")

# visualize_adversarial_samples(generator, test_loader)
# plot_confidence_histograms(target_model, test_loader, generator)