In [None]:
import torch
import numpy as np
import torch.nn as nn
from CL.resnet import *
import torch.nn.functional as F
import kornia.augmentation as K

def info_nce_loss(features):

    labels = torch.cat([torch.arange(features.shape[0] // 2) for i in range(2)], dim=0)
    labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float()
    labels = labels.to(features.device)

    features = F.normalize(features, dim=1)
    similarity_matrix = torch.matmul(features, features.T)

    mask = torch.eye(labels.shape[0], dtype=torch.bool).to(features.device)
    labels = labels[~mask].view(labels.shape[0], -1)
    similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1)
    positives = similarity_matrix[labels.bool()].view(labels.shape[0], -1)
    negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1)

    logits = torch.cat([positives, negatives], dim=1)
    labels = torch.zeros(logits.shape[0], dtype=torch.long).to(features.device)

    logits = logits / 0.007
    return logits, labels

In [None]:

with_projection_head = True
with_augmentation = False

class AttackModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.classifier = ResNet18()
        self.simclr = ResNetSimCLR(ResNet18(), 128)
        
        # These two models are essentiall the same
        self.classifier.load_state_dict(torch.load('./CL/simclr_classifier.pt'))
        self.simclr.load_state_dict(torch.load('./CL/simclr.pt'))
        
        if not with_projection_head:
            self.simclr.backbone.linear = nn.Identity((512, 512))
        
        if with_augmentation:
            mean = (0.4914, 0.4822, 0.4465)
            std = (0.2023, 0.1994, 0.2010)
            self.transform = nn.Sequential(
                K.RandomResizedCrop(size=(32, 32), scale=(0.2, 1.)), 
                K.RandomHorizontalFlip(), 
                K.ColorJitter(0.4 , 0.4 , 0.4 , 0.1 , p=0.8 ), 
                K.RandomGrayscale(p=0.2 ), 
                K.Normalize(mean=mean, std=std) 
            )
        else:
            mean = (0.4914, 0.4822, 0.4465)
            std = (0.2023, 0.1994, 0.2010)
            self.transform = nn.Sequential(
                K.Normalize(mean=mean, std=std) 
            )
        
        self.label = None
        self.data = None
    
    def classify_forward(self, delta):
        logits = self.classifier(self.transform(self.data + delta))
        print(f'ACC = {(logits.argmax(axis=1) == self.label).float().mean().cpu().item()}')
        return logits, self.label
    
    def simclr_forward(self, delta):
        x_1, x_2 = self.transform(self.data + delta), self.transform(self.data + delta)
        x_two_view = torch.cat((x_1, x_2), dim=0)
        feature = self.simclr(x_two_view)
        return info_nce_loss(feature)

model = AttackModel().cuda()

In [None]:
from torchvision.datasets import CIFAR10
from torchvision import transforms
from torch.utils.data import DataLoader

dataset = CIFAR10('./data', train=True, transform=transforms.ToTensor())
dataloader = DataLoader(dataset, batch_size=512, drop_last=True)

In [None]:
# First attack Info-NCE loss, and see the change of cross-entropy loss
from tqdm import tqdm
from torch.optim import SGD

loss_func = nn.CrossEntropyLoss()
epsilon = 8/255
alpha = 1/255
data, label = next(iter(dataloader))

model = model.cuda()
model.data = data.cuda()
model.label = label.cuda()

def attack(target, no_target, eot):
    losses = []
    delta = torch.empty_like(data).uniform_(-epsilon, epsilon).cuda()
    optimizer = SGD([delta], lr=0.0001, momentum=0.9)
    for i in range(0, 100):
        delta.requires_grad = True
        
        loss = 0
        for j in range(0, eot):
            logits, label = target(delta)

            loss_1 = loss_func(logits, label) 
            loss += loss_1 
        loss /= eot
        
        optimizer.zero_grad()
        loss.backward()
        
        with torch.no_grad():
            logits, label = no_target(delta)
            loss_no_target = loss_func(logits, label)
        print(f'iter {i}, target loss = {loss.item()}, non target loss = {loss_no_target})')
        
        delta.requires_grad = False
        delta = torch.clamp(delta + alpha * delta.grad.data.sign(), -epsilon, epsilon)
        delta = torch.min(torch.max(delta, -model.data), 1 - model.data)
        
        
        losses.append((loss, loss_no_target))
    return losses
    

In [None]:
target = model.simclr_forward
no_target = model.classify_forward
losses_1 = attack(target, no_target, eot=3)

target = model.classify_forward
no_target = model.simclr_forward
losses_2 = attack(target, no_target, eot=1)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch

plt.plot(range(0, 100), [l[1].detach().cpu() for l in losses_1], color='orange', label='Attack InfoNCE', linewidth=3)
plt.plot(range(0, 100), [l[0].detach().cpu() for l in losses_2], label='Attack CE', linewidth=3)


# plt.title("Change of Classification Loss", fontsize=16)
plt.xlabel("Iteration", fontsize=10)
plt.ylabel("Loss Value", fontsize=10)

plt.tick_params(axis='both', labelsize=10)
plt.legend(loc='center right', fontsize=18)
plt.savefig('Change_of_NCE_SL_new.eps')

plt.show()
