In [None]:
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np

# Load CIFAR-10 dataset
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()

# Define class names
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer',
               'dog', 'frog', 'horse', 'ship', 'truck']

In [None]:
# Import libraries
import torch
import torchvision.transforms as transforms
from torchvision import models

transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])# Define transforms: resize, convert to tensor, normalize

#Data Augmentation for training
robust_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),  # Random crop and resize
    transforms.RandomHorizontalFlip(),                    # Random horizontal flip
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),  # Color jitter
    transforms.RandomRotation(15),                        # Random rotation
    transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)),  # Random blur
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [None]:
import torchvision.transforms.functional as F

def unnormalize(img_tensor):
    device = img_tensor.device
    if img_tensor.dim() == 4:
        mean = torch.tensor([0.485, 0.456, 0.406], device=device).view(1, 3, 1, 1)
        std = torch.tensor([0.229, 0.224, 0.225], device=device).view(1, 3, 1, 1)
    else:
        mean = torch.tensor([0.485, 0.456, 0.406], device=device).view(3, 1, 1)
        std = torch.tensor([0.229, 0.224, 0.225], device=device).view(3, 1, 1)
    return img_tensor * std + mean

In [None]:
def convert_normalization(imgs):
    """
    Convert a batch of images from normalization (mean=0.5, std=0.5)
    to GoogLeNet normalization (mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]).
    imgs: torch.Tensor of shape (B, 3, H, W)
    Returns: torch.Tensor of same shape, normalized for GoogLeNet
    """
    # Unnormalize from (0.5, 0.5, 0.5) to [0, 1]
    imgs = imgs * 0.5 + 0.5
    # Normalize to GoogLeNet
    mean = torch.tensor([0.485, 0.456, 0.406], device=imgs.device).view(1, 3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225], device=imgs.device).view(1, 3, 1, 1)
    imgs = (imgs - mean) / std
    return imgs

In [None]:
import torch.nn as nn
from torchvision.models import GoogLeNet_Weights

def get_googlenet(pretrained=True):
    if pretrained:
        model = models.googlenet(weights=GoogLeNet_Weights.IMAGENET1K_V1)
    else:
        model = models.googlenet(weights=None, init_weights=True)  # Explicit init
    model.fc = nn.Linear(model.fc.in_features, 10)
    return model

In [None]:
from torch.utils.data import Dataset, DataLoader, ConcatDataset
from sklearn.model_selection import train_test_split

# Dataset for original images
class CIFAR10TorchDataset(Dataset):
    def __init__(self, images, labels, transform=None):
        self.images = images
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img = self.images[idx]
        label = self.labels[idx][0]
        img = img.astype('uint8')
        if self.transform:
            img = self.transform(img)
        return img, label

# Flatten y_train for stratification
y_train_flat = y_train.flatten()

# Split train into train/val (80/20) with stratification
x_train_split, x_val, y_train_split, y_val = train_test_split(
    x_train, y_train, test_size=0.2, random_state=42, stratify=y_train_flat
)

# Datasets: both normal and robust for train, only normal for val/test
train_dataset_normal = CIFAR10TorchDataset(x_train_split, y_train_split, transform=transform)
train_dataset_robust = CIFAR10TorchDataset(x_train_split, y_train_split, transform=robust_transform)
from torch.utils.data import ConcatDataset
train_dataset = ConcatDataset([train_dataset_normal, train_dataset_robust])

val_dataset = CIFAR10TorchDataset(x_val, y_val, transform=transform)
test_dataset = CIFAR10TorchDataset(x_test, y_test, transform=transform)

# DataLoaders
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

In [None]:
from tqdm import tqdm
import torch.optim as optim
import foolbox as fb

def adversarial_train_pgd(
    model, train_loader, val_loader, device, epochs=5, lr=1e-3, patience=5,
    epsilon=0.05, alpha=0.5, steps=10, rel_stepsize=0.1
):
    import copy

    model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)

    patience_counter = 0
    best_val_acc = 0.0
    best_weights = copy.deepcopy(model.state_dict())

    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} [AdvTrain]", leave=False)
        for images, labels in loop:
            images = images.to(device)
            labels = labels.to(device).squeeze()

            # Putting the images in the range [-1, 1]
            images_01 = unnormalize(images)
            images_pgd = images_01 * 2 - 1

            # Generate adversarial examples using PGD
            model.eval()
            fmodel = fb.PyTorchModel(model, bounds=(-1, 1))
            attack_pgd = fb.attacks.LinfPGD(steps=steps, rel_stepsize=rel_stepsize)
            advs, _, _ = attack_pgd(fmodel, images_pgd, labels, epsilons=epsilon)
            advs = convert_normalization(advs)  # Convert adversarial images back to original range
            model.train()

            # Compute loss as weighted sum (Goodfellow style)
            outputs_clean = model(images)
            outputs_adv = model(advs)
            loss_clean = criterion(outputs_clean, labels)
            loss_adv = criterion(outputs_adv, labels)
            loss = alpha * loss_clean + (1 - alpha) * loss_adv

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Accuracy on clean data (for info)
            _, predicted = outputs_clean.max(1)
            correct += predicted.eq(labels).sum().item()
            total += labels.size(0)
            running_loss += loss.item() * labels.size(0)
            loop.set_postfix(loss=loss.item())
        train_loss = running_loss / total
        train_acc = correct / total

        # Validation (on clean data)
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0
        loop_val = tqdm(val_loader, desc=f"Epoch {epoch+1}/{epochs} [Val]", leave=False)
        with torch.no_grad():
            for images, labels in loop_val:
                images = images.to(device)
                labels = labels.to(device).squeeze()
                outputs = model(images)
                loss = criterion(outputs, labels)
                val_loss += loss.item() * images.size(0)
                _, predicted = outputs.max(1)
                val_correct += predicted.eq(labels).sum().item()
                val_total += labels.size(0)
                loop_val.set_postfix(loss=loss.item())
        val_loss /= val_total
        val_acc = val_correct / val_total

        print(f"Epoch {epoch+1}/{epochs} | "
              f"AdvTrain Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f} | "
              f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")

        # Early stopping check (on accuracy as here we are doing adversarial training)
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            patience_counter = 0
            best_weights = copy.deepcopy(model.state_dict())
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"Early stopping at epoch {epoch+1}")
                model.load_state_dict(best_weights)
                break

    print(f"Best validation accuracy for alpha={alpha}: {best_val_acc:.4f}")
    return best_val_acc, best_weights

In [None]:
def train_multiple_alphas_epsilons_and_save_best_pgd(
    model_fn, train_loader, val_loader, device, alphas, epsilons,
    epochs=5, lr=1e-3, patience=5, steps=10, rel_stepsize=0.1, save_path="best_adv_pgd_aug.pth"
):
    best_acc = 0.0
    best_weights = None
    best_alpha = None
    best_epsilon = None
    for alpha in alphas:
        for epsilon in epsilons:
            print(f"\nTraining with alpha={alpha}, epsilon={epsilon}")
            model = model_fn()
            val_acc, weights = adversarial_train_pgd(
                model, train_loader, val_loader, device,
                epochs=epochs, lr=lr, patience=patience,
                epsilon=epsilon, alpha=alpha, steps=steps, rel_stepsize=rel_stepsize
            )
            if val_acc > best_acc:
                best_acc = val_acc
                best_weights = weights
                best_alpha = alpha
                best_epsilon = epsilon
    if best_weights is not None:
        torch.save(best_weights, save_path)
        print(f"Saved best model (alpha={best_alpha}, epsilon={best_epsilon}, val_acc={best_acc:.4f}) to {save_path}")
    return best_acc, best_alpha, best_epsilon, save_path

In [None]:
print(torch.cuda.is_available())  # Should be True
print(torch.cuda.device_count())  # Should be > 0
print(torch.cuda.get_device_name(0))  # Should return GPU name

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
model_fgsm = get_googlenet(pretrained=True)

In [None]:
print(f"Using device: {device}")
num_epochs = 100
learning_rate = 1e-4

alphas = [0.3, 0.5, 0.8]
epsilons = [0.05, 0.1]
best_acc, best_alpha, best_epsilon, best_path = train_multiple_alphas_epsilons_and_save_best_pgd(
    lambda: get_googlenet(pretrained=True),
    train_loader, val_loader, device,
    alphas=alphas, 
    epsilon=epsilons,epochs=5, 
    lr=1e-3, patience=5,
    save_path="best_adv_pgd_aug.pth"
)

In [None]:
import pandas as pd
from skimage.metrics import peak_signal_noise_ratio as psnr

def run_attacks_metrics(model, test_loader, device, epsilons=[0.01, 0.03, 0.05]):
    import foolbox as fb

    model.eval()
    fmodel = fb.PyTorchModel(model, bounds=(-1, 1))
    class_names = ('airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
    results = []

    # For aggregate confusion
    fgsm_agg_conf = np.zeros((10, 10), dtype=int)
    pgd_agg_conf = np.zeros((10, 10), dtype=int)

    # Collect all test images and labels
    all_images = []
    all_labels = []
    for images, labels in test_loader:
        all_images.append(images)
        all_labels.append(labels)
    all_images = torch.cat(all_images, dim=0)
    all_labels = torch.cat(all_labels, dim=0)

    total_images = all_images.shape[0]

    fgsm_conf_matrices = []
    pgd_conf_matrices = []

    # PSNR metrics
    fgsm_psnr_per_eps = {}
    fgsm_psnr_per_class = {}
    pgd_psnr_per_eps = {}
    pgd_psnr_per_class = {}

    for eps in epsilons:
        batch_size = 128
        clean_correct = 0
        fgsm_correct = 0
        pgd_correct = 0
        total = 0

        fgsm_confusion = np.zeros((10, 10), dtype=int)
        pgd_confusion = np.zeros((10, 10), dtype=int)

        # For PSNR
        fgsm_psnr_list = []
        fgsm_psnr_per_class_list = [[] for _ in range(10)]
        pgd_psnr_list = []
        pgd_psnr_per_class_list = [[] for _ in range(10)]

        for i in range(0, total_images, batch_size):
            batch_imgs = all_images[i:i+batch_size].to(device)
            batch_lbls = all_labels[i:i+batch_size].to(device)

            attack_fgsm = fb.attacks.FGSM()
            advs_fgsm, _, _ = attack_fgsm(fmodel, batch_imgs, batch_lbls, epsilons=eps)
            attack_pgd = fb.attacks.LinfPGD(steps=10, rel_stepsize=0.1)
            advs_pgd, _, _ = attack_pgd(fmodel, batch_imgs, batch_lbls, epsilons=eps)

            batch_imgs_norm = convert_normalization(batch_imgs)  # Convert images for model
            advs_fgsm_norm = convert_normalization(advs_fgsm)  # Convert FGSM adversarial images
            advs_pgd_norm = convert_normalization(advs_pgd)  # Convert PGD adversarial images

            clean_pred = model(batch_imgs_norm).argmax(axis=1)
            fgsm_pred = model(advs_fgsm_norm).argmax(axis=1)
            pgd_pred = model(advs_pgd_norm).argmax(axis=1)

            clean_correct += (clean_pred == batch_lbls).sum().item()
            fgsm_correct += (fgsm_pred == batch_lbls).sum().item()
            pgd_correct += (pgd_pred == batch_lbls).sum().item()
            total += batch_lbls.size(0)

            for t, p in zip(batch_lbls.cpu().numpy(), fgsm_pred.cpu().numpy()):
                fgsm_confusion[t, p] += 1
                fgsm_agg_conf[t, p] += 1
            for t, p in zip(batch_lbls.cpu().numpy(), pgd_pred.cpu().numpy()):
                pgd_confusion[t, p] += 1
                pgd_agg_conf[t, p] += 1

            # FGSM PSNR calculation (unnormalize to [0,1] for PSNR)
            batch_imgs_unnorm = (batch_imgs_norm * torch.tensor([0.229, 0.224, 0.225], device=device).view(1,3,1,1)) + torch.tensor([0.485, 0.456, 0.406], device=device).view(1,3,1,1)
            advs_fgsm_unnorm = (advs_fgsm_norm * torch.tensor([0.229, 0.224, 0.225], device=device).view(1,3,1,1)) + torch.tensor([0.485, 0.456, 0.406], device=device).view(1,3,1,1)
            batch_imgs_unnorm = torch.clamp(batch_imgs_unnorm, 0, 1)
            advs_fgsm_unnorm = torch.clamp(advs_fgsm_unnorm, 0, 1)
            for j in range(batch_imgs_unnorm.shape[0]):
                psnr_fgsm = psnr(
                    batch_imgs_unnorm[j].cpu().numpy(),
                    advs_fgsm_unnorm[j].cpu().numpy(),
                    data_range=1.0
                )
                fgsm_psnr_list.append(psnr_fgsm)
                label = int(batch_lbls[j].item())
                fgsm_psnr_per_class_list[label].append(psnr_fgsm)

            # PGD PSNR calculation (unnormalize to [0,1] for PSNR)
            advs_pgd_unnorm = (advs_pgd_norm * torch.tensor([0.229, 0.224, 0.225], device=device).view(1,3,1,1)) + torch.tensor([0.485, 0.456, 0.406], device=device).view(1,3,1,1)
            advs_pgd_unnorm = torch.clamp(advs_pgd_unnorm, 0, 1)
            for j in range(batch_imgs_unnorm.shape[0]):
                psnr_pgd = psnr(
                    batch_imgs_unnorm[j].cpu().numpy(),
                    advs_pgd_unnorm[j].cpu().numpy(),
                    data_range=1.0
                )
                pgd_psnr_list.append(psnr_pgd)
                label = int(batch_lbls[j].item())
                pgd_psnr_per_class_list[label].append(psnr_pgd)

        clean_acc = 100 * clean_correct / total
        fgsm_acc = 100 * fgsm_correct / total
        pgd_acc = 100 * pgd_correct / total
        results.append({'epsilon': eps, 'clean_acc': clean_acc, 'fgsm_acc': fgsm_acc, 'pgd_acc': pgd_acc})

        # Store PSNR metrics
        fgsm_psnr_per_eps[eps] = np.mean(fgsm_psnr_list) if fgsm_psnr_list else float('nan')
        fgsm_psnr_per_class[eps] = [np.mean(fgsm_psnr_per_class_list[c]) if fgsm_psnr_per_class_list[c] else float('nan') for c in range(10)]
        pgd_psnr_per_eps[eps] = np.mean(pgd_psnr_list) if pgd_psnr_list else float('nan')
        pgd_psnr_per_class[eps] = [np.mean(pgd_psnr_per_class_list[c]) if pgd_psnr_per_class_list[c] else float('nan') for c in range(10)]

        fgsm_conf_matrices.append(fgsm_confusion.copy())
        pgd_conf_matrices.append(pgd_confusion.copy())

    df_results = pd.DataFrame(results)
    return {
        "fgsm_conf_matrices": fgsm_conf_matrices,
        "pgd_conf_matrices": pgd_conf_matrices,
        "fgsm_agg_conf": fgsm_agg_conf,
        "pgd_agg_conf": pgd_agg_conf,
        "class_names": class_names,
        "results": df_results,
        "fgsm_psnr_per_eps": fgsm_psnr_per_eps,
        "fgsm_psnr_per_class": fgsm_psnr_per_class,
        "pgd_psnr_per_eps": pgd_psnr_per_eps,
        "pgd_psnr_per_class": pgd_psnr_per_class
    }

In [None]:
def print_attack_metrics(metrics, attack_type="pgd", save_prefix=None):
    """
    Print and optionally save aggregate confusion, per-class confusion, and accuracy table for FGSM or PGD attacks.
    attack_type: "fgsm" or "pgd"
    save_prefix: if provided, saves CSVs with this prefix (e.g., "student_pgd")
    """
    import pandas as pd
    import numpy as np

    assert attack_type in ("fgsm", "pgd"), "attack_type must be 'fgsm' or 'pgd'"

    agg_conf_key = f"{attack_type}_agg_conf"
    psnr_per_eps_key = f"{attack_type}_psnr_per_eps"
    psnr_per_class_key = f"{attack_type}_psnr_per_class"

    print(f"\nAggregate {attack_type.upper()} confusion (all epsilons):")
    agg_conf = metrics[agg_conf_key]
    agg_df = pd.DataFrame(agg_conf, index=metrics["class_names"], columns=metrics["class_names"])
    print(agg_df)
    if save_prefix:
        agg_df.to_csv(f"{save_prefix}_agg_conf_{attack_type}.csv")

    # Calculate mean PSNR per epsilon and per class
    if psnr_per_eps_key in metrics and psnr_per_class_key in metrics:
        mean_psnr_per_eps = metrics[psnr_per_eps_key]  # dict: epsilon -> mean psnr
        mean_psnr_per_class = metrics[psnr_per_class_key]  # dict: epsilon -> [mean psnr per class]
    else:
        print(f"Warning: {attack_type.upper()} PSNR metrics not found in metrics dict, computing for last epsilon only.")
        mean_psnr_per_eps = {}
        mean_psnr_per_class = {}

    # Per-class confusion summary with mean PSNR per class
    summary = []
    for idx, row in enumerate(agg_conf):
        true_label = metrics["class_names"][idx]
        row_copy = row.copy()
        row_copy[idx] = 0
        total_confused = row_copy.sum()
        if total_confused == 0:
            most_confused = "-"
            count = 0
            percentage = 0.0
        else:
            most_confused_idx = np.argmax(row_copy)
            most_confused = metrics["class_names"][most_confused_idx]
            count = row_copy[most_confused_idx]
            percentage = 100.0 * count / total_confused
        # Get mean PSNR for this class (for the last epsilon)
        mean_psnr = None
        if psnr_per_class_key in metrics and metrics[psnr_per_class_key]:
            last_eps = list(metrics[psnr_per_class_key].keys())[-1]
            mean_psnr = metrics[psnr_per_class_key][last_eps][idx]
        summary.append({
            "True Label": true_label,
            "Most Confused With": most_confused,
            "Count": count,
            "Percentage": f"{percentage:.2f}%",
            "Mean PSNR": f"{mean_psnr:.2f}" if mean_psnr is not None else "-"
        })
    summary_df = pd.DataFrame(summary)
    print(summary_df.to_markdown(index=False))
    if save_prefix:
        summary_df.to_csv(f"{save_prefix}_perclass_{attack_type}.csv", index=False)

    # Show accuracy table with mean PSNR per epsilon
    print(f"\nAccuracy Table ({attack_type.upper()}):")
    df_results = metrics["results"]
    if psnr_per_eps_key in metrics:
        df_results = df_results.copy()
        col_name = f"mean_psnr_{attack_type}"
        df_results[col_name] = df_results["epsilon"].map(lambda eps: f"{metrics[psnr_per_eps_key][eps]:.2f}")
    print(df_results.to_markdown(index=False))
    if save_prefix:
        df_results.to_csv(f"{save_prefix}_accuracy_{attack_type}.csv", index=False)

In [None]:
# Foolbox setup
transform_fgsm = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

test_dataset_attack = CIFAR10TorchDataset(x_test, y_test, transform=transform_fgsm)
test_loader_attack = DataLoader(test_dataset_attack, batch_size=128, shuffle=False)

In [None]:
def evaluate_model_metrics(model_path, pretrained=True):
    """
    Loads a GoogLeNet model from the given path, evaluates it on adversarial attacks,
    and prints FGSM and PGD metrics.
    """
    model = get_googlenet(pretrained=pretrained)
    model.to(device)
    model.load_state_dict(torch.load(model_path, map_location=device, weights_only=True))
    model.eval()

    epsilons = np.arange(0.05, 0.21, 0.05)
    metrics = run_attacks_metrics(model, test_loader_attack, device, epsilons=epsilons)
    print_attack_metrics(metrics, attack_type="fgsm")
    print_attack_metrics(metrics, attack_type="pgd")

In [None]:
evaluate_model_metrics("best_adv_pgd_aug.pth")