In [460]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import numpy as np
import torchvision
from torchvision import transforms
import matplotlib as plt
import cmath

In [461]:
def get_default_device():
    """Pick GPU if available, else CPU"""
    if torch.cuda.is_available():
        return torch.device('cpu')
    else:
        return torch.device('cpu')
DEVICE = get_default_device()
print(DEVICE)

cpu


In [462]:
def switch_to_device(dataset, device = None):
    tensor_list_x, tensor_list_y = [], []

    for x, y in dataset:
        tensor_list_x.append(x)
        tensor_list_y.append(y)
    
    X = torch.stack(tensor_list_x)
    Y = torch.tensor(tensor_list_y)

    if DEVICE is not None:
        X = X.to(DEVICE)
        Y = Y.to(DEVICE)
    
    return torch.utils.data.TensorDataset(X, Y)

In [463]:
def get_mnist_dl(batch_size_train = 256, batch_size_valid = 1024, device = None):
    transform = transforms.Compose([transforms.ToTensor()])

    data_train = torchvision.datasets.MNIST('./mnist', train = True, download = True, transform = transform)
    data_train = switch_to_device(data_train, device)
    data_train, data_valid = torch.utils.data.random_split(data_train, [55000, 5000])

    data_test = torchvision.datasets.MNIST('./mnist', train = False, download = True, transform = transform)
    data_test = switch_to_device(data_test, device)

    train_dl = DataLoader(data_train, batch_size=batch_size_train, shuffle=True)
    valid_dl = DataLoader(data_valid, batch_size=batch_size_valid, shuffle=False)
    test_dl = DataLoader(data_test, batch_size=batch_size_valid, shuffle=False)

    return train_dl, valid_dl, test_dl

In [464]:
def possibility_normalize(possi, non_zero_perc, slimness):
    mean = torch.mean(possi)
    std = torch.std(possi)
    if 999 <= mean <= 1001:
        new_possi = torch.full(possi.shape, non_zero_perc).to(possi.device)
    else:
        new_possi = ((possi - mean)/std*slimness) + non_zero_perc
        new_possi = torch.clip(new_possi, 0, 1)
    return new_possi

In [465]:
def delta_generation(possi):
    non_zero = torch.bernoulli(possi)
    possi_ = non_zero/2
    neg_pos = torch.bernoulli(possi_)
    res = non_zero - 2*neg_pos
    return res

In [466]:
def quantization(x, q_bits):
    q_max = 2 ** q_bits - 1
    x_max = torch.max(x)
    x_min = torch.min(x)
    q_x = q_max*(x - x_min)/(x_max - x_min)
    q_x = torch.round(q_x)
    return q_x

In [467]:
def q_f_multiplication(q_a, q_w, b_a, b_w):
    term_1 = 2*(torch.matmul(q_a, q_w))/((2 ** b_a - 1)*(2 ** b_w - 1))
    term_2 = torch.matmul(q_a, torch.ones(q_w.shape)) / (2 ** b_a - 1)
    return term_1 - term_2

In [468]:
class Relu():
    @torch.no_grad()
    def __init__(self, qa_flag = True, qa_bits = 2):
        self.qa_flag = qa_flag
        self.qa_bits = qa_bits

    @torch.no_grad()
    def forward(self, input):
        if self.qa_flag:
            q_input = quantization(input, self.qa_bits)
            res = (torch.abs(q_input) + q_input) / 2.0
        else:
            res = (torch.abs(input) + input) / 2.0
        return res

In [469]:
class Softmax_CrossEntropy():
    @torch.no_grad()
    def __init__(self, qa_flag, qo_flag, qa_bits, qo_bits):
        self.qa_flag = qa_flag
        self.qo_flag = qo_flag
        self.qa_bits = qa_bits
        self.qo_bits = qo_bits

    @torch.no_grad()
    def forward(self, input, labels):
        if self.qa_flag:
            input = quantization(input, self.qa_bits)
        
        max = input.max()
        min = input.min()
        input = (input-min)/(max-min)

        exp_z = torch.exp(input)
        sum_exp_z = torch.sum(exp_z, dim = 1).reshape(input.shape[0], 1)
        softmax_z = exp_z/sum_exp_z
        softmax_z += 1e-6
        
        if self.qo_flag:
            softmax_z = quantization(softmax_z, self.qo_bits)

        loss = torch.sum(-(labels*torch.log(softmax_z))) / input.shape[0]
        
        return loss

            

In [470]:
from functools import total_ordering
from torch import logical_not


class QLinear:
    @torch.no_grad()
    def __init__(self, input_num, output_num, qa_flag = True, qa_bits = 2, qw_bits = 2, qb_bits = 2, device = None):
        if device is None:
            print("Must have device")
            return
        self.device = device
        self.input_num, self.output_num = input_num, output_num
        self.qa_bits, self.qw_bits, self.qb_bits = qa_bits, qw_bits, qb_bits
        self.weights = torch.randint(0, 2**self.qw_bits, (self.input_num, self.output_num)).to(device)
        self.bias = torch.randint(0, 2**self.qb_bits, (1, self.output_num)).to(device)
        self.weights = self.weights.float()
        self.bias = self.bias.float()
        self.w_u = torch.zeros(self.input_num, self.output_num).to(device)
        self.b_u = torch.zeros(1, self.output_num).to(device)
        self.w_pickflag = torch.zeros(self.input_num, self.output_num).to(device)
        self.b_pickflag = torch.zeros(1, self.output_num).to(device)
        self.w_p = torch.full((self.input_num, self.output_num), 1000.).to(device)
        self.b_p = torch.full((1, self.output_num), 1000.).to(device)
        self.qa_flag = qa_flag
    
    @torch.no_grad()
    def forward(self, input):
        self.w_u = torch.zeros(self.input_num, self.output_num).to(self.device)
        self.b_u = torch.zeros(1, self.output_num).to(self.device)
        self.w_p = torch.full((self.input_num, self.output_num), 1000.).to(self.device)
        self.b_p = torch.full((1, self.output_num), 1000.).to(self.device)
        self.w_pickflag = torch.zeros(self.input_num, self.output_num).to(self.device)
        self.b_pickflag = torch.zeros(1, self.output_num).to(self.device)
        
        '''
        if self.qa_flag:
            q_input = quantization(input, self.qa_bits)
            res = torch.matmul(q_input, self.weights) + self.bias
        else:
            res = torch.matmul(input, self.weights) + self.bias
        '''
        if self.qa_flag:
            q_input = quantization(input, self.qa_bits)
            res = q_f_multiplication(q_input, self.weights, self.qa_bits, self.qw_bits) + (2*self.bias)/(2**self.qb_bits - 1) - 1

        else:
            res = torch.matmul(input, ((2*self.weights)/(2**self.qw_bits - 1) - 1)) + (2*self.bias)/(2**self.qb_bits - 1) - 1

        return res

    @torch.no_grad()
    def random_search(self, input, non_zero_percent):
        if non_zero_percent>0.5:
            slimness = 1-non_zero_percent
        else:
            slimness = non_zero_percent
            
        temp_wp = possibility_normalize(torch.abs(self.w_p), non_zero_perc=non_zero_percent, slimness=slimness)
        temp_bp = possibility_normalize(torch.abs(self.b_p), non_zero_perc=non_zero_percent, slimness=slimness)
        self.delta_matrix_w = delta_generation(temp_wp)
        self.delta_matrix_b = delta_generation(temp_bp)
        

        temp_weights = torch.clip(self.weights + self.delta_matrix_w, 0, 2**(self.qw_bits) - 1)
        temp_bias = torch.clip(self.bias + self.delta_matrix_b, 0, 2**(self.qb_bits) - 1)

        self.delta_matrix_w = temp_weights - self.weights
        self.delta_matrix_b = temp_bias - self.bias

        #For debug purpose
        '''
        print("This is the delta matrix")
        print(self.delta_matrix_w)
        '''

        '''
        if self.qa_flag:
            q_input = quantization(input, self.qa_bits)
            res = torch.matmul(q_input, temp_weights) + temp_bias
        else:
            res = torch.matmul(input, temp_weights) + temp_bias
        '''
        if self.qa_flag:
            q_input = quantization(input, self.qa_bits)
            res = q_f_multiplication(q_input, temp_weights, self.qa_bits, self.qw_bits) + (2*temp_bias)/(2**self.qb_bits - 1) - 1
        else:
            res = torch.matmul(input, ((2*temp_weights)/(2**self.qw_bits - 1) - 1)) + (2*temp_bias)/(2**self.qb_bits - 1) - 1
        
        #For debug purpose
        '''
        print("The activations are:")
        print(res)
        '''

        return res

    @torch.no_grad()
    def mid_update(self, delta_loss):
        self.w_u = self.w_u + (delta_loss * self.delta_matrix_w)
        self.b_u = self.b_u + (delta_loss * self.delta_matrix_b)
        '''
        w_ptemp = torch.mul(self.w_pickflag, self.w_p)/(torch.ones(self.w_pickflag.shape).to(self.device) + torch.abs(self.delta_matrix_w)) + (delta_loss) * (self.delta_matrix_w)/(torch.ones(self.w_pickflag.shape).to(self.device) + self.w_pickflag)
        b_ptemp = torch.mul(self.b_pickflag, self.b_p)/(torch.ones(self.b_pickflag.shape).to(self.device) + torch.abs(self.delta_matrix_b)) + (delta_loss) * (self.delta_matrix_b)/(torch.ones(self.b_pickflag.shape).to(self.device) + self.b_pickflag)
        '''
        w_ptemp = torch.mul(self.w_pickflag, self.w_p) + delta_loss * self.delta_matrix_w
        b_ptemp = torch.mul(self.b_pickflag, self.b_p) + delta_loss * self.delta_matrix_b
        self.w_pickflag = (self.w_pickflag + torch.abs(self.delta_matrix_w)) > 0
        self.b_pickflag = (self.b_pickflag + torch.abs(self.delta_matrix_b)) > 0
        self.w_p = w_ptemp + 1000*logical_not(self.w_pickflag)
        self.b_p = b_ptemp + 1000*logical_not(self.b_pickflag)

        #For debug purpose
        '''
        print("The update matrix for w")
        print(self.w_u)
        print("The probability matrix for w:")
        print(self.w_p)
        '''
    
    @torch.no_grad()
    def final_update(self, total_delta_loss_abs, total_delta_loss,c):
        plus_matrix_w = torch.zeros(self.weights.shape).to(self.device)
        minus_matrix_w = torch.zeros(self.weights.shape).to(self.device)
        plus_matrix_w[self.w_u >= (c*total_delta_loss_abs)] = 1
        minus_matrix_w[self.w_u <= -(c*total_delta_loss_abs)] = -1
        '''
        print(self.w_u)
        print(total_delta_loss)
        print(plus_matrix_w)
        print(minus_matrix_w)
        print(self.weights)
        print(plus_matrix_w + minus_matrix_w)
        print(torch.count_nonzero(plus_matrix_w + minus_matrix_w))
        print(total_delta_loss)
        print(self.w_u.max(), self.w_u.min())
        '''
        #For debug purpose
        '''
        print("The total delta loss: ", total_delta_loss)
        print("The number of update parameter: ", torch.count_nonzero(plus_matrix_w + minus_matrix_w))
        print("The maxium and minimum of the update matrix: ")
        print(self.w_u)
        print(self.w_u.max(), self.w_u.min())
        '''
        self.weights = self.weights + plus_matrix_w + minus_matrix_w
        plus_matrix_b = torch.zeros(self.bias.shape).to(self.device)
        minus_matrix_b = torch.zeros(self.bias.shape).to(self.device)
        plus_matrix_b[self.b_u >= (c*total_delta_loss_abs)] = 1
        minus_matrix_b[self.b_u <= -(c*total_delta_loss_abs)] = -1
        self.bias = self.bias + plus_matrix_b + minus_matrix_b

In [471]:
class MLP_BNN():
    def __init__(self, device = None):
        if device is None:
            print("Must have device")
            return
        else:
            self.device = device
        self.fc_1 = QLinear(input_num=28*28, output_num=1024, qa_flag=False, qa_bits=1, qw_bits=1, qb_bits=1, device = device)
        self.Relu_1 = Relu(qa_flag=True, qa_bits=1)
        self.fc_2 = QLinear(input_num=1024, output_num=1024, qa_flag=True, qa_bits=1, qw_bits=1, qb_bits=1, device = device)
        self.Relu_2 = Relu(qa_flag=True, qa_bits=1)
        self.fc_3 = QLinear(input_num=1024, output_num=1024, qa_flag=True, qa_bits=1, qw_bits=1, qb_bits=1, device = device)
        self.Relu_3 = Relu(qa_flag=True, qa_bits=1)
        self.fc_4 = QLinear(input_num=1024, output_num=10, qa_flag=True, qa_bits=1, qw_bits=1, qb_bits=1, device = device)
        self.CrossEntropy = Softmax_CrossEntropy(qa_flag=False, qa_bits=0, qo_flag=False, qo_bits=0)
        self.output = None
        self.loss_base = None
        self.total_delta_loss = 0
    
    def forward(self, input, labels, non_zero_percent=0, update_itr = 0):
        self.total_delta_loss_abs = 0
        self.total_delta_loss = 0
        self.input = torch.reshape(input, (input.shape[0], 28*28))
        output = self.fc_1.forward(self.input)
        output = self.Relu_1.forward(output)
        output = self.fc_2.forward(output)
        output = self.Relu_2.forward(output)
        output = self.fc_3.forward(output)
        output = self.Relu_3.forward(output)
        output = self.fc_4.forward(output)
        self.output = output
        self.loss_base = self.CrossEntropy.forward(output, labels)

        for i in range(update_itr):
            #print("The random searching of the first layer, round: ", i)
            output = self.fc_1.random_search(self.input, non_zero_percent=non_zero_percent)
            output = self.Relu_1.forward(output)
            #print("The random searchig of the second layer, round: ", i)
            output = self.fc_2.random_search(output, non_zero_percent=non_zero_percent)
            output = self.Relu_2.forward(output)

            output = self.fc_3.random_search(output, non_zero_percent=non_zero_percent)
            output = self.Relu_3.forward(output)
            #print("The random searchig of the third layer, round: ", i)
            output = self.fc_4.random_search(output, non_zero_percent=non_zero_percent)
            loss = self.CrossEntropy.forward(output, labels)
            delta_loss = self.loss_base - loss
            self.total_delta_loss_abs = self.total_delta_loss_abs + torch.abs(delta_loss)
            self.total_delta_loss = self.total_delta_loss + delta_loss
            #print("The midupdate of the first layer, round: ", i)
            self.fc_1.mid_update(delta_loss)
            #print("The midupdate of the second layer, round: ", i)
            self.fc_2.mid_update(delta_loss)
            #print("The midupdate of the third layer, round: ", i)
            self.fc_3.mid_update(delta_loss)
            self.fc_4.mid_update(delta_loss)

        

    def update(self, c):
        #print(self.total_delta_loss_abs, self.total_delta_loss)
        #print("The final update of the first layer")
        self.fc_1.final_update(self.total_delta_loss_abs, self.total_delta_loss, c)
        #print("The final update of the second layer")
        self.fc_2.final_update(self.total_delta_loss_abs, self.total_delta_loss, c)
        #print("The final update of the third layer")
        self.fc_3.final_update(self.total_delta_loss_abs, self.total_delta_loss, c)
        self.fc_4.final_update(self.total_delta_loss_abs, self.total_delta_loss, c)

In [472]:
def print_stats(stats):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize = (7, 3), dpi = 110)
    ax1.grid()
    ax2.grid()

    ax1.set_title("ERM loss")
    ax2.set_title("Valid Acc")

    ax1.setxlabel("iterations")
    ax2.setxlabel("iterations")

    itrs = [x[0] for x in stats['train-loss']]
    loss = [x[1] for x in stats['train-loss']]
    ax1.plot(itrs, loss)

    itrs = [x[0] for x in stats['valid-acc']]
    acc = [x[1] for x in stats['valid-acc']]
    ax2.plot(itrs, acc)

    ax1.set_ylim(0.0, 4.05)
    ax2.set_ylim(0.0, 1.05)


In [473]:
@torch.no_grad()
def get_acc(model, dl, device = DEVICE):
    acc = []

    for X, y in dl:
        one_hot_y = torch.zeros(X.shape[0], 10).to(device)
        one_hot_y[[i for i in range(X.shape[0])], [k.item() for k in y]] = 1
        model.forward(X, one_hot_y)
        acc.append(torch.argmax(model.output, dim = 1) == y)

    acc = torch.cat(acc)
    acc = torch.sum(acc)/len(acc)

    return acc.item()



In [474]:
def run_experiment(model, train_dl, valid_dl, test_dl, max_epochs=20, device=DEVICE):
    itr = -1
    stats = {'train-loss' : [], 'valid-acc' : []}
    for epoch in range(max_epochs):
        for X, y in train_dl:
            itr += 1
            one_hot_y = torch.zeros(X.shape[0], 10).to(device)
            one_hot_y[[i for i in range(X.shape[0])], [k.item() for k in y]] = 1
            model.forward(X, one_hot_y, 0.3, 10)
            #print("Now is the iteration :", itr)
            '''
            print(model.output)
            '''
            model.update(0.5)
            stats['train-loss'].append((itr, model.loss_base.item()))

            if itr != 0 and itr % 20 == 0:
                valid_acc = get_acc(model, valid_dl, device = device)
                stats['valid-acc'].append((itr, valid_acc))
                s = f"{epoch}:{itr} [train] loss:{model.loss_base.item():.3f}, [valid] acc:{valid_acc:.3f}"
                print(s)
    
    test_acc = get_acc(model, test_dl, device=device)
    print(f"[test] acc:{test_acc:.3f}")
    return stats

            

In [475]:
max_epochs = 10
train_batch = 256
valid_batch = 1024

In [476]:
train_dl, valid_dl, test_dl = get_mnist_dl(batch_size_train=train_batch, batch_size_valid=valid_batch, device = DEVICE)
model = MLP_BNN(device=DEVICE)
stats = run_experiment(model, train_dl, valid_dl, test_dl, max_epochs=max_epochs, device = DEVICE)
print_stats(stats)

0:20 [train] loss:2.294, [valid] acc:0.111
0:40 [train] loss:2.304, [valid] acc:0.101
0:60 [train] loss:2.310, [valid] acc:0.100
0:80 [train] loss:2.304, [valid] acc:0.089
0:100 [train] loss:2.311, [valid] acc:0.086
0:120 [train] loss:2.312, [valid] acc:0.082
0:140 [train] loss:2.310, [valid] acc:0.085
0:160 [train] loss:2.310, [valid] acc:0.084
0:180 [train] loss:2.311, [valid] acc:0.086
0:200 [train] loss:2.313, [valid] acc:0.085
1:220 [train] loss:2.308, [valid] acc:0.098
1:240 [train] loss:2.303, [valid] acc:0.101
1:260 [train] loss:2.300, [valid] acc:0.094
1:280 [train] loss:2.301, [valid] acc:0.092
1:300 [train] loss:2.299, [valid] acc:0.092
1:320 [train] loss:2.300, [valid] acc:0.098
1:340 [train] loss:2.298, [valid] acc:0.101
1:360 [train] loss:2.300, [valid] acc:0.102
1:380 [train] loss:2.298, [valid] acc:0.104
1:400 [train] loss:2.296, [valid] acc:0.100
1:420 [train] loss:2.297, [valid] acc:0.096
2:440 [train] loss:2.302, [valid] acc:0.099
2:460 [train] loss:2.300, [valid] ac