This notebook is used to implement binarized version of VGG16 on Cifar-10 dataset, using the traditional backpropagation method and the simulation approach for forward-mode autodiff.

In [46]:
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10, MNIST
from torch.utils.data import DataLoader, TensorDataset

import torch.nn as nn
import torch.nn.functional as F

import matplotlib.pyplot as plt
import os
import warnings

warnings.filterwarnings("ignore", category=Warning)
DEVICE = torch.device('cuda')
os.environ['CUDA_VISIBLE_DEVICES'] = '3'
print(DEVICE)

cuda


In [47]:
def switch_to_device(dataset,device=None):
    final_X, final_Y = [], []
    for x, y in dataset:
        final_X.append(x)
        final_Y.append(y)
    X = torch.stack(final_X)
    Y = torch.tensor(final_Y)
    if device is not None:
        X = X.to(device)
        Y = Y.to(device)
    return TensorDataset(X, Y)

In [48]:
def get_Cifar10_dl(batch_size_train=256, batch_size_eval=1024, device=DEVICE):
    transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    
    data_train = CIFAR10('./datasets', train=True, download=True, transform=transform)
    data_train = switch_to_device(data_train, device=device)
    data_train, data_valid = torch.utils.data.random_split(data_train, [45000,5000])
    
    data_test = CIFAR10('./datasets', train=False, download=True, transform=transform)
    data_test = switch_to_device(data_test, device=device)
    
    train_dl = DataLoader(data_train, batch_size=batch_size_train, shuffle=True)
    valid_dl = DataLoader(data_valid, batch_size=batch_size_eval, shuffle=False)
    test_dl = DataLoader(data_test, batch_size=batch_size_eval, shuffle=False)
    
    return train_dl, valid_dl, test_dl

In [49]:
def get_mnist_dl(batch_size_train=1024, batch_size_eval=1024, device=torch.device('cuda')):
    transform = transforms.Compose([transforms.ToTensor()])
    
    data_train = MNIST('./datasets', train=True, download=True, transform=transform)
    data_train = switch_to_device(data_train, device=device)
    data_train, data_valid = torch.utils.data.random_split(data_train, [55000,5000])
    
    data_test = MNIST('./datasets', train=False, download=True, transform=transform)
    data_test = switch_to_device(data_test, device=device)
    
    train_dl = DataLoader(data_train, batch_size=batch_size_train, shuffle=True)
    valid_dl = DataLoader(data_valid, batch_size=batch_size_eval, shuffle=False)
    test_dl = DataLoader(data_test, batch_size=batch_size_eval, shuffle=False)
    
    return train_dl, valid_dl, test_dl

In [50]:
def print_stats(stats):

  fig, (ax1, ax2) = plt.subplots(1,2,figsize=(7,3), dpi=110)
  ax1.grid()
  ax2.grid()

  ax1.set_title("ERM loss")
  ax2.set_title("Valid Acc")
  
  ax1.set_xlabel("iterations")
  ax2.set_xlabel("iterations")

  itrs = [x[0] for x in stats['train-loss']]
  loss = [x[1].cpu().detach().numpy() for x in stats['train-loss']]
  ax1.plot(itrs, loss)

  itrs = [x[0] for x in stats['valid-acc']]
  acc = [x[1] for x in stats['valid-acc']]
  ax2.plot(itrs, acc)

  ax1.set_ylim(0.0, max(loss))
  ax2.set_ylim(0.0, 1.05)

In [51]:
@torch.no_grad()
def get_acc(model, dl):
  acc = []
  for X, y in dl:
    #acc.append((torch.sigmoid(model(X)) > 0.5) == y)
    acc.append(torch.argmax(model(X), dim=1) == y)
  acc = torch.cat(acc)
  acc = torch.sum(acc)/len(acc)
  return acc.item()

In [52]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [53]:
def Binarize(x, quant_mode = 'det'):
    if quant_mode == 'det':
        return x.sign()
    else:
        return x.add_(1).div_(2).add_(torch.rand(x.size()).add(-0.5)).clamp_(0,1).round().mul_(2).add_(-1)

In [54]:
class BinarizeLinear(nn.Linear):

    def __init__(self, *kargs, **kwargs):
        super(BinarizeLinear, self).__init__(*kargs, **kwargs)

    def forward(self, input):
        input.data=Binarize(input.data)

        if not hasattr(self.weight,'org'):
            self.weight.org=self.weight.data.clone()
        self.weight.data=Binarize(self.weight.org)
        out = nn.functional.linear(input, self.weight)

        if not hasattr(self.bias,'org'):
            self.bias.org=self.bias.data.clone()
        self.bias.data = Binarize(self.bias.org)
        out += self.bias.view(1, -1).expand_as(out)

        return out

In [55]:
class BinarizeConv2D(nn.Conv2d):
    def __init__(self, *kargs, **kwargs):
        super(BinarizeConv2D, self).__init__(*kargs, **kwargs)

    def forward(self, input, ba = True):
        if ba:
            input.data=Binarize(input.data)
        
        if not hasattr(self.weight, 'org'):
            self.weight.org = self.weight.data.clone()
        self.weight.data = Binarize(self.weight.org)

        if not hasattr(self.bias, 'org'):
            self.bias.org = self.bias.clone()
        self.bias.data = Binarize(self.bias.org)

        out = torch.nn.functional.conv2d(input, self.weight, self.bias, stride = self.stride, padding = self.padding)

        return out
    

In [56]:
class BVGG16(nn.Module):
    def __init__(self, num_classes=10):
        super(BVGG16, self).__init__()
        self.conv1 = BinarizeConv2D(3, 64, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        self.htanh1 = nn.Hardtanh()
        #self.htanh1 = nn.ReLU()

        self.conv2 = BinarizeConv2D(64, 64, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.htanh2 = nn.Hardtanh()
        #self.htanh2 = nn.ReLU()
        self.maxpooling2 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv3 = BinarizeConv2D(64, 128, kernel_size=3, stride=1, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.htanh3 = nn.Hardtanh()
        #self.htanh3 = nn.ReLU()

        self.conv4 = BinarizeConv2D(128, 128, kernel_size=3, stride=1, padding=1)
        self.bn4 = nn.BatchNorm2d(128)
        self.htanh4 = nn.Hardtanh()
        #self.htanh4 = nn.ReLU()
        self.maxpooling4 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv5 = BinarizeConv2D(128, 256, kernel_size=3, stride=1, padding=1)
        self.bn5 = nn.BatchNorm2d(256)
        #self.htanh5 = nn.ReLU()
        self.htanh5 = nn.Hardtanh()

        self.conv6 = BinarizeConv2D(256, 256, kernel_size=3, stride=1, padding=1)
        self.bn6 = nn.BatchNorm2d(256)
        #self.htanh6 = nn.ReLU()
        self.htanh6 = nn.Hardtanh()

        self.conv7 = BinarizeConv2D(256, 512, kernel_size=3, stride=1, padding=1)
        self.bn7 = nn.BatchNorm2d(512)
        #self.htanh7 = nn.ReLU()
        self.htanh7 = nn.Hardtanh()
        self.maxpooling7 = nn.MaxPool2d(kernel_size=2, stride=2)
        '''
        self.conv8 = BinarizeConv2D(256, 512, kernel_size=3, stride=1, padding=1)
        self.bn8 = nn.BatchNorm2d(512)
        #self.htanh7 = nn.ReLU()
        self.htanh8 = nn.Hardtanh()

        self.conv9 = BinarizeConv2D(512, 512, kernel_size=3, stride=1, padding=1)
        self.bn9 = nn.BatchNorm2d(512)
        #self.htanh7 = nn.ReLU()
        self.htanh9 = nn.Hardtanh()

        self.conv10 = BinarizeConv2D(512, 512, kernel_size=3, stride=1, padding=1)
        self.bn10 = nn.BatchNorm2d(512)
        #self.htanh7 = nn.ReLU()
        self.htanh10 = nn.Hardtanh()
        self.maxpooling10 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv11 = BinarizeConv2D(512, 512, kernel_size=3, stride=1, padding=1)
        self.bn11 = nn.BatchNorm2d(512)
        #self.htanh7 = nn.ReLU()
        self.htanh11 = nn.Hardtanh()

        self.conv12 = BinarizeConv2D(512, 512, kernel_size=3, stride=1, padding=1)
        self.bn12 = nn.BatchNorm2d(512)
        #self.htanh7 = nn.ReLU()
        self.htanh12 = nn.Hardtanh()

        self.conv13 = BinarizeConv2D(512, 512, kernel_size=3, stride=1, padding=1)
        self.bn13 = nn.BatchNorm2d(512)
        #self.htanh7 = nn.ReLU()
        self.htanh13 = nn.Hardtanh()
        self.maxpooling13 = nn.MaxPool2d(kernel_size=2, stride=2)
        '''
        self.fc14 = BinarizeLinear(4*4*512, 1024)
        self.bn14 = nn.BatchNorm1d(1024)
        self.htanh14 = nn.Hardtanh()

        self.fc15 = BinarizeLinear(1024, 1024)
        self.bn15 = nn.BatchNorm1d(1024)
        self.htanh15=  nn.Hardtanh()

        self.fc16 = BinarizeLinear(1024, 10)
        #self.softmax = nn.LogSoftmax(dim = 1)
    
    def forward(self, input):
        x = self.htanh1(self.bn1(self.conv1(input, ba = False)))
        x = self.maxpooling2(self.htanh2(self.bn2(self.conv2(x))))
        x = self.htanh3(self.bn3(self.conv3(x)))
        x = self.maxpooling4(self.htanh4(self.bn4(self.conv4(x))))
        x = self.htanh5(self.bn5(self.conv5(x)))
        x = self.htanh6(self.bn6(self.conv6(x)))
        x = self.maxpooling7(self.htanh7(self.bn7(self.conv7(x))))
        '''
        x = self.htanh8(self.bn8(self.conv8(x)))
        x = self.htanh9(self.bn9(self.conv9(x)))
        x = self.maxpooling10(self.htanh10(self.bn10(self.conv10(x))))
        x = self.htanh11(self.bn11(self.conv11(x)))
        x = self.htanh12(self.bn12(self.conv12(x)))
        x = self.maxpooling13(self.htanh13(self.bn13(self.conv13(x))))
        '''
        x = x.reshape(x.size(0), -1)
        x = self.htanh14(self.bn14(self.fc14(x)))
        x = self.htanh15(self.bn15(self.fc15(x)))
        x = self.fc16(x)
        #x = self.softmax(x)

        return x


In [57]:
class BMLP_1(nn.Module):
    def __init__(self):
        super(BMLP_1, self).__init__()
        self.fc_1 = BinarizeLinear(28*28, 1024, bias = True, device = DEVICE)
        self.htan_1 = nn.Hardtanh()
        self.bn_1 = nn.BatchNorm1d(1024, device = DEVICE)

        self.fc_2 = BinarizeLinear(1024, 1024, bias = True, device = DEVICE)
        self.htan_2 = nn.Hardtanh()
        self.bn_2 = nn.BatchNorm1d(1024, device = DEVICE)

        self.fc_3 = BinarizeLinear(1024, 1024, bias = True, device = DEVICE)
        self.htan_3 = nn.Hardtanh()
        self.bn_3 = nn.BatchNorm1d(1024, device = DEVICE)
        
        self.fc_4 = BinarizeLinear(1024, 10, bias = True, device = DEVICE)

    def forward(self, input):
        x = torch.reshape(input, (input.shape[0], -1))
        x = self.fc_1(x)
        x = self.bn_1(x)
        x = self.htan_1(x)
        #x = self.bn_1(x)

        x = self.fc_2(x)
        x = self.bn_2(x)
        x = self.htan_2(x)
        #x = self.bn_2(x)

        x = self.fc_3(x)
        x = self.bn_3(x)
        x = self.htan_3(x)
        #x = self.bn_3(x)
        
        x = self.fc_4(x)
        return x

In [58]:
def run_experiment(model, opt, scheduler, softweight_flag, trick_flag, bits_storage, criterion, train_dl, valid_dl, test_dl, max_epochs, use_forward_grad, num_dir):
    itr = -1
    stats = {'train-loss' : [], 'valid-acc' : []}
    model.train()
    accumulation = {}
    change = {}
    threshold = bits_storage

    if use_forward_grad:
        random_dir = {}
        for i, p in enumerate(model.parameters()):
            random_dir[i] = 0

    for i, p in enumerate(model.parameters()):
        accumulation[i] = torch.zeros(p.data.shape, device = DEVICE)
        change[i] = torch.zeros(p.data.shape, device = DEVICE)

    for epoch in range(max_epochs):
        for x, y in train_dl:
            itr += 1
            opt.zero_grad()
            loss = criterion(model(x), y)
            stats['train-loss'].append((itr, loss))
            loss.backward()

            if use_forward_grad:
                with torch.no_grad():
                    da = torch.zeros((num_dir, 1), device = DEVICE)

                    for i, p in enumerate(model.parameters()):
                        g = p.grad.view(-1)
                        v = torch.randn(num_dir, len(g), device = DEVICE)
                        da += (v@g).view(num_dir, 1)
                        random_dir[i] = v
                    
                    for i, p in enumerate(model.parameters()):
                        g = torch.mean(da * random_dir[i], dim = 0)
                        p.grad = g.view(p.grad.shape)


            if softweight_flag:
                for p in list(model.parameters()):
                    if hasattr(p, 'org'):
                        p.data.copy_(p.org)
            elif trick_flag:
                for i, p in enumerate(model.parameters()):
                    if hasattr(p, 'org'):
                        #p.grad = p.grad*10
                        tmp_accumulation = accumulation[i] + p.grad.sign()#.int()
                        tmp_accumulation = tmp_accumulation.clamp_(-bits_storage, bits_storage)
                        possible_pos = (tmp_accumulation.sign() == p.grad.sign())
                        accumulation[i] = tmp_accumulation.clone()
                        p.grad = p.grad*torch.abs(tmp_accumulation)*possible_pos
                        p.grad[torch.abs(tmp_accumulation)>(bits_storage-0.5)]*=(1e12/(1/opt.state_dict()['param_groups'][0]['lr']))
                        #p.grad[torch.abs(tmp_accumulation)>31.5] = p.data[torch.abs(tmp_accumulation)>31.5].sign()*100000
                        change[i] = ((p.data.sign()*p.grad)>(1/opt.state_dict()['param_groups'][0]['lr']))
                        accumulation[i] = accumulation[i] * ~change[i]
            
            opt.step()
            if itr%2000 == 0 and itr != 0 and bits_storage<128:
                bits_storage *= 2
            #scheduler.step()

            for p in list(model.parameters()):
                if hasattr(p, 'org'):
                    if softweight_flag:
                        p.org.copy_(p.data.clamp_(-1,1))
                    else:
                        p.org.copy_(p.data.sign())
            
            if itr % 100 == 0:
                valid_acc = get_acc(model, valid_dl)
                stats['valid-acc'].append((itr, valid_acc))
                s = f"{epoch}:{itr} [train] loss:{loss.item():.3f}, [valid] acc:{valid_acc:.3f}"
                print(s)
        #scheduler.step()
            
    test_acc = get_acc(model, test_dl)
    print(f"[test] acc:{test_acc:.3f}")

    return stats

In [59]:
model = BVGG16().to(DEVICE)
print(count_parameters(model))

train_batch_size = 128
test_batch_size = 1024

opt = torch.optim.Adam(model.parameters(), lr = 5e-3)
#opt = torch.optim.SGD(model.parameters(), lr = 1e-4)

scheduler = torch.optim.lr_scheduler.ExponentialLR(opt, gamma=0.98)
#scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=1500, gamma=0.2)

softweight_flag = True
trick_flag = False
bits_storage = 16

criterion =nn.CrossEntropyLoss()
#criterion = nn.MultiMarginLoss(p = 2)

max_epochs = 500

use_forward_grad = True
num_dir = 20

11781962


In [60]:
train_dl, valid_dl, test_dl = get_Cifar10_dl(train_batch_size, test_batch_size, device = DEVICE)
#train_dl, valid_dl, test_dl = get_mnist_dl(train_batch_size, test_batch_size, device = DEVICE)
stats = run_experiment(model, opt, scheduler, softweight_flag, trick_flag, bits_storage, criterion, train_dl, valid_dl, test_dl, max_epochs, use_forward_grad, num_dir)
print_stats(stats)

Files already downloaded and verified
Files already downloaded and verified
0:0 [train] loss:48.982, [valid] acc:0.104
0:100 [train] loss:52.259, [valid] acc:0.093
0:200 [train] loss:46.334, [valid] acc:0.107
0:300 [train] loss:45.851, [valid] acc:0.097
1:400 [train] loss:55.076, [valid] acc:0.100
1:500 [train] loss:44.620, [valid] acc:0.104
1:600 [train] loss:46.636, [valid] acc:0.098
1:700 [train] loss:48.984, [valid] acc:0.093
2:800 [train] loss:47.566, [valid] acc:0.094
2:900 [train] loss:53.127, [valid] acc:0.107
2:1000 [train] loss:51.077, [valid] acc:0.092
3:1100 [train] loss:47.480, [valid] acc:0.107
3:1200 [train] loss:46.908, [valid] acc:0.101


KeyboardInterrupt: 