<a href="https://colab.research.google.com/github/Aviral09/JacobianReg/blob/master/Jacobian_Reg.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Imports
import torch
from torch import nn, norm, randn, ones, zeros, autograd, addcdiv
from torch.utils.data import DataLoader, random_split
from torchvision import datasets
from torchvision.transforms import ToTensor
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline
from torchsummary import summary
import gc

In [None]:
torch.cuda.empty_cache()

In [None]:
# Downloading datasets
cifar_train = datasets.CIFAR10(
    root="data/CIFAR",
    train=True,
    download=True,
    transform = ToTensor()
)

cifar_test = datasets.CIFAR10(
    root="data/CIFAR",
    train=False,
    download=True,
    transform = ToTensor()
)

Files already downloaded and verified
Files already downloaded and verified


In [None]:
batch_size = 64
cifar_train_dataloader = DataLoader(cifar_train, batch_size=batch_size)
cifar_test_dataloader = DataLoader(cifar_test, batch_size=batch_size)

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} for training")

Using cuda for training


In [None]:
class Jreg(nn.Module):
  def __init__(self, n):
    self.n = n
    super(Jreg, self).__init__()
  
  def forward(self, x, y):
    nproj = self.n
    jreg=0
    for i in range(nproj):
      if nproj == y.shape[1]:
        v = zeros(y.shape[0], y.shape[1])
        v[:,i] = 1
      elif y.shape[1] == 1:
        v = ones(y.shape[0])
      else:
        v = randn(y.shape[0], y.shape[1])
        vnorm = norm(v, 2, 1, True)
        v = addcdiv(zeros(y.shape[0], y.shape[1]), 1.0, v, vnorm)
      v= v.to(device)
      j, = autograd.grad(y.reshape(-1), x, v.reshape(-1), retain_graph=True, create_graph=True)
      jreg += (y.shape[1])*(norm(j, dim=None)**2) /(nproj*y.shape[0])
    return 0.5*jreg


In [None]:
class CIFARModel(nn.Module):
  def __init__(self):
    super(CIFARModel,self).__init__()
    self.conv1 = nn.Conv2d(3, 6, 3)
    self.pool = nn.MaxPool2d(2, 2)
    self.conv2 = nn.Conv2d(6, 16, 3)
    self.fc1 = nn.Linear(16 * 6 * 6, 120)
    self.fc2 = nn.Linear(120, 84)
    self.fc3 = nn.Linear(84, 10)

  def forward(self,x):
    x = self.pool(F.relu(self.conv1(x)))
    x = self.pool(F.relu(self.conv2(x)))
    x = torch.flatten(x, 1) # flatten all dimensions except batch
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = self.fc3(x)
    return x

In [None]:
cifarmodel = CIFARModel().to(device)
summary(cifarmodel, (3,32,32))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1            [-1, 6, 30, 30]             168
         MaxPool2d-2            [-1, 6, 15, 15]               0
            Conv2d-3           [-1, 16, 13, 13]             880
         MaxPool2d-4             [-1, 16, 6, 6]               0
            Linear-5                  [-1, 120]          69,240
            Linear-6                   [-1, 84]          10,164
            Linear-7                   [-1, 10]             850
Total params: 81,302
Trainable params: 81,302
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.01
Forward/backward pass size (MB): 0.08
Params size (MB): 0.31
Estimated Total Size (MB): 0.40
----------------------------------------------------------------


In [None]:
def train(dataloader,model,loss_fn,optimizer, reg):
  size = len(dataloader.dataset)
  model.train()
  for batch, (x, y) in enumerate(dataloader):
    x, y = x.to(device), y.to(device)
    x = torch.tensor(x, requires_grad=True)
    optimizer.zero_grad()
    
    # Calculating model predictions
    y_hat = model(x)
    mainloss = loss_fn(y_hat, y)
    jloss = reg(x,y_hat)
    loss = mainloss + jlambda*jloss

    # Backpropagation
    loss.backward()
    optimizer.step()

    if batch % 100 == 0:
        loss, current, mainloss, jloss = loss.item(), batch, mainloss.item(), jloss.item()        
        print(f"Training loss Total loss: {loss} Main loss: {mainloss} Jacob Loss: {jloss} [{current}/{int(size/batch_size)+1}]")

In [None]:
def test(dataloader, model, loss_fn, jlambda):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct, mainloss, jloss = 0, 0, 0, 0
    flag = False
    reg = 0
    for X, y in dataloader:
        X, y = X.to(device), y.to(device)
        X.requires_grad = True
        pred = model(X)
        if flag == False:
          reg = Jreg(pred.shape[1])
          flag = True
        mainloss += loss_fn(pred, y).item()
        jloss += reg(X, pred)
        test_loss += mainloss + jloss*jlambda
        correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    mainloss /= num_batches
    jloss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct)}%, Avg loss: {test_loss}, Main loss: {mainloss}, Jacob loss: {jloss} \n")
    return test_loss, correct, mainloss, jloss

In [None]:
loss,correct, mainloss, jloss = 0, 0, 0, 0
jlambda = 0.1
nproj = 2

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(cifarmodel.parameters(),lr=1e-4)
reg = Jreg(nproj)
epochs = 20

for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(cifar_train_dataloader, cifarmodel, loss_fn, optimizer, reg)
    gc.collect()
print("Done!")

loss, correct, mainloss, jloss = test(cifar_test_dataloader, cifarmodel, loss_fn, jlambda)

Epoch 1
-------------------------------
Training loss Total loss: 2.307790994644165 Main loss: 2.307739734649658 Jacob Loss: 0.0005135788815096021 [0/782]


  


Training loss Total loss: 2.2897164821624756 Main loss: 2.289475440979004 Jacob Loss: 0.002411404624581337 [100/782]
Training loss Total loss: 2.1716253757476807 Main loss: 2.1626124382019043 Jacob Loss: 0.09012880176305771 [200/782]
Training loss Total loss: 2.149374008178711 Main loss: 2.119419574737549 Jacob Loss: 0.2995452880859375 [300/782]
Training loss Total loss: 2.149991750717163 Main loss: 2.1078574657440186 Jacob Loss: 0.42134392261505127 [400/782]
Training loss Total loss: 2.101565361022949 Main loss: 2.044743299484253 Jacob Loss: 0.5682212710380554 [500/782]
Training loss Total loss: 2.149871587753296 Main loss: 2.0887625217437744 Jacob Loss: 0.611090898513794 [600/782]
Training loss Total loss: 2.0146830081939697 Main loss: 1.9444307088851929 Jacob Loss: 0.7025219202041626 [700/782]
Epoch 2
-------------------------------
Training loss Total loss: 2.131319761276245 Main loss: 2.070870876312256 Jacob Loss: 0.6044883728027344 [0/782]
Training loss Total loss: 2.053279399871