In [1045]:

from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
import numpy as np
import os
import pandas as pd
from time import time
import matplotlib.pyplot as plt
from collections import defaultdict
from models.binarized_modules import binarized
# from binarized_modules import  BinarizeLinear,BinarizeConv2d

In [1046]:
import sys
sys.path.append('/home/earapidis/Fast-Crossbar-Sim/python')
from crossbar import VectorSim, ParallelSim, _task
from tqdm import tqdm

In [1047]:
cuda = False
# cuda = True

In [1048]:
kwargs = {'num_workers': 1, 'pin_memory': True} if cuda else {}

In [1049]:
batch_size = 64
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=batch_size, shuffle=True, **kwargs)

In [1050]:
test_batch_size=1000

In [1051]:
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('data', train=False, transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=test_batch_size, shuffle=False, **kwargs)

In [1052]:
# from mnist_bnn import Net
from models.lenet_5 import BinarizedLeNet5_BN as Net

model = Net()
if cuda:
    torch.cuda.set_device(0)
    model.cuda()


In [1053]:
# model_path = os.path.join(models_path,f"epoch_7.pth")
model_idx = 1
models_path = os.path.abspath(f"/home/earapidis/BinarizedNN/saved_models/lenet_5/model_{model_idx}")
model_path = os.path.join(models_path,f"epoch_15.pth")
# model_path = os.path.join(models_path,f"best.pth")
model = Net()
model.load_state_dict(torch.load(model_path))
if cuda:
    torch.cuda.set_device(0)
    model.cuda()

In [1054]:
model

BinarizedLeNet5_BN(
  (conv1): BinarizeConv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (bn1): BatchNorm2d(6, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (htanh1): Hardtanh(min_val=-1.0, max_val=1.0)
  (pool1): AvgPool2d(kernel_size=2, stride=2, padding=0)
  (conv2): BinarizeConv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (htanh2): Hardtanh(min_val=-1.0, max_val=1.0)
  (pool2): AvgPool2d(kernel_size=2, stride=2, padding=0)
  (fc1): BinarizeLinear(in_features=256, out_features=120, bias=True)
  (bn_fc1): BatchNorm1d(120, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (htanh3): Hardtanh(min_val=-1.0, max_val=1.0)
  (fc2): BinarizeLinear(in_features=120, out_features=84, bias=True)
  (bn_fc2): BatchNorm1d(84, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (htanh4): Hardtanh(min_val=-1.0, max_val=1.0)
  (fc3): BinarizeLinear(in_features=84

In [1055]:
# model.conv2.weight

In [1056]:
CIN = 1
COUT  = 6
Hi = 28
Wi = 28
Kh = 5
Kw = 5
padding = 0 
N = 1
Num_rows = 32
Num_Columns = 32                    

filters = model.conv1.weight.data

# filters = torch.randn(COUT, CIN, Kh, Kw)

inputs = torch.randn(N, CIN, Hi, Wi)
inputs = binarized(inputs)
filters_b = binarized(filters)

In [1057]:
filters_b.shape

torch.Size([6, 1, 5, 5])

In [1058]:
# import torch
# import torch.nn.functional as F

# def conv2d_loops(x, w, padding=0):
#     N, Cin, H, W     = x.shape
#     Cout, _, Kh, Kw  = w.shape
#     Hout = H + 2*padding - Kh + 1
#     Wout = W + 2*padding - Kw + 1

#     # Zero-pad input
#     x_p = torch.zeros((N, Cin, H + 2*padding, W + 2*padding))
#     x_p[:, :, padding:padding+H, padding:padding+W] = x

#     y = torch.zeros((N, Cout, Hout, Wout))
#     for n in range(N):
#         for co in range(Cout):
#             for i in range(Hout):
#                 for j in range(Wout):
#                     acc = 0.0
#                     for ci in range(Cin):
#                         acc += torch.sum(x_p[n, ci, i:i+Kh, j:j+Kw] * w[co, ci])
#                     y[n, co, i, j] = acc
#     return y

# ref = F.conv2d(inputs, filters, padding=padding)
# out_loops  = conv2d_loops(inputs, filters, padding)
# print("Reference shape:", ref.shape)
# print("Max abs diff (loops):",  (ref - out_loops).abs().max().item())


In [1059]:
ref = F.conv2d(inputs, filters_b, padding=padding)
padding

0

In [1060]:
import numpy as np

def checkerboard_last_cols(arr: np.ndarray, C: int) -> None:
    """
    Overwrite the last C columns of `arr` in-place with a checkerboard pattern of 0s and 1s.
    
    Parameters
    ----------
    arr : np.ndarray
        Input array of shape (n, m).
    C : int
        Number of columns at the right edge to turn into a checkerboard.
    """
    n, m = arr.shape
    if C > m:
        raise ValueError("C cannot be larger than the number of columns m.")
    
    # row indices: shape (n, 1)
    rows = np.arange(n)[:, None]
    # column indices of the last C columns: shape (C,)
    cols = np.arange(m - C, m)
    
    # Checkerboard pattern: (row + col) mod 2
    # This produces 0/1 alternating in both directions.
    pattern = (rows + cols) % 2
    
    # Write it back into the last C columns
    arr[:, -C:] = pattern

# Example
A = np.ones((6, 8), dtype=int)
checkerboard_last_cols(A, C=5)
print(A)


[[1 1 1 1 0 1 0 1]
 [1 1 1 0 1 0 1 0]
 [1 1 1 1 0 1 0 1]
 [1 1 1 0 1 0 1 0]
 [1 1 1 1 0 1 0 1]
 [1 1 1 0 1 0 1 0]]


In [1061]:
# import numpy as np
# import torch
# from concurrent.futures import ProcessPoolExecutor
# from tqdm import tqdm

# def checkerboard_last_cols(arr: np.ndarray, C: int) -> None:
#     """
#     Overwrite the last C columns of `arr` in-place with a checkerboard pattern of 0s and 1s.
#     """
#     n, m = arr.shape
#     if C > m:
#         raise ValueError("C cannot be larger than the number of columns m.")
#     rows = np.arange(n)[:, None]
#     cols = np.arange(m - C, m)
#     pattern = (rows + cols) % 2
#     arr[:, -C:] = pattern

# def _process_one_pixel(args):
#     """
#     Compute the conv2d_tiles result for a single (n,i,j) position.
#     Returns (n, i, j, output_vector).
#     """
#     (n, i, j,
#      input_vec,
#      crossbar_weights,
#      Cout,
#      num_tiles_columns,
#      Num_rows,
#      Num_Columns,
#      mode,
#      checkboard) = args

#     inp = input_vec[n, i, j, :, :]  # shape (num_tiles_rows, Num_rows)
#     full_out = np.zeros(num_tiles_columns * Num_Columns, dtype=float)

#     for col_idx in range(num_tiles_columns):
#         weight_tiles = crossbar_weights[:, col_idx, :, :]  # (num_tiles_rows, Num_rows, Num_Columns)
#         accum = np.zeros((weight_tiles.shape[0], Num_Columns), dtype=float)

#         for t_idx, vec in enumerate(inp):
#             W = weight_tiles[t_idx]
#             if checkboard and col_idx == num_tiles_columns - 1:
#                 checkerboard_last_cols(W, Num_Columns - Cout)
#             _, out_vec = _task((t_idx, vec), W, Num_rows, Num_Columns, mode, False)
#             accum[t_idx, :] = out_vec

#         summed = accum.sum(axis=0)
#         start = col_idx * Num_Columns
#         full_out[start:start + Num_Columns] = summed

#     return n, i, j, full_out[:Cout]

# def conv2d_tiles(x, w, Num_rows, Num_Columns, padding=0, mode="gs", checkboard=False):
#     N, Cin, H, W     = x.shape
#     Cout, _, Kh, Kw  = w.shape
#     Hout = H + 2*padding - Kh + 1
#     Wout = W + 2*padding - Kw + 1

#     # Zero-pad input
#     x_p = torch.zeros((N, Cin, H + 2*padding, W + 2*padding), dtype=x.dtype, device=x.device)
#     x_p[:, :, padding:padding+H, padding:padding+W] = x

#     # Build crossbar_weights
#     kernel_size = Kh * Kw
#     cin_per_cross = Num_rows // kernel_size
#     num_tiles_rows = int(np.ceil((kernel_size * Cin) / (cin_per_cross * kernel_size)))
#     num_tiles_columns = int(np.ceil(Cout / Num_Columns))

#     crossbar_weights = np.zeros((num_tiles_rows, num_tiles_columns, Num_rows, Num_Columns), dtype=float)
#     for co in range(Cout):
#         tile_j = co // Num_Columns
#         col_idx = co % Num_Columns
#         for ci in range(Cin):
#             tile_i = ci // cin_per_cross
#             id_mod = ci % cin_per_cross
#             start = id_mod * kernel_size
#             end   = start + kernel_size
#             flat_w = w[co, ci].view(-1).cpu().numpy()
#             crossbar_weights[tile_i, tile_j, start:end, col_idx] = flat_w

#     # Build input_vec
#     input_vec = np.zeros((N, Hout, Wout, num_tiles_rows, Num_rows), dtype=float)
#     for n in range(N):
#         for ci in range(Cin):
#             tile_i = ci // cin_per_cross
#             id_mod = ci % cin_per_cross
#             start = id_mod * kernel_size
#             end   = start + kernel_size
#             for i in range(Hout):
#                 for j in range(Wout):
#                     patch = x_p[n, ci, i:i+Kh, j:j+Kw].contiguous().view(-1).cpu().numpy()
#                     input_vec[n, i, j, tile_i, start:end] = patch

#     # Allocate output
#     output = np.zeros((N, Cout, Hout, Wout), dtype=float)

#     # Parallelize over (i,j) for each batch n
#     for n in range(N):
#         tasks = [
#             (
#                 n, i, j,
#                 input_vec,
#                 crossbar_weights,
#                 Cout,
#                 num_tiles_columns,
#                 Num_rows,
#                 Num_Columns,
#                 mode,
#                 checkboard
#             )
#             for i in range(Hout) for j in range(Wout)
#         ]

#         with ProcessPoolExecutor() as executor:
#             for n_ret, i_ret, j_ret, vec_out in executor.map(_process_one_pixel, tasks):
#                 output[n_ret, :, i_ret, j_ret] = vec_out

#     return output


In [1062]:
import numpy as np
import torch
from concurrent.futures import ProcessPoolExecutor, as_completed
from tqdm import tqdm

def checkerboard_last_cols(arr: np.ndarray, C: int) -> None:
    """
    Overwrite the last C columns of `arr` in-place with a checkerboard pattern of 0s and 1s.
    """
    n, m = arr.shape
    if C > m:
        raise ValueError("C cannot be larger than the number of columns m.")
    rows = np.arange(n)[:, None]
    cols = np.arange(m - C, m)
    pattern = (rows + cols) % 2
    arr[:, -C:] = pattern


def _process_one_pixel(args):
    """
    Compute the conv2d_tiles result for a single (n,i,j) position.
    Returns (n, i, j, output_vector).
    """
    (n, i, j,
     input_vec,
     crossbar_weights,
     Cout,
     num_tiles_columns,
     Num_rows,
     Num_Columns,
     mode,
     checkboard) = args

    inp = input_vec[n, i, j, :, :]  # (num_tiles_rows, Num_rows)
    full_out = np.zeros(num_tiles_columns * Num_Columns, dtype=float)

    for col_idx in range(num_tiles_columns):
        weight_tiles = crossbar_weights[:, col_idx, :, :]  # (num_tiles_rows, Num_rows, Num_Columns)
        accum = np.zeros((weight_tiles.shape[0], Num_Columns), dtype=float)

        for t_idx, vec in enumerate(inp):
            W = weight_tiles[t_idx]
            if checkboard and col_idx == num_tiles_columns - 1:
                checkerboard_last_cols(W, Num_Columns - Cout)
            _, out_vec = _task((t_idx, vec), W, Num_rows, Num_Columns, mode, False)
            accum[t_idx, :] = out_vec

        summed = accum.sum(axis=0)
        start = col_idx * Num_Columns
        full_out[start:start + Num_Columns] = summed

    return n, i, j, full_out[:Cout]


def conv2d_tiles(x, w, Num_rows, Num_Columns, padding=0, mode="gs", checkboard=False,workers=8):
    N, Cin, H, W     = x.shape
    Cout, _, Kh, Kw  = w.shape
    # print("Input shape:", x.shape)
    # print("weight shape:", w.shape)
    # print(H, W, Kh, Kw)
    # print(Cout, Cin, N)
    # print(padding)
    Hout = H + 2*padding - Kh + 1
    Wout = W + 2*padding - Kw + 1

    # Zero-pad input
    x_p = torch.zeros((N, Cin, H + 2*padding, W + 2*padding), dtype=x.dtype, device=x.device)
    x_p[:, :, padding:padding+H, padding:padding+W] = x

    # Build crossbar_weights
    kernel_size = Kh * Kw
    cin_per_cross = Num_rows // kernel_size
    num_tiles_rows = int(np.ceil((kernel_size * Cin) / (cin_per_cross * kernel_size)))
    num_tiles_columns = int(np.ceil(Cout / Num_Columns))

    crossbar_weights = np.zeros((num_tiles_rows, num_tiles_columns, Num_rows, Num_Columns), dtype=float)
    for co in range(Cout):
        tile_j = co // Num_Columns
        col_idx = co % Num_Columns
        for ci in range(Cin):
            tile_i = ci // cin_per_cross
            id_mod = ci % cin_per_cross
            start = id_mod * kernel_size
            end   = start + kernel_size
            # flat_w = w[co, ci].view(-1).detach()
            flat_w = w[co, ci].view(-1).detach().numpy()
            crossbar_weights[tile_i, tile_j, start:end, col_idx] = flat_w

    # Build input_vec
    input_vec = np.zeros((N, Hout, Wout, num_tiles_rows, Num_rows), dtype=float)
    for n in range(N):
        for ci in range(Cin):
            tile_i = ci // cin_per_cross
            id_mod = ci % cin_per_cross
            start = id_mod * kernel_size
            end   = start + kernel_size
            for i in range(Hout):
                for j in range(Wout):
                    # patch = x_p[n, ci, i:i+Kh, j:j+Kw].contiguous().view(-1).detach()
                    patch = x_p[n, ci, i:i+Kh, j:j+Kw].contiguous().view(-1).detach().numpy()
                    input_vec[n, i, j, tile_i, start:end] = patch

    # Allocate output
    output = np.zeros((N, Cout, Hout, Wout), dtype=float)

    # Parallelize over (i,j) for each batch n using futures and as_completed
    for n in range(N):
        args_list = [
            (
                n, i, j,
                input_vec,
                crossbar_weights,
                Cout,
                num_tiles_columns,
                Num_rows,
                Num_Columns,
                mode,
                checkboard
            )
            for i in range(Hout) for j in range(Wout)
        ]
        disable_loggin = True
        with ProcessPoolExecutor(max_workers=workers) as executor:
            futures = [executor.submit(_process_one_pixel, args) for args in args_list]
            for future in tqdm(as_completed(futures), total=len(futures),disable=disable_loggin):
                n_ret, i_ret, j_ret, vec_out = future.result()
                output[n_ret, :, i_ret, j_ret] = vec_out
    output = torch.from_numpy(output)
    return output


In [1063]:
# def conv2d_tiles(x, w, Num_rows, Num_Columns, padding=0,mode="gs",checkboard=False):
#     N, Cin, H, W     = x.shape
#     Cout, _, Kh, Kw  = w.shape
#     Hout = H + 2*padding - Kh + 1
#     Wout = W + 2*padding - Kw + 1



#     # Zero-pad input
#     x_p = torch.zeros((N, Cin, H + 2*padding, W + 2*padding))
#     x_p[:, :, padding:padding+H, padding:padding+W] = x

#     kernel_size = Kh * Kw
#     cin_per_cross = Num_rows // kernel_size 
#     num_tiles_rows = int(np.ceil((Kw * Kh*Cin)/ (cin_per_cross*kernel_size)))
#     num_tiles_columns = int(np.ceil(Cout / Num_Columns))

#     crossbar_weights = np.zeros((num_tiles_rows,num_tiles_columns,Num_rows, Num_Columns))
#     tile_i_idx = 0
#     tile_j_idx = 0
#     # print(num_tiles_rows)
#     # print(crossbar_weights.shape)
#     for co in range(Cout):
#         tile_j_idx = co // Num_Columns
#         for ci in range(Cin):
#             id = (ci%cin_per_cross)
#             tile_row_start = id*kernel_size
#             tile_row_end = (id+1)*kernel_size
#             tile_i_idx = ci // cin_per_cross
#             flat_w = w[co, ci].view(-1).numpy()
#             # print(f"co: {co}, ci: {ci}, tile_j_idx: {tile_j_idx}, tile_i_idx: {tile_i_idx}, tile_row_start: {tile_row_start}, tile_row_end: {tile_row_end}")
#             column = co % Num_Columns
#             crossbar_weights[tile_i_idx, tile_j_idx, tile_row_start:tile_row_end, column] = flat_w
    
#     input_vec = np.zeros((N,Hout,Wout,num_tiles_rows,Num_rows))
#     # print(input_vec.shape)
#     for n in range(N):
#         for ci in range(Cin):
#             for i in range(Hout):
#                 for j in range(Wout):
#                     id = (ci%cin_per_cross)
#                     tile_row_start = id*kernel_size
#                     tile_row_end = (id+1)*kernel_size
#                     tile_i_idx = ci // cin_per_cross
#                     flat_input = torch.flatten(x_p[n, ci, i:i+Kh, j:j+Kw]).numpy()
#                     # print(flat_input.shape)
#                     input_vec[n,i,j,tile_i_idx, tile_row_start:tile_row_end] = flat_input

#     output = np.zeros((N,Cout,Hout,Wout))
#     # cim =  VectorSim(Num_rows,Num_Columns,mode=mode)
#     # cim =  VectorSim(Num_rows,Num_Columns,mode="cs")

#     for n in range(N):
#         for i in tqdm(range(Hout)):
#             for j in range(Wout):
#                 cout = np.zeros((num_tiles_columns * Num_Columns))
#                 for col_idx in range(num_tiles_columns):
#                     inp = input_vec[n,i,j,:,:]
#                     # print(inp.shape)
#                     w = crossbar_weights[:,col_idx,:,:]
#                     # w = w.reshape(num_tiles_rows, Num_rows, Num_Columns)
#                     # w = crossbar_weights.reshape(num_tiles_rows, Num_rows, Num_Columns)
#                     # print(w.shape)

#                     intermidiate_out = np.zeros((num_tiles_rows, Num_Columns))
#                     for idx, vec in enumerate(inp):
#                         weight = w[idx,:,:]
#                         if checkboard and col_idx == num_tiles_columns - 1:
#                             checkerboard_last_cols(weight, Num_Columns-Cout)
#                         # print(weight)
#                         _,out = _task((idx,vec),weight,Num_rows,Num_Columns,mode,False)
#                         # cim.set_weights(weight)
#                         # out = cim.run_vector(vec)
#                         # # out = np.dot(vec, w[idx,:,:])
#                         intermidiate_out[idx,:] = out
#                     # print(intermidiate_out.shape)
#                     cout_outs = np.sum(intermidiate_out, axis=0)
#                     cout[col_idx*Num_Columns:(col_idx+1)*Num_Columns] = cout_outs
#                     # if num_tiles_columns==1:
#                     #     cout = cout_outs[:Cout]
#                     # else
#                 final_cout = np.ravel(cout[:Cout])
#                 output[n,:,i,j] = final_cout
#     return output
# # out_loops  = conv2d_tiles(inputs, filters,Num_rows,Num_Columns, padding)
# # print(np.array_equal(out_loops, ref.numpy()))
# # print(out_loops-ref.numpy())
# # print("Output shape:", out_loops[0].shape, out_loops[1].shape)

In [1064]:
# crossbar_weights

In [1065]:
def alt_compliment(x):
    Vp = (x == 1).int()
    Vm = (x == -1).int()
    return Vp, Vm

In [1066]:
inputs

tensor([[[[-1., -1., -1., -1., -1., -1.,  1., -1.,  1., -1., -1., -1.,  1.,  1.,
            1.,  1., -1.,  1., -1.,  1.,  1., -1., -1.,  1.,  1., -1.,  1., -1.],
          [ 1.,  1., -1.,  1., -1., -1.,  1.,  1., -1.,  1.,  1., -1.,  1., -1.,
            1.,  1., -1., -1., -1., -1., -1., -1.,  1., -1., -1., -1.,  1., -1.],
          [ 1., -1.,  1.,  1., -1., -1., -1.,  1., -1.,  1., -1., -1., -1.,  1.,
           -1., -1., -1.,  1.,  1., -1.,  1.,  1.,  1.,  1.,  1., -1.,  1.,  1.],
          [-1., -1.,  1., -1., -1., -1.,  1., -1., -1., -1., -1., -1.,  1.,  1.,
           -1., -1.,  1., -1., -1., -1., -1.,  1.,  1., -1.,  1., -1.,  1., -1.],
          [-1.,  1.,  1., -1., -1., -1.,  1., -1., -1., -1., -1., -1., -1., -1.,
           -1.,  1., -1.,  1.,  1., -1.,  1.,  1., -1.,  1.,  1.,  1., -1.,  1.],
          [ 1., -1., -1.,  1., -1., -1.,  1., -1.,  1.,  1.,  1., -1., -1.,  1.,
           -1., -1., -1.,  1., -1.,  1.,  1., -1., -1.,  1.,  1.,  1., -1., -1.],
          [-1.,  1.,  

In [1067]:
padding=1

In [1068]:
a=nn.functional.pad(inputs, (padding, padding, padding, padding), mode='constant', value=1)

In [1069]:
def compliment(x):
    x = x.clone()
    neg = -1*x
    pos = x

    pos[pos==-1] = 0
    neg[neg==-1] = 0
    return pos, neg

# pos_inputs, neg_inputs = alt_compliment(inputs)
# pos_filters, neg_filters = alt_compliment(filters)
pos_inputs, neg_inputs = compliment(inputs)
pos_filters, neg_filters = compliment(filters_b)
# padding=1
# pos_inputs = nn.functional.pad(pos_inputs, (padding, padding, padding, padding), mode='constant', value=1)
# neg_inputs = nn.functional.pad(neg_inputs, (padding, padding, padding, padding), mode='constant', value=1)
padding=0
pos_ref = F.conv2d(pos_inputs, pos_filters, padding=padding)
neg_ref = F.conv2d(neg_inputs, neg_filters, padding=padding)
pos_ref.shape

torch.Size([1, 6, 24, 24])

In [1070]:
print(inputs)
print(pos_inputs)
print(neg_inputs)

tensor([[[[-1., -1., -1., -1., -1., -1.,  1., -1.,  1., -1., -1., -1.,  1.,  1.,
            1.,  1., -1.,  1., -1.,  1.,  1., -1., -1.,  1.,  1., -1.,  1., -1.],
          [ 1.,  1., -1.,  1., -1., -1.,  1.,  1., -1.,  1.,  1., -1.,  1., -1.,
            1.,  1., -1., -1., -1., -1., -1., -1.,  1., -1., -1., -1.,  1., -1.],
          [ 1., -1.,  1.,  1., -1., -1., -1.,  1., -1.,  1., -1., -1., -1.,  1.,
           -1., -1., -1.,  1.,  1., -1.,  1.,  1.,  1.,  1.,  1., -1.,  1.,  1.],
          [-1., -1.,  1., -1., -1., -1.,  1., -1., -1., -1., -1., -1.,  1.,  1.,
           -1., -1.,  1., -1., -1., -1., -1.,  1.,  1., -1.,  1., -1.,  1., -1.],
          [-1.,  1.,  1., -1., -1., -1.,  1., -1., -1., -1., -1., -1., -1., -1.,
           -1.,  1., -1.,  1.,  1., -1.,  1.,  1., -1.,  1.,  1.,  1., -1.,  1.],
          [ 1., -1., -1.,  1., -1., -1.,  1., -1.,  1.,  1.,  1., -1., -1.,  1.,
           -1., -1., -1.,  1., -1.,  1.,  1., -1., -1.,  1.,  1.,  1., -1., -1.],
          [-1.,  1.,  

In [1071]:
ref = nn.functional.conv2d(inputs, filters_b, padding=0)
ref

tensor([[[[ -3.,  -3.,   1.,  ...,   3.,   1.,   3.],
          [ -7.,  -5.,  -1.,  ...,   1.,   1.,   7.],
          [ -3.,  -3.,   3.,  ...,  -5.,  -5.,  -9.],
          ...,
          [  1.,   1.,   5.,  ...,   1.,  -3.,  -1.],
          [  5.,   7.,   3.,  ...,  -5.,  -7.,  -1.],
          [  3.,   3.,   1.,  ...,  -3.,  -1.,   3.]],

         [[  1.,  -3.,   1.,  ...,   3.,  -3.,  -1.],
          [ -7.,  -1.,  -5.,  ...,  -3.,   1.,  -5.],
          [  1.,  -3.,  -9.,  ...,   3.,   3.,   3.],
          ...,
          [ -7.,   5.,   1.,  ...,   9.,   5.,  -1.],
          [ -3.,  -5.,   3.,  ...,  11.,   5.,   3.],
          [ -1.,   3.,   5.,  ...,   1.,  -5.,   3.]],

         [[ -5.,  -1.,   3.,  ...,   5.,   3.,   1.],
          [ -1.,   1.,   5.,  ...,  -5.,  -9.,  -3.],
          [  7.,   7.,  -3.,  ...,  -7.,  -7.,  -3.],
          ...,
          [  3.,  -1.,  -1.,  ...,  -5.,  -1.,   1.],
          [  3.,   1.,   1.,  ...,  -3.,  -9.,  -3.],
          [  1.,  -3.,   3.,  ...

In [1072]:
I  = pos_ref + neg_ref
out = 2*I - Kh*Kw*CIN
out

tensor([[[[ -3.,  -3.,   1.,  ...,   3.,   1.,   3.],
          [ -7.,  -5.,  -1.,  ...,   1.,   1.,   7.],
          [ -3.,  -3.,   3.,  ...,  -5.,  -5.,  -9.],
          ...,
          [  1.,   1.,   5.,  ...,   1.,  -3.,  -1.],
          [  5.,   7.,   3.,  ...,  -5.,  -7.,  -1.],
          [  3.,   3.,   1.,  ...,  -3.,  -1.,   3.]],

         [[  1.,  -3.,   1.,  ...,   3.,  -3.,  -1.],
          [ -7.,  -1.,  -5.,  ...,  -3.,   1.,  -5.],
          [  1.,  -3.,  -9.,  ...,   3.,   3.,   3.],
          ...,
          [ -7.,   5.,   1.,  ...,   9.,   5.,  -1.],
          [ -3.,  -5.,   3.,  ...,  11.,   5.,   3.],
          [ -1.,   3.,   5.,  ...,   1.,  -5.,   3.]],

         [[ -5.,  -1.,   3.,  ...,   5.,   3.,   1.],
          [ -1.,   1.,   5.,  ...,  -5.,  -9.,  -3.],
          [  7.,   7.,  -3.,  ...,  -7.,  -7.,  -3.],
          ...,
          [  3.,  -1.,  -1.,  ...,  -5.,  -1.,   1.],
          [  3.,   1.,   1.,  ...,  -3.,  -9.,  -3.],
          [  1.,  -3.,   3.,  ...

In [1073]:
def get_output(pos_ref, neg_ref,Kh, Kw, CIN):
    I = pos_ref + neg_ref
    out = 2*I - Kh*Kw*CIN
    return out

In [1074]:
def conv_inferenece(x,w, Num_rows, Num_Columns, padding=0, mode="gs", checkboard=False, workers=8):
    N, Cin, H, W     = x.shape
    Cout, _, Kh, Kw  = w.shape
    pos_inputs, neg_inputs = compliment(x)
    pos_filters, neg_filters = compliment(w)
    pos_cim = conv2d_tiles(pos_inputs,pos_filters,Num_rows,Num_Columns,padding=padding,mode=mode,checkboard=checkboard,workers=workers)
    neg_cim = conv2d_tiles(neg_inputs,neg_filters,Num_rows,Num_Columns,padding=padding,mode=mode,checkboard=checkboard,workers=workers)
    output = get_output(pos_cim, neg_cim, Kh, Kw, Cin)
    return output

In [1075]:
# out_cim_gs = conv_inferenece(inputs,filters_b, Num_rows, Num_Columns, padding=padding, mode="gs", checkboard=False, workers=8)
# out_cim_cs = conv_inferenece(inputs,filters_b, Num_rows, Num_Columns, padding=padding, mode="cs", checkboard=False, workers=8)
# out_cim_gs_check = conv_inferenece(inputs,filters_b, Num_rows, Num_Columns, padding=padding, mode="gs", checkboard=True, workers=8)
# out_cim_cs_check = conv_inferenece(inputs,filters_b, Num_rows, Num_Columns, padding=padding, mode="cs", checkboard=True, workers=8)

In [1076]:
# mode = "gs"
# pos_ref_cim_gs = conv2d_tiles(pos_inputs,pos_filters,Num_rows,Num_Columns,mode=mode)
# neg_ref_cim_gs = conv2d_tiles(neg_inputs,neg_filters,Num_rows,Num_Columns,mode=mode)


In [1077]:
# mode = "cs"
# pos_ref_cim_cs = conv2d_tiles(pos_inputs,pos_filters,Num_rows,Num_Columns,mode=mode)
# neg_ref_cim_cs = conv2d_tiles(neg_inputs,neg_filters,Num_rows,Num_Columns,mode=mode)


In [1078]:
# mode = "gs"
# pos_ref_cim_gs_check = conv2d_tiles(pos_inputs,pos_filters,Num_rows,Num_Columns,mode=mode,checkboard=True)
# neg_ref_cim_gs_check = conv2d_tiles(neg_inputs,neg_filters,Num_rows,Num_Columns,mode=mode,checkboard=True)


In [1079]:
# mode = "cs"
# pos_ref_cim_cs_check = conv2d_tiles(pos_inputs,pos_filters,Num_rows,Num_Columns,mode=mode,checkboard=True)
# neg_ref_cim_cs_check = conv2d_tiles(neg_inputs,neg_filters,Num_rows,Num_Columns,mode=mode,checkboard=True)


In [1080]:
ref = ref.numpy()

In [1081]:
# out_cim_gs = get_output(pos_ref_cim_gs, neg_ref_cim_gs,Kh, Kw, CIN)
# out_cim_cs = get_output(pos_ref_cim_cs, neg_ref_cim_cs,Kh, Kw, CIN)
# out_cim_gs_check = get_output(pos_ref_cim_gs_check, neg_ref_cim_gs_check,Kh, Kw, CIN)
# out_cim_cs_check = get_output(pos_ref_cim_cs_check, neg_ref_cim_cs_check,Kh, Kw, CIN)

In [1082]:
# out_cim_gs = pos_ref_cim_gs+neg_ref_cim_gs
# out_cim_gs = 2*out_cim_gs - Kh*Kw*CIN
# out_cim_gs

In [1083]:

# out_cim_cs = pos_ref_cim_cs+neg_ref_cim_cs
# out_cim_cs = 2*out_cim_cs - Kh*Kw*CIN
# out_cim_cs

In [1084]:
# np.array_equal(out,ref)

In [1085]:
# np.array_equal(out_cim,ref)


In [1086]:
ref

array([[[[ -3.,  -3.,   1., ...,   3.,   1.,   3.],
         [ -7.,  -5.,  -1., ...,   1.,   1.,   7.],
         [ -3.,  -3.,   3., ...,  -5.,  -5.,  -9.],
         ...,
         [  1.,   1.,   5., ...,   1.,  -3.,  -1.],
         [  5.,   7.,   3., ...,  -5.,  -7.,  -1.],
         [  3.,   3.,   1., ...,  -3.,  -1.,   3.]],

        [[  1.,  -3.,   1., ...,   3.,  -3.,  -1.],
         [ -7.,  -1.,  -5., ...,  -3.,   1.,  -5.],
         [  1.,  -3.,  -9., ...,   3.,   3.,   3.],
         ...,
         [ -7.,   5.,   1., ...,   9.,   5.,  -1.],
         [ -3.,  -5.,   3., ...,  11.,   5.,   3.],
         [ -1.,   3.,   5., ...,   1.,  -5.,   3.]],

        [[ -5.,  -1.,   3., ...,   5.,   3.,   1.],
         [ -1.,   1.,   5., ...,  -5.,  -9.,  -3.],
         [  7.,   7.,  -3., ...,  -7.,  -7.,  -3.],
         ...,
         [  3.,  -1.,  -1., ...,  -5.,  -1.,   1.],
         [  3.,   1.,   1., ...,  -3.,  -9.,  -3.],
         [  1.,  -3.,   3., ...,   7.,   1.,  -3.]],

        [[ -3., 

In [1087]:
# ref = ref.numpy()

In [1088]:
# out_cim_gs.shape

In [1089]:
# with np.printoptions(threshold=np.inf):
#     out_cim_gs = out_cim_gs.numpy()
#     diff =out_cim_gs-ref 
#     print(diff)
#     print(np.mean(np.abs(diff)))

#     # print(ref)

In [1090]:
# with np.printoptions(threshold=np.inf):
#     out_cim_cs = out_cim_cs.numpy()
#     diff =out_cim_cs-ref
#     print(diff)
#     # print(np.mean(diff))
#     print(np.mean(np.abs(diff)))

#     # print(ref)

In [1091]:
# with np.printoptions(threshold=np.inf):
#     out_cim_gs_check = out_cim_gs_check.numpy()
#     diff =out_cim_gs_check-ref
#     print(diff)
#     print(np.mean(np.abs(diff)))
#     # print(out_cim_gs_check-ref)
#     # print(ref)

In [1092]:
# with np.printoptions(threshold=np.inf):
#     out_cim_cs_check = out_cim_cs_check.numpy()
#     diff =out_cim_cs_check-ref
#     print(diff)
#     print(np.mean(np.abs(diff)))
#     # print(ref)

In [1093]:
class BinarizeConv2dInference(nn.Conv2d):
    def __init__(self,
                 in_channels, out_channels, kernel_size,
                 stride=1, padding=0, dilation=1, groups=1,
                 bias=True,
                 Num_rows=4, Num_Columns=4,
                 mode="gs", checkboard=False, workers=8):
        super().__init__(
            in_channels, out_channels, kernel_size,
            stride=stride, padding=padding,
            dilation=dilation, groups=groups, bias=bias
        )
        # parameters for custom tiled conv
        self.Num_rows    = Num_rows
        self.Num_Columns = Num_Columns
        self.mode        = mode
        self.checkboard  = checkboard
        self.workers     = workers

    def forward(self, input):
        # binarize inputs (but keep first 3-channel inputs full-precision)
        if input.size(1) != 3:
            input_b = binarized(input)
        else:
            input_b = input

        # binarize weights
        weight_b = binarized(self.weight)

        # use custom inference routine instead of F.conv2d
        padding = self.padding[0] if isinstance(self.padding, tuple) else self.padding
        out = conv_inferenece(
            input_b, weight_b,
            self.Num_rows, self.Num_Columns,
            padding=padding,
            mode=self.mode,
            checkboard=self.checkboard,
            workers=self.workers
        )

        # add bias if present
        if self.bias is not None:
            # store original bias for potential gradient updates, etc.
            self.bias.org = self.bias.data.clone()
            out = out + self.bias.view(1, -1, 1, 1).expand_as(out)

        return out

In [1094]:
model_conv = BinarizeConv2dInference(CIN, COUT, kernel_size=Kh,Num_rows=Num_rows, Num_Columns=Num_Columns,bias=False, mode="cs", checkboard=True, workers=8)



In [1095]:
model_conv.weight = nn.Parameter(model.conv1.weight)


In [1096]:
out = model_conv(inputs)

In [1097]:
type(out)

torch.Tensor

In [1098]:
np.mean(np.abs(out.detach().numpy() - ref))

np.float64(0.017361111111111112)

In [1099]:
# def one_tile(b,i,j,N,num_tiles_columns, Num_rows, Num_Columns,vec, weight, mode="gs", checkboard=False):
#     if checkboard and j == (num_tiles_columns - 1):
#         print(Num_Columns - N % Num_Columns)
#         checkerboard_last_cols(weight, Num_Columns - N % Num_Columns)
#     _, out_vec = _task(((i,j), vec), weight, Num_rows, Num_Columns, mode, False)
#     return ((b,i,j),out_vec)

# def linear(x,w, Num_rows, Num_Columns, mode="gs", checkboard=False, workers=8):
#     # print(x.shape, w.shape)
#     # single = x.ndim == 1
#     # if single:
#     #     x = x[None, :]
#     B, Μ = x.shape
#     M, N= w.shape
#     num_tiles_columns = int(np.ceil(N / Num_Columns)) 
#     num_tiles_rows = int(np.ceil(M / Num_rows))
#     print(num_tiles_rows,num_tiles_columns,  Num_rows, Num_Columns)
#     out = np.zeros((B, num_tiles_columns*Num_Columns), dtype=int)

#     crossbar_weights = np.zeros((num_tiles_rows, num_tiles_columns, Num_rows, Num_Columns), dtype=bool)
#     for i in range(num_tiles_rows):
#         for j in range(num_tiles_columns):
#             start_row = i * Num_rows
#             end_row = min(start_row + Num_rows, M)
#             start_col = j * Num_Columns
#             end_col = min(start_col + Num_Columns, N)
#             w_tile = w[start_row:end_row, start_col:end_col]
#             crossbar_weights[i, j, :w_tile.shape[0], :w_tile.shape[1]] = w_tile
    
#     input_vec = np.zeros((B, num_tiles_rows, Num_rows), dtype=bool)
#     for b in range(B):
#         for i in range(num_tiles_rows):
#             start_row = i * Num_rows
#             end_row = min(start_row + Num_rows, M)
#             # print(start_row, end_row)
#             input_vec[b, i, :end_row] = x[b, start_row:end_row]
    
#     for b in range(B):
#         for i in range(num_tiles_rows):
#             for j in range(num_tiles_columns):
#                 vec = input_vec[b, i, :]
#                 W = crossbar_weights[i, j, :, :]
#                 # if checkboard and j == num_tiles_columns - 1:
#                 #     # print(Num_Columns - N % Num_Columns)
#                 #     checkerboard_last_cols(W, Num_Columns - N % Num_Columns)
#                 _, out_vec = one_tile(b,i,j,N,num_tiles_columns, Num_rows, Num_Columns,vec, W, mode=mode, checkboard=checkboard)

#                 # _, out_vec = _task(((i,j), vec), W, Num_rows, Num_Columns, mode, False)
#                 out[b, j*Num_Columns:(j+1)*Num_Columns] += out_vec
#     output = torch.from_numpy(out)

#     return output[:, :N]





In [1100]:
import numpy as np
import torch
from concurrent.futures import ProcessPoolExecutor, as_completed

def _process_linear_tile(args):
    """
    Worker for one tile of the linear layer.
    Returns (b, j, out_vec) same shape as one_tile would produce.
    """
    b, i, j, N, num_tiles_columns, Num_rows, Num_Columns, vec, W, mode, checkboard = args
    if checkboard and j == (num_tiles_columns - 1):
        # apply checkerboard correction on the last column tile
        checkerboard_last_cols(W, Num_Columns - N % Num_Columns)
    _, out_vec = _task(((i, j), vec), W, Num_rows, Num_Columns, mode, False)
    return b, j, out_vec

def linear_parallel(x, w,
                    Num_rows, Num_Columns,
                    mode="gs", checkboard=False,
                    workers=8):
    """
    x: torch tensor or numpy array of shape (B, M)
    w: numpy array of shape (M, N)
    returns torch tensor of shape (B, N)
    """
    # ensure numpy
    x_np = x.detach().cpu().numpy() if isinstance(x, torch.Tensor) else x
    B, M = x_np.shape
    M2, N = w.shape
    assert M == M2, "Input/features mismatch"

    # how many tiles we need
    num_tiles_columns = int(np.ceil(N / Num_Columns))
    num_tiles_rows    = int(np.ceil(M / Num_rows))

    # build tiled weight array
    crossbar_weights = np.zeros((num_tiles_rows, num_tiles_columns, Num_rows, Num_Columns),
                                dtype=bool)
    for i in range(num_tiles_rows):
        for j in range(num_tiles_columns):
            r0, r1 = i*Num_rows, min((i+1)*Num_rows, M)
            c0, c1 = j*Num_Columns, min((j+1)*Num_Columns, N)
            w_tile = w[r0:r1, c0:c1]
            crossbar_weights[i, j, :w_tile.shape[0], :w_tile.shape[1]] = w_tile

    # build tiled input array
    input_vec = np.zeros((B, num_tiles_rows, Num_rows), dtype=bool)
    for b in range(B):
        for i in range(num_tiles_rows):
            r0, r1 = i*Num_rows, min((i+1)*Num_rows, M)
            input_vec[b, i, :r1-r0] = x_np[b, r0:r1]

    # prepare output accumulator (numpy)
    out = np.zeros((B, num_tiles_columns * Num_Columns), dtype=int)

    # pack all tile arguments
    tasks = []
    for b in range(B):
        for i in range(num_tiles_rows):
            for j in range(num_tiles_columns):
                vec = input_vec[b, i, :]
                W   = crossbar_weights[i, j, :, :]
                tasks.append((b, i, j, N, num_tiles_columns,
                              Num_rows, Num_Columns,
                              vec, W, mode, checkboard))

    # parallel execution
    with ProcessPoolExecutor(max_workers=workers) as exe:
        future_to_tile = {
            exe.submit(_process_linear_tile, args): args[:3]
            for args in tasks
        }
        for fut in as_completed(future_to_tile):
            b, j, out_vec = fut.result()
            start = j * Num_Columns
            end   = start + Num_Columns
            out[b, start:end] += out_vec

    # slice off any padding columns and wrap in torch.Tensor
    out = out[:, :N]
    return torch.from_numpy(out)


In [1101]:
shape = model.fc1.weight.shape

In [1102]:
shape

torch.Size([120, 256])

In [1103]:
# filters = torch.randn(COUT, CIN, Kh, Kw)
B = 1
fc_inputs = torch.randn(B, shape[1])
fc_inputs = binarized(fc_inputs)

fc_weights = torch.randn(shape[0], shape[1])
fc_weights = binarized(model.fc1.weight)

In [1104]:
nn.functional.linear(fc_inputs, fc_weights)

tensor([[-16., -20.,  14.,  18.,  -4., -24.,  32.,   6.,  16.,  12.,  16., -10.,
         -24.,  -8.,  14.,  10.,  -8., -18.,  10., -18.,  -6., -20.,  20.,  22.,
          26.,  38.,   8.,  14.,   6.,  12., -10.,   6., -42., -16., -14.,  32.,
         -28.,  18., -14., -26.,   8., -10.,  24., -42., -14.,  18.,  12., -22.,
          30.,  -6.,  20., -12.,  14.,  26.,  16.,   8.,  -6.,  -4.,   8., -12.,
         -10.,   6., -22.,   4.,  12., -14.,  12.,   6.,   0.,   8.,   2.,   4.,
         -12.,   8., -40., -12.,  18., -10.,  22.,   6.,   0.,  -6.,   4.,  -6.,
          20.,  20., -18., -10.,  -6.,  -2.,  16., -14., -24.,  -4.,  -6., -16.,
         -10., -16.,  -8.,  -4.,  20.,  16., -16.,   0.,   4., -28.,  -4.,  16.,
           6.,  -8.,  22.,  20.,   8.,   0.,  16.,  16.,  26.,   2.,  -6.,  20.]],
       grad_fn=<MmBackward0>)

In [1105]:
def get_fc_output(pos, neg, M):
    I = pos + neg
    out = 2*I - M
    return out

In [1106]:
fc_pos_inputs, fc_neg_inputs = compliment(fc_inputs)
fc_pos_filters, fc_neg_filters = compliment(fc_weights)

pos_rf_fc = nn.functional.linear(fc_pos_inputs, fc_pos_filters)
neg_rf_fc = nn.functional.linear(fc_neg_inputs, fc_neg_filters)

out_fc = get_fc_output(pos_rf_fc, neg_rf_fc, shape[1])
torch.equal(out_fc, nn.functional.linear(fc_inputs, fc_weights))

True

In [None]:
def linear_inferenece(x,w, Num_rows, Num_Columns, mode="gs", checkboard=False, workers=8):
    shape = w.shape
    pos_inputs, neg_inputs = compliment(x)
    pos_filters, neg_filters = compliment(w)
    pos_cim = linear_parallel(pos_inputs.numpy(),pos_filters.detach().numpy().T, Num_rows, Num_Columns, mode=mode, checkboard=checkboard, workers=workers)
    neg_cim = linear_parallel(neg_inputs.numpy(),neg_filters.detach().numpy().T, Num_rows, Num_Columns, mode=mode, checkboard=checkboard, workers=workers)
    # pos_cim = linear(pos_inputs.numpy(),pos_filters.detach().numpy().T, Num_rows, Num_Columns, mode=mode, checkboard=checkboard, workers=8)
    # neg_cim = linear(neg_inputs.numpy(),neg_filters.detach().numpy().T, Num_rows, Num_Columns, mode=mode, checkboard=checkboard, workers=8)
    output = get_fc_output(pos_cim, neg_cim, shape[1])
    return output

In [1108]:
out_cim_fc = linear_inferenece(fc_inputs, fc_weights, Num_rows, Num_Columns, mode="cs", checkboard=True, workers=8)

In [1109]:
# pos_cim = linear(fc_pos_inputs.numpy(),fc_pos_filters.detach().numpy().T, Num_rows, Num_Columns, mode="cs", checkboard=True, workers=8)
# neg_cim = linear(fc_neg_inputs.numpy(),fc_neg_filters.detach().numpy().T, Num_rows, Num_Columns, mode="cs", checkboard=True, workers=8)

# out_cim_fc = get_fc_output(pos_cim, neg_cim, shape[1])


In [1110]:
out_cim_fc

tensor([[-16, -20,  14,  18,  -4, -24,  32,   6,  14,  14,  16, -12, -24,  -8,
          14,   8,  -8, -18,   8, -18,  -6, -20,  20,  22,  26,  36,   8,  14,
           6,  14, -10,   8, -42, -16, -14,  32, -28,  18, -12, -26,   8, -10,
          24, -42, -14,  20,  14, -24,  30,  -6,  20, -12,  14,  26,  16,   8,
          -6,  -2,   8, -12, -10,   6, -20,   4,  12, -14,  12,   8,   0,   8,
           0,   4, -10,   8, -40, -12,  18, -10,  22,   6,   0,  -6,   4,  -6,
          20,  20, -18, -10,  -6,   0,  16, -14, -24,  -4,  -6, -14, -10, -16,
          -8,  -4,  20,  16, -14,   0,   4, -28,  -4,  16,   6,  -6,  22,  20,
           8,   0,  18,  16,  26,   2,  -6,  20]])

In [1111]:
out_cim_fc.shape

torch.Size([1, 120])

In [1112]:
out_cim_fc - out_fc.detach()

tensor([[ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0., -2.,  2.,  0., -2.,  0.,  0.,
          0., -2.,  0.,  0., -2.,  0.,  0.,  0.,  0.,  0.,  0., -2.,  0.,  0.,
          0.,  2.,  0.,  2.,  0.,  0.,  0.,  0.,  0.,  0.,  2.,  0.,  0.,  0.,
          0.,  0.,  0.,  2.,  2., -2.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
          0.,  2.,  0.,  0.,  0.,  0.,  2.,  0.,  0.,  0.,  0.,  2.,  0.,  0.,
         -2.,  0.,  2.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
          0.,  0.,  0.,  0.,  0.,  2.,  0.,  0.,  0.,  0.,  0.,  2.,  0.,  0.,
          0.,  0.,  0.,  0.,  2.,  0.,  0.,  0.,  0.,  0.,  0.,  2.,  0.,  0.,
          0.,  0.,  2.,  0.,  0.,  0.,  0.,  0.]])

In [None]:
class BinarizeLinearInference(nn.Linear):

    def __init__(self, in_features, out_features,Num_rows,Num_Columns,mode="gs",checkboard=False,workers=8, bias=True, device=None, dtype=None):
        # super(BinarizeLinear, self).__init__(*kargs, **kwargs)
        super().__init__(in_features, out_features, bias=bias, device=None, dtype=None)
        self.Num_rows    = Num_rows
        self.Num_Columns = Num_Columns
        self.mode        = mode
        self.checkboard  = checkboard
        self.workers     = workers
    def forward(self, input):
        # print(input.size(1))

        # if input.size(1) != 784:
        input_b=binarized(input)
        weight_b=binarized(self.weight)
        out = linear_inferenece(input_b, weight_b, self.Num_rows, self.Num_Columns, mode=self.mode, checkboard=self.checkboard, workers=self.workers)
        # out = nn.functional.linear(input_b,weight_b)6
        if not self.bias is None:
            self.bias.org=self.bias.data.clone()
            out += self.bias.view(1, -1).expand_as(out)
        # print(out)

        return out

In [1126]:
model_fc = BinarizeLinearInference(shape[1], shape[0],bias=False, Num_rows=Num_rows, Num_Columns=Num_Columns, mode="cs", checkboard=True, workers=8)

model_fc.weight = nn.Parameter(model.fc1.weight)

out_cim_fc = model_fc(fc_inputs)

In [1127]:
out_cim_fc - out_fc.detach()

tensor([[ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0., -2.,  2.,  0., -2.,  0.,  0.,
          0., -2.,  0.,  0., -2.,  0.,  0.,  0.,  0.,  0.,  0., -2.,  0.,  0.,
          0.,  2.,  0.,  2.,  0.,  0.,  0.,  0.,  0.,  0.,  2.,  0.,  0.,  0.,
          0.,  0.,  0.,  2.,  2., -2.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
          0.,  2.,  0.,  0.,  0.,  0.,  2.,  0.,  0.,  0.,  0.,  2.,  0.,  0.,
         -2.,  0.,  2.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
          0.,  0.,  0.,  0.,  0.,  2.,  0.,  0.,  0.,  0.,  0.,  2.,  0.,  0.,
          0.,  0.,  0.,  0.,  2.,  0.,  0.,  0.,  0.,  0.,  0.,  2.,  0.,  0.,
          0.,  0.,  2.,  0.,  0.,  0.,  0.,  0.]])

In [1128]:
diff = out_cim_fc - out_fc.detach()

np.mean(np.abs(diff.numpy()))

np.float32(0.36666667)