In [None]:
import numpy as np                # import numpy
from tqdm import tqdm

#import torch packages
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader

if torch.cuda.is_available():
  print('Running on Graphics')
  device=torch.device('cuda:0')
else:
  device=torch.device('cpu')
  print('Running on Processor')

In [None]:
class bottleneck(nn.Module):
  def __init__(self, in_size, bn_size, act=nn.ReLU()):
    super().__init__()
    self.L1 = nn.Linear(in_size, bn_size)
    self.L2 = nn.Linear(bn_size, in_size)
    self.act = act

  def forward(self, x):
    self.bn = self.act(self.L1(x))
    x = self.act(self.L2(self.bn))
    return x

class Encoder(nn.Module):
  def __init__(self, in_size, bn_size):
    super().__init__()
    self.enc = nn.Sequential(
        nn.Linear(28*28, 100),
        nn.ReLU(),nn.Linear(100,in_size),
        nn.ReLU())
        
  def forward(self, x):
    x = torch.flatten(x, start_dim=1)
    x = self.enc(x)
    return x

class Decoder(nn.Module):
  def __init__(self, in_size, bn_size):
    super().__init__()
    self.dec = nn.Sequential(nn.Linear(in_size, 100), nn.ReLU(),
        nn.Linear(100, 28*28),
        nn.Sigmoid())

  def forward(self, x):
    x = self.dec(x)
    return x.reshape(-1,1,28,28)

class DNA(nn.Module):
    def __init__(self, in_size, bn_size):
        super().__init__()
        self.enc = Encoder(in_size, bn_size)
        self.dec = Decoder(in_size, bn_size)
        self.bn1 = bottleneck(in_size, bn_size)
        self.bn2 = bottleneck(in_size, bn_size)
        
    def forward(self, x):
        x = self.enc(x)
        r1 = self.dec(self.bn1(x))
        r2 = self.dec(self.bn2(x))
        return r1, r2

In [None]:
class Classifier(nn.Module):
  def __init__(self):
        super().__init__()
        self.c = nn.Sequential(
            nn.Conv2d(1,8,3,padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(8,16,3,padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(16,32,3,padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Flatten(start_dim=1),
            nn.Linear(288,10))
        
  def forward(self, x):
        self.out = self.c(x)
        return self.out 

In [None]:
def train(forward_fs, loss_f, optim, train_data, test_data, models, epochs, batch_size, auto=False, epsilon=0, loss_adv=None):
    metrics = []
    for i in tqdm(range(epochs)):
        t_loss=0
        for idx, (x, y) in enumerate(DataLoader(train_data, batch_size=batch_size, shuffle=True)):
          x = x.to(device)
          y = y.to(device)
          for forward_f in forward_fs:
              if epsilon != 0:
                x = gen_FGSM(x, y, epsilon, loss_adv, models)
              y_hat = forward_f(x)
              if auto:
                  loss = loss_f(y_hat, x)
              else:
                  loss = loss_f(y_hat, y)
              for model in models:
                  model.zero_grad()
              loss.backward()
              optim.step()
        for idx, (x, y) in enumerate(DataLoader(test_data, batch_size=batch_size)):
          t_loss = 0
          x = x.to(device)
          y = y.to(device)
          for forward_f in forward_fs:
              y_hat = forward_f(x)
              if auto:
                  loss = loss_f(y_hat, x)
              else:
                  loss = loss_f(y_hat, y)
              t_loss = t_loss + loss.data
        metrics.append(t_loss)
    return np.transpose(metrics)

In [None]:
train_data = MNIST('../../mnist_digits/', train=True, download=True,transform=torchvision.transforms.ToTensor())
test_data = MNIST('../../mnist_digits/', train=False, download=True,transform=torchvision.transforms.ToTensor())

In [None]:
#Find R for loss function
def get_R(X,Y):
    #First modify to create nonsingular X:
    _,R = torch.linalg.qr(X)
    cols = torch.diag(R)
    cols = abs(cols/torch.max(cols))>0.0005
    X = X[:,cols]

    X = torch.cat([X, torch.ones([batch_size,1]).to(device)],dim=1)
    Yhat = torch.matmul(torch.matmul(X,torch.linalg.pinv(X)),Y)
    Ehat = Y - Yhat
    SSres = torch.sum(torch.square(Ehat))
    Ybar = torch.mean(Y, dim=0).unsqueeze(0)
    SStot = torch.sum(torch.square(Y-Ybar))
    eta = 0.001 #constant for stability
    R = 1 - SSres/(SStot+eta)
    return torch.log(SStot+eta)-torch.log(SSres+eta)

In [None]:
auto_loss = nn.MSELoss()
lambda1 = 0.05
def total_loss(xr, x):
    L1 = auto_loss(x, xr[0])
    L2 = auto_loss(x, xr[1])
    L3 = get_R(dna.bn1.bn, dna.bn2.bn)
    total_loss = torch.sqrt(L1**2 + L2**2) + lambda1*L3
    print('MSE1:{} MSE2:{} R2:{}'.format(L1,L2,L3))
    return total_loss

In [None]:
in_size = 128
bn_size = 64

dna = DNA(in_size, bn_size).to(device)

In [None]:
#Train DNA
batch_size = 500
learning_rate = 5.0e-4
epochs = 40

optimizer = optim.Adam(dna.parameters(), lr = learning_rate)
forward_both = lambda x: dna(x)
models = [dna]

metric = train([forward_both], total_loss, optimizer, train_data, test_data, models, epochs, batch_size, auto=True)

In [None]:
#Train Classifier
cls = Classifier().to(device)
optimizer = optim.Adam(cls.parameters(), lr = learning_rate)
loss_ce = nn.CrossEntropyLoss()
models = [cls]
forward_f1 = lambda x: cls(dna(x)[0])
forward_f2 = lambda x: cls(dna(x)[1])

metric = train([forward_f1, forward_f2], loss_ce, optimizer, train_data, test_data, models, epochs, batch_size, auto=False)

In [None]:
torch.save(dna.state_dict(), 'models/dna')
torch.save(cls.state_dict(), 'models/cls_dna')