In [None]:
import timm
import torch
from torch.utils.data import DataLoader
import torchattacks
import torchvision.transforms as transforms
from torchvision import datasets
import albumentations as A
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
from utils import training_loop,validation_loop
import tqdm
from torch.utils.data import Dataset
device= "cuda" if torch.cuda.is_available() else "cpu"
names = ["airplane","automobile","bird","cat","deer","dog","frog","horse","ship","truck"]
num_classes=len(names)

In [None]:
import torchattacks
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)

model1 = timm.create_model("resnet50",pretrained=True,num_classes=num_classes).to(device) #trained on imagenet
model2 = timm.create_model("convnext_small",pretrained=False,num_classes=num_classes).to(device) # will be trained on cifar10
train_transform=transforms.Compose(
    [
        transforms.Resize(32),
        transforms.CenterCrop(32),
        transforms.ToTensor(),
        transforms.Normalize(mean,std)
    ]
)

val_transform=transforms.Compose(
    [
                transforms.Resize(32),
                transforms.CenterCrop(32),
                transforms.ToTensor(),
                transforms.Normalize(mean, std),
    ]
)

def target_transform(label) :
     label_vector = torch.zeros((num_classes))
     label_vector[label]=1
     return label_vector
training_data = datasets.CIFAR10(
    root="/mnt/f/data",
    train=True,
    download=True,
    transform=train_transform,
    target_transform=target_transform
)

test_data = datasets.CIFAR10(
    root="/mnt/f/data",
    train=False,
    download=True,
    transform=val_transform,
    target_transform=target_transform
)


atk = torchattacks.PGD(model1, eps=8/255, alpha=8/3/255, steps=100)


In [None]:
train_loader = DataLoader(
    training_data,
    num_workers=4,
    batch_size=128,
    shuffle=True
)
test_loader= DataLoader(
    test_data,
    num_workers=4,
    batch_size=128
)

In [None]:
for model in  [model2] :

    optimizer = torch.optim.SGD(model.parameters(),lr=1e-3)
    criterion = lambda  pred, true : torch.nn.functional.cross_entropy(torch.softmax(pred,dim=1),true)

    scaler = torch.cuda.amp.GradScaler()
    max_epoch = 10
    scheduler=torch.optim.lr_scheduler.OneCycleLR(max_lr=1e-3,optimizer=optimizer,epochs=max_epoch,steps_per_epoch=len(train_loader))
    model = model.to(device)
    for epoch in range(0,max_epoch) :
        loss = training_loop(model,tqdm.tqdm(train_loader),optimizer=optimizer,criterion=criterion,device=device,scaler=scaler,clip_norm=10,autocast=False,scheduler=scheduler,epoch=max_epoch)

        val_loss,results = validation_loop(model,test_loader,criterion,device,autocast=False)

        print(f"Training loss : {loss/len(train_loader)} ,  validation loss : {val_loss/len(test_loader)}")

In [None]:

# If, images are normalized:
# atk.set_normalization_used(mean=[...], std=[...])
optimizer = torch.optim.SGD(model1.parameters(),lr=1e-3)
accumulate=1000

iter = 0
for images,labels in tqdm.tqdm(train_loader) :


    images,labels = images.to(device),labels.to(device)
    adv_images = atk(images, labels)
    adv_images = adv_images.to(device)
    pred_model1_adv_img = torch.softmax(model1(adv_images),dim=1)
    pred_model2_img = torch.softmax(model2(images),dim=1)
    pred_model2_adv_img = torch.softmax(model2(adv_images),dim=1)


    loss = torch.mean((pred_model2_adv_img-pred_model1_adv_img)**2)
    loss.backward()
    if iter == accumulate//train_loader.batch_size :
        optimizer.step()
        optimizer.zero_grad()
        iter=0
        sensitivity = torch.mean(torch.abs(pred_model2_img-pred_model2_adv_img))
        print(f"Sensitivity : {sensitivity} ,  loss : {loss}")

    iter+=1
    #adv_pred = torch.softmax(model1(adv_images),dim=1)




In [None]:
print(adv_out.shape, adv_im.shape)
#plt.imshow(adv_out[0].detach().cpu().numpy())
#plt.show()
for ex,(im,pred) in enumerate(zip(adv_im,adv_out)) :
    print(names[torch.argmax(torch.softmax(pred,dim=0))],pred)
    plt.imshow(im.detach().cpu().permute(1,2,0).numpy())
    plt.show()
    if ex==10 :
        break