In [173]:

from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
import numpy as np
import os
import pandas as pd
from time import time
import matplotlib.pyplot as plt
from collections import defaultdict
from models.binarized_modules import binarized
# from binarized_modules import  BinarizeLinear,BinarizeConv2d

In [174]:
import sys
sys.path.append('/home/earapidis/Fast-Crossbar-Sim/python')
from crossbar import VectorSim, ParallelSim, _task
from tqdm import tqdm

In [175]:
os.getcwd()

'/home/earapidis/BinarizedNN'

In [176]:
cuda = False
# cuda = True

In [177]:
kwargs = {'num_workers': 1, 'pin_memory': True} if cuda else {}

In [178]:
batch_size = 64
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=batch_size, shuffle=True, **kwargs)

In [179]:
test_batch_size=1000

In [180]:
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('data', train=False, transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=test_batch_size, shuffle=False, **kwargs)

In [181]:
# from mnist_bnn import Net
from models.lenet_5 import BinarizedLeNet5_BN as Net

model = Net()
if cuda:
    torch.cuda.set_device(0)
    model.cuda()


In [182]:
# model_path = os.path.join(models_path,f"epoch_7.pth")
model_idx = 1
models_path = os.path.abspath(f"/home/earapidis/BinarizedNN/saved_models/lenet_5/model_{model_idx}")
model_path = os.path.join(models_path,f"epoch_15.pth")
# model_path = os.path.join(models_path,f"best.pth")
model = Net()
model.load_state_dict(torch.load(model_path))
if cuda:
    torch.cuda.set_device(0)
    model.cuda()

In [183]:
model

BinarizedLeNet5_BN(
  (conv1): BinarizeConv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (bn1): BatchNorm2d(6, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (htanh1): Hardtanh(min_val=-1.0, max_val=1.0)
  (pool1): AvgPool2d(kernel_size=2, stride=2, padding=0)
  (conv2): BinarizeConv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (htanh2): Hardtanh(min_val=-1.0, max_val=1.0)
  (pool2): AvgPool2d(kernel_size=2, stride=2, padding=0)
  (fc1): BinarizeLinear(in_features=256, out_features=120, bias=True)
  (bn_fc1): BatchNorm1d(120, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (htanh3): Hardtanh(min_val=-1.0, max_val=1.0)
  (fc2): BinarizeLinear(in_features=120, out_features=84, bias=True)
  (bn_fc2): BatchNorm1d(84, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (htanh4): Hardtanh(min_val=-1.0, max_val=1.0)
  (fc3): BinarizeLinear(in_features=84

In [184]:
# model.conv2.weight

In [185]:
model.conv2.weight.data.shape

torch.Size([16, 6, 5, 5])

In [186]:
filters = model.conv2.weight.data
filters.shape

torch.Size([16, 6, 5, 5])

In [187]:
COUT = 50
CIN = 6
Kh = 5
Kw = 5
filters = torch.randn(COUT,CIN,Kh,Kw)
filters = binarized(filters)

# COUT, CIN, Kh, Kw = filters.shape

Hi = 24
Wi = 24
padding = 0 
N = 1
Num_rows = 32
Num_Columns = 32                    
Hout = Hi + 2*padding - Kh + 1
Wout = Wi + 2*padding - Kw + 1


# filters = torch.randn(COUT, CIN, Kh, Kw)

inputs = torch.randn(N, CIN, Hi, Wi)
inputs = binarized(inputs)
filters_b = binarized(filters)

In [188]:
filters_b.shape

torch.Size([50, 6, 5, 5])

In [203]:
def checkerboard_last_cols(arr: torch.Tensor, C: int) -> None:
    """
    Overwrite the last C columns of `arr` in-place with a checkerboard pattern of 0s and 1s.

    Parameters
    ----------
    arr : torch.Tensor
        Input 2D tensor of shape (n, m).
    C : int
        Number of columns at the right edge to turn into a checkerboard.
    """
    n, m = arr.shape
    if C > m:
        raise ValueError("C cannot be larger than the number of columns m.")
    
    # Create row and column indices
    rows = torch.arange(n).view(-1, 1)                # shape: (n, 1)
    cols = torch.arange(m - C, m).view(1, -1)         # shape: (1, C)

    # Generate checkerboard pattern
    pattern = (rows + cols) % 2

    # Apply the pattern to the last C columns in-place
    arr[:, -C:] = pattern

# Example
A = torch.ones((6, 8), dtype=torch.int)
checkerboard_last_cols(A, C=5)
print(A)


tensor([[1, 1, 1, 1, 0, 1, 0, 1],
        [1, 1, 1, 0, 1, 0, 1, 0],
        [1, 1, 1, 1, 0, 1, 0, 1],
        [1, 1, 1, 0, 1, 0, 1, 0],
        [1, 1, 1, 1, 0, 1, 0, 1],
        [1, 1, 1, 0, 1, 0, 1, 0]], dtype=torch.int32)


In [190]:
def compliment(x):
    x = x.clone()
    neg = -1*x
    pos = x

    pos[pos==-1] = 0
    neg[neg==-1] = 0
    return pos, neg

In [191]:
pos_inputs, neg_inputs = compliment(inputs)
pos_filters, neg_filters = compliment(filters_b)

print(pos_filters.shape)

torch.Size([50, 6, 5, 5])


In [192]:
new_weights = torch.empty((COUT,CIN,2*(Kh*Kw)))

pos = pos_filters.reshape(pos_filters.shape[0],pos_filters.shape[1],-1)
neg = neg_filters.reshape(neg_filters.shape[0],neg_filters.shape[1],-1)
kernel_size = Kh*Kw

for i in range(kernel_size):

    new_weights[:,:,2*i] = pos[:,:,i]
    new_weights[:,:,2*i+1] = neg[:,:,i]
new_weights.shape

torch.Size([50, 6, 50])

In [193]:
new_inputs = torch.empty((N,Hout,Wout,CIN,2*kernel_size))
for ii in range(Hout):
    for jj in range(Wout):
        _pos_ = pos_inputs[:,:,ii:ii+Kh,jj:jj+Kw]
        _neg_ = neg_inputs[:,:,ii:ii+Kh,jj:jj+Kw]
        _pos_ = _pos_.reshape(_pos_.shape[0],_pos_.shape[1],-1)
        _neg_ = _neg_.reshape(_neg_.shape[0],_neg_.shape[1],-1)
        one_pixel = torch.empty(_pos_.shape[0],_pos_.shape[1],2*kernel_size)
        for z in range(kernel_size):
            one_pixel[:,:,2*z]=_pos_[:,:,z]
            one_pixel[:,:,2*z+1]=_neg_[:,:,z]
        new_inputs[:,ii,jj,:,:] = one_pixel

In [194]:
new_inputs.shape

torch.Size([1, 20, 20, 6, 50])

In [195]:
def get_output(I,Kh, Kw, CIN):
    out = 2*I - Kh*Kw*CIN
    return out

In [196]:
ref = nn.functional.conv2d(inputs, filters_b, padding=0)
# ref

In [197]:
out_2d = torch.empty((N,COUT,Hout,Wout))

for ii in range(Hout):
    for jj in range(Wout):
        one_pixel = new_inputs[:,ii,jj,:,:]
        # print(one_pixel.shape)
        out_1d = F.conv1d(one_pixel,new_weights)
        out_1d = out_1d.squeeze(-1)
        # print(out_1d.shape)
        out_2d[:,:,ii,jj] = out_1d
out_2d = get_output(out_2d,Kh,Kw,CIN)
torch.equal(out_2d,ref)

True

In [198]:
new_weights.shape

torch.Size([50, 6, 50])

In [199]:
import math
math.ceil

<function math.ceil(x, /)>

In [206]:
def torch_conv_2d(crossbar_inputs,crossbar_weights,_COUT_,Total_dim):
    _N_,_HOUT_,_WOUT_,crossbar_y,Num_rows = crossbar_inputs.shape
    _,crossbar_x,_,Num_Columns = crossbar_weights.shape
    output_conv_2d = torch.zeros(_N_,_COUT_,_HOUT_,_WOUT_)

    columns_per_crossbar = math.floor(_COUT_/crossbar_x)

    for ii in range(_HOUT_):
        for jj in range(_WOUT_):
            mac_out_columns = torch.zeros((_N_,_COUT_))
            for cy in range(crossbar_y):
                for cx in range(crossbar_x):
                    tmp_x=crossbar_inputs[:,ii,jj,cy,:]
                    tmp_w=crossbar_weights[cy,cx,:,:]
                    checkerboard_last_cols(tmp_w,Num_Columns-columns_per_crossbar)
                    # tmp_w=crossbar_weights
                    # print(tmp_x.shape)
                    # print(tmp_w.shape)
                    # _out_ = torch.matmul(tmp_w,tmp_x)
                    _out_ = torch.matmul(tmp_x,tmp_w)

                    column_start_idx = cx*columns_per_crossbar
                    if cx==crossbar_x-1:
                        column_end_idx=Total_dim
                    else:
                        column_end_idx = (cx+1)*columns_per_crossbar
                    
                    mac_out_columns[:,column_start_idx:column_end_idx] += _out_[:,:columns_per_crossbar]
            output_conv_2d[:,:,ii,jj]=mac_out_columns
    return output_conv_2d


In [207]:
def conv2d_tile(x,w,Num_rows,Num_Columns):
    _N_,_HOUT_, _WOUT_,_CIN_, _kernel_size_ = x.shape
    _COUT_, _CIN_, _kernel_size_ = w.shape
    output_conv_2d = torch.zeros(_N_,_COUT_,_HOUT_,_WOUT_)

    whole_input_size = _CIN_*_kernel_size_
    print(f"total inputs: {whole_input_size}")
    crossbar_y = math.ceil((_CIN_*_kernel_size_)/Num_rows)
    crossbar_x = math.ceil(_COUT_/Num_Columns)

    crossbar_weights = torch.zeros((crossbar_y,crossbar_x,Num_rows,Num_Columns))
    crossbar_inputs = torch.zeros((_N_,_HOUT_,_WOUT_,crossbar_y,Num_rows))

    rows_per_crossbar = math.floor(whole_input_size/crossbar_y)

    flatten_x = x.reshape(*x.shape[:-2],-1)
    print(flatten_x.shape)
    flatten_w = w.reshape(*w.shape[:-2],-1).T
    print(flatten_w.shape)
    print(crossbar_inputs.shape)
    print(f"weights shape: {crossbar_weights.shape}")
    print(f"inputs shape: {crossbar_inputs.shape}")

    columns_per_crossbar = math.floor(_COUT_/crossbar_x)

    for ii in range(crossbar_y):
        row_start_idx = ii*rows_per_crossbar
        if ii==crossbar_y-1:
            row_end_idx = flatten_x.shape[-1]
        else:
            row_end_idx = (ii+1)*rows_per_crossbar
        # print(start_idx,end_idx)
        crossbar_inputs[:,:,:,ii,:rows_per_crossbar] = flatten_x[:,:,:,row_start_idx:row_end_idx]
    # for ii in range(0,whole_input_size,step=inpu)

        for jj in range(crossbar_x):
            column_start_idx = jj*columns_per_crossbar
            if jj==crossbar_x-1:
                column_end_idx=flatten_w.shape[-1]
            else:
                column_end_idx = (jj+1)*columns_per_crossbar
            
            crossbar_weights[ii,jj,:rows_per_crossbar,:columns_per_crossbar] = flatten_w[row_start_idx:row_end_idx,column_start_idx:column_end_idx]


    # for n in range(_N_):
    output_conv_2d = torch_conv_2d(crossbar_inputs,crossbar_weights,_COUT_,flatten_w.shape[-1])
    return output_conv_2d
output_conv_2d=conv2d_tile(new_inputs,new_weights,Num_rows,Num_Columns)
output_conv_2d = get_output(output_conv_2d,Kh,Kw,CIN)
torch.equal(output_conv_2d,ref)

total inputs: 300
torch.Size([1, 20, 20, 300])
torch.Size([300, 50])
torch.Size([1, 20, 20, 10, 32])
weights shape: torch.Size([10, 2, 32, 32])
inputs shape: torch.Size([1, 20, 20, 10, 32])


True