In [58]:

from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
import numpy as np
import os
import pandas as pd
from time import time
import matplotlib.pyplot as plt
from collections import defaultdict
from models.binarized_modules import binarized
from concurrent.futures import ProcessPoolExecutor

# from binarized_modules import  BinarizeLinear,BinarizeConv2d

In [59]:
import sys
sys.path.append('/home/earapidis/Fast-Crossbar-Sim/python')
from crossbar import VectorSim, ParallelSim, _task
from tqdm import tqdm

In [60]:
os.getcwd()

'/home/earapidis/BinarizedNN'

In [61]:
cuda = False

In [62]:
kwargs = {'num_workers': 1, 'pin_memory': True} if cuda else {}

In [63]:
test_batch_size=1

In [64]:
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('data', train=False, transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=test_batch_size, shuffle=False, **kwargs)

In [65]:
# from mnist_bnn import Net
from models.lenet_5 import BinarizedLeNet5_BN as Net

model = Net()
if cuda:
    torch.cuda.set_device(0)
    model.cuda()


In [66]:
# model_path = os.path.join(models_path,f"epoch_7.pth")
model_idx = 1
models_path = os.path.abspath(f"/home/earapidis/BinarizedNN/saved_models/lenet_5/model_{model_idx}")
model_path = os.path.join(models_path,f"epoch_15.pth")
# model_path = os.path.join(models_path,f"best.pth")
model = Net()
model.load_state_dict(torch.load(model_path))
if cuda:
    torch.cuda.set_device(0)
    model.cuda()

In [67]:
model.eval()

BinarizedLeNet5_BN(
  (conv1): BinarizeConv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (bn1): BatchNorm2d(6, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (htanh1): Hardtanh(min_val=-1.0, max_val=1.0)
  (pool1): AvgPool2d(kernel_size=2, stride=2, padding=0)
  (conv2): BinarizeConv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (htanh2): Hardtanh(min_val=-1.0, max_val=1.0)
  (pool2): AvgPool2d(kernel_size=2, stride=2, padding=0)
  (fc1): BinarizeLinear(in_features=256, out_features=120, bias=True)
  (bn_fc1): BatchNorm1d(120, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (htanh3): Hardtanh(min_val=-1.0, max_val=1.0)
  (fc2): BinarizeLinear(in_features=120, out_features=84, bias=True)
  (bn_fc2): BatchNorm1d(84, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (htanh4): Hardtanh(min_val=-1.0, max_val=1.0)
  (fc3): BinarizeLinear(in_features=84

In [99]:
images, labels = next(iter(test_loader))
print(f"image shape: {images.shape}")

image shape: torch.Size([1, 1, 28, 28])


In [68]:

x=model.conv1(images)
x=model.bn1(x)
x=model.htanh1(x)
x=model.pool1(x)
x = x.detach()

image shape: torch.Size([1, 1, 28, 28])


In [100]:
inputs = images
# inputs = x
inputs = binarized(inputs)
inputs.shape

torch.Size([1, 1, 28, 28])

In [101]:
filters = model.conv1.weight.data
filters_b = binarized(filters)
filters_b.shape

torch.Size([6, 1, 5, 5])

In [102]:
# COUT = 50
# CIN = 6
# Kh = 5
# Kw = 5
# filters = torch.randn(COUT,CIN,Kh,Kw)
# filters = binarized(filters)

COUT, CIN, Kh, Kw = filters.shape
N, CIN, Hi,Wi = inputs.shape
# Hi = 24
# Wi = 24
# N = 1
padding = 0 
Num_rows = 32
Num_Columns = 32                    
Hout = Hi + 2*padding - Kh + 1
Wout = Wi + 2*padding - Kw + 1


# filters = torch.randn(COUT, CIN, Kh, Kw)


In [103]:
def checkerboard_last_cols(arr: torch.Tensor, C: int) -> None:
    """
    Overwrite the last C columns of `arr` in-place with a checkerboard pattern of 0s and 1s.

    Parameters
    ----------
    arr : torch.Tensor
        Input 2D tensor of shape (n, m).
    C : int
        Number of columns at the right edge to turn into a checkerboard.
    """
    n, m = arr.shape
    if C > m:
        raise ValueError("C cannot be larger than the number of columns m.")
    
    # Create row and column indices
    rows = torch.arange(n).view(-1, 1)                # shape: (n, 1)
    cols = torch.arange(m - C, m).view(1, -1)         # shape: (1, C)

    # Generate checkerboard pattern
    pattern = (rows + cols) % 2

    # Apply the pattern to the last C columns in-place
    arr[:, -C:] = pattern

# Example
A = torch.ones((6, 8), dtype=torch.int)
checkerboard_last_cols(A, C=5)
print(A)


tensor([[1, 1, 1, 1, 0, 1, 0, 1],
        [1, 1, 1, 0, 1, 0, 1, 0],
        [1, 1, 1, 1, 0, 1, 0, 1],
        [1, 1, 1, 0, 1, 0, 1, 0],
        [1, 1, 1, 1, 0, 1, 0, 1],
        [1, 1, 1, 0, 1, 0, 1, 0]], dtype=torch.int32)


In [104]:
def compliment(x):
    x = x.clone()
    neg = -1*x
    pos = x

    pos[pos==-1] = 0
    neg[neg==-1] = 0
    return pos, neg

In [105]:
pos_inputs, neg_inputs = compliment(inputs)
pos_filters, neg_filters = compliment(filters_b)

print(pos_filters.shape)
print(pos_inputs.shape)

torch.Size([6, 1, 5, 5])
torch.Size([1, 1, 28, 28])


In [106]:
new_weights = torch.empty((COUT,CIN,2*(Kh*Kw)))

pos = pos_filters.reshape(pos_filters.shape[0],pos_filters.shape[1],-1)
neg = neg_filters.reshape(neg_filters.shape[0],neg_filters.shape[1],-1)
kernel_size = Kh*Kw

for i in range(kernel_size):

    new_weights[:,:,2*i] = pos[:,:,i]
    new_weights[:,:,2*i+1] = neg[:,:,i]
new_weights.shape

torch.Size([6, 1, 50])

In [107]:
print(Hout,Wout)

24 24


In [108]:
new_inputs = torch.empty((N,Hout,Wout,CIN,2*kernel_size))
for ii in range(Hout):
    for jj in range(Wout):
        _pos_ = pos_inputs[:,:,ii:ii+Kh,jj:jj+Kw]
        _neg_ = neg_inputs[:,:,ii:ii+Kh,jj:jj+Kw]
        # print(_pos_.shape)
        _pos_ = _pos_.reshape(_pos_.shape[0],_pos_.shape[1],-1)
        _neg_ = _neg_.reshape(_neg_.shape[0],_neg_.shape[1],-1)
        one_pixel = torch.empty(_pos_.shape[0],_pos_.shape[1],2*kernel_size)
        for z in range(kernel_size):
            one_pixel[:,:,2*z]=_pos_[:,:,z]
            one_pixel[:,:,2*z+1]=_neg_[:,:,z]
        new_inputs[:,ii,jj,:,:] = one_pixel

In [109]:
new_inputs.shape

torch.Size([1, 24, 24, 1, 50])

In [110]:
def get_output(I,Kh, Kw, CIN):
    out = 2*I - Kh*Kw*CIN
    return out

In [111]:
ref = nn.functional.conv2d(inputs, filters_b, padding=0)
# ref

In [112]:
out_2d = torch.empty((N,COUT,Hout,Wout))

for ii in range(Hout):
    for jj in range(Wout):
        one_pixel = new_inputs[:,ii,jj,:,:]
        # print(one_pixel.shape)
        out_1d = F.conv1d(one_pixel,new_weights)
        out_1d = out_1d.squeeze(-1)
        # print(out_1d.shape)
        out_2d[:,:,ii,jj] = out_1d
out_2d = get_output(out_2d,Kh,Kw,CIN)
torch.equal(out_2d,ref)

True

In [113]:
new_weights.shape

torch.Size([6, 1, 50])

In [114]:
import math
math.ceil

<function math.ceil(x, /)>

In [115]:
def torch_conv_2d(crossbar_inputs,crossbar_weights,_COUT_,Total_dim):
    _N_,_HOUT_,_WOUT_,crossbar_y,Num_rows = crossbar_inputs.shape
    _,crossbar_x,_,Num_Columns = crossbar_weights.shape
    output_conv_2d = torch.zeros(_N_,_COUT_,_HOUT_,_WOUT_)

    columns_per_crossbar = math.floor(_COUT_/crossbar_x)

    for ii in range(_HOUT_):
        for jj in range(_WOUT_):
            mac_out_columns = torch.zeros((_N_,_COUT_))
            for cy in range(crossbar_y):
                for cx in range(crossbar_x):
                    tmp_x=crossbar_inputs[:,ii,jj,cy,:]
                    tmp_w=crossbar_weights[cy,cx,:,:]
                    checkerboard_last_cols(tmp_w,Num_Columns-columns_per_crossbar)
                    # tmp_w=crossbar_weights
                    # print(tmp_x.shape)
                    # print(tmp_w.shape)
                    # _out_ = torch.matmul(tmp_w,tmp_x)
                    _out_ = torch.matmul(tmp_x,tmp_w)

                    column_start_idx = cx*columns_per_crossbar
                    if cx==crossbar_x-1:
                        column_end_idx=Total_dim
                    else:
                        column_end_idx = (cx+1)*columns_per_crossbar
                    
                    mac_out_columns[:,column_start_idx:column_end_idx] += _out_[:,:columns_per_crossbar]
            output_conv_2d[:,:,ii,jj]=mac_out_columns
    return output_conv_2d


In [116]:
# def cim_conv_2d(crossbar_inputs,crossbar_weights,_COUT_,Total_dim):
#     _N_,_HOUT_,_WOUT_,crossbar_y,Num_rows = crossbar_inputs.shape
#     _,crossbar_x,_,Num_Columns = crossbar_weights.shape
#     output_conv_2d = torch.zeros(_N_,_COUT_,_HOUT_,_WOUT_)

#     columns_per_crossbar = math.floor(_COUT_/crossbar_x)

#     for ii in tqdm(range(_HOUT_)):
#         for jj in tqdm(range(_WOUT_)):
#             mac_out_columns = torch.zeros((_N_,_COUT_))
#             for cy in range(crossbar_y):
#                 for cx in range(crossbar_x):
#                     tmp_x=crossbar_inputs[:,ii,jj,cy,:]
#                     tmp_w=crossbar_weights[cy,cx,:,:]
#                     checkerboard_last_cols(tmp_w,Num_Columns-columns_per_crossbar)
#                     # tmp_w=crossbar_weights
#                     # print(tmp_x.shape)
#                     # print(tmp_w.shape)
#                     # _out_ = torch.matmul(tmp_w,tmp_x)
#                     # _out_ = torch.matmul(tmp_x,tmp_w)
#                     _,_out_ = _task(((cy,cx),tmp_x),tmp_w,Num_rows,Num_rows,"cs",False)
#                     _out_ = torch.from_numpy(_out_)

#                     column_start_idx = cx*columns_per_crossbar
#                     if cx==crossbar_x-1:
#                         column_end_idx=Total_dim
#                     else:
#                         column_end_idx = (cx+1)*columns_per_crossbar
                    
#                     mac_out_columns[:,column_start_idx:column_end_idx] += _out_[:,:columns_per_crossbar]
#             output_conv_2d[:,:,ii,jj]=mac_out_columns
#     return output_conv_2d


In [226]:
from concurrent.futures import ProcessPoolExecutor, as_completed
import torch
import math
from tqdm import tqdm

def run_tile(cy, cx, tmp_x_np, tmp_w_np, Num_rows, Num_Columns, columns_per_crossbar, Total_dim, mode):
    checkerboard_last_cols(tmp_w_np, tmp_w_np.shape[1] - columns_per_crossbar)
    _, _out_np = _task(((cy, cx), tmp_x_np), tmp_w_np, Num_rows, Num_Columns, mode, False)

    column_start_idx = cx * columns_per_crossbar
    column_end_idx = Total_dim if tmp_w_np.shape[1] - columns_per_crossbar < columns_per_crossbar else (cx + 1) * columns_per_crossbar
    return (cy, cx, column_start_idx, column_end_idx, _out_np[:, :columns_per_crossbar])

def cim_conv_2d(crossbar_inputs, crossbar_weights, _COUT_, mode, Total_dim, max_workers=None):
    _N_, _HOUT_, _WOUT_, crossbar_y, Num_rows = crossbar_inputs.shape
    _, crossbar_x, _, Num_Columns = crossbar_weights.shape
    output_conv_2d = torch.zeros(_N_, _COUT_, _HOUT_, _WOUT_)

    columns_per_crossbar = math.floor(_COUT_ / crossbar_x)

    # Preconvert weights to NumPy once (this is reused for each (ii, jj))
    weights_np = [[crossbar_weights[cy, cx].numpy() for cx in range(crossbar_x)] for cy in range(crossbar_y)]

    with ProcessPoolExecutor(max_workers=max_workers) as executor:
        for ii in tqdm(range(_HOUT_)):
            for jj in range(_WOUT_):
                mac_out_columns = torch.zeros((_N_, _COUT_))
                tasks = []

                for cy in range(crossbar_y):
                    for cx in range(crossbar_x):
                        tmp_x = crossbar_inputs[:, ii, jj, cy, :].numpy()
                        tmp_w = weights_np[cy][cx]
                        tasks.append((cy, cx, tmp_x, tmp_w, Num_rows, Num_Columns, columns_per_crossbar, Total_dim, mode))

                futures = [executor.submit(run_tile, *t) for t in tasks]

                for f in as_completed(futures):
                    cy, cx, col_start, col_end, out_np = f.result()
                    out = torch.from_numpy(out_np)
                    mac_out_columns[:, col_start:col_end] += out

                output_conv_2d[:, :, ii, jj] = mac_out_columns

    return output_conv_2d


In [227]:
def conv2d_tile(x,w,Num_rows,Num_Columns,mode):
    _N_,_HOUT_, _WOUT_,_CIN_, _kernel_size_ = x.shape
    _COUT_, _CIN_, _kernel_size_ = w.shape
    output_conv_2d = torch.zeros(_N_,_COUT_,_HOUT_,_WOUT_)

    whole_input_size = _CIN_*_kernel_size_
    print(f"total inputs: {whole_input_size}")
    crossbar_y = math.ceil((_CIN_*_kernel_size_)/Num_rows)
    crossbar_x = math.ceil(_COUT_/Num_Columns)

    crossbar_weights = torch.zeros((crossbar_y,crossbar_x,Num_rows,Num_Columns))
    crossbar_inputs = torch.zeros((_N_,_HOUT_,_WOUT_,crossbar_y,Num_rows))

    rows_per_crossbar = math.floor(whole_input_size/crossbar_y)

    flatten_x = x.reshape(*x.shape[:-2],-1)
    print(flatten_x.shape)
    flatten_w = w.reshape(*w.shape[:-2],-1).T
    print(flatten_w.shape)
    print(crossbar_inputs.shape)
    print(f"weights shape: {crossbar_weights.shape}")
    print(f"inputs shape: {crossbar_inputs.shape}")

    columns_per_crossbar = math.floor(_COUT_/crossbar_x)

    for ii in range(crossbar_y):
        row_start_idx = ii*rows_per_crossbar
        # if ii==crossbar_y-1:
        #     row_end_idx = flatten_x.shape[-1]
        # else:
        row_end_idx = (ii+1)*rows_per_crossbar
        # print(start_idx,end_idx)
        crossbar_inputs[:,:,:,ii,:rows_per_crossbar] = flatten_x[:,:,:,row_start_idx:row_end_idx]
    # for ii in range(0,whole_input_size,step=inpu)

        for jj in range(crossbar_x):
            column_start_idx = jj*columns_per_crossbar
            # if jj==crossbar_x-1:
            #     column_end_idx=flatten_w.shape[-1]
            # else:
            column_end_idx = (jj+1)*columns_per_crossbar
            
            crossbar_weights[ii,jj,:rows_per_crossbar,:columns_per_crossbar] = flatten_w[row_start_idx:row_end_idx,column_start_idx:column_end_idx]


    # for n in range(_N_):
    output_conv_2d = cim_conv_2d(crossbar_inputs,crossbar_weights,_COUT_,mode,flatten_w.shape[-1],max_workers=8)
    # output_conv_2d = torch_conv_2d(crossbar_inputs,crossbar_weights,_COUT_,flatten_w.shape[-1])
    return output_conv_2d
output_conv_2d_cs=conv2d_tile(new_inputs,new_weights,Num_rows,Num_Columns,mode="cs")
output_conv_2d_cs = get_output(output_conv_2d_cs,Kh,Kw,CIN)
torch.equal(output_conv_2d_cs,ref)

total inputs: 50
torch.Size([1, 24, 24, 50])
torch.Size([50, 6])
torch.Size([1, 24, 24, 2, 32])
weights shape: torch.Size([2, 1, 32, 32])
inputs shape: torch.Size([1, 24, 24, 2, 32])


  0%|          | 0/24 [00:00<?, ?it/s]

100%|██████████| 24/24 [00:19<00:00,  1.22it/s]


True

In [119]:
output_conv_2d_gs=conv2d_tile(new_inputs,new_weights,Num_rows,Num_Columns,mode="gs")
output_conv_2d_gs = get_output(output_conv_2d_gs,Kh,Kw,CIN)
output_conv_2d_gs_b = binarized(output_conv_2d_gs)

total inputs: 50
torch.Size([1, 24, 24, 50])
torch.Size([50, 6])
torch.Size([1, 24, 24, 2, 32])
weights shape: torch.Size([2, 1, 32, 32])
inputs shape: torch.Size([1, 24, 24, 2, 32])


100%|██████████| 24/24 [00:16<00:00,  1.46it/s]


In [120]:
output_conv_2d_cs_b = binarized(output_conv_2d_cs) 
ref_b = binarized(ref)

In [124]:
ref.shape

torch.Size([1, 6, 24, 24])

In [123]:
with np.printoptions(threshold=float('inf')):
    print(ref_b - output_conv_2d_gs_b)
    # print(ref_b - output_conv_2d_cs_b)

tensor([[[[-2., -2., -2., -2., -2., -2., -2., -2., -2., -2., -2., -2., -2., -2.,
           -2., -2., -2., -2., -2., -2., -2., -2., -2., -2.],
          [-2., -2., -2., -2., -2., -2., -2., -2., -2., -2., -2., -2., -2., -2.,
           -2., -2., -2., -2., -2., -2., -2., -2., -2., -2.],
          [-2., -2., -2., -2., -2., -2., -2., -2., -2., -2., -2., -2., -2., -2.,
           -2., -2., -2., -2., -2., -2., -2., -2., -2., -2.],
          [-2., -2.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0., -2., -2.,
           -2., -2., -2., -2., -2., -2., -2., -2., -2., -2.],
          [-2., -2.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
            0.,  0.,  0.,  0.,  0.,  0.,  0.,  0., -2., -2.],
          [-2., -2.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
            0.,  0.,  0.,  0.,  0.,  0.,  0.,  0., -2., -2.],
          [-2., -2.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
            0.,  0.,  0.,  0.,  0.,  0.,  0.,  0., -2., -2.],

In [129]:
fc_weight = model.fc1.weight.data.T
fc_weight_b = binarized(fc_weight)
fc_weight_b.shape


torch.Size([256, 120])

In [None]:
fc_input = torch.randn(10,fc_weight_b.shape[0])
fc_input_b = binarized(fc_input)
fc_input_b.shape

torch.Size([1, 256])

In [131]:
fc_weight_pos, fc_weight_neg = compliment(fc_weight_b) 
fc_input_pos, fc_input_neg = compliment(fc_input_b)

print(fc_weight_pos.shape)

torch.Size([256, 120])


In [159]:
new_fc_weights = torch.empty((2*fc_weight_pos.shape[0],fc_weight_pos.shape[1]))
new_fc_inputs = torch.empty((fc_input_pos.shape[0],2*fc_input_pos.shape[1]))

for idx in range(fc_weight.shape[0]):
    _pos_w_ = fc_weight_pos[idx]
    _neg_w_ = fc_weight_neg[idx]
    new_fc_weights[2*idx]=_pos_w_
    new_fc_weights[2*idx+1]=_neg_w_

    _pos_i_ = fc_input_pos[:,idx]
    _neg_i_ = fc_input_neg[:,idx]

    new_fc_inputs[:,2*idx] = _pos_i_
    new_fc_inputs[:,2*idx+1] = _neg_i_

print(new_fc_inputs.shape)
print(new_fc_weights.shape)




torch.Size([1, 512])
torch.Size([512, 120])


In [168]:
def get_fc_output(I, M):
    out = 2*I - M
    return out

In [169]:
out_2d = F.linear(new_fc_inputs,new_fc_weights.T)
out_2d = get_fc_output(out_2d,fc_input.shape[-1])
out_2d

tensor([[ 18.,  -2.,   8., -32., -10., -18.,   6., -28., -10., -10.,  10.,   0.,
         -34.,  30.,  12.,  24., -22., -12.,  -4.,   8.,  12.,  -2.,   2., -32.,
          -8.,  24., -10.,  -4.,   4.,   6., -16.,   4.,  24., -26.,   0., -18.,
          -2.,   8.,   4.,  -8.,   6., -28., -22.,  -8., -16.,  -8., -18.,  20.,
         -16.,  -4.,   6.,  18., -12., -16.,  -2.,  14.,  -8.,  10., -14., -10.,
          -4.,  12., -20.,   2., -18.,   8.,   2., -32.,  -6.,   6., -12., -14.,
          -2.,  -6.,  10.,  -6.,   4.,   4.,  -4., -16.,  14.,   4.,  10.,   0.,
           6., -18.,  -8.,   0.,   4.,  -4., -14.,  16., -14.,  26.,  32., -14.,
         -24.,  18.,  -2.,  14., -30., -14.,   6.,  42.,   6.,  14.,   2.,   2.,
          -8.,  18.,   8.,   6.,  30., -18.,  22., -18., -20.,  -8.,   8.,  38.]])

In [175]:
ref_fc = F.linear(fc_input_b,fc_weight_b.T)
torch.equal(ref_fc,out_2d)

True

In [None]:
def  fc_linear(crossbar_inputs,crossbar_weights,N,mode,max_workers):
    crossbar_y,crossbar_x,Num_rows,Num_Columns = crossbar_weights.shape
    _N_, crossbar_y, Num_rows = crossbar_inputs.shape
    output = torch.zeros(_N_,crossbar_x)
    columns_per_crossbar = math.floor(N/crossbar_x)
    print()
    print(output.shape)
    tasks = []

    with ProcessPoolExecutor(max_workers=max_workers) as executor:
        for ii in range(crossbar_y):
            for jj in range(crossbar_x):
                column_start_idx = jj*columns_per_crossbar
                column_end_idx = (jj+1)*columns_per_crossbar
                tmp_x = crossbar_inputs[:,ii]
                tmp_w = crossbar_weights[ii][jj]
                checkerboard_last_cols(tmp_w,Num_Columns-columns_per_crossbar)
                # out_matmul = torch.matmul(tmp_x,tmp_w)
                args = (((ii,jj),tmp_x),tmp_w,Num_rows,Num_Columns,mode,False)
                tasks.append(args)
        futures = [executor.submit(_task, *t) for t in tasks]
        for f in as_completed(futures):
            (ii,jj), out_matmul = f.result()
            print(jj)
            column_start_idx = jj*columns_per_crossbar
            column_end_idx = (jj+1)*columns_per_crossbar
            out_matmul = torch.from_numpy(out_matmul)
            print(column_start_idx,column_end_idx)
            print(out_matmul.shape)
            print(columns_per_crossbar)
            output[:,column_start_idx:column_end_idx] += out_matmul[:,:columns_per_crossbar]

    return output
        
# _ , out_matmul = _task(((ii,jj),tmp_x),tmp_w,Num_rows,Num_Columns,mode,False)
# out_matmul = torch.from_numpy(out_matmul)
# # print(out_matmul.shape)
# output[:,column_start_idx:column_end_idx] += out_matmul[:,:columns_per_crossbar]

In [245]:
def fc_tile(x,w, Num_rows,Num_Columns,mode):
    _N_ , M = x.shape
    M , N = w.shape

    crossbar_y = math.ceil(M/Num_rows)
    crossbar_x = math.ceil(N/Num_Columns)
    print("crossbar grid",crossbar_y,crossbar_x)

    crossbar_inputs = torch.zeros((_N_,crossbar_y,Num_rows))
    crossbar_weights = torch.zeros((crossbar_y,crossbar_x,Num_rows,Num_Columns))

    rows_per_crossbar = math.floor(M/crossbar_y)
    columns_per_crossbar = math.floor(N/crossbar_x)

    print(rows_per_crossbar, columns_per_crossbar)

    for ii in range(crossbar_y):
        row_start_idx = ii*rows_per_crossbar
        # if ii==crossbar_y-1:
        #     row_end_idx = M
        # else:
        row_end_idx = (ii+1)*rows_per_crossbar
        crossbar_inputs[:,ii,:rows_per_crossbar] = x[:,row_start_idx:row_end_idx]
    # for ii in range(0,whole_input_size,step=inpu)
        # print(row_start_idx,row_end_idx)

        for jj in range(crossbar_x):
            column_start_idx = jj*columns_per_crossbar
            # if jj==crossbar_x-1:
            #     column_end_idx=N
            # else:
            column_end_idx = (jj+1)*columns_per_crossbar
            # print(column_start_idx,column_end_idx)
            
            crossbar_weights[ii,jj,:rows_per_crossbar,:columns_per_crossbar] = w[row_start_idx:row_end_idx,column_start_idx:column_end_idx]
        
    print(f"crossbar weigths : {crossbar_weights.shape}")
    print(f"crossbar inputs : {crossbar_inputs.shape}")
    return fc_linear(crossbar_inputs,crossbar_weights,N,mode,8)
    # output = torch.zeros(_N_,N)
    # for ii in range(crossbar_y):
    #     # row_start_idx = ii*rows_per_crossbar
    #     # row_end_idx = (ii+1)*rows_per_crossbar
    #     for jj in range(crossbar_x):
    #         column_start_idx = jj*columns_per_crossbar
    #         # if jj==crossbar_x-1:
    #         #     column_end_idx=N
    #         # else:
    #         column_end_idx = (jj+1)*columns_per_crossbar
    #         tmp_x = crossbar_inputs[:,ii]
    #         tmp_w = crossbar_weights[ii][jj]
    #         checkerboard_last_cols(tmp_w,Num_Columns-columns_per_crossbar)
    #         _ , out_matmul = _task(((ii,jj),tmp_x),tmp_w,Num_rows,Num_Columns,mode,False)
    #         out_matmul = torch.from_numpy(out_matmul)
    #         output[:,column_start_idx:column_end_idx] += out_matmul[:,:columns_per_crossbar]

    # return output


fc_output = fc_tile(new_fc_inputs,new_fc_weights,Num_rows,Num_Columns,mode="gs")
fc_output=get_fc_output(fc_output,256)
# print(fc_output)
torch.equal(fc_output,ref_fc)

crossbar grid 16 4
32 30
crossbar weigths : torch.Size([16, 4, 32, 32])
crossbar inputs : torch.Size([1, 16, 32])
torch.Size([1, 4])


0
0 30
torch.Size([1, 32])
30


RuntimeError: The size of tensor a (4) must match the size of tensor b (30) at non-singleton dimension 1

In [224]:
ref_fc

tensor([[ 18.,  -2.,   8., -32., -10., -18.,   6., -28., -10., -10.,  10.,   0.,
         -34.,  30.,  12.,  24., -22., -12.,  -4.,   8.,  12.,  -2.,   2., -32.,
          -8.,  24., -10.,  -4.,   4.,   6., -16.,   4.,  24., -26.,   0., -18.,
          -2.,   8.,   4.,  -8.,   6., -28., -22.,  -8., -16.,  -8., -18.,  20.,
         -16.,  -4.,   6.,  18., -12., -16.,  -2.,  14.,  -8.,  10., -14., -10.,
          -4.,  12., -20.,   2., -18.,   8.,   2., -32.,  -6.,   6., -12., -14.,
          -2.,  -6.,  10.,  -6.,   4.,   4.,  -4., -16.,  14.,   4.,  10.,   0.,
           6., -18.,  -8.,   0.,   4.,  -4., -14.,  16., -14.,  26.,  32., -14.,
         -24.,  18.,  -2.,  14., -30., -14.,   6.,  42.,   6.,  14.,   2.,   2.,
          -8.,  18.,   8.,   6.,  30., -18.,  22., -18., -20.,  -8.,   8.,  38.]])

In [225]:
fc_output - ref_fc

tensor([[ 32.,  30.,  32.,  24.,  26.,  16.,  14.,   4.,   4.,   2.,   0.,   0.,
           0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,  -2.,   0.,  -2.,   0.,
           0.,  -2.,  -2.,  -2.,  -4.,  -6.,  32.,  32.,  32.,  30.,  24.,  26.,
          16.,  12.,   4.,   0.,   2.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
           0.,   0.,  -2.,   0.,  -2.,   0.,  -6.,  -8.,  -4.,  -2.,   0.,  -4.,
          32.,  32.,  32.,  32.,  26.,  30.,  14.,  14.,   2.,   4.,   2.,   0.,
           0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,  -2.,  -2.,  -2.,
          -2.,  -2.,  -4.,  -6.,  -4.,  -4.,  32.,  30.,  28.,  32.,  32.,  18.,
          18.,   8.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
           0.,   0.,   0.,  -2.,  -8.,  -4.,  -6.,  -4.,  -2.,  -6.,  -6., -12.]])