In [1]:
import torch
import torch.nn as nn
import numpy as np
import torch.fft as fft

class FrequencyPixelAttacks(nn.Module):
    """
    A PyTorch transform to apply various frequency-based and pixel-based attacks.
    """
    def __init__(self, attack_type='phase', epsilon=0.1, frequency_radius=0.1, num_pixels=100, noise_std=0.05, seed=None):
        super(FrequencyPixelAttacks, self).__init__()
        self.attack_type = attack_type
        self.epsilon = epsilon
        self.frequency_radius = frequency_radius
        self.num_pixels = num_pixels
        self.noise_std = noise_std
        if seed is not None:
            torch.manual_seed(seed)
            np.random.seed(seed)
    
    def forward(self, img):
        perturbed_img = img.clone().detach()
        if self.attack_type == 'phase':
            perturbed_img = self._phase_attack(perturbed_img)
        elif self.attack_type == 'magnitude':
            perturbed_img = self._magnitude_attack(perturbed_img)
        elif self.attack_type == 'low_freq':
            perturbed_img = self._low_frequency_attack(perturbed_img)
        elif self.attack_type == 'high_freq':
            perturbed_img = self._high_frequency_attack(perturbed_img)
        elif self.attack_type == 'pixel':
            perturbed_img = self._pixel_attack(perturbed_img)
        elif self.attack_type == 'normal':
            perturbed_img = self._normal_noise_attack(perturbed_img)
        else:
            raise ValueError(f"Unknown attack type: {self.attack_type}")
        return torch.clamp(perturbed_img, 0, 1)
    
    def _phase_attack(self, img):
        batch, channels, height, width = img.shape
        perturbed_img = torch.zeros_like(img)
        for c in range(channels):
            f_transform = fft.fftshift(fft.fft2(img[:, c]))
            magnitude = torch.abs(f_transform)
            phase = torch.angle(f_transform)
            phase_noise = torch.randn_like(phase) * self.epsilon
            perturbed_phase = phase + phase_noise
            f_transform_perturbed = magnitude * torch.exp(1j * perturbed_phase)
            img_perturbed = fft.ifft2(fft.ifftshift(f_transform_perturbed)).real
            perturbed_img[:, c] = img_perturbed
        return perturbed_img
    
    def _magnitude_attack(self, img):
        batch, channels, height, width = img.shape
        perturbed_img = torch.zeros_like(img)
        for c in range(channels):
            f_transform = fft.fftshift(fft.fft2(img[:, c]))
            magnitude = torch.abs(f_transform)
            phase = torch.angle(f_transform)
            magnitude_noise = torch.randn_like(magnitude) * self.epsilon * magnitude
            perturbed_magnitude = magnitude + magnitude_noise
            f_transform_perturbed = perturbed_magnitude * torch.exp(1j * phase)
            img_perturbed = fft.ifft2(fft.ifftshift(f_transform_perturbed)).real
            perturbed_img[:, c] = img_perturbed
        return perturbed_img
    
    def _create_frequency_mask(self, height, width, is_low_freq=True):
        y_indices, x_indices = torch.meshgrid(torch.arange(height), torch.arange(width), indexing='ij')
        y_indices = y_indices - height // 2
        x_indices = x_indices - width // 2
        distance = torch.sqrt(y_indices**2 + x_indices**2).float()
        max_distance = torch.sqrt(torch.tensor(height**2 + width**2, dtype=torch.float32))
        distance /= max_distance
        return (distance <= self.frequency_radius).float() if is_low_freq else (distance >= (1 - self.frequency_radius)).float()
    
    def _low_frequency_attack(self, img):
        batch, channels, height, width = img.shape
        perturbed_img = torch.zeros_like(img)
        mask = self._create_frequency_mask(height, width, is_low_freq=True).to(img.device)
        for c in range(channels):
            f_transform = fft.fftshift(fft.fft2(img[:, c]))
            noise = (torch.randn_like(f_transform.real) + 1j * torch.randn_like(f_transform.imag)) * self.epsilon
            f_transform_perturbed = f_transform + noise * mask
            img_perturbed = fft.ifft2(fft.ifftshift(f_transform_perturbed)).real
            perturbed_img[:, c] = img_perturbed
        return perturbed_img
    
    def _high_frequency_attack(self, img):
        batch, channels, height, width = img.shape
        perturbed_img = torch.zeros_like(img)
        mask = self._create_frequency_mask(height, width, is_low_freq=False).to(img.device)
        for c in range(channels):
            f_transform = fft.fftshift(fft.fft2(img[:, c]))
            noise = (torch.randn_like(f_transform.real) + 1j * torch.randn_like(f_transform.imag)) * self.epsilon
            f_transform_perturbed = f_transform + noise * mask
            img_perturbed = fft.ifft2(fft.ifftshift(f_transform_perturbed)).real
            perturbed_img[:, c] = img_perturbed
        return perturbed_img
    
    def _pixel_attack(self, img):
        batch, channels, height, width = img.shape
        perturbed_img = img.clone()
        num_pixels = min(self.num_pixels, height * width)
        pixel_indices = torch.randint(0, height * width, (batch, num_pixels))
        y_indices, x_indices = pixel_indices // width, pixel_indices % width
        for b in range(batch):
            for c in range(channels):
                for i in range(num_pixels):
                    perturbed_img[b, c, y_indices[b, i], x_indices[b, i]] += torch.randn(1).item() * self.epsilon
        return perturbed_img
    
    def _normal_noise_attack(self, img):
        noise = torch.randn_like(img) * self.noise_std
        return img + noise

In [2]:
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
def display_image(original, adversarial):
    """Display last batch of original and adversarial images."""
    fig, axes = plt.subplots(2, len(original), figsize=(16, 10))
    
    for i in range(len(original)):
        orig_img = original[i].permute(1, 2, 0).cpu().numpy()
        adv_img = adversarial[i].permute(1, 2, 0).cpu().numpy()
        
        orig_img = np.clip(orig_img, 0, 1)
        adv_img = np.clip(adv_img, 0, 1)

        
        axes[0, i].imshow(orig_img)
        axes[0, i].axis("off")

        axes[1, i].imshow(adv_img)
        axes[1, i].axis("off")

    plt.show()

def train_adversary(model,test_loader):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # 1. Freeze the model (so only adversary parameters are updated)
    for param in model.parameters():
        param.requires_grad = False
    model.eval()
    
    # 2. Evaluate adversarial attack on test data
    # test_loader = load_test(batch_size=1)  # Load test dataset
    test_dataset = test_loader.dataset  # Get dataset without batching

    attacks = [
        # {'attack_type': 'phase', 'epsilon': 0.1, 'frequency_radius': 0.1, 'num_pixels': 100, 'noise_std': 0.05, 'name': 'phase_weak'},
        {'attack_type': 'phase', 'epsilon': 0.5, 'frequency_radius': 0.1, 'num_pixels': 100, 'noise_std': 0.05, 'name': 'phase_strong'},
        # {'attack_type': 'magnitude', 'epsilon': 0.1, 'frequency_radius': 0.1, 'num_pixels': 100, 'noise_std': 0.05, 'name': 'magnitude_weak'},
        {'attack_type': 'magnitude', 'epsilon': 0.5, 'frequency_radius': 0.1, 'num_pixels': 100, 'noise_std': 0.05, 'name': 'magnitude_strong'},
        # {'attack_type': 'low_freq', 'epsilon': 0.2, 'frequency_radius': 0.1, 'num_pixels': 100, 'noise_std': 0.05, 'name': 'low_freq_small_radius'},
        # {'attack_type': 'low_freq', 'epsilon': 0.2, 'frequency_radius': 0.3, 'num_pixels': 100, 'noise_std': 0.05, 'name': 'low_freq_large_radius'},
        # {'attack_type': 'high_freq', 'epsilon': 0.2, 'frequency_radius': 0.1, 'num_pixels': 100, 'noise_std': 0.05, 'name': 'high_freq_small_radius'},
        # {'attack_type': 'high_freq', 'epsilon': 0.2, 'frequency_radius': 0.3, 'num_pixels': 100, 'noise_std': 0.05, 'name': 'high_freq_large_radius'},
        # {'attack_type': 'pixel', 'epsilon': 0.5, 'frequency_radius': 0.1, 'num_pixels': 100, 'noise_std': 0.05, 'name': 'pixel_few'},
        # {'attack_type': 'pixel', 'epsilon': 0.5, 'frequency_radius': 0.1, 'num_pixels': 1000, 'noise_std': 0.05, 'name': 'pixel_many'},
        # {'attack_type': 'normal', 'epsilon': 0.1,  'frequency_radius': 0.1, 'num_pixels': 100, 'noise_std': 0.05, 'name': 'normal_weak'},
        # {'attack_type': 'normal', 'epsilon': 0.1, 'frequency_radius': 0.1, 'num_pixels': 100, 'noise_std': 0.2, 'name': 'normal_strong'}
    ]

    for attack in attacks:
        total_correct_original = 0
        total_correct_perturbed = 0
        total_samples = 0
        last_batch_images = []
        last_batch_perturbed = []
        
        adversary = FrequencyPixelAttacks(attack["attack_type"], attack["epsilon"], attack["frequency_radius"], attack["num_pixels"], attack["noise_std"])

        # Process each image individually
        for i in range(len(test_dataset)):
            image, label = test_dataset[i]  # Get single image and label
            image, label = image.to(device).unsqueeze(0), torch.tensor([label]).to(device)  # Add batch dimension

            # image = image.squeeze(0)  # Removes the batch dimension
            perturbed_image = adversary(image)
            
            with torch.no_grad():
                # y_hat = model(image)
                y_pred = model(perturbed_image)
            
            # pred_original = torch.argmax(y_hat, dim=1)
            pred_perturbed = torch.argmax(y_pred, dim=1)
            
            # total_correct_original += (pred_original == label).item()
            total_correct_perturbed += (pred_perturbed == label).item()
            total_samples += 1  # Each image is processed individually
            
            if i < 5:  # Store only a few images for visualization
                # last_batch_images.append(image.cpu().squeeze(0))
                last_batch_perturbed.append(perturbed_image.cpu().squeeze(0))

        # original_acc = (total_correct_original / total_samples) * 100
        perturbed_acc = (total_correct_perturbed / total_samples) * 100
    
        print(f"Attack: {attack['name']}")
        print("Total Samples: ", total_samples)
        # print(f"Accuracy on Original Images: {original_acc:.2f}%")
        print(f"Accuracy on Perturbed Images: {perturbed_acc:.2f}%")

        # display_image(last_batch_images, last_batch_perturbed)

In [3]:
%pip install adversarial-robustness-toolbox
%pip install ipywidgets

Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


In [4]:

from art.estimators.classification import PyTorchClassifier
from art.attacks.evasion import FastGradientMethod, ProjectedGradientDescent
from art.attacks.evasion import CarliniL2Method
import matplotlib.pyplot as plt

# --- Utility function for visualization ---
def display_image(original, adversarial, title=""):
    """Display a batch of original and adversarial images."""
    n = len(original)
    fig, axes = plt.subplots(2, n, figsize=(4 * n, 8))
    
    for i in range(n):
        # Convert tensor to numpy (assumes tensor shape is [C, H, W])
        orig_img = original[i].permute(1, 2, 0).cpu().numpy()
        adv_img = adversarial[i].permute(1, 2, 0).cpu().numpy()
        
        orig_img = np.clip(orig_img, 0, 1)
        adv_img = np.clip(adv_img, 0, 1)
        
        axes[0, i].imshow(orig_img)
        axes[0, i].set_title("Original")
        axes[0, i].axis("off")
        
        axes[1, i].imshow(adv_img)
        axes[1, i].set_title("Adversarial")
        axes[1, i].axis("off")
    
    fig.suptitle(title)
    plt.show()


# --- Modified training function using ART attacks ---
def train_adversary_art(model, test_loader):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()

    # IMPORTANT: Freeze model parameters
    for param in model.parameters():
        param.requires_grad = False

    # Wrap the ViT model into an ART classifier.
    # Adjust input_shape and nb_classes according to your ViT model.
    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)  # Not used for attacks.
    classifier = PyTorchClassifier(
        model=model,
        loss=loss_fn,
        optimizer=optimizer,
        input_shape=(3, 224, 224),  # Example shape for ViT (channels, height, width)
        nb_classes=10,           # Update based on your model's number of classes
        clip_values=(0, 1)
    )

    # Define ART attack objects with different parameters
    attacks = [
        {"name": "FGSM_weak", "attack": FastGradientMethod(estimator=classifier, eps=0.1)},
        {"name": "FGSM_strong", "attack": FastGradientMethod(estimator=classifier, eps=0.5)},
        {"name": "PGD", "attack": ProjectedGradientDescent(estimator=classifier, eps=0.3, max_iter=10)},
        # {"name": "CW", "attack": CarliniL2Method(classifier=classifier, max_iter=3)}
    ]

    # Loop over each attack type
    for attack_dict in attacks:
        attack_name = attack_dict["name"]
        attack_obj = attack_dict["attack"]
        total_correct_original = 0
        total_correct_adv = 0
        total_samples = 0
        
        # To store a few examples for visualization
        orig_examples = []
        adv_examples = []

        # Iterate over the test dataset (assumes test_loader returns (image, label))
        for i, (image, label) in enumerate(test_loader):
            image = image.to(device)  # image shape: [B, C, H, W]
            label = label.to(device)

            # ART expects numpy arrays. Convert and ensure the shape is [B, C, H, W]
            image_np = image.cpu().numpy()
            
            # Generate adversarial examples using ART
            adv_image_np = attack_obj.generate(x=image_np)
            adv_image = torch.from_numpy(adv_image_np).to(device)
            print("here")
            # Get predictions for original and adversarial images
            with torch.no_grad():
                output_orig = model(image)
                output_adv = model(adv_image)
            
            pred_orig = output_orig.argmax(dim=1)
            pred_adv = output_adv.argmax(dim=1)
            total_correct_original += (pred_orig == label).sum().item()
            total_correct_adv += (pred_adv == label).sum().item()
            total_samples += image.shape[0]

            # Save the first batch (or a few images) for visualization
            if len(orig_examples) < 5:
                orig_examples.append(image[0].cpu())
                adv_examples.append(adv_image[0].cpu())
            
            del image, label, adv_image, output_orig, output_adv
            torch.cuda.empty_cache()

        original_acc = (total_correct_original / total_samples) * 100
        adv_acc = (total_correct_adv / total_samples) * 100
        
        print(f"Attack: {attack_name}")
        print("Total Samples: ", total_samples)
        print(f"Accuracy on Original Images: {original_acc:.2f}%")
        print(f"Accuracy on Adversarial Images: {adv_acc:.2f}%")
        # display_image(orig_examples, adv_examples)

# --- Example usage ---
# Assuming you have a pre-trained ViT model and a test_loader defined.
# model = ...  # Your ViT model (e.g. from timm or torchvision.models)
# test_loader = ...  # Your DataLoader for the test set
# train_adversary_art(model, test_loader)

In [3]:
from eval.models.image.cct import AlgebraicCCT, SinusoidalCCT, AbsoluteCCT, AlgebraicSeqCCT,ourCCT

In [4]:
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision import datasets


In [5]:
# Load Oxford-IIIT Pet training data
def load_train(batch_size=32):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor()
        # transforms.Normalize(mean=[0.5], std=[0.5])
        # transforms.Normalize(mean=[0.485, 0.456, 0.406],
        #                      std=[0.229, 0.224, 0.225])
    ])
    
    dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    return loader


def load_test(batch_size=16):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        # transforms.Normalize(mean=[0.5], std=[0.5])
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
    
    dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    return loader

In [6]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [7]:
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"


In [8]:
from torchvision.datasets import CIFAR10
from eval.models.image.augmentations import CIFAR10Policy

In [73]:
dim = 256
num_heads = 4
num_layers = 7
mlp_ratio = 2
in_channels, num_classes, image_size = 3, 10, 10

augmentations = [CIFAR10Policy(),
                    transforms.RandomCrop((32, 32), padding=4),
                    transforms.RandomHorizontalFlip()]
transformations = [transforms.ToTensor(),
                    transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2470, 0.2435, 0.2616])]
train_set = CIFAR10("data" ,train=True, download=True,
                    transform=transforms.Compose([*augmentations, *transformations]))
test_set = CIFAR10("data", train=False, download=True,
                    transform=transforms.Compose(transformations))
in_channels, num_classes, image_size = 3, 10, 10

train_dl = DataLoader(train_set, batch_size = 1, shuffle=True, num_workers=4)
test_dl = DataLoader(test_set, batch_size= 1, shuffle=False, num_workers=4)

model = AlgebraicCCT(
                dim=dim,
                num_heads=num_heads,
                num_layers=num_layers,
                kernel_size=(4, 4),
                in_channels=in_channels,
                num_classes=num_classes,
                mlp_ratio=mlp_ratio).cuda()
model_path = "best_model_.pth"
if os.path.exists(model_path):
    print("Found model... Loading the model")
    model.load_state_dict(torch.load(model_path))

# train_adversary_art(model, test_loader)
train_adversary(model,test_dl)

Found model... Loading the model
Attack: phase_strong
Total Samples:  10000
Accuracy on Perturbed Images: 54.75%
Attack: magnitude_strong
Total Samples:  10000
Accuracy on Perturbed Images: 54.32%


In [22]:
dim = 256
num_heads = 4
num_layers = 7
mlp_ratio = 2
in_channels, num_classes, image_size = 3, 10, 10

augmentations = [CIFAR10Policy(),
                    transforms.RandomCrop((32, 32), padding=4),
                    transforms.RandomHorizontalFlip()]
transformations = [transforms.ToTensor(),
                    transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2470, 0.2435, 0.2616])]
train_set = CIFAR10("data" ,train=True, download=True,
                    transform=transforms.Compose([*augmentations, *transformations]))
test_set = CIFAR10("data", train=False, download=True,
                    transform=transforms.Compose(transformations))
in_channels, num_classes, image_size = 3, 10, 10

train_dl = DataLoader(train_set, batch_size = 1, shuffle=True, num_workers=4)
test_dl = DataLoader(test_set, batch_size= 1, shuffle=False, num_workers=4)

model = ourCCT(
                dim=dim,
                num_heads=num_heads,
                num_layers=num_layers,
                kernel_size=(4, 4),
                in_channels=in_channels,
                num_classes=num_classes,
                mlp_ratio=mlp_ratio).cuda()
model_path = "best_model_t.pth"
if os.path.exists(model_path):
    print("Found model... Loading the model")
    model.load_state_dict(torch.load(model_path))

# train_adversary_art(model, test_loader)
train_adversary(model,test_dl)

Found model... Loading the model
Attack: phase_strong
Total Samples:  10000
Accuracy on Perturbed Images: 53.88%
Attack: magnitude_strong
Total Samples:  10000
Accuracy on Perturbed Images: 53.68%


In [12]:
import torch.nn as nn
dim = 256
num_heads = 4
num_layers = 7
mlp_ratio = 2
in_channels, num_classes, image_size = 3, 10, 10

augmentations = [CIFAR10Policy(),
                    transforms.RandomCrop((32, 32), padding=4),
                    transforms.RandomHorizontalFlip()]
transformations = [transforms.ToTensor(),
                    transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2470, 0.2435, 0.2616])]
train_set = CIFAR10("data" ,train=True, download=True,
                    transform=transforms.Compose([*augmentations, *transformations]))
test_set = CIFAR10("data", train=False, download=True,
                    transform=transforms.Compose(transformations))
in_channels, num_classes, image_size = 3, 10, 10

train_dl = DataLoader(train_set, batch_size = 1, shuffle=True, num_workers=4)
test_dl = DataLoader(test_set, batch_size= 1, shuffle=False, num_workers=4)

model = SinusoidalCCT(
                dim=dim,
                num_heads=num_heads,
                num_layers=num_layers,
                kernel_size=(4, 4),
                in_channels=in_channels,
                num_classes=num_classes,
                mlp_ratio=mlp_ratio).cuda()
model_path = "best_model_s.pth"
if os.path.exists(model_path):
    print("Found model... Loading the model")
    model.load_state_dict(torch.load(model_path))

# train_adversary_art(model, test_loader)
train_adversary(model,test_dl)

Found model... Loading the model
Attack: phase_strong
Total Samples:  10000
Accuracy on Perturbed Images: 52.59%
Attack: magnitude_strong
Total Samples:  10000
Accuracy on Perturbed Images: 52.55%


In [10]:
dim = 256
num_heads = 4
num_layers = 7
mlp_ratio = 2
in_channels, num_classes, image_size = 3, 10, 10

augmentations = [CIFAR10Policy(),
                    transforms.RandomCrop((32, 32), padding=4),
                    transforms.RandomHorizontalFlip()]
transformations = [transforms.ToTensor(),
                    transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2470, 0.2435, 0.2616])]
train_set = CIFAR10("data" ,train=True, download=True,
                    transform=transforms.Compose([*augmentations, *transformations]))
test_set = CIFAR10("data", train=False, download=True,
                    transform=transforms.Compose(transformations))
in_channels, num_classes, image_size = 3, 10, 10

train_dl = DataLoader(train_set, batch_size = 1, shuffle=True, num_workers=4)
test_dl = DataLoader(test_set, batch_size= 1, shuffle=False, num_workers=4)

In [65]:
dim = 256
num_heads = 4
num_layers = 7
mlp_ratio = 2
in_channels, num_classes, image_size = 3, 10, 10

augmentations = [CIFAR10Policy(),
                    transforms.RandomCrop((32, 32), padding=4),
                    transforms.RandomHorizontalFlip()]
transformations = [transforms.ToTensor(),
                    transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2470, 0.2435, 0.2616])]
train_set = CIFAR10("data" ,train=True, download=True,
                    transform=transforms.Compose([*augmentations, *transformations]))
test_set = CIFAR10("data", train=False, download=True,
                    transform=transforms.Compose(transformations))
in_channels, num_classes, image_size = 3, 10, 10

train_dl = DataLoader(train_set, batch_size = 1, shuffle=True, num_workers=4)
test_dl = DataLoader(test_set, batch_size= 1, shuffle=False, num_workers=4)

model = AlgebraicCCT(
                dim=dim,
                num_heads=num_heads,
                num_layers=num_layers,
                kernel_size=(4, 4),
                in_channels=in_channels,
                num_classes=num_classes,
                mlp_ratio=mlp_ratio).cuda()
model_path = "best_model_.pth"
if os.path.exists(model_path):
    print("Found model... Loading the model")
    model.load_state_dict(torch.load(model_path))

Found model... Loading the model


In [62]:
# Get one batch from the training dataloader
images, labels = next(iter(train_dl))

# Print the shape of the image tensor
print("Image tensor shape:", images.shape)


Image tensor shape: torch.Size([1, 3, 32, 32])


In [14]:
import torch
from torch.nn.functional import softmax

def evaluate_accuracy(model, dataloader, device="cuda"):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            preds = torch.argmax(outputs, dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    return correct / total * 100

# Calculate accuracy
# train_acc = evaluate_accuracy(model, train_dl)
test_acc = evaluate_accuracy(model, test_dl)

# print(f"Train Accuracy: {train_acc:.2f}%")
print(f"Test Accuracy: {test_acc:.2f}%")


Test Accuracy: 86.52%


# Frequency Attacks


In [14]:
import torch.fft
from torch.fft import fft2, ifft2, fftshift, ifftshift
import torch.nn.functional as F


In [37]:
from tqdm import tqdm
def fourier_attack(model, image, label, epsilon=0.1, lambda_reg=0.1, num_iters=10):
    """
    Implements the Fourier-based adversarial attack.
    :param model: Target classifier
    :param image: Input image (normalized, tensor of shape [1, C, H, W])
    :param label: Ground truth label (integer tensor)
    :param epsilon: Perturbation strength
    :param lambda_reg: Regularization factor balancing L2 and CE loss
    :param num_iters: Number of attack iterations
    :return: Adversarial image
    """
    # Ensure the image requires gradients
    image = image.clone().detach().requires_grad_(True)
    label = label.to(image.device)
    
    # Compute Fourier Transform of the image
    fft_image = fft2(image)
    fft_image = fftshift(fft_image)  # Shift low frequencies to center
    magnitude = torch.abs(fft_image)
    phase = torch.angle(fft_image)
    
    # Initialize perturbations
    delta_mag = torch.ones_like(magnitude, requires_grad=True, device=image.device)
    delta_phase = torch.zeros_like(phase, requires_grad=True, device=image.device)
    delta_pixel = torch.zeros_like(image, requires_grad=True, device=image.device)
    
    optimizer = torch.optim.Adam([delta_mag, delta_phase, delta_pixel], lr=0.01)
    
    for _ in range(num_iters):
        optimizer.zero_grad()
        
        # Apply perturbations
        perturbed_magnitude = magnitude * delta_mag
        perturbed_phase = phase + delta_phase
        
        # Construct perturbed Fourier representation
        perturbed_fft = perturbed_magnitude * torch.exp(1j * perturbed_phase)
        perturbed_fft = ifftshift(perturbed_fft)  # Shift back before ifft2
        perturbed_image = ifft2(perturbed_fft).real + delta_pixel  # Ensure real values
        perturbed_image = torch.clamp(perturbed_image, 0, 1)  # Clip to valid range
        
        # Compute loss
        l2_loss = F.mse_loss(perturbed_image, image)
        ce_loss = F.cross_entropy(model(perturbed_image), label)
        loss = lambda_reg * l2_loss - ce_loss
        
        # Backpropagation
        loss.backward(retain_graph=True)
        optimizer.step()
        
        # Clip perturbations
        delta_mag.data = torch.clamp(delta_mag, 1 - epsilon, 1 + epsilon)
        delta_phase.data = torch.clamp(delta_phase, -epsilon, epsilon)
        delta_pixel.data = torch.clamp(delta_pixel, -epsilon, epsilon)
    
    return perturbed_image.detach()

In [38]:
def apply_attack_and_evaluate(model, data_loader, attack):
    model.eval()
    correct = 0
    total = 0
    
    
    for inputs, targets in tqdm(data_loader):
        inputs, targets = inputs.to(device), targets.to(device)
        perturbed_inputs = attack(model,inputs,label=targets)

        with torch.no_grad():
            outputs = model(perturbed_inputs)
            # print(outputs)
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
    
    accuracy = 100. * correct / total
    return accuracy

In [29]:
from torch.utils.data import Subset, DataLoader

# Create a subset with the first 100 samples
subset_indices = list(range(1000))
test_subset = Subset(test_dl.dataset, subset_indices)

# Create a new DataLoader with this subset
test_100_dl = DataLoader(test_subset, batch_size=1, shuffle=False, num_workers=4)


In [11]:
model = SinusoidalCCT(
                dim=dim,
                num_heads=num_heads,
                num_layers=num_layers,
                kernel_size=(4, 4),
                in_channels=in_channels,
                num_classes=num_classes,
                mlp_ratio=mlp_ratio).cuda()
model_path = "best_model_s.pth"
if os.path.exists(model_path):
    print("Found model... Loading the model")
    model.load_state_dict(torch.load(model_path))

model1 = AlgebraicCCT(
            dim=dim,
            num_heads=num_heads,
            num_layers=num_layers,
            kernel_size=(4, 4),
            in_channels=in_channels,
            num_classes=num_classes,
            mlp_ratio=mlp_ratio).cuda()
model_path = "best_model_.pth"
if os.path.exists(model_path):
    print("Found model... Loading the model")
    model1.load_state_dict(torch.load(model_path))

model2 = ourCCT(
            dim=dim,
            num_heads=num_heads,
            num_layers=num_layers,
            kernel_size=(4, 4),
            in_channels=in_channels,
            num_classes=num_classes,
            mlp_ratio=mlp_ratio).cuda()
model_path = "best_model_t.pth"
if os.path.exists(model_path):
    print("Found model... Loading the model")
    model2.load_state_dict(torch.load(model_path))

Found model... Loading the model


  log = torch.tensor(logm(out)).real


Found model... Loading the model
Found model... Loading the model


In [39]:
print('Sinusoidal: ' ,apply_attack_and_evaluate(model,test_100_dl,fourier_attack))

100%|██████████| 1000/1000 [04:39<00:00,  3.58it/s]

Sinusoidal:  3.7





In [40]:
print('APE: ' ,apply_attack_and_evaluate(model1,test_100_dl,fourier_attack))

100%|██████████| 1000/1000 [06:40<00:00,  2.50it/s]

APE:  2.2





In [41]:

print('Ours: ' ,apply_attack_and_evaluate(model2,test_100_dl,fourier_attack))

100%|██████████| 1000/1000 [05:02<00:00,  3.30it/s]

Ours:  4.1



