In [1]:
#HERE TO TEST IDEA FROM LINE CONVERSATION
import numpy as np                # import numpy
import matplotlib.pyplot as plt   # import matplotlib, a python 2d plotting library
from tqdm import tqdm
#from pytorch_msssim import ssim, ms_ssim, SSIM, MS_SSIM

#import torch packages
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader

if torch.cuda.is_available():
  print('Running on Graphics')
  device=torch.device('cuda:0')
else:
  device=torch.device('cpu')
  print('Running on Processor')

Running on Graphics


In [2]:
class Classifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.c1 = nn.Conv2d(1,8,3, padding=1)
        self.c2 = nn.Conv2d(8,16,3, padding=1)
        self.c3 = nn.Conv2d(16,32,3, padding=1)
        self.l = nn.Linear(32,10)
        self.pool = nn.MaxPool2d(2)
        self.avgpool = nn.AvgPool2d(7)
        self.act = nn.ReLU()
        
    def forward(self,x):
        x = self.pool(self.act(self.c1(x)))
        self.feat = self.pool(self.act(self.c2(x)))
        self.maps = self.act(self.c3(self.feat))
        x = self.avgpool(self.maps).flatten(start_dim=1)
        x = self.l(x)
        return x

In [4]:
train_data = MNIST('../mnist_digits/', train=True, download=True,transform=torchvision.transforms.ToTensor())
test_data = MNIST('../mnist_digits/', train=False, download=True,transform=torchvision.transforms.ToTensor())

RuntimeError: Dataset not found. You can use download=True to download it

In [None]:
def train(optim, train_data, test_data, models, epochs, batch_size):
    metrics = []
    n = len(test_data)
    for i in tqdm(range(epochs)):
        for idx, (x, y) in enumerate(DataLoader(train_data, batch_size=batch_size, shuffle=True)):
          x = x.to(device)
          y = y.to(device)
          yhat1 = models[0](x)
          yhat2 = models[1](x)
          for model in models:
              model.zero_grad()
          loss = get_loss(y, yhat1, yhat2, models[0], models[1])
          loss.backward()
          optim.step()
        t_acc = torch.zeros(2, dtype=torch.float32)
        for idx, (x, y) in enumerate(DataLoader(test_data, batch_size=batch_size)):
          x = x.to(device)
          y = y.to(device)
          for j in range(len(models)):
              y_hat = models[j](x)
              t_acc[j] = t_acc[j] + torch.sum(torch.argmax(y_hat, dim=1)==y)
        t_acc = t_acc/n
        metrics.append(t_acc)
    return metrics

In [None]:
#Find mutual residual (proxy for mutual information)
def get_R(X,Y):
    X = torch.flatten(X, start_dim=1)
    Y = torch.flatten(Y, start_dim=1)
    #First modify to create nonsingular X:
    _,R = torch.linalg.qr(X)
    cols = torch.diag(R)
    cols = abs(cols/torch.max(cols))>0.0005
    X = X[:,cols]

    X = torch.cat([X, torch.ones([batch_size,1]).to(device)],dim=1)
    Yhat = torch.matmul(torch.matmul(X,torch.linalg.pinv(X)),Y)
    #Yhat = torch.matmul(torch.matmul(X,get_pinv(X, q)), Y)
    Ehat = Y - Yhat
    SSres = torch.sum(torch.square(Ehat))
    Ybar = torch.mean(Y, dim=0).unsqueeze(0)
    SStot = torch.sum(torch.square(Y-Ybar))
    eta = 0.001 #constant for stability
    R = 1 - SSres/(SStot+eta)
    #print('SSres:{} SStot:{} R:{}'.format(SSres, SStot, R))
    return torch.log(SStot+eta)-torch.log(SSres+eta) #R

In [None]:
lambda1 = 0.0
loss_ce = nn.CrossEntropyLoss()
def get_loss(y,yhat1,yhat2, cls1, cls2):
    L1 = loss_ce(yhat1,y)
    L2 = 0#loss_ce(yhat2,y)
    L3 = 0#get_R(cls1.feat, cls2.feat)
    print('L1:{} L2:{} L3:{}'.format(L1,L2,L3))
    return L1+L2+lambda1*L3

In [None]:
def get2classfiers(epochs):
    cls1 = Classifier().to(device)
    cls2 = Classifier().to(device)
    optimizer = optim.Adam(list(cls1.parameters())+list(cls2.parameters()), lr = 5.0e-3)
    loss_ce = nn.CrossEntropyLoss()
    models = [cls1, cls2]
    metric = train(optimizer, train_data, test_data, models, epochs, batch_size)
    print(metric)
    return cls1, cls2

In [None]:
def gen_FGSM(x, y, eta, model):
    model.zero_grad()
    x.requires_grad = True
    y_hat = model(x)
    loss = loss_ce(y_hat, y)
    loss.backward()
    perturbed_x = torch.clamp(x + eta*(x.grad.data).sign(), min=0, max=1.0)
    return perturbed_x#, x.grad.data

In [None]:
def get_transfer(etas):
    trans_rate = np.zeros(len(etas))
    R2 = np.zeros(len(etas))
    cls1, cls2 = get2classfiers(epochs)
    for eidx, eta in enumerate(etas):
        for x,y in DataLoader(test_data, batch_size, shuffle=True):
            x = x.to(device)
            y = y.to(device)
            x = gen_FGSM(x,y,eta,cls1)
            c1=(torch.argmax(cls1(x), dim=1)==y)
            c2=(torch.argmax(cls2(x), dim=1)==y)
            adv_tran = torch.sum(~c1 & ~c2)/torch.sum(~c1)
            trans_rate[eidx] = trans_rate[eidx] + adv_tran#.detach().cpu().numpy()
            R2[eidx] = R2[eidx] + get_R(cls1.feat, cls2.feat)
    trans_rate = trans_rate*batch_size/(len(test_data))
    R2 = R2*batch_size/(len(test_data))
    
    return trans_rate, R2, cls1, cls2
                

In [None]:
etas = [0.0, 0.05, 0.1, 0.15]
epochs=20
batch_size=1000
trans_rate, R2, cls1, cls2 = get_transfer(etas)

In [None]:
trans_rate

In [None]:
np.savez('different_cnns', trans_rate_diff, R2_diff)

In [None]:
cls1.state_dict()['l.weight'].shape

In [None]:
a = cls1.maps[0:5,:,:,:]
b = cls1.state_dict()['l.weight'][0,:]
b = torch.zeros(32, dtype=torch.float32).to(device)
b[7] = 1.0
a.shape
c = torch.matmul(a.transpose(1,3),b).transpose(1,2)
print(c[3,:,:])
print(a[3,7,:,:])

In [None]:
#Generate Heat Maps
import cv2
det = torch.nn.CrossEntropyLoss(reduction='none')
eta = 0.0
x,y = next(iter(DataLoader(test_data, batch_size=100, shuffle=True)))
x = x.to(device)
y = y.to(device)
#x = gen_FGSM(x, y, eta, cls1)
yhat1 = cls1(x)
yhat2 = cls2(x)

order = torch.argsort(det(yhat1,y), descending=False)
#find best examples
x = x[order[0:5]]
yhat1 = torch.argmax(cls1(x), dim=1)
yhat2 = torch.argmax(cls2(x), dim=1)

w1 = cls1.state_dict()['l.weight'][yhat1,:]
w2 = cls1.state_dict()['l.weight'][yhat2,:]

maps1 = []
maps2 = []
ims = []

for i in range(len(w1)):
    map1 = torch.matmul(cls1.maps[i].transpose(0,2),w1[i,:]).transpose(0,1).detach().cpu().numpy()
    map2 = torch.matmul(cls2.maps[i].transpose(0,2),w2[i,:]).transpose(0,1).detach().cpu().numpy()
    map1 = cv2.resize(map1,(28,28))
    map1 = cv2.applyColorMap(np.uint8(map1*255/np.max(map1)), cv2.COLORMAP_JET)
    map2 = cv2.resize(map2,(28,28))
    map2 = cv2.applyColorMap(np.uint8(map2*255/np.max(map2)), cv2.COLORMAP_JET)
    maps1.append(map1)
    maps2.append(map2)
    
    ims.append(cv2.cvtColor(np.uint8(255*torch.squeeze(x[i]).detach().cpu().numpy()), cv2.COLOR_GRAY2RGB))

opacity=.85
fig, ax = plt.subplots(3,len(ims))
for i in range(len(ims)):
    ax[0,i].imshow(ims[i])
    ax[1,i].imshow(np.float32(opacity*maps1[i]+ims[i])/np.max(opacity*maps1[i]+ims[i]))
    ax[2,i].imshow(np.float32(opacity*maps2[i]+ims[i])/np.max(opacity*maps2[i]+ims[i]))
    for s in ax[:,i]:
        s.axes.xaxis.set_ticks([])
        s.axes.yaxis.set_ticks([])

In [None]:
#Generate Heat Maps
import cv2
det = torch.nn.CrossEntropyLoss(reduction='none')
eta = 0.0
x,y = next(iter(DataLoader(test_data, batch_size=5, shuffle=True)))
x = x.to(device)
y = y.to(device)
#x = gen_FGSM(x, y, eta, cls1)
yhat1 = cls1(x)
yhat2 = cls2(x)

#order = torch.argsort(det(yhat1,y), descending=False)
#find best examples
#x = x[order[0:5]]
cls1.eval()
cls2.eval()
yhat1 = torch.argmax(cls1(x), dim=1)
yhat2 = torch.argmax(cls2(x), dim=1)

w1 = cls1.state_dict()['l.weight'][yhat1,:]
w2 = cls1.state_dict()['l.weight'][yhat2,:]

maps1 = []
maps2 = []
ims = []

for i in range(len(w1)):
    map1 = torch.matmul(cls1.maps[i].transpose(0,2),w1[i,:]).transpose(0,1).detach().cpu().numpy()
    map2 = torch.matmul(cls2.maps[i].transpose(0,2),w2[i,:]).transpose(0,1).detach().cpu().numpy()
    map1 = cv2.resize(map1,(28,28))
    map1 = cv2.applyColorMap(np.uint8(map1*255/np.max(map1)), cv2.COLORMAP_JET)
    map2 = cv2.resize(map2,(28,28))
    map2 = cv2.applyColorMap(np.uint8(map2*255/np.max(map2)), cv2.COLORMAP_JET)
    maps1.append(map1)
    maps2.append(map2)
    
    ims.append(cv2.cvtColor(np.uint8(255*torch.squeeze(x[i]).detach().cpu().numpy()), cv2.COLOR_GRAY2RGB))

opacity=.85
fig, ax = plt.subplots(3,len(ims))
for i in range(len(ims)):
    ax[0,i].imshow(ims[i])
    ax[1,i].imshow(np.float32(opacity*maps1[i]+ims[i])/np.max(opacity*maps1[i]+ims[i]))
    ax[2,i].imshow(np.float32(opacity*maps2[i]+ims[i])/np.max(opacity*maps2[i]+ims[i]))
    for s in ax[:,i]:
        s.axes.xaxis.set_ticks([])
        s.axes.yaxis.set_ticks([])

In [None]:
fig.savefig('heat_map.jpg', format='jpg', dpi=600)

In [None]:
n = 5
x1 = decoder(bn1(encoder(x_adv[ind[0:n]]))).detach().cpu().numpy()
x2 = decoder(bn2(encoder(x_adv[ind[0:n]]))).detach().cpu().numpy()
x_n = x_adv[ind[0:n]].detach().cpu().numpy()
x_gt = x[ind[0:n]].detach().cpu().numpy()


    
fig, ax = plt.subplots(4,n)
for i in range(len(ax[0])):
    ax[0,i].imshow(np.squeeze(x_gt[i]))
    ax[1,i].imshow(np.squeeze(x_n[i]))
    ax[2,i].imshow(np.squeeze(x1[i]))
    ax[3,i].imshow(np.squeeze(x2[i]))
    for s in ax[:,i]:
        s.axes.xaxis.set_ticks([])
        s.axes.yaxis.set_ticks([])

#ax[0,0].set_ylabel('GT')
#ax[1,0].set_ylabel('Recon 1')
#ax[2,0].set_ylabel('Recon 2')