In [625]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import numpy as np
import torchvision
from torchvision import transforms
import matplotlib.pyplot as plt
import cmath

In [626]:
def get_default_device():
    """Pick GPU if available, else CPU"""
    if torch.cuda.is_available():
        return torch.device('cpu')
    else:
        return torch.device('cpu')
DEVICE = get_default_device()
print(DEVICE)

cpu


In [627]:
def switch_to_device(dataset, device = None):
    tensor_list_x, tensor_list_y = [], []

    for x, y in dataset:
        tensor_list_x.append(x)
        tensor_list_y.append(y)
    
    X = torch.stack(tensor_list_x)
    Y = torch.tensor(tensor_list_y)

    if DEVICE is not None:
        X = X.to(DEVICE)
        Y = Y.to(DEVICE)
    
    return torch.utils.data.TensorDataset(X, Y)

In [628]:
def get_mnist_dl(batch_size_train = 256, batch_size_valid = 1024, device = None):
    transform = transforms.Compose([transforms.ToTensor()])

    data_train = torchvision.datasets.MNIST('./mnist', train = True, download = True, transform = transform)
    data_train = switch_to_device(data_train, device)
    data_train, data_valid = torch.utils.data.random_split(data_train, [55000, 5000])

    data_test = torchvision.datasets.MNIST('./mnist', train = False, download = True, transform = transform)
    data_test = switch_to_device(data_test, device)

    train_dl = DataLoader(data_train, batch_size=batch_size_train, shuffle=True)
    valid_dl = DataLoader(data_valid, batch_size=batch_size_valid, shuffle=False)
    test_dl = DataLoader(data_test, batch_size=batch_size_valid, shuffle=False)

    return train_dl, valid_dl, test_dl

In [629]:
def symmtric_quantization(x, q_bits):
    q_max = 2**q_bits - 1
    x_max = torch.max(x)
    x_min = torch.min(x)
    max = torch.max(torch.tensor([torch.abs(x_max), torch.abs(x_min)]))
    q_x = x * (q_max/(2*max))
    q_x = torch.floor(q_x)
    return q_x

In [630]:
class Symmtric_quantizer():
    def __init__(self, q_bits):
        self.q_bits = q_bits

    def forward(self, x):
        q_max = 2**self.q_bits - 1
        x_max = torch.max(x)
        x_min = torch.min(x)
        max = torch.max(torch.tensor([torch.abs(x_max), torch.abs(x_min)]))
        q_x = x * (q_max/(2*max))
        q_x = torch.floor(q_x)
        self.coef = q_max/(2*max)
        return q_x
    
    def backward(self, next_grad):
        return next_grad*self.coef


In [631]:
def asymmtric_quantization(x, q_bits):
    q_max = 2**q_bits - 1
    x_max = torch.max(x)
    x_min = torch.min(x)
    q_x = (x - x_min) * (q_max/(x_max - x_min))
    q_x = torch.round(q_x)
    q_x = q_x - torch.round(torch.tensor(q_max/2))
    return q_x

In [632]:
class Asymmtric_quantizer():
    def __init__(self, q_bits):
        self.q_bits = q_bits

    def forward(self, x):
        q_max = 2**self.q_bits - 1
        x_max = torch.max(x)
        x_min = torch.min(x)
        q_x = (x - x_min) * (q_max/(x_max - x_min))
        q_x = torch.round(q_x)
        q_x = q_x - torch.round(torch.tensor(q_max/2))
        self.coef = q_max/(x_max - x_min)
        return q_x
    
    def backward(self, next_grad):
        return next_grad*self.coef

In [633]:
class QLinear:
    @torch.no_grad()
    def __init__(self, input_num, output_num, qa_flag = True, qw_flag = True, qb_flag = True, qa_bits = 2, qw_bits = 2, qb_bits = 2, device = None):
        if device is None:
            print("Must have device")
            return
        self.device = device
        self.input_num, self.output_num = input_num, output_num
        self.qa_flag, self.qw_flag, self.qb_flag = qa_flag, qw_flag, qb_flag
        self.qa_bits, self.qw_bits, self.qb_bits = qa_bits, qw_bits, qb_bits

        if qa_flag:
            self.a_quantizer = Asymmtric_quantizer(self.qa_bits)

        self.fp_weights = torch.normal(0, 0.1, size = (self.input_num, self.output_num)).to(device)
        if qw_flag:
            self.w_quantizer = Symmtric_quantizer(self.qw_bits)
            #self.q_weights = symmtric_quantization(self.fp_weights, qw_bits)
            self.q_weights = self.w_quantizer.forward(self.fp_weights)
        
        self.fp_bias = torch.normal(0, 0.1, size = (1, self.output_num)).to(device)
        if qb_flag:
            self.b_quantizer = Symmtric_quantizer(self.qb_bits)
            #self.q_bias = symmtric_quantization(self.fp_bias, qb_bits)
            self.q_bias = self.b_quantizer.forward(self.fp_bias)

        self.w_grad = None
        self.b_grad = None
        self.input_grad = None
    
    @torch.no_grad()
    def forward(self, input):
        self.input = input

        if self.qa_flag:
            #self.input = asymmtric_quantization(self.input, self.qa_bits)
            self.input = self.a_quantizer.forward(self.input)

        if self.qw_flag:
            self.use_weights = self.q_weights
        else:
            self.use_weights = self.fp_weights

        if self.qb_flag:
            self.use_bias = self.q_bias
        else:
            self.use_bias = self.fp_bias

        self.output = torch.matmul(self.input, self.use_weights) + self.use_bias

        return self.output

    @torch.no_grad()
    def backward(self, next_grad):
        self.w_grad = torch.matmul(self.input.T, next_grad)
        if self.qw_flag:
            self.w_grad = self.w_quantizer.backward(self.w_grad)
        
        self.b_grad = torch.sum(next_grad, dim = 0)
        if self.qb_flag:
            self.b_grad = self.b_quantizer.backward(self.b_grad)
        
        self.input_grad = torch.matmul(next_grad, self.use_weights.T)
        if self.qa_flag:
            self.input_grad = self.a_quantizer.backward(self.input_grad)

        return self.input_grad
    
    @torch.no_grad()
    def update(self, lr):
        self.fp_weights = self.fp_weights - lr*self.w_grad
        if self.qw_flag:
            #self.q_weights = symmtric_quantization(self.fp_weights, self.qw_bits)
            self.q_weights = self.w_quantizer.forward(self.fp_weights)
        self.fp_bias = self.fp_bias - lr*self.b_grad
        if self.qb_flag:
            #self.q_bias = symmtric_quantization(self.fp_bias, self.qb_bits)
            self.q_bias = self.b_quantizer.forward(self.fp_bias)

        

In [634]:
class Relu():
    @torch.no_grad()
    def __init__(self, qa_flag, qa_bits):
        self.qa_flag = qa_flag
        self.qa_bits = qa_bits
        if self.qa_flag:
            self.a_quantizer = Asymmtric_quantizer(qa_bits)
        self.input_grad = None

    @torch.no_grad()
    def forward(self, input):
        self.input = input
        if self.qa_flag:
            #self.input = asymmtric_quantization(input, self.qa_bits)
            self.input = self.a_quantizer.forward(self.input)
        self.output = (torch.abs(self.input) + self.input) / 2.0
        return self.output
    
    @torch.no_grad()
    def backward(self, next_grad):
        self.input_grad = next_grad.clone()
        self.input_grad[self.output <= 0] = 0
        if self.qa_flag:
            self.input_grad = self.a_quantizer.backward(self.input_grad)
        return self.input_grad



In [635]:
class Softmax_CrossEntropy():
    @torch.no_grad()
    def __init__(self, qa_flag, qa_bits):
        self.qa_flag = qa_flag
        self.qa_bits = qa_bits
        if self.qa_flag:
            self.a_quantizer = Asymmtric_quantizer(qa_bits)
        self.input_grad = None

    @torch.no_grad()
    def forward(self, input, labels):
        self.labels = labels
        self.input = input
        if self.qa_flag:
            #self.input = asymmtric_quantization(self.input, self.qa_bits)
            self.input = self.a_quantizer.forward(self.input)

        exp_z = torch.exp(self.input)
        sum_exp_z = torch.sum(exp_z, dim = 1).reshape(self.input.shape[0], 1)
        self.softmax_z = exp_z/sum_exp_z
        #print(self.softmax_z)
        loss = torch.sum(-(labels*torch.log(self.softmax_z))) / self.input.shape[0]

        return loss

    @torch.no_grad()
    def backward(self):
        self.input_grad = self.softmax_z - self.labels
        if self.qa_flag:
            self.input_grad = self.a_quantizer.backward(self.input_grad)
        return self.input_grad/self.labels.shape[0]


In [636]:
from math import nextafter


class Quantized_backprop_BNN():
    def __init__(self, lr = 0.1, device = None):
        if device is None:
            print("Must have device")
            return
        self.device = device
        self.lr = lr

        self.fc_1 = QLinear(input_num=28*28, output_num=1024, qa_flag=False, qw_flag=False, qb_flag=False, qa_bits=8, qw_bits=8, qb_bits=8, device = self.device)
        self.Relu_1 = Relu(qa_flag=False, qa_bits=8)
        self.fc_2 = QLinear(input_num=1024, output_num=1024, qa_flag=True, qw_flag=True, qb_flag=True, qa_bits=4, qw_bits=4, qb_bits=4, device = self.device)
        self.Relu_2 = Relu(qa_flag=True, qa_bits=4)
        self.fc_3 = QLinear(input_num=1024, output_num=1024, qa_flag=True, qw_flag=True, qb_flag=True, qa_bits=4, qw_bits=4, qb_bits=4, device = self.device)
        self.Relu_3 = Relu(qa_flag=True, qa_bits=4)
        self.fc_4 = QLinear(input_num=1024, output_num=10, qa_flag=False, qw_flag=False, qb_flag=False, qa_bits=8, qw_bits=8, qb_bits=8, device = self.device)
        self.Softmax_CrossEntropy = Softmax_CrossEntropy(qa_flag=False, qa_bits=8)
        self.output = None
        self.loss = None

    def forward(self, input, labels):
        self.input = torch.reshape(input, (input.shape[0], 28*28))
        output = self.fc_1.forward(self.input)
        output = self.Relu_1.forward(output)
        output = self.fc_2.forward(output)
        output = self.Relu_2.forward(output)
        output = self.fc_3.forward(output)
        output = self.Relu_3.forward(output)
        output = self.fc_4.forward(output)
        self.output = output
        self.loss = self.Softmax_CrossEntropy.forward(output, labels)

    def backward(self):
        next_grad = self.Softmax_CrossEntropy.backward()
        next_grad = self.fc_4.backward(next_grad)
        next_grad = self.Relu_3.backward(next_grad)
        next_grad = self.fc_3.backward(next_grad)
        next_grad = self.Relu_2.backward(next_grad)
        next_grad = self.fc_2.backward(next_grad)
        next_grad = self.Relu_1.backward(next_grad)
        next_grad = self.fc_1.backward(next_grad)

    def update(self):
        self.fc_1.update(self.lr)
        self.fc_2.update(self.lr)
        self.fc_3.update(self.lr)
        self.fc_4.update(self.lr)


In [637]:
def print_stats(stats):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize = (7, 3), dpi = 110)
    ax1.grid()
    ax2.grid()

    ax1.set_title("ERM loss")
    ax2.set_title("Valid Acc")

    ax1.set_xlabel("iterations")
    ax2.set_xlabel("iterations")

    itrs = [x[0] for x in stats['train-loss']]
    loss = [x[1] for x in stats['train-loss']]
    ax1.plot(itrs, loss)

    itrs = [x[0] for x in stats['valid-acc']]
    acc = [x[1] for x in stats['valid-acc']]
    ax2.plot(itrs, acc)

    ax1.set_ylim(0.0, 4.05)
    ax2.set_ylim(0.0, 1.05)

In [638]:
@torch.no_grad()
def get_acc(model, dl, device = DEVICE):
    acc = []

    for X, y in dl:
        one_hot_y = torch.zeros(X.shape[0], 10).to(device)
        one_hot_y[[i for i in range(X.shape[0])], [k.item() for k in y]] = 1
        model.forward(X, one_hot_y)
        acc.append(torch.argmax(model.output, dim = 1) == y)

    acc = torch.cat(acc)
    acc = torch.sum(acc)/len(acc)

    return acc.item()

In [639]:
def run_experiment(model, train_dl, valid_dl, test_dl, max_epochs = 20, device = DEVICE):
    itr = -1
    stats = {'train-loss' : [], 'valid-acc' : []}
    for epoch in range(max_epochs):
        for X, y in train_dl:
            itr += 1
            one_hot_y = torch.zeros(X.shape[0], 10).to(device)
            one_hot_y[[i for i in range(X.shape[0])], [k.item() for k in y]] = 1
            model.forward(X, one_hot_y)
            #print("Now is the iteration :", itr)
            '''
            print(model.output)
            '''
            model.backward()
            model.update()
            stats['train-loss'].append((itr, model.loss.item()))

            if itr != 0 and itr % 20 == 0:
                valid_acc = get_acc(model, valid_dl, device = device)
                stats['valid-acc'].append((itr, valid_acc))
                s = f"{epoch}:{itr} [train] loss:{model.loss.item():.3f}, [valid] acc:{valid_acc:.3f}"
                print(s)
    
    test_acc = get_acc(model, test_dl, device=device)
    print(f"[test] acc:{test_acc:.3f}")
    return stats

In [640]:
max_epochs = 50
train_batch = 256
valid_batch = 1024
lr = 1e-3

In [641]:
train_dl, valid_dl, test_dl = get_mnist_dl(batch_size_train=train_batch, batch_size_valid=valid_batch, device = DEVICE)
model = Quantized_backprop_BNN(lr = lr, device=DEVICE)
stats = run_experiment(model, train_dl, valid_dl, test_dl, max_epochs=max_epochs, device = DEVICE)
print_stats(stats)

0:20 [train] loss:2.892, [valid] acc:0.107
0:40 [train] loss:2.671, [valid] acc:0.134
0:60 [train] loss:2.501, [valid] acc:0.164
0:80 [train] loss:2.356, [valid] acc:0.194
0:100 [train] loss:2.232, [valid] acc:0.228
0:120 [train] loss:2.119, [valid] acc:0.267
0:140 [train] loss:2.019, [valid] acc:0.299
0:160 [train] loss:1.927, [valid] acc:0.329
0:180 [train] loss:1.845, [valid] acc:0.361
0:200 [train] loss:1.770, [valid] acc:0.392
1:220 [train] loss:1.700, [valid] acc:0.420
1:240 [train] loss:1.634, [valid] acc:0.447
1:260 [train] loss:1.576, [valid] acc:0.467
1:280 [train] loss:1.521, [valid] acc:0.491
1:300 [train] loss:1.471, [valid] acc:0.509
1:320 [train] loss:1.425, [valid] acc:0.527
1:340 [train] loss:1.382, [valid] acc:0.542
1:360 [train] loss:1.342, [valid] acc:0.557
1:380 [train] loss:1.304, [valid] acc:0.569
1:400 [train] loss:1.270, [valid] acc:0.578
1:420 [train] loss:1.238, [valid] acc:0.588
2:440 [train] loss:1.208, [valid] acc:0.601
2:460 [train] loss:1.180, [valid] ac

KeyboardInterrupt: 