In [1]:
# from google.colab import drive
#
# drive.mount('/content/gdrive')
#!pip install timm wandb torchattacks einops torch-optimizer
import torch_optimizer
import os
import timm
import torch
from torch.utils.data import DataLoader
import numpy as np
import wandb

import torchattacks
import torchvision.transforms as transforms
from torchvision import datasets
import albumentations as A
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
from utils import training_loop,validation_loop,randAugment,PGD
from tqdm.notebook import tqdm
from vit_small import ViT
from torch.utils.data import Dataset
device= "cuda" if torch.cuda.is_available() else "cpu"
names = ["airplane","automobile","bird","cat","deer","dog","frog","horse","ship","truck"]
num_classes=len(names)
#root="/content/gdrive/MyDrive/Adversarial_paper"
root="."

In [2]:

mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)

#model1 = timm.create_model("resnext50_32x4d",pretrained=True,num_classes=num_classes).to(device) #trained on imagenet
model1 = timm.create_model("resnext50d_32x4d",pretrained=True,num_classes=num_classes).to(device)
model2 = timm.create_model("convnext_base",pretrained=True,num_classes=num_classes).to(device) # will be trained on cifar10
#model2 = ViT(image_size=32,patch_size=4,num_classes=10,dim=512,depth=6,heads=8,mlp_dim=512,dropout=0,emb_dropout=0)

# setattr(model2,"default_cfg", {
#     "architecture" : "ViT_small"
# })

In [3]:

#ops = randAugment(N=2,M=5,p=0.5,mode="all",cut_out=True)
train_transform=transforms.Compose(
    [
        #transforms.RandAugment(num_ops=5,magnitude=2),

        transforms.RandomHorizontalFlip(),
        #transforms.GaussianBlur(kernel_size=3),
        transforms.RandomCrop(32, padding=4),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
        transforms.ToTensor(),
        transforms.Normalize(mean,std)
    ]
)

val_transform=transforms.Compose(
    [
                #transforms.Resize(32),
                #transforms.CenterCrop(32),
                transforms.ToTensor(),
                transforms.Normalize(mean, std),
    ]
)

def target_transform(label) :
     label_vector = torch.zeros((num_classes))
     label_vector[label]=1
     return label_vector

# class Cifar10SearchDataset(datasets.CIFAR10):
#     def __init__(self, root="~/data/cifar10", train=True, download=True, transform=None):
#         super().__init__(root=root, train=train, download=download, transform=transform)
#
#     def __getitem__(self, index):
#         image, label = self.data[index], self.targets[index]
#
#         if self.transform is not None:
#             transformed = self.transform(image=image)
#             image = transformed["image"]
#
#         return image, label
training_data = datasets.CIFAR10(
    root=root,
    train=True,
    download=True,
    transform=train_transform,
    #target_transform=target_transform

)

test_data = datasets.CIFAR10(
    root=root,
    train=False,
    download=True,
    transform=val_transform,
    #target_transform=target_transform

)





Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting ./cifar-10-python.tar.gz to .
Files already downloaded and verified


In [None]:
train_loader = DataLoader(
    training_data,
    num_workers=4,
    batch_size=512,
    shuffle=True
)
test_loader= DataLoader(
    test_data,
    num_workers=4,
    batch_size=512
)

In [None]:

metrics = {
    "accuracy" : [],
    "train_loss" : [],
    "val_loss" : [],
}
keep_training=True
criterion = lambda  pred, true : torch.nn.functional.cross_entropy(torch.softmax(pred,dim=1),true)
for model in  [model2,model1] :
    if os.path.exists(f"{root}/{model.default_cfg['architecture']}.pt"):
     
        model.load_state_dict(torch.load(f"{root}/{model.default_cfg['architecture']}.pt"))
        model = model.to(device)
        val_loss,results = validation_loop(model,test_loader,criterion,device,autocast=False)
        accuracy = (results[0].cpu().numpy()== np.argmax(torch.softmax(results[1].cpu(),dim=1).numpy(),axis=1)).mean()
        print(f"Model : {model.default_cfg['architecture']} , Accuracy : {accuracy}")
    else :
        accuracy=0
    if keep_training and accuracy<0.90 :
        model = model.to(device)
        optimizer = torch_optimizer.Lamb(model.parameters(),lr=1e-3)


        scaler = torch.cuda.amp.GradScaler()
        max_epoch = 50
        scheduler=torch.optim.lr_scheduler.OneCycleLR(max_lr=1e-3,optimizer=optimizer,epochs=max_epoch,steps_per_epoch=len(train_loader))
        #scheduler=torch.optim.lr_scheduler.CyclicLR(max_lr=1e-3,optimizer=optimizer,base_lr=1e-4)

        for epoch in (pbar:=tqdm(range(0,max_epoch),position=0,leave=True)) :
            loss = training_loop(model,tqdm(train_loader,position=1,leave=False),optimizer=optimizer,criterion=criterion,device=device,scaler=scaler,clip_norm=0,autocast=True,scheduler=scheduler,epoch=max_epoch)

            val_loss,results = validation_loop(model,test_loader,criterion,device,autocast=True)
            accuracy = (results[0].cpu().numpy()== np.argmax(torch.softmax(results[1].cpu(),dim=1).numpy(),axis=1)).mean()
            pbar.set_description(f"Training loss : {loss/len(train_loader)} ,  validation loss : {val_loss/len(test_loader)} , accuracy : {accuracy}")
            metrics["accuracy"].append(accuracy)
            metrics["train_loss"].append(loss/len(train_loader))
            metrics["val_loss"].append(val_loss/len(test_loader))


        torch.save(model.state_dict(), f"{root}/{model.default_cfg['architecture']}.pt")
        model.to("cpu")



In [None]:
torch.save(model1.state_dict(), f"{root}/{model.default_cfg['architecture']}.pt")
torch.save(model2.state_dict(), f"{root}/{model.default_cfg['architecture']}.pt")

In [None]:



os.environ["WANDB_MODE"]="dryrun"#"online"

val_transform=transforms.Compose(
    [
                #transforms.Resize(32),
                #transforms.CenterCrop(32),

                transforms.ToTensor(),
                #transforms.Normalize(mean, std),
    ]
)

test_data = datasets.CIFAR10(
    root=root,
    train=False,
    download=True,
    transform=val_transform,
    #target_transform=target_transform

)

test_loader= DataLoader(
    test_data,
    num_workers=4,
    batch_size=1024
)
normalize =  transforms.Normalize(mean, std)



# If, images are normalized:
# atk.set_normalization_used(mean=[...], std=[...])

for eps in [4,8,16,32] :
    model1.load_state_dict(torch.load(f"{root}/{model1.default_cfg['architecture']}.pt"))
    model2.load_state_dict(torch.load(f"{root}/{model2.default_cfg['architecture']}.pt"))
    eps/=256
    model1 = model1.to(device)
    model2 = model2.to(device)
    #train_loader.batch_size=128
    optimizer = torch.optim.AdamW(model1.parameters(),lr=1e-4)
    scheduler=torch.optim.lr_scheduler.OneCycleLR(optimizer,max_lr=1e-4,epochs=25,steps_per_epoch=len(test_loader))

    #atk.loss = torch.nn.MSELoss()
    mse = torch.nn.MSELoss()
    ce = torch.nn.CrossEntropyLoss()

    model1.train()
    model2.eval()
    import wandb
    config ={
            "eps" : eps,
            "model1" : model1.default_cfg["architecture"],
            "model2" : model2.default_cfg["architecture"],

        }
    wandb.init(project="adversarial-paper-ift6164-2", entity="j-bytes",config=config)
    
    wandb.watch(model1)
    wandb.watch(model2)
    
    for epoch in (pbar:=tqdm(range(0,25),leave=True,position=0)) :
        running_loss=0

        #atk.loss = lambda pred,true : mse(torch.nn.functional.softmax(pred,dim=1).float(),torch.nn.functional.one_hot(true,10).float())
        #atk.loss = lambda pred,true : ce(torch.nn.functional.softmax(pred,dim=1),true)
        std_logits=None
        sensitivity = torch.zeros((10,))
        ex,running_loss=0,0

        accuracy1 = []
        accuracy2 = []
        adv_accuracy1 = []
        adv_accuracy2 = []
        for images,labels in tqdm(test_loader,position=1,leave=False) :
            atk = PGD(model1, eps=eps, alpha=eps/3, steps=10)
            #atk.loss = lambda pred,true : ce(torch.nn.functional.softmax(pred,dim=1),true) 
            #atk.loss = mse
            images,labels = images.to(device,non_blocking=True),labels.to(device,non_blocking=True) 
            
            assert torch.min(images)>=0 and torch.max(images)<=1,f"{torch.min(images)},{torch.max(images)}"
            adv_images = atk(images,labels)# logit_model2_img) # supposed to use labels! ->sensitibity 1.06 with eps=.01/256
            
            adv_images = adv_images.to(device,non_blocking=True)
            for ex,(image,adv_image) in enumerate(zip(images,adv_images)) :
                images[ex] = normalize(image)
                adv_images[ex] = normalize(adv_image)
            
            
            with torch.cuda.amp.autocast(enabled=True) :
              #logit_model1_img        = model1(images)
              logit_model1_adv_img    = model1(adv_images)
              with torch.no_grad() :
                  logit_model1_img        = model1(images)
                  logit_model2_img        = model2(images).detach()
                  logit_model2_adv_img    = model2(adv_images).detach()



            if epoch != 0 :

              logit_model2_adv_img2 = logit_model2_adv_img.clone()
              mu = torch.mean(logit_model2_adv_img2)
              
              logit_model2_adv_img2[:,labels] = logit_model1_adv_img[:,labels]  # replace the correct label with an average
              loss = mse(logit_model2_adv_img2,logit_model1_adv_img)#+mse(logit_model2_img,logit_model1_img)
              #loss = ((logit_model2_adv_img-logit_model1_adv_img)**2).mean()

              loss.backward()
              optimizer.step()
              scheduler.step()
              optimizer.zero_grad()

              running_loss+=loss.item()


            if std_logits is None :
                std_logits = logit_model2_img.to("cpu",non_blocking=True)
            else :
                std_logits = torch.cat((std_logits,logit_model2_img.to("cpu",non_blocking=True)),dim=0)


            sensitivity += torch.mean(torch.abs(logit_model2_img-logit_model2_adv_img),dim=0).to("cpu",non_blocking=True)
            baseline  = (np.argmax(torch.softmax(logit_model2_img,dim=1).cpu().numpy(),axis=1) == labels.cpu().numpy()).mean().round(decimals=2)

            accuracy1.append( (np.argmax(torch.softmax(logit_model1_img,dim=1).cpu().detach().numpy(),axis=1) == labels.cpu().numpy()).mean().round(decimals=2) )
            accuracy2.append( (np.argmax(torch.softmax(logit_model2_img,dim=1).cpu().detach().numpy(),axis=1) == labels.cpu().numpy()).mean().round(decimals=2) )


            adv_accuracy1.append( (np.argmax(torch.softmax(logit_model1_adv_img,dim=1).cpu().detach().numpy(),axis=1) == labels.cpu().numpy()).mean().round(decimals=2) )
            adv_accuracy2.append( (np.argmax(torch.softmax(logit_model2_adv_img,dim=1).cpu().detach().numpy(),axis=1) == labels.cpu().numpy()).mean().round(decimals=2) )

        pbar.set_description(f"Sensitivity : {round(torch.mean(sensitivity/len(test_loader)/torch.abs(std_logits)).item(),ndigits=2)} ,  loss : {round(running_loss/len(test_loader),ndigits=2)} , acc1 : {np.mean(accuracy1).round(2)} , acc2 : {np.mean(accuracy2).round(2)},adv_acc1 : {np.mean(adv_accuracy1).round(2)},adv_acc2 : {np.mean(adv_accuracy2).round(2)} , baseline : {baseline}")

        metrics = {
            "sensitivity"   : torch.mean(sensitivity/len(test_loader)/torch.abs(std_logits)).round(decimals=2),
            "loss"          : running_loss/len(test_loader),
            "Adversarial accuracy 1"     : np.mean(adv_accuracy1),
            "Adversarial accuracy 2"     : np.mean(adv_accuracy2),
            "Accuracy 1" : np.mean(accuracy1),
            "Accuracy 2" : np.mean(accuracy2),
            "epoch"         : epoch,
        }
        wandb.log(metrics)

    wandb.finish()
            #adv_pred = torch.softmax(model1(adv_images),dim=1)


