This book is used to implement the VGG16 on Cifar-10 dataset, using the traditional backpropagation method and the simulationn approach for forward-mode autodiff.

In [1]:
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader, TensorDataset

import torch.nn as nn
import torch.nn.functional as F

import matplotlib.pyplot as plt
import os

DEVICE = torch.device('cuda')
os.environ['CUDA_VISIBLE_DEVICES'] = '0'


In [2]:
def switch_to_device(dataset,device=None):
    final_X, final_Y = [], []
    for x, y in dataset:
        final_X.append(x)
        final_Y.append(y)
    X = torch.stack(final_X)
    Y = torch.tensor(final_Y)
    if device is not None:
        X = X.to(device)
        Y = Y.to(device)
    return TensorDataset(X, Y)

In [3]:
def get_Cifar10_dl(batch_size_train=256, batch_size_eval=1024, device=DEVICE):
    transform = transforms.Compose([transforms.ToTensor()])
    
    data_train = CIFAR10('./datasets', train=True, download=True, transform=transform)
    data_train = switch_to_device(data_train, device=device)
    data_train, data_valid = torch.utils.data.random_split(data_train, [45000,5000])
    
    data_test = CIFAR10('./datasets', train=False, download=True, transform=transform)
    data_test = switch_to_device(data_test, device=device)
    
    train_dl = DataLoader(data_train, batch_size=batch_size_train, shuffle=True)
    valid_dl = DataLoader(data_valid, batch_size=batch_size_eval, shuffle=False)
    test_dl = DataLoader(data_test, batch_size=batch_size_eval, shuffle=False)
    
    return train_dl, valid_dl, test_dl

In [4]:
class VGG16(nn.Module):
    def __init__(self, num_classes=10):
        super(VGG16, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU())
        self.layer2 = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(), 
            nn.MaxPool2d(kernel_size = 2, stride = 2))
        self.layer3 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU())
        self.layer4 = nn.Sequential(
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size = 2, stride = 2))
        self.layer5 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU())
        self.layer6 = nn.Sequential(
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU())
        '''
        self.layer7 = nn.Sequential(
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size = 2, stride = 2))
        '''
        '''
        self.layer8 = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU())
        self.layer9 = nn.Sequential(
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU())
        self.layer10 = nn.Sequential(
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size = 2, stride = 2))
        self.layer11 = nn.Sequential(
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU())
        self.layer12 = nn.Sequential(
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU())
        self.layer13 = nn.Sequential(
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size = 2, stride = 2))
        '''
        self.fc = nn.Sequential(
            #nn.Dropout(0.5),
            nn.Linear(8*8*256, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU())
        self.fc1 = nn.Sequential(
            #nn.Dropout(0.5),
            nn.Linear(1024, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU())
        self.fc2= nn.Sequential(
            nn.Linear(1024, num_classes))
        
    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.layer5(out)
        out = self.layer6(out)
        #out = self.layer7(out)
        '''
        out = self.layer8(out)
        out = self.layer9(out)
        out = self.layer10(out)
        out = self.layer11(out)
        out = self.layer12(out)
        out = self.layer13(out)
        '''
        out = out.reshape(out.size(0), -1)
        out = self.fc(out)
        out = self.fc1(out)
        out = self.fc2(out)
        return out

In [5]:
def print_stats(stats):

  fig, (ax1, ax2) = plt.subplots(1,2,figsize=(7,3), dpi=110)
  ax1.grid()
  ax2.grid()

  ax1.set_title("ERM loss")
  ax2.set_title("Valid Acc")
  
  ax1.set_xlabel("iterations")
  ax2.set_xlabel("iterations")

  itrs = [x[0] for x in stats['train-loss']]
  loss = [x[1] for x in stats['train-loss']]
  ax1.plot(itrs, loss)

  itrs = [x[0] for x in stats['valid-acc']]
  acc = [x[1] for x in stats['valid-acc']]
  ax2.plot(itrs, acc)

  ax1.set_ylim(0.0, max(loss))
  ax2.set_ylim(0.0, 1.05)
  fig.savefig('testing.jpg', bbox_inches = 'tight')

In [6]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [7]:
@torch.no_grad()
def get_acc(model, dl):
  model.eval()
  acc = []
  for X, y in dl:
    #acc.append((torch.sigmoid(model(X)) > 0.5) == y)
    acc.append(torch.argmax(model(X), dim=1) == y)
  acc = torch.cat(acc)
  acc = torch.sum(acc)/len(acc)
  model.train()
  return acc.item()

In [8]:
def Normalize(input):
    mean = torch.mean(input, dim=1).view(-1, 1)
    std = torch.std(input, dim=1).view(-1, 1)
    return (input - mean)/std

In [9]:
def cal_proj_matrix(A):
    return A @ torch.inverse(A.T @ A) @ A.T

In [10]:
def run_experiment(model, opt, schedular, criterion, train_dl, valid_dl, test_dl, max_epochs, use_forward_grad, num_dir, use_linear_projection, use_cnn_projection):
    itr = -1
    stats = {'train-loss' : [], 'valid-acc' : []}

    if use_forward_grad:
        random_dir = {}
        for i, p in enumerate(model.parameters()):
            random_dir[i] = 0
    
    if use_linear_projection or use_cnn_projection:
        name_module = {}
        layer_inputs = {}
        def hook(mod, input):
            layer_inputs[mod] = input[0]
        for module in model.modules():
            module.register_forward_pre_hook(hook)
        for name, module in model.named_modules():
            name_module[name] = module


    for epoch in range(max_epochs):
        for x, y in train_dl:
            itr += 1
            opt.zero_grad()
            loss = criterion(model(x), y)
            loss.backward()

            if use_forward_grad:
                with torch.no_grad():
                    da = torch.zeros((num_dir, 1), device=DEVICE)
                    '''
                    for i, p in enumerate(model.parameters()):
                        g = p.grad.view(-1)
                        v = torch.randn(num_dir, len(g), device=DEVICE)#.sign()
                        random_dir[i] = v
                        da += (v @ g).view(num_dir, 1)
                    '''
                    for i, (name, parameters) in enumerate(model.named_parameters()):
                        g = parameters.grad.view(-1)

                        if len(parameters.shape) == 2 and use_linear_projection:
                            input = layer_inputs[name_module[name[:name.find('.', name.find('.') + 1)]]]
                            projection_matrix = cal_proj_matrix(input.T)
                            input_sample = torch.randn(num_dir, parameters.shape[1], device = DEVICE).view(num_dir, parameters.shape[1], 1)
                            input_sample = (projection_matrix @ input_sample).view(num_dir, 1, -1)
                            v = torch.randn(num_dir, parameters.shape[0], device = DEVICE).view(num_dir, parameters.shape[0], 1)
                            v = (v @ input_sample).view(num_dir, -1)

                        elif len(parameters.shape) == 4 and use_cnn_projection:
                            input = layer_inputs[name_module[name[:name.find('.', name.find('.') + 1)]]]
                            shape_input = input.shape
                            input = input.view(input.shape[0], -1)
                            projection_matrix = cal_proj_matrix(input.T)
                            input_sample = torch.randn(num_dir, input.shape[1], device = DEVICE).view(num_dir, input.shape[1], 1)
                            input_sample = (projection_matrix @ input_sample).view(num_dir, shape_input[1], shape_input[2], shape_input[3])
                            output_sample = torch.randn(num_dir, parameters.shape[0], shape_input[2], shape_input[3], device = DEVICE)
                            v = torch.zeros(num_dir, parameters.shape[0], parameters.shape[1], parameters.shape[2], parameters.shape[3], device = DEVICE)
                            for n in range(num_dir):
                                for f in range(parameters.shape[0]):
                                    v[n][f] = torch.nn.functional.conv2d(input_sample[n], output_sample[n][f].unsqueeze(0).unsqueeze(0).expand(parameters.shape[1], parameters.shape[1],shape_input[2], shape_input[3]), stride=1, padding=1)
                            v = v.view(num_dir, -1)
      
                        else:
                            v = torch.randn(num_dir, len(g), device=DEVICE)
                        random_dir[i] = v
                        da += (v @ g).view(num_dir, 1)

                    
                    for i, p in enumerate(model.parameters()):
                        g = da * random_dir[i]
                        p.grad = torch.mean(g, dim = 0).view(p.grad.shape)

            opt.step()
            if itr <= 10000:
                schedular.step()
            stats['train-loss'].append((itr, loss.item()))

            if itr % 10 == 0:
                valid_acc = get_acc(model, valid_dl)
                stats['valid-acc'].append((itr, valid_acc))
                s = f"{epoch}:{itr} [train] loss:{loss.item():.3f}, [valid] acc:{valid_acc:.3f}"
                print(s)

    test_acc = get_acc(model, test_dl)
    print(f"[test] acc:{test_acc:.3f}")

    return stats
        

In [11]:
model = VGG16().to(DEVICE)
print(count_parameters(model))

train_batch_size = 128
test_batch_size = 1024

opt = torch.optim.Adam(model.parameters(), lr = 1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=2500, gamma=0.5)
criterion =nn.CrossEntropyLoss()
max_epochs = 100

use_forward_grad = True
num_dir = 20
use_linear_projection = True
use_cnn_projection = True

18989386


In [12]:
train_dl, valid_dl, test_dl = get_Cifar10_dl(train_batch_size, test_batch_size, device = DEVICE)
stats = run_experiment(model, opt, scheduler, criterion, train_dl, valid_dl, test_dl, max_epochs, use_forward_grad, num_dir, use_linear_projection, use_cnn_projection)
print_stats(stats)

Files already downloaded and verified
Files already downloaded and verified
0:0 [train] loss:2.381, [valid] acc:0.104


KeyboardInterrupt: 