In [49]:

from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
import numpy as np
import os
import pandas as pd
from time import time
import matplotlib.pyplot as plt
from collections import defaultdict
from models.binarized_modules import binarized
# from binarized_modules import  BinarizeLinear,BinarizeConv2d

In [4]:
cuda = False
# cuda = True

In [5]:
kwargs = {'num_workers': 1, 'pin_memory': True} if cuda else {}

In [6]:
batch_size = 64
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=batch_size, shuffle=True, **kwargs)

In [7]:
test_batch_size=1000

In [8]:
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('data', train=False, transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=test_batch_size, shuffle=False, **kwargs)

In [9]:
# from mnist_bnn import Net
from models.lenet_5 import BinarizedLeNet5_BN as Net

model = Net()
if cuda:
    torch.cuda.set_device(0)
    model.cuda()


In [15]:
# model_path = os.path.join(models_path,f"epoch_7.pth")
model_idx = 1
models_path = os.path.abspath(f"/home/earapidis/BinarizedNN/saved_models/lenet_5/model_{model_idx}")
model_path = os.path.join(models_path,f"epoch_15.pth")
# model_path = os.path.join(models_path,f"best.pth")
model = Net()
model.load_state_dict(torch.load(model_path))
if cuda:
    torch.cuda.set_device(0)
    model.cuda()

In [16]:
model

BinarizedLeNet5_BN(
  (conv1): BinarizeConv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (bn1): BatchNorm2d(6, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (htanh1): Hardtanh(min_val=-1.0, max_val=1.0)
  (pool1): AvgPool2d(kernel_size=2, stride=2, padding=0)
  (conv2): BinarizeConv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (htanh2): Hardtanh(min_val=-1.0, max_val=1.0)
  (pool2): AvgPool2d(kernel_size=2, stride=2, padding=0)
  (fc1): BinarizeLinear(in_features=256, out_features=120, bias=True)
  (bn_fc1): BatchNorm1d(120, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (htanh3): Hardtanh(min_val=-1.0, max_val=1.0)
  (fc2): BinarizeLinear(in_features=120, out_features=84, bias=True)
  (bn_fc2): BatchNorm1d(84, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (htanh4): Hardtanh(min_val=-1.0, max_val=1.0)
  (fc3): BinarizeLinear(in_features=84

In [None]:
model.conv1.weight.shape

torch.Size([16, 6, 5, 5])

In [116]:
filters = torch.randn(8, 4, 3, 3)
inputs = torch.randn(10, 4, 5, 5)
inputs = binarized(inputs)
filters = binarized(filters)

In [117]:
# With square kernels and equal stride

padding = 1



output = F.conv2d(inputs, filters, padding=padding)
output.shape

torch.Size([10, 8, 5, 5])

In [118]:
import torch
import torch.nn.functional as F

def conv2d_loops(x, w, padding=0):
    N, Cin, H, W     = x.shape
    Cout, _, Kh, Kw  = w.shape
    Hout = H + 2*padding - Kh + 1
    Wout = W + 2*padding - Kw + 1

    # Zero-pad input
    x_p = torch.zeros((N, Cin, H + 2*padding, W + 2*padding))
    x_p[:, :, padding:padding+H, padding:padding+W] = x

    y = torch.zeros((N, Cout, Hout, Wout))
    for n in range(N):
        for co in range(Cout):
            for i in range(Hout):
                for j in range(Wout):
                    acc = 0.0
                    for ci in range(Cin):
                        acc += torch.sum(x_p[n, ci, i:i+Kh, j:j+Kw] * w[co, ci])
                    y[n, co, i, j] = acc
    return y

ref = F.conv2d(inputs, filters, padding=padding)
out_loops  = conv2d_loops(inputs, filters, padding)
print("Reference shape:", ref.shape)
print("Max abs diff (loops):",  (ref - out_loops).abs().max().item())


Reference shape: torch.Size([10, 8, 5, 5])
Max abs diff (loops): 0.0


In [119]:
def conv2d_tiles(x, w, Num_rows, Num_Columns, padding=0):
    N, Cin, H, W     = x.shape
    Cout, _, Kh, Kw  = w.shape
    Hout = H + 2*padding - Kh + 1
    Wout = W + 2*padding - Kw + 1



    # Zero-pad input
    x_p = torch.zeros((N, Cin, H + 2*padding, W + 2*padding))
    x_p[:, :, padding:padding+H, padding:padding+W] = x

    num_tiles_rows = int(np.ceil((Kw * Kh*Cin)/ Num_rows))
    num_tiles_columns = int(np.ceil(Cout / Num_Columns))

    crossbar_weights = np.zeros((num_tiles_rows,num_tiles_columns,Num_rows, Num_Columns))
    kernel_size = Kh * Kw
    tile_i_idx = 0
    tile_j_idx = 0
    cin_per_cross = Num_rows // kernel_size 

    print(crossbar_weights.shape)
    for co in range(Cout):
        tile_j_idx = co // Num_Columns
        for ci in range(Cin):
            id = (ci%cin_per_cross)
            tile_row_start = id*kernel_size
            tile_row_end = (id+1)*kernel_size
            tile_i_idx = ci // cin_per_cross
            flat_w = w[co, ci].view(-1).numpy()
            crossbar_weights[tile_i_idx, tile_j_idx, tile_row_start:tile_row_end, co] = flat_w
    
    input_vec = np.zeros((N,Hout,Wout,num_tiles_rows,Num_rows))
    print(input_vec.shape)
    for n in range(N):
        for ci in range(Cin):
            for i in range(Hout):
                for j in range(Wout):
                    id = (ci%cin_per_cross)
                    tile_row_start = id*kernel_size
                    tile_row_end = (id+1)*kernel_size
                    tile_i_idx = ci // cin_per_cross
                    flat_input = torch.flatten(x_p[n, ci, i:i+Kh, j:j+Kw]).numpy()
                    # print(flat_input.shape)
                    input_vec[n,i,j,tile_i_idx, tile_row_start:tile_row_end] = flat_input

    output = np.zeros((N,Cout,Hout,Wout))

    for n in range(N):
        for i in range(Hout):
            for j in range(Wout):
                inp = input_vec[n,i,j,:,:]
                # print(inp.shape)
                w = crossbar_weights.reshape(num_tiles_rows, Num_rows, Num_Columns)
                # print(w.shape)

                intermidiate_out = np.zeros((num_tiles_rows, Num_Columns))
                for idx, vec in enumerate(inp):
                    out = np.dot(vec, w[idx,:,:])
                    intermidiate_out[idx,:] = out
                # print(intermidiate_out.shape)
                cout_outs = np.sum(intermidiate_out, axis=0)
                cout = cout_outs[:Cout]
                output[n,:,i,j] = cout
    return output
Num_rows = 32
Num_Columns = 32                    
out_loops  = conv2d_tiles(inputs, filters,Num_rows,Num_Columns, padding)
print(out_loops-ref.numpy())
# print("Output shape:", out_loops[0].shape, out_loops[1].shape)

(2, 1, 32, 32)
(10, 5, 5, 2, 32)
[[[[0. 0. 0. 0. 0.]
   [0. 0. 0. 0. 0.]
   [0. 0. 0. 0. 0.]
   [0. 0. 0. 0. 0.]
   [0. 0. 0. 0. 0.]]

  [[0. 0. 0. 0. 0.]
   [0. 0. 0. 0. 0.]
   [0. 0. 0. 0. 0.]
   [0. 0. 0. 0. 0.]
   [0. 0. 0. 0. 0.]]

  [[0. 0. 0. 0. 0.]
   [0. 0. 0. 0. 0.]
   [0. 0. 0. 0. 0.]
   [0. 0. 0. 0. 0.]
   [0. 0. 0. 0. 0.]]

  ...

  [[0. 0. 0. 0. 0.]
   [0. 0. 0. 0. 0.]
   [0. 0. 0. 0. 0.]
   [0. 0. 0. 0. 0.]
   [0. 0. 0. 0. 0.]]

  [[0. 0. 0. 0. 0.]
   [0. 0. 0. 0. 0.]
   [0. 0. 0. 0. 0.]
   [0. 0. 0. 0. 0.]
   [0. 0. 0. 0. 0.]]

  [[0. 0. 0. 0. 0.]
   [0. 0. 0. 0. 0.]
   [0. 0. 0. 0. 0.]
   [0. 0. 0. 0. 0.]
   [0. 0. 0. 0. 0.]]]


 [[[0. 0. 0. 0. 0.]
   [0. 0. 0. 0. 0.]
   [0. 0. 0. 0. 0.]
   [0. 0. 0. 0. 0.]
   [0. 0. 0. 0. 0.]]

  [[0. 0. 0. 0. 0.]
   [0. 0. 0. 0. 0.]
   [0. 0. 0. 0. 0.]
   [0. 0. 0. 0. 0.]
   [0. 0. 0. 0. 0.]]

  [[0. 0. 0. 0. 0.]
   [0. 0. 0. 0. 0.]
   [0. 0. 0. 0. 0.]
   [0. 0. 0. 0. 0.]
   [0. 0. 0. 0. 0.]]

  ...

  [[0. 0. 0. 0. 0.]
   [0. 0. 0. 

In [120]:
inputs.size()

torch.Size([10, 4, 5, 5])