In [1]:
#HERE TO TEST IDEA FROM LINE CONVERSATION
import numpy as np                # import numpy
import matplotlib.pyplot as plt   # import matplotlib, a python 2d plotting library
from tqdm import tqdm
from pytorch_msssim import ssim, ms_ssim, SSIM, MS_SSIM

#import torch packages
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader

if torch.cuda.is_available():
  print('Running on Graphics')
  device=torch.device('cuda:0')
else:
  device=torch.device('cpu')
  print('Running on Processor')

Running on Graphics


In [2]:
class Classifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.c1 = nn.Conv2d(1,8,3, padding=1)
        self.c2 = nn.Conv2d(8,16,3, padding=1)
        self.c3 = nn.Conv2d(16,32,3, padding=1)
        self.l = nn.Linear(32,10)
        self.pool = nn.MaxPool2d(2)
        self.avgpool = nn.AvgPool2d(7)
        self.act = nn.ReLU()
        
    def forward(self,x):
        x = self.pool(self.act(self.c1(x)))
        self.feat = self.pool(self.act(self.c2(x)))
        x = self.avgpool(self.act(self.c3(self.feat))).flatten(start_dim=1)
        x = self.l(x)
        return x

In [3]:
train_data = MNIST('../mnist_digits/', train=True, download=True,transform=torchvision.transforms.ToTensor())
test_data = MNIST('../mnist_digits/', train=False, download=True,transform=torchvision.transforms.ToTensor())

In [35]:
def train(optim, train_data, test_data, models, epochs, batch_size, target):
    metrics = []
    n = len(test_data)
    for i in tqdm(range(epochs)):
        for idx, (x, y) in enumerate(DataLoader(train_data, batch_size=batch_size, shuffle=True)):
          x = x.to(device)
          y = y.to(device)
          yhat1 = models[0](x)
          yhat2 = models[1](x)
          for model in models:
              model.zero_grad()
          loss = get_loss(y, yhat1, yhat2, models, target)
          loss.backward()
          optim.step()
        t_acc = torch.zeros(2, dtype=torch.float32)
        for idx, (x, y) in enumerate(DataLoader(test_data, batch_size=batch_size)):
          x = x.to(device)
          y = y.to(device)
          for j in range(len(models)):
              y_hat = models[j](x)
              t_acc[j] = t_acc[j] + torch.sum(torch.argmax(y_hat, dim=1)==y)
        t_acc = t_acc/n
        metrics.append(t_acc)
    return metrics

In [36]:
lambda1 = 0.00001
loss_ce = nn.CrossEntropyLoss()
def get_loss(y,yhat1,yhat2, models, target): #Use as a loss function
    L1 = loss_ce(yhat1,y)
    L2 = loss_ce(yhat2,y)
    L3 = torch.tensor(0)
    a = list(models[0].parameters())
    b = list(models[1].parameters())
    for j in range(len(list(models[0].parameters()))):
        L3 = L3 + torch.sum(torch.square(a[j] - b[j]))  
    #print('L1:{} L2:{} L3:{}'.format(L1,L2,L3))
    return L1+L2+lambda1*torch.square(L3-target)

In [37]:
def get2classfiers(epochs, target): #Trains 2 classifiers in parallel with the target parametric distance
    cls1 = Classifier().to(device)
    cls2 = Classifier().to(device)
    optimizer = optim.Adam(list(cls1.parameters())+list(cls2.parameters()), lr = 5.0e-3)
    loss_ce = nn.CrossEntropyLoss()
    models = [cls1, cls2]
    metric = train(optimizer, train_data, test_data, models, epochs, batch_size, target)
    #print(metric)
    return cls1, cls2

In [38]:
def gen_FGSM(x, y, eta, model): #Generates FGSM adversarial samples
    model.zero_grad()
    x.requires_grad = True
    y_hat = model(x)
    loss = loss_ce(y_hat, y)
    loss.backward()
    perturbed_x = torch.clamp(x + eta*(x.grad.data).sign(), min=0, max=1.0)
    return perturbed_x#, x.grad.data

In [49]:
#Find mutual residual (proxy for mutual information)
def get_R(X,Y):
    X = torch.flatten(X, start_dim=1)
    Y = torch.flatten(Y, start_dim=1)
    #First modify to create nonsingular X:
    _,R = torch.linalg.qr(X)
    cols = torch.diag(R)
    cols = abs(cols/torch.max(cols))>0.0005
    X = X[:,cols]

    X = torch.cat([X, torch.ones([batch_size,1]).to(device)],dim=1)
    Yhat = torch.matmul(torch.matmul(X,torch.linalg.pinv(X)),Y)
    #Yhat = torch.matmul(torch.matmul(X,get_pinv(X, q)), Y)
    Ehat = Y - Yhat
    SSres = torch.sum(torch.square(Ehat))
    Ybar = torch.mean(Y, dim=0).unsqueeze(0)
    SStot = torch.sum(torch.square(Y-Ybar))
    eta = 0.001 #constant for stability
    R = 1 - SSres/(SStot+eta)
    return 1-SSres/SStot #torch.log(SStot+eta)-torch.log(SSres+eta) #R

In [50]:
def get_transfer(targets, etas):
    trans_rate = np.zeros([len(targets), len(etas)])
    R2 = np.zeros([len(targets), len(etas)])
    for tidx, target in enumerate(targets):
        cls1, cls2 = get2classfiers(epochs, target)
        for eidx, eta in enumerate(etas):
            for x,y in DataLoader(test_data, batch_size, shuffle=True):
                x = x.to(device)
                y = y.to(device)
                x = gen_FGSM(x,y,eta,cls1)
                c1=(torch.argmax(cls1(x), dim=1)==y)
                c2=(torch.argmax(cls2(x), dim=1)==y)
                adv_tran = torch.sum(~c1 & ~c2)/torch.sum(~c1)
                trans_rate[tidx,eidx] = trans_rate[tidx,eidx] + adv_tran#.detach().cpu().numpy()
                R2[tidx,eidx] = R2[tidx,eidx] + get_R(cls1.feat, cls2.feat)
    trans_rate = trans_rate*batch_size/(len(test_data))
    R2 = R2*batch_size/(len(test_data))
    
    return trans_rate, R2
                

In [51]:
targets = [1000,1500,2000,2500,3000]
etas = [0.0, 0.05, 0.1, 0.15]
epochs=10
batch_size=1000
trans_rate, R2 = get_transfer(targets, etas)

100%|██████████| 10/10 [00:43<00:00,  4.34s/it]


[tensor([0.4844, 0.4184]), tensor([0.7726, 0.7767]), tensor([0.8581, 0.8782]), tensor([0.8923, 0.9061]), tensor([0.9079, 0.9194]), tensor([0.9172, 0.9281]), tensor([0.9250, 0.9400]), tensor([0.9334, 0.9382]), tensor([0.9282, 0.9457]), tensor([0.9441, 0.9538])]


100%|██████████| 10/10 [00:39<00:00,  3.90s/it]


[tensor([0.3605, 0.3857]), tensor([0.7788, 0.7471]), tensor([0.8579, 0.8583]), tensor([0.9016, 0.8798]), tensor([0.9141, 0.9005]), tensor([0.9201, 0.9148]), tensor([0.9330, 0.9162]), tensor([0.9340, 0.9247]), tensor([0.9443, 0.9341]), tensor([0.9502, 0.9330])]


100%|██████████| 10/10 [00:37<00:00,  3.71s/it]


[tensor([0.2788, 0.2874]), tensor([0.6293, 0.7081]), tensor([0.8308, 0.8521]), tensor([0.8813, 0.8840]), tensor([0.9133, 0.9093]), tensor([0.9301, 0.9216]), tensor([0.9353, 0.9294]), tensor([0.9402, 0.9358]), tensor([0.9487, 0.9337]), tensor([0.9540, 0.9437])]


100%|██████████| 10/10 [00:36<00:00,  3.66s/it]


[tensor([0.2577, 0.2896]), tensor([0.6796, 0.5628]), tensor([0.7928, 0.7589]), tensor([0.8707, 0.8456]), tensor([0.8782, 0.8787]), tensor([0.8962, 0.8914]), tensor([0.9128, 0.9078]), tensor([0.9260, 0.9277]), tensor([0.9241, 0.9345]), tensor([0.9399, 0.9381])]


100%|██████████| 10/10 [00:36<00:00,  3.69s/it]


[tensor([0.2231, 0.2193]), tensor([0.5888, 0.5722]), tensor([0.8040, 0.7796]), tensor([0.8465, 0.8492]), tensor([0.8863, 0.8667]), tensor([0.9015, 0.8951]), tensor([0.9196, 0.9079]), tensor([0.9208, 0.9083]), tensor([0.9340, 0.9185]), tensor([0.9403, 0.9215])]


In [54]:
np.savez('cnn_transfer.npz', trans_rate, R2, np.array(targets), np.array(etas))