## Dataset

In [None]:
import mydataset

In [None]:
test_dl = mydataset.test_dl
for xb,yb in test_dl:
    print(xb.shape, yb.shape)
    break

## Pre-trained Model

In [None]:
import mymodel

In [None]:
model = mymodel.model

In [None]:
import torch
device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")
model = model.to(device)

In [None]:
def freeze_model(model):
    for child in model.children():
        for param in child.parameters():
            param.requires_grad = False
    print("model frozen")
    return model

In [None]:
model = freeze_model(model)

In [None]:
import numpy as np

def deploy_model(model, test_dl):
    y_pred = []
    y_gt = []
    with torch.no_grad():
        for x,y in test_dl:
            y_gt.append(y.item())
            out = model(x.to(device)).cpu().numpy()
            out = np.argmax(out, axis=1)[0]
            y_pred.append(out)    
    return y_pred, y_gt
y_pred, y_gt = deploy_model(model,test_dl)

In [None]:
from sklearn.metrics import accuracy_score
acc=accuracy_score(y_pred,y_gt)
print("accuracy: %.2f" %acc)

## FGS Attach

In [None]:
def perturb_input(xb, yb, model, alfa):
    xb = xb.to(device)
    xb.requires_grad = True
    out = model(xb).cpu()
    loss = F.nll_loss(out, yb)
    model.zero_grad()
    loss.backward()
    xb_grad = xb.grad.data
    xb_p = xb + alfa * xb_grad.sign()
    xb_p = torch.clamp(xb_p, 0, 1)
    return xb_p, out.detach()

In [None]:
from torchvision.transforms.functional import to_pil_image
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline

y_pred = []
y_pred_p = []
for xb,yb in test_dl:
    xb_p, out = perturb_input(xb, yb, model, alfa = 0.005)
    
    with torch.no_grad():
        pred = out.argmax(dim=1, keepdim=False).item()
        y_pred.append(pred) 
        prob = torch.exp(out[:, 1])[0].item()

        out_p = model(xb_p).cpu()
        pred_p = out_p.argmax(dim=1, keepdim=False).item()
        y_pred_p.append(pred_p)
        prob_p = torch.exp(out_p[:, 1])[0].item()
        
    plt.subplot(1, 2, 1)
    plt.imshow(to_pil_image(xb[0].detach().cpu()))
    plt.title(prob)
    plt.subplot(1, 2, 2)
    plt.imshow(to_pil_image(xb_p[0].detach().cpu()))
    plt.title(prob_p)
    plt.show()
    

In [None]:
acc=accuracy_score(y_pred,y_gt)
print("accuracy: %.2f" %acc)

In [None]:
acc=accuracy_score(y_pred_p,y_gt)
print("accuracy: %.2f" %acc)