This notebook use the pytorch's build in API to implement forward-mode automatic differentation.

Model: BVGG16

Training method: Forward mode autodiff + Adam

In [37]:
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10, MNIST
from torch.utils.data import DataLoader, TensorDataset

import torch.nn as nn
import torch.nn.functional as F
import torch.autograd.forward_ad as fwAD

import matplotlib.pyplot as plt
import os
import warnings

warnings.filterwarnings("ignore", category=Warning)
DEVICE = torch.device('cuda')
os.environ['CUDA_VISIBLE_DEVICES'] = '3'
print(DEVICE)


cuda


In [38]:
def switch_to_device(dataset,device=None):
    final_X, final_Y = [], []
    for x, y in dataset:
        final_X.append(x)
        final_Y.append(y)
    X = torch.stack(final_X)
    Y = torch.tensor(final_Y)
    if device is not None:
        X = X.to(device)
        Y = Y.to(device)
    return TensorDataset(X, Y)

In [39]:
def get_Cifar10_dl(batch_size_train=256, batch_size_eval=1024, device=DEVICE):
    transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    
    data_train = CIFAR10('./datasets', train=True, download=True, transform=transform)
    data_train = switch_to_device(data_train, device=device)
    data_train, data_valid = torch.utils.data.random_split(data_train, [45000,5000])
    
    data_test = CIFAR10('./datasets', train=False, download=True, transform=transform)
    data_test = switch_to_device(data_test, device=device)
    
    train_dl = DataLoader(data_train, batch_size=batch_size_train, shuffle=True)
    valid_dl = DataLoader(data_valid, batch_size=batch_size_eval, shuffle=False)
    test_dl = DataLoader(data_test, batch_size=batch_size_eval, shuffle=False)
    
    return train_dl, valid_dl, test_dl

In [40]:
def print_stats(stats):

  fig, (ax1, ax2) = plt.subplots(1,2,figsize=(7,3), dpi=110)
  ax1.grid()
  ax2.grid()

  ax1.set_title("ERM loss")
  ax2.set_title("Valid Acc")
  
  ax1.set_xlabel("iterations")
  ax2.set_xlabel("iterations")

  itrs = [x[0] for x in stats['train-loss']]
  loss = [x[1].cpu().detach().numpy() for x in stats['train-loss']]
  ax1.plot(itrs, loss)

  itrs = [x[0] for x in stats['valid-acc']]
  acc = [x[1] for x in stats['valid-acc']]
  ax2.plot(itrs, acc)

  ax1.set_ylim(0.0, max(loss))
  ax2.set_ylim(0.0, 1.05)

  plt.tight_layout()
  fig.savefig('testing.jpg', bbox_inches = 'tight')

In [41]:
@torch.no_grad()
def get_acc(model, dl, device = DEVICE):
    model.eval()
    acc = []

    for X, y in dl:
        #one_hot_y = torch.zeros((X.shape[0], 10), device = DEVICE)
        #one_hot_y[[i for i in range(X.shape[0])], [k.item() for k in y]] = 1
        model.forward(X, y)
        acc.append(torch.argmax(model.output, dim=1) == y)

    model.train()
    acc = torch.cat(acc)
    acc = torch.sum(acc)/len(acc)

    return acc.item()

In [42]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [43]:
#这里现在用的还是最简单的binarize的方法
def Binarize(x, quant_mode = 'det'):
    if quant_mode == 'det':
        return x.sign()
    else:
        return x.add_(1).div_(2).add_(torch.rand(x.size()).add(-0.5)).clamp_(0,1).round().mul_(2).add_(-1)

In [44]:
class self_BinarizeLinear():
    def __init__(self, input_dim, output_dim, num_dir, ba_flag, bda_flag, sw_flag, bits_storage = 0):
        self.input_dim, self.output_dim = input_dim, output_dim
        self.num_dir = num_dir
        self.ba_flag = ba_flag
        self.bda_flag = bda_flag
        self.sw_flag = sw_flag
        self.bits_storage = bits_storage

        self.linear = nn.Linear(input_dim, output_dim).to(DEVICE)

        if sw_flag:
            self.soft_w = self.linear.weight.data.clone()
            self.soft_b = self.linear.bias.data.clone()
        else:
            self.soft_w = self.linear.weight.data.sign()
            self.soft_b = self.linear.bias.data.sign()
            self.accumulate_w = torch.zeros(self.linear.weight.data.shape, device = DEVICE)
            self.change_w = torch.zeros(self.linear.weight.data.shape, device = DEVICE)
            self.accumulate_b = torch.zeros(self.linear.bias.data.shape, device = DEVICE)
            self.change_b = torch.zeros(self.linear.bias.data.shape, device = DEVICE)

    def train_forward(self, input, da = None):
        if self.ba_flag:
            input.data = Binarize(input.data)

        self.linear.weight.data = Binarize(self.soft_w)
        self.vector_w = torch.randn((self.num_dir, self.output_dim, self.input_dim), device = DEVICE)
        self.linear.bias.data = Binarize(self.soft_b)
        self.vector_b = torch.randn((self.num_dir, self.output_dim), device = DEVICE)

        if self.bda_flag:
            self.vector_w = self.vector_w.sign()
            self.vector_b = self.vector_b.sign()
            if torch.is_tensor(da):
                da = da.sign()
        
        new_da = torch.zeros((self.num_dir, input.shape[0], self.output_dim), device = DEVICE)

        params = {name: p for name, p in self.linear.named_parameters()}

        with fwAD.dual_level():
            for i in range(self.num_dir):
                for name, p in params.items():
                    delattr(self.linear, name)            
                    if name == "weight":
                        setattr(self.linear, name, fwAD.make_dual(p, self.vector_w[i]))
                    elif name == "bias":
                        setattr(self.linear, name, fwAD.make_dual(p, self.vector_b[i]))
                
                if torch.is_tensor(da):
                    dual_input = fwAD.make_dual(input, da[i])
                else:
                    dual_input = input
                
                out = self.linear(dual_input)
                new_da[i] = fwAD.unpack_dual(out).tangent

        for name, p in params.items():
            setattr(self.linear, name, p)

        return out, new_da
    
    def eval_forward(self, input):
        self.linear.weight.data = self.soft_w.sign()
        self.linear.bias.data = self.soft_b.sign()
        out = self.linear(input)
        return out
    
    def update(self, da, lr):
        gw = da.view(-1, 1, 1) * self.vector_w
        gw = torch.mean(gw, dim = 0)
        gb = da.view(-1, 1)*self.vector_b
        gb = torch.mean(gb, dim = 0)
        if self.sw_flag:
            self.soft_w -= lr*gw
            self.soft_b -= lr*gb

            self.soft_w = self.soft_w.clamp_(-1, 1)
            self.soft_b = self.soft_b.clamp_(-1, 1)
        
        else:
            new_accumulate_w = self.accumulate_w + gw.sign()
            new_accumulate_w = new_accumulate_w.clamp(-self.bits_storage, self.bits_storage)
            self.accumulate_w = new_accumulate_w.clone()
            possible_pos_w = (new_accumulate_w.sign() == gw.sign())
            gw = gw * torch.abs(new_accumulate_w) * possible_pos_w
            gw[torch.abs(new_accumulate_w) > (self.bits_storage - 0.5)] *= 1e10/lr

            new_accumulate_b = self.accumulate_b + gb.sign()
            new_accumulate_b = new_accumulate_b.clamp(-self.bits_storage, self.bits_storage)
            self.accumulate_b = new_accumulate_b.clone()
            possible_pos_b = (new_accumulate_b.sign() == gb.sign())
            gb = gb * torch.abs(new_accumulate_b) * possible_pos_b
            gb[torch.abs(new_accumulate_b) > (self.bits_storage - 0.5)] *= 1e10/lr

            self.soft_w -= (lr*gw)
            self.soft_w = self.soft_w.sign()
            self.soft_b -= (lr*gb)
            self.soft_b = self.soft_b.sign()

            self.change_w = ((self.soft_w.sign() * gw) > 1/lr)
            self.accumulate_w *= ~self.change_w
            self.change_b = ((self.soft_b.sign() * gb) > 1/lr)
            self.accumulate_b *= ~self.change_b      

In [45]:
class self_BinarizeConv2D():
    def __init__(self, in_channel, out_channel, kernel_size, stride, padding, num_dir, ba_flag, bda_flag, sw_flag, bits_storage = 0):
        self.in_channel, self.out_channel = in_channel, out_channel
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.num_dir = num_dir
        self.ba_flag = ba_flag
        self.bda_flag = bda_flag
        self.sw_flag = sw_flag
        self.bits_storage = bits_storage

        self.conv = nn.Conv2d(in_channel, out_channel, kernel_size=kernel_size, stride=stride, padding=padding).to(DEVICE)

        if sw_flag:
            self.soft_w = self.conv.weight.data.clone()
            self.soft_b = self.conv.bias.data.clone()
        else:
            self.soft_w = self.conv.weight.data.sign()
            self.soft_b = self.conv.bias.data.sign()
            self.accumulate_w = torch.zeros(self.conv.weight.data.shape, device = DEVICE)
            self.change_w = torch.zeros(self.conv.weight.data.shape, device = DEVICE)
            self.accumulate_b = torch.zeros(self.conv.bias.data.shape, device = DEVICE)
            self.change_b = torch.zeros(self.conv.bias.data.shape, device = DEVICE)
        
    def train_forward(self, input, da = None):
        if self.ba_flag:
            input.data = input.data.sign()
        
        self.conv.weight.data = self.soft_w.sign()
        self.vector_w = torch.randn((self.num_dir, self.out_channel, self.in_channel, self.kernel_size, self.kernel_size), device = DEVICE)
        self.conv.bias.data = self.soft_b.sign()
        self.vector_b = torch.randn((self.num_dir, self.out_channel), device = DEVICE)

        if self.bda_flag:
            self.vector_w = self.vector_w.sign()
            self.vector_b = self.vector_b.sign()
            if torch.is_tensor(da):
                da = da.sign()
        
        new_da = None

        params = {name: p for name, p in self.conv.named_parameters()}

        with fwAD.dual_level():
            for i in range(self.num_dir):
                for name, p in params.items():
                    delattr(self.conv, name)
                    if name == "weight":
                        setattr(self.conv, name, fwAD.make_dual(p, self.vector_w[i]))
                    elif name== "bias":
                        setattr(self.conv, name, fwAD.make_dual(p, self.vector_b[i]))

                if torch.is_tensor(da):
                    dual_input = fwAD.make_dual(input, da[i])
                else:
                    dual_input = input
                out = self.conv(dual_input)
                tmp_da = fwAD.unpack_dual(out).tangent.unsqueeze(0)
                
                if torch.is_tensor(new_da):
                    new_da = torch.cat((new_da, tmp_da), dim = 0)
                else:
                    new_da = tmp_da
            
        for name, p in params.items():
            setattr(self.conv, name, p)
        
        return out, new_da
    
    def eval_forward(self, input):
        self.conv.weight.data = self.soft_w.sign()
        self.conv.bias.data = self.soft_b.sign()
        out = self.conv(input)
        return out

    def update(self, da, lr):
        gw = da.view(-1, 1, 1, 1, 1) * self.vector_w
        gw = torch.mean(gw, dim = 0)
        gb = da.view(-1, 1)*self.vector_b
        gb = torch.mean(gb, dim = 0)
        if self.sw_flag:
            self.soft_w -= lr*gw
            self.soft_b -= lr*gb

            self.soft_w = self.soft_w.clamp_(-1, 1)
            self.soft_b = self.soft_b.clamp_(-1, 1)
        
        else:
            new_accumulate_w = self.accumulate_w + gw.sign()
            new_accumulate_w = new_accumulate_w.clamp(-self.bits_storage, self.bits_storage)
            self.accumulate_w = new_accumulate_w.clone()
            possible_pos_w = (new_accumulate_w.sign() == gw.sign())
            gw = gw * torch.abs(new_accumulate_w) * possible_pos_w
            gw[torch.abs(new_accumulate_w) > (self.bits_storage - 0.5)] *= 1e10/lr

            new_accumulate_b = self.accumulate_b + gb.sign()
            new_accumulate_b = new_accumulate_b.clamp(-self.bits_storage, self.bits_storage)
            self.accumulate_b = new_accumulate_b.clone()
            possible_pos_b = (new_accumulate_b.sign() == gb.sign())
            gb = gb * torch.abs(new_accumulate_b) * possible_pos_b
            gb[torch.abs(new_accumulate_b) > (self.bits_storage - 0.5)] *= 1e10/lr

            self.soft_w -= (lr*gw)
            self.soft_w = self.soft_w.sign()
            self.soft_b -= (lr*gb)
            self.soft_b = self.soft_b.sign()

            self.change_w = ((self.soft_w.sign() * gw) > 1/lr)
            self.accumulate_w *= ~self.change_w
            self.change_b = ((self.soft_b.sign() * gb) > 1/lr)
            self.accumulate_b *= ~self.change_b    


        

In [46]:
class self_Hardtanh():
    def train_forward(self, input, da):
        '''
        input: batch_size * input_dim
        da: num_dir * batch * input_dim
        '''
        num_dir = da.shape[0]
        new_da = torch.zeros(da.shape, device = DEVICE)
        with fwAD.dual_level():
            for i in range(num_dir):
                dual_input = fwAD.make_dual(input, da[i])
                out = nn.functional.hardtanh(dual_input)
                new_da[i] = fwAD.unpack_dual(out).tangent
        return out, new_da
    
    def eval_forward(self, input):
        return nn.functional.hardtanh(input)

In [47]:
class self_MaxPool2d():
    def __init__(self, kernel_size, stride):
        self.kernel_size = kernel_size
        self.stride = stride
        self.maxpool = nn.MaxPool2d(kernel_size=kernel_size, stride=stride).to(DEVICE)
    
    def train_forward(self, input, da):
        num_dir = da.shape[0]
        new_da = None

        with fwAD.dual_level():
            for i in range(num_dir):
                dual_input = fwAD.make_dual(input, da[i])
                out = self.maxpool(dual_input)
                tmp_da = fwAD.unpack_dual(out).tangent.unsqueeze(0)
                if torch.is_tensor(new_da):
                    new_da = torch.cat((new_da, tmp_da), dim = 0)
                else:
                    new_da = tmp_da
        return out, new_da
    
    def eval_forward(self, input):
        return self.maxpool(input)

In [48]:
class self_Batchnorm1d():
    def __init__(self, dim, bda_flag=False):
        self.dim = dim
        self.BN = nn.BatchNorm1d(self.dim).to(DEVICE)
        self.bda_flag = bda_flag
        self.BN.train()
    
    def train_forward(self, input, da):
        num_dir = da.shape[0]
        self.vector_w = torch.randn((num_dir, self.dim), device = DEVICE)#.sign()
        self.vector_b = torch.randn((num_dir, self.dim), device = DEVICE)#.sign()

        if self.bda_flag:
            self.vector_w = self.vector_w.sign()
            self.vector_b = self.vector_b.sign()

        new_da = torch.zeros(da.shape, device = DEVICE)

        params = {name: p for name, p in self.BN.named_parameters()}

        with fwAD.dual_level():
            for i in range(num_dir):
                for name, p in params.items():
                    if name == 'weight':
                        delattr(self.BN, name)
                        setattr(self.BN, name, fwAD.make_dual(p, self.vector_w[i]))
                    elif name == 'bias':
                        delattr(self.BN, name)
                        setattr(self.BN, name, fwAD.make_dual(p, self.vector_b[i]))
                dual_input = fwAD.make_dual(input, da[i])

                out = self.BN(dual_input)
                new_da[i] = fwAD.unpack_dual(out).tangent

        for name, p in params.items():
            setattr(self.BN, name, p)
        return out, da
    
    def eval_forward(self, input):
        self.BN.eval()
        out = self.BN(input)
        self.BN.train()
        return out
    
    def update(self, da, lr):
        gw = da.view(-1, 1)*self.vector_w
        gw = torch.mean(gw, dim = 0)
        self.BN.weight.data -= lr*gw

        gb = da.view(-1, 1)*self.vector_b
        gb = torch.mean(gb, dim = 0)
        self.BN.bias.data -= lr*gb

In [49]:
class self_Batchnorm2d():
    def __init__(self, dim, bda_flag=False):
        self.dim = dim
        self.BN = nn.BatchNorm2d(self.dim).to(DEVICE)
        self.bda_flag = bda_flag
        self.BN.train()

    def train_forward(self, input, da):
        num_dir = da.shape[0]
        self.vector_w = torch.randn((num_dir, self.dim), device = DEVICE)
        self.vector_b = torch.randn((num_dir, self.dim), device = DEVICE)

        if self.bda_flag:
            self.vector_w = self.vector_w.sign()
            self.vector_b = self.vector_b.sign()
        
        new_da = torch.zeros(da.shape, device = DEVICE)

        params = {name: p for name, p in self.BN.named_parameters()}

        with fwAD.dual_level():
            for i in range(num_dir):
                for name, p in params.items():
                    delattr(self.BN, name)
                    if name == "weight":
                        setattr(self.BN, name, fwAD.make_dual(p, self.vector_w[i]))
                    elif name == "bias":
                        setattr(self.BN, name, fwAD.make_dual(p, self.vector_b[i]))
                dual_input = fwAD.make_dual(input, da[i])

                out = self.BN(dual_input)
                new_da[i] = fwAD.unpack_dual(out).tangent
        
        for name, p in params.items():
            setattr(self.BN, name, p)
        return out, da
    
    def eval_forward(self, input):
        self.BN.eval()
        out = self.BN(input)
        self.BN.train()
        return out
    
    def update(self, da, lr):
        gw = da.view(-1, 1)*self.vector_w
        gw = torch.mean(gw, dim = 0)
        self.BN.weight.data -= lr*gw

        gb = da.view(-1, 1)*self.vector_b
        gb = torch.mean(gb, dim = 0)
        self.BN.bias.data -= lr*gb

In [50]:
class self_Softmax_CrossEntropy():
    def __init__(self):
        self.CE = nn.CrossEntropyLoss().to(DEVICE)
    
    def train_forward(self, input, y, da):
        num_dir = da.shape[0]
        new_da = torch.zeros(num_dir, device = DEVICE)
        with fwAD.dual_level():
            for i in range(num_dir):
                dual_input = fwAD.make_dual(input, da[i])
                out = self.CE(dual_input, y)
                new_da[i] = fwAD.unpack_dual(out).tangent
        return out, new_da

In [51]:
class self_BVGG(nn.Module):
    def __init__(self, num_dir, bda_flag=False, sw_flag=True, bits_storage=0, device=DEVICE):
        self.device = DEVICE
        self.is_tarin = True

        self.num_dir = num_dir
        self.bda_flag = bda_flag
        self.sw_flag = sw_flag
        self.bits_storage = bits_storage

        self.conv1 = self_BinarizeConv2D(3, 64, 3, 1, 1, num_dir=num_dir, ba_flag=False, bda_flag=bda_flag, sw_flag=sw_flag, bits_storage=bits_storage)
        self.bn1 = self_Batchnorm2d(64, bda_flag=bda_flag)
        self.htanh1 = self_Hardtanh()

        self.conv2 = self_BinarizeConv2D(64, 64, 3, 1, 1, num_dir=num_dir, ba_flag=True, bda_flag=bda_flag, sw_flag=sw_flag, bits_storage=bits_storage)
        self.bn2 = self_Batchnorm2d(64, bda_flag=bda_flag)
        self.htanh2 = self_Hardtanh()
        self.maxpooling2 = self_MaxPool2d(kernel_size=2, stride=2)

        self.conv3 = self_BinarizeConv2D(64, 128, 3, 1, 1, num_dir=num_dir, ba_flag=True, bda_flag=bda_flag, sw_flag=sw_flag, bits_storage=bits_storage)
        self.bn3 = self_Batchnorm2d(128, bda_flag=bda_flag)
        self.htanh3 = self_Hardtanh()

        self.conv4 = self_BinarizeConv2D(128, 128, 3, 1, 1, num_dir=num_dir, ba_flag=True, bda_flag=bda_flag, sw_flag=sw_flag, bits_storage=bits_storage)
        self.bn4 = self_Batchnorm2d(128, bda_flag=bda_flag)
        self.htanh4 = self_Hardtanh()
        self.maxpooling4 = self_MaxPool2d(kernel_size=2, stride=2)

        self.conv5 = self_BinarizeConv2D(128, 256, 3, 1, 1, num_dir=num_dir, ba_flag=True, bda_flag=bda_flag, sw_flag=sw_flag, bits_storage=bits_storage)
        self.bn5 = self_Batchnorm2d(256, bda_flag=bda_flag)
        self.htanh5 = self_Hardtanh()

        self.conv6 = self_BinarizeConv2D(256, 256, 3, 1, 1, num_dir=num_dir, ba_flag=True, bda_flag=bda_flag, sw_flag=sw_flag, bits_storage=bits_storage)
        self.bn6 = self_Batchnorm2d(256, bda_flag=bda_flag)
        self.htanh6 = self_Hardtanh()

        self.conv7 = self_BinarizeConv2D(256, 512, 3, 1, 1, num_dir=num_dir, ba_flag=True, bda_flag=bda_flag, sw_flag=sw_flag, bits_storage=bits_storage)
        self.bn7 = self_Batchnorm2d(512, bda_flag=bda_flag)
        self.htanh7 = self_Hardtanh()
        self.maxpooling7 = self_MaxPool2d(kernel_size=2, stride=2)

        self.fc14 = self_BinarizeLinear(4*4*512, 1024, num_dir=num_dir, ba_flag=True, bda_flag=bda_flag, sw_flag=sw_flag, bits_storage=bits_storage)
        self.bn14 = self_Batchnorm1d(1024, bda_flag=bda_flag)
        self.htanh14 = self_Hardtanh()

        self.fc15 = self_BinarizeLinear(1024, 1024, num_dir=num_dir, ba_flag=True, bda_flag=bda_flag, sw_flag=sw_flag, bits_storage=bits_storage)
        self.bn15 = self_Batchnorm1d(1024, bda_flag=bda_flag)
        self.htanh15 = self_Hardtanh()

        self.fc16 = self_BinarizeLinear(1024, 10, num_dir=num_dir, ba_flag=True, bda_flag=bda_flag, sw_flag=sw_flag, bits_storage=bits_storage)
        self.CrossEntropy = self_Softmax_CrossEntropy()
    
    def train(self):
        self.is_train = True

    def eval(self):
        self.is_train = False

    def change_bits(self, num):
        self.bits_storage = self.bits_storage*(2**num)
        self.conv1.bits_storage = self.bits_storage
        self.conv2.bits_storage = self.bits_storage
        self.conv3.bits_storage = self.bits_storage
        self.conv4.bits_storage = self.bits_storage
        self.conv5.bits_storage = self.bits_storage
        self.conv6.bits_storage = self.bits_storage
        self.conv7.bits_storage = self.bits_storage

        self.fc14.bits_storage = self.bits_storage
        self.fc15.bits_storage = self.bits_storage
        self.fc16.bits_storage = self.bits_storage
    
    def forward(self, input, labels):
        if self.is_train:
            self.train_forward(input, labels)
        
        else:
            self.eval_forward(input)
    
    def train_forward(self, input, labels):
        da = None
        x, da = self.conv1.train_forward(input, da)
        x, da = self.bn1.train_forward(x, da)
        x, da = self.htanh1.train_forward(x, da)

        x, da = self.conv2.train_forward(x, da)
        x, da = self.bn2.train_forward(x, da)
        x, da = self.htanh2.train_forward(x, da)
        x, da = self.maxpooling2.train_forward(x, da)

        x, da = self.conv3.train_forward(x, da)
        x, da = self.bn3.train_forward(x, da)
        x, da = self.htanh3.train_forward(x, da)

        x, da = self.conv4.train_forward(x, da)
        x, da = self.bn4.train_forward(x, da)
        x, da = self.htanh4.train_forward(x, da)
        x, da = self.maxpooling4.train_forward(x, da)

        x, da = self.conv5.train_forward(x, da)
        x, da = self.bn5.train_forward(x, da)
        x, da = self.htanh5.train_forward(x, da)

        x, da = self.conv6.train_forward(x, da)
        x, da = self.bn6.train_forward(x, da)
        x, da = self.htanh6.train_forward(x, da)

        x, da = self.conv7.train_forward(x, da)
        x, da = self.bn7.train_forward(x, da)
        x, da = self.htanh7.train_forward(x, da)
        x, da = self.maxpooling7.train_forward(x, da)


        x = x.reshape(x.size(0), -1)
        da = da.reshape(da.size(0), da.size(1), -1)

        x, da = self.fc14.train_forward(x, da)
        x, da = self.bn14.train_forward(x, da)
        x, da = self.htanh14.train_forward(x, da)

        x, da = self.fc15.train_forward(x, da)
        x, da = self.bn15.train_forward(x, da)
        x, da = self.htanh15.train_forward(x, da)

        x, da = self.fc16.train_forward(x, da)

        self.output = x
        loss, da = self.CrossEntropy.train_forward(x, labels, da)
        self.loss, self.da = loss, da

    def eval_forward(self, input):
        x = self.conv1.eval_forward(input)
        x = self.bn1.eval_forward(x)
        x = self.htanh1.eval_forward(x)

        x = self.conv2.eval_forward(x)
        x = self.bn2.eval_forward(x)
        x = self.htanh2.eval_forward(x)
        x = self.maxpooling2.eval_forward(x)

        x = self.conv3.eval_forward(x)
        x = self.bn3.eval_forward(x)
        x = self.htanh3.eval_forward(x)

        x = self.conv4.eval_forward(x)
        x = self.bn4.eval_forward(x)
        x = self.htanh4.eval_forward(x)
        x = self.maxpooling4.eval_forward(x)

        x = self.conv5.eval_forward(x)
        x = self.bn5.eval_forward(x)
        x = self.htanh5.eval_forward(x)

        x = self.conv6.eval_forward(x)
        x = self.bn6.eval_forward(x)
        x = self.htanh6.eval_forward(x)

        x = self.conv7.eval_forward(x)
        x = self.bn7.eval_forward(x)
        x = self.htanh7.eval_forward(x)
        x = self.maxpooling7.eval_forward(x)

        x = x.reshape(x.size(0), -1)

        x = self.fc14.eval_forward(x)
        x = self.bn14.eval_forward(x)
        x = self.htanh14.eval_forward(x)

        x = self.fc15.eval_forward(x)
        x = self.bn15.eval_forward(x)
        x = self.htanh15.eval_forward(x)

        x = self.fc16.eval_forward(x)

        self.output = x

    def update(self, lr):
        self.conv1.update(self.da, lr)
        self.conv2.update(self.da, lr)
        self.conv3.update(self.da, lr)
        self.conv4.update(self.da, lr)
        self.conv5.update(self.da, lr)
        self.conv6.update(self.da, lr)
        self.conv7.update(self.da, lr)

        self.fc14.update(self.da, lr)
        self.fc15.update(self.da, lr)
        self.fc16.update(self.da, lr)

        self.bn1.update(self.da, lr)
        self.bn2.update(self.da, lr)
        self.bn3.update(self.da, lr)
        self.bn4.update(self.da, lr)
        self.bn5.update(self.da, lr)
        self.bn6.update(self.da, lr)
        self.bn7.update(self.da, lr)
        self.bn14.update(self.da, lr)
        self.bn15.update(self.da, lr)

In [52]:
def run_experiment(train_dl, valid_dl, test_dl, num_dir, bda_flag, sw_flag, bits_storage, max_epochs, lr):
    model = self_BVGG(num_dir=num_dir, bda_flag=bda_flag, sw_flag=sw_flag, bits_storage=bits_storage)
    itr = -1
    stats = {'train-loss' : [], 'valid-acc' : []}

    model.train()
    with torch.no_grad():
        for epoch in range(max_epochs):
            for x, y in train_dl:
                itr += 1
                model.forward(x, y)
                model.update(lr)
                stats['train-loss'].append((itr, model.loss.item()))
                if itr != 0 and itr % 500 == 0:
                    model.change_bits(1)
                if itr % 100 == 0:
                    #print(model.output)
                    valid_acc = get_acc(model, valid_dl)
                    stats['valid-acc'].append((itr, valid_acc))
                    s = f"{epoch}:{itr} [train] loss:{model.loss.item():.3f}, [valid] acc:{valid_acc:.3f}"
                    print(s)
    
    test_acc = get_acc(model, test_dl)
    print(f"[test] acc:{test_acc:.3f}")
    return stats

In [53]:
train_batch = 128
valid_batch = 1024

num_dir = 20
bda_flag = False
sw_flag = True
bits_storage = 8

max_epochs = 20
lr = 1e-7

In [54]:
train_dl, valid_dl, test_dl = get_Cifar10_dl(batch_size_train=train_batch, batch_size_eval=valid_batch, device=DEVICE)
stats = run_experiment(train_dl, valid_dl, test_dl, num_dir, bda_flag, sw_flag, bits_storage, max_epochs, lr)
print_stats(stats)

Files already downloaded and verified
Files already downloaded and verified
0:0 [train] loss:51.292, [valid] acc:0.106
0:100 [train] loss:49.285, [valid] acc:0.108
0:200 [train] loss:55.468, [valid] acc:0.107
0:300 [train] loss:53.122, [valid] acc:0.106
1:400 [train] loss:55.038, [valid] acc:0.107
1:500 [train] loss:54.910, [valid] acc:0.099
1:600 [train] loss:50.841, [valid] acc:0.109
1:700 [train] loss:54.805, [valid] acc:0.099
2:800 [train] loss:51.041, [valid] acc:0.101
2:900 [train] loss:56.170, [valid] acc:0.113
