In [29]:
import os

import torch
import torch.nn as nn
from PIL import Image
import torchvision.transforms as transforms
import numpy as np

class lenet(nn.Module):
    def __init__(self):
        super(lenet, self).__init__()
        self.conv1 = nn.Conv2d(3, 8, 3)
        self.fc = nn.Linear(7200, 10)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        out = self.conv1(x)
        out = out ** 2
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out
    
model = lenet()
model.load_state_dict(torch.load("checkpoint/lenet_8/best.pth").get("model"))
model.eval()
os.makedirs("fastest", exist_ok=True)

## Input

In [30]:
num_slots = 8192

def load_image(image_path, tensor_len):
    image = Image.open(image_path).convert('RGB')
    transform = transforms.ToTensor()
    image_tensor = transform(image)
    
    image_flatten_tensor = torch.zeros((1, tensor_len))
    image_flatten_tensor[0, :3 * 1024] = image_tensor.view((1, -1))
    
    return image_tensor.unsqueeze(0), image_flatten_tensor


image_ori, image = load_image('./images/test.png', num_slots)
image.shape

torch.Size([1, 8192])

## Conv1

### Convert weight

In [31]:
weights_flatten = model.conv1.weight.data.view(8, 3, -1)
weights = torch.zeros((8, 9, 1024 * 3))
bias = torch.zeros((8, 1024 * 3))

for i in range(8):
    bias[i, :] = model.conv1.bias.data[i]
    for j in range(9):
        for k in range(3):
            weights[i, j, k * 1024 : k * 1024 + 1024] = weights_flatten[i, k, j]
            
            
mask = torch.zeros((1, num_slots))
for i in range(30):
    for j in range(30):
        mask[0, i * 32 + j] = 1
        

for i in range(8):
    np.savetxt(f'fastest/conv1-ch{i}-bias.bin', bias[i, :], delimiter=',')
    for j in range(9):
        np.savetxt(f'fastest/conv1-ch{i}-k{j}.bin', weights[i, j, :], delimiter=',')

np.savetxt('fastest/conv1-mask.bin', mask[0], delimiter=',')


In [32]:
weights.numpy().tofile('fastest/conv1-weights.bin')
bias.numpy().tofile('fastest/conv1-bias.bin')

In [33]:
bias.shape

torch.Size([8, 3072])

### Input rotate

In [34]:
image_rotations = []
rolls = [0, 1, 2, 32, 33, 34, 64, 65, 66]
for r in rolls:
    image_rotations.append(torch.roll(image, -r))

In [35]:
conv_res = torch.zeros((1, num_slots))
for i in range(8):
    encoded_bias = torch.zeros((1, num_slots))
    encoded_bias[0, :3072] = bias[i, :]
    temp_res = torch.zeros((1, num_slots))
    for j in range(9):
        encoded_weights = torch.zeros((1, num_slots))
        encoded_weights[0, :3072] = weights[i, j, :]
        temp_res += image_rotations[j] * encoded_weights

    temp_res = temp_res + torch.roll(temp_res, -1024) + torch.roll(temp_res, -2048) + encoded_bias
    temp_res *= mask

    if i == 0:
        conv_res = temp_res
    else:
        conv_res += temp_res
    conv_res = torch.roll(conv_res, -1024)
    

print(conv_res.shape)

torch.Size([1, 8192])


In [36]:
def check_correct(res, image_ori, model):
    model_output = model.conv1(image_ori).squeeze()
    res_reshape = torch.zeros_like(model_output)
    for i in range(8):
        for j in range(30):
            res_reshape[i, j, :] = res[0, i * 1024 + 32 * j : i * 1024 + 32 * j + 30]
            
    print((res_reshape - model_output).abs().sum())
    
check_correct(conv_res, image_ori, model)

tensor(0.0010, grad_fn=<SumBackward0>)


## FC

### Convert weight

In [39]:
weights_flatten = model.fc.weight.data
weights = torch.zeros((16, num_slots))
bias = torch.zeros((1, num_slots))

for i in range(10):
    bias[0, i] = model.fc.bias.data[i]
    for j in range(8):
        for k in range(30):
            weights[i, 1024 * j + 32 * k : 1024 * j + 32 * k + 30] = weights_flatten[i, 900 * j + 30 * k : 900 * j + 30 * k + 30]

In [40]:
weights.numpy().tofile('fastest/fc-weights.bin')

In [41]:
weights_store = torch.zeros((16, 8192))
for i in range(16):
    for j in range(8192):
            weights_store[i, j] = weights[j%16, (i+j)%8192]

In [42]:
for i in range(16):
    np.savetxt(f'fastest/fc-c{i}.bin', weights_store[i, :], delimiter=',')

np.savetxt('fastest/fc-bias.bin', bias[0], delimiter=',')

In [43]:
feature = conv_res ** 2

final_res = torch.zeros((1, num_slots))

for i in range(16):
    final_res += weights_store[i] * torch.roll(feature[0], -i)

rolls = [4096, 2048, 1024, 512, 256, 128, 64, 32, 16]

for r in rolls:
    final_res += torch.roll(final_res, -r)
        
final_res += bias

In [44]:
final_res[0,:10]

tensor([ -4.8011,  -2.6713,   8.1740,  18.7194,   2.7581, -10.6273,  -1.4387,
        -10.3061,  -5.7208,   5.1434])

In [45]:
model(image_ori)

tensor([[ -4.8011,  -2.6713,   8.1740,  18.7194,   2.7581, -10.6273,  -1.4387,
         -10.3061,  -5.7208,   5.1434]], grad_fn=<AddmmBackward0>)

In [46]:
weights.shape

torch.Size([16, 8192])