In [4]:
import numpy as np
import matplotlib.pyplot as plt

from sklearn.svm import SVC

T=5

np.set_printoptions(threshold=np.inf)

In [5]:
class SVM_Auxiliary_Classifier:
    def __init__(self, C, k, n_features):
        self.C = C  # Regularization parameter
        self.k = k  # Number of classes
        self.omega = np.random.rand(k, n_features)  # Weights
        self.b = np.random.rand(k)

    def fit(self, X, y):
        n_samples, n_features = X.shape
        L = np.zeros((n_samples, self.k))
        mask = np.zeros((n_samples, self.k))

        print("SVM Shapes", self.omega.T.shape, X[0].shape, self.b.shape, y.shape)

        for i in range(n_samples):
            for j in range(self.k):
                if j != y[i]:
                    margin = self.omega[y[i], j] @ X[i] - self.omega[y[i]] @ X[i] + self.b[j] - self.b[y[i]]
                    L[i, j] = max(0, 1 - margin)
                    if L[i, j] > 0:
                        mask[i, j] = 1

        # print(X.shape, mask.shape, L.shape, y.shape)

        for j in range(self.k):
            if j in y:
                indices = np.where(y == j)[0]
                omega_grad = np.sum(X[indices][:, np.newaxis, :] * mask[indices][:, :, np.newaxis], axis=0)
                self.omega[j] -= omega_grad.mean(axis=0) - self.C * self.omega[j]

                b_grad = mask[indices].sum(axis=0)
                self.b[j] -= b_grad.mean() - self.C * self.b[j]

        return mask, L

    def predict(self, X):
        decision_function = X @ self.omega.T + self.b
        return np.argmax(decision_function, axis=1)

In [6]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
import random

def pgd_attack(image, epsilon, data_grad, num_steps=10, step_size=2/255):
    perturbed_image = image.clone().detach()

    for _ in range(num_steps):
        sign_data_grad = data_grad.sign()
        perturbed_image += step_size * sign_data_grad
        perturbed_image = torch.clamp(perturbed_image, image - epsilon, image + epsilon)
        perturbed_image = torch.clamp(perturbed_image, 0, 1)

    return perturbed_image.detach().to("cpu").numpy()

def generate_pgd_attacks(model, images, labels):
    perturbed_images = []
    images.requires_grad = True
    output = model(images)
    loss = criterion(output, labels)
    model.zero_grad()
    loss.backward()
    data_grad = images.grad.data
    for i in range(images.shape[0]):
        perturbed_images.append(pgd_attack(images[i], 8/255, data_grad[i]))
    
    return np.array(perturbed_images)

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=625, shuffle=True, drop_last=True)

test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = DataLoader(test_dataset, batch_size=625, shuffle=False, drop_last=True)

resnet18 = models.resnet18(num_classes=10)
feature_extractor = torch.nn.Sequential(*list(resnet18.children())[:-1])

device = torch.device("mps")

resnet18 = resnet18.to(device)

criterion = nn.CrossEntropyLoss()
criterion2 = nn.CrossEntropyLoss(reduction='none')

optimizer = optim.SGD(resnet18.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
scheduler = StepLR(optimizer, step_size=75, gamma=0.1)

print(resnet18.fc.in_features)

# svm_classifier = SVM_Auxiliary_Classifier(C=1, k=10, n_features=trainloader.batch_size * 2)
svm_classifier = SVC(decision_function_shape='ovr')
resnet18.train()

for epoch in range(T):
    correct = 0
    total = 0
    for i, (inputs, labels) in enumerate(trainloader, 0):
        inputs, labels = inputs.to(device), labels.to(device)

        adv_inputs = torch.tensor(generate_pgd_attacks(resnet18, inputs, labels))
        adv_inputs = adv_inputs.to(device)

        total_inputs = torch.cat((adv_inputs, inputs), 0)

        labels = torch.cat((torch.tensor([0] * labels.shape[0]).to(device), labels), 0)

        adv_inputs = adv_inputs.to(device)
 
        # Get features before the last FC layer
        features = feature_extractor(total_inputs)
        # print(features.shape)

        mask = np.zeros((trainloader.batch_size * 2))
        l_margin = np.zeros((trainloader.batch_size * 2, 10))

        svm_classifier.fit(features.cpu().detach().squeeze().numpy(), labels.cpu().numpy())

        distances = svm_classifier.decision_function(features.cpu().detach().squeeze().numpy())

        l_margin = np.maximum(0, 1 - distances + distances[range(trainloader.batch_size * 2), labels.cpu().numpy()][:, np.newaxis])

        mask = (l_margin > 0).any(axis=1).astype(int)

        # print(mask)

        # exit()
        # print(mask.shape)

        outputs = resnet18(total_inputs)

        _, predicted = torch.max(outputs.data, 1)
        correct += (predicted == labels).sum().item()
        total += labels.size(0)


        labels = labels.to(device)
        outputs = outputs.to(device)

        loss = 0

        # for i in range(outputs.size(0)):
        #     loss += criterion2(outputs[i], labels[i]).detach().cpu() * mask[i]

        # print(mask.shape, l_margin.shape, loss.shape)
        # print("BRUH", criterion(outputs, labels).shape)

        loss = (criterion2(outputs, labels) * torch.from_numpy(mask).float().to(device)).sum()

        # print(loss)

        loss = loss.to(device, dtype=torch.float32)
        l_margin = torch.tensor(l_margin).float().to(device)

        total_loss = (loss + 0.1 * l_margin).mean().clone().detach().requires_grad_(True)

        # print(total_loss)
        # total_loss = total_loss.requires_grad_(True)
        # print(total_loss, loss, L_margin.mean())
        total_loss.backward()

        optimizer.zero_grad()
        optimizer.step()

    print(f'Epoch {epoch+1}/{T} finished, Loss: {total_loss.item()}, Accuracy: {100 * correct / total}%')

print('Finished Training. SVM auxiliary classifier will be removed for inference.')

resnet18.eval()
correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data[0].to(device), data[1].to(device)
        outputs = resnet18(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
print(f'Accuracy: {100 * correct / total}%')

Files already downloaded and verified
Files already downloaded and verified




512
Epoch 1/5 finished, Loss: 3138.901611328125, Accuracy: 7.535%
Epoch 2/5 finished, Loss: 3174.666259765625, Accuracy: 7.421%
Epoch 3/5 finished, Loss: 3183.661376953125, Accuracy: 7.503%
Epoch 4/5 finished, Loss: 3186.893310546875, Accuracy: 7.461%
Epoch 5/5 finished, Loss: 3141.361083984375, Accuracy: 7.488%
Finished Training. SVM auxiliary classifier will be removed for inference.
Accuracy: 11.41%
