"SetUp" and "Original Definition and Training Sections" are inspired form <https://mike-12.medium.com/depthwise-separable-convolutions-simple-image-classification-with-pytorch-7f7d2ba06af7>.

# SetUp

In [1]:
import torch
print(torch.cuda.is_available())

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

import matplotlib.pyplot as plt
import torchvision
import numpy as np

False
cpu


In [29]:
from torch import optim, nn
from collections import OrderedDict

hidden_units = [4, 12]
output_units = 10
model_d = nn.Sequential(OrderedDict([
    ('conv1_depthwise', nn.Conv2d(1, 1, 3, stride=3, padding=1, groups=1, bias=False)),
    ('conv1_pointwise', nn.Conv2d(1, hidden_units[0], 1, bias=False)),
    ('Relu1', nn.ReLU()),
    ('conv2_depthwise', nn.Conv2d(hidden_units[0], hidden_units[0], 3, stride=3, padding=1, groups=hidden_units[0], bias=False)),
    ('conv2_pointwise', nn.Conv2d(hidden_units[0], hidden_units[1], 1, bias=False)),
    ('Relu2', nn.ReLU()),
    ('conv3_depthwise', nn.Conv2d(hidden_units[1], hidden_units[1], 4, stride=4, padding=0, groups=hidden_units[1], bias=False)),
    ('conv3_pointwise', nn.Conv2d(hidden_units[1], output_units, 1, bias=False)),
    ('log_softmax', nn.LogSoftmax(dim = 1))
]))

# Original Definition and Training

In [30]:
from torch.utils.data.dataset import Subset
from torch.utils.data import DataLoader
import torchvision
dataset = torchvision.datasets.MNIST(root="mnist/", train=True, download=True, transform=torchvision.transforms.ToTensor())
train = Subset(dataset, torch.arange(10000))
test = Subset(dataset, torch.arange(10000, 11024))
batch_size = 128
trainloader = DataLoader(train, batch_size=batch_size, shuffle=True)
testloader = DataLoader(test, batch_size=1, shuffle=False)

In [31]:
from collections import OrderedDict
from torch import optim, nn


class Flatten(nn.Module):
  def forward(self, input):
    return input.view(input.size(0), -1)


total_params = 0
for parameter in model_d.parameters():
  if parameter.requires_grad:
    total_params += np.prod(parameter.size())

print('total_params:', total_params)

model_d.to(device)
optimizer_d = optim.Adam(model_d.parameters(), lr = 0.02)

criterion = nn.NLLLoss()
epochs = 20
for i in range(epochs):
  running_classification_loss = 0
  running_cycle_consistent_loss = 0
  running_loss = 0
  for images, labels in trainloader:
    images, labels = images.to(device), labels.to(device)
    optimizer_d.zero_grad()

    # Run classification model
    predicted_labels = model_d(images)
    classification_loss = criterion(Flatten()(predicted_labels), labels)

    # Optimize classification weights
    classification_loss.backward()
    optimizer_d.step()

    running_classification_loss += classification_loss.item()
    running_loss = running_classification_loss
  else:
    print(f"{i} Training loss: {running_loss/len(trainloader)}")


total_params: 409
0 Training loss: 1.418622894377648
1 Training loss: 0.7422349347343927
2 Training loss: 0.6520623006398165
3 Training loss: 0.6151688664774352
4 Training loss: 0.6015881809252727
5 Training loss: 0.5960249742375144
6 Training loss: 0.5809037153479419
7 Training loss: 0.5652959682518923
8 Training loss: 0.5657608173316038
9 Training loss: 0.558564608610129
10 Training loss: 0.5448345394828652
11 Training loss: 0.535114960957177
12 Training loss: 0.5429439325875873
13 Training loss: 0.5405502696580524
14 Training loss: 0.5373978999596608
15 Training loss: 0.5311610185647313
16 Training loss: 0.5293893180316007
17 Training loss: 0.5276448485217516
18 Training loss: 0.5344730536394482
19 Training loss: 0.5300060665305657


In [32]:
total_correct = 0
total_num = 0
for images, labels in testloader:
  images, labels = images.to(device), labels.to(device)
  ps = Flatten()(torch.exp(model_d(images)))
  predictions = ps.topk(1, 1, True, True)[1].t()
  correct = predictions.eq(labels.view(1, -1))

  total_correct += correct.sum().cpu().numpy()
  total_num += images.shape[0]

print('Accuracy:', total_correct / float(total_num))

Accuracy: 0.841796875


In [33]:
for k in model_d.state_dict():
  print(k, model_d.state_dict()[k].size())

conv1_depthwise.weight torch.Size([1, 1, 3, 3])
conv1_pointwise.weight torch.Size([4, 1, 1, 1])
conv2_depthwise.weight torch.Size([4, 1, 3, 3])
conv2_pointwise.weight torch.Size([12, 4, 1, 1])
conv3_depthwise.weight torch.Size([12, 1, 4, 4])
conv3_pointwise.weight torch.Size([10, 12, 1, 1])


# Save Model Params

In [None]:
import json
# import pickle
# params = {}
# for k in model_d.state_dict():
#   params[k] = model_d.state_dict()[k].cpu().numpy()
#   pickle.dump(params[k], open(f'{k}.pk', 'wb'))
# pickle.dump(params, open('params.pk', 'wb'))

In [None]:
conv1_depthwise = model_d.state_dict()["conv1_depthwise.weight"].view(1, 3, 3)
conv1_pointwise = model_d.state_dict()["conv1_pointwise.weight"].view(4, 1)
conv2_depthwise = model_d.state_dict()["conv2_depthwise.weight"].view(4, 3, 3)
conv2_pointwise = model_d.state_dict()["conv2_pointwise.weight"].view(12, 4)
conv3_depthwise = model_d.state_dict()["conv3_depthwise.weight"].view(12, 4, 4)
conv3_pointwise = model_d.state_dict()["conv3_pointwise.weight"].view(10, 12)

In [None]:
params = {}
params["conv1_depthwise"] = conv1_depthwise.cpu().numpy().tolist()
params["conv1_pointwise"] = conv1_pointwise.cpu().numpy().tolist()
params["conv2_depthwise"] = conv2_depthwise.cpu().numpy().tolist()
params["conv2_pointwise"] = conv2_pointwise.cpu().numpy().tolist()
params["conv3_depthwise"] = conv3_depthwise.cpu().numpy().tolist()
params["conv3_pointwise"] = conv3_pointwise.cpu().numpy().tolist()
json.dump(params, open('params.json', 'w'))

In [None]:
torch.save(model_d.state_dict(), "model.pth")

# Read Model Params

In [34]:
import json
params = json.load(open('params.json'))
conv1_depthwise = torch.tensor(params["conv1_depthwise"])
conv1_pointwise = torch.tensor(params["conv1_pointwise"])
conv2_depthwise = torch.tensor(params["conv2_depthwise"])
conv2_pointwise = torch.tensor(params["conv2_pointwise"])
conv3_depthwise = torch.tensor(params["conv3_depthwise"])
conv3_pointwise = torch.tensor(params["conv3_pointwise"])

In [35]:
model_d.load_state_dict(torch.load("model.pth"))

<All keys matched successfully>

# Vitis HLS Like Implementation

## Definition

In [36]:
def pointwise(x, weight, size, in_ch, out_ch):
  # store weight to local buffer
  out = torch.zeros(size, size, out_ch)
  for py in range(size):
    for px in range(size):
      for kp in range(in_ch):
        read = x[py, px, kp] # stream in
        for l in range(out_ch):
            out[py, px, l] += read * weight[l, kp]
      for l in range(out_ch):
        if out[py, px, l] < 0:
          out[py, px, l] = 0
        # stream out
  return out

def depthwise(x, weight, size, in_ch):
  next_size = (size+2)//3
  out = torch.zeros(next_size, next_size, in_ch)
  x_pad = torch.zeros(size+2, size+2, in_ch)
  x_pad[1:size+1, 1:size+1, :] = x
  for py in range(next_size):
    for px in range(next_size):
      for l in range(in_ch):
        val = 0
        for ky in range(3):
          for kx in range(3):
            val += x_pad[py * 3 + ky, px * 3 + kx, l] * weight[l, ky, kx]
        out[py, px, l] = val
  return out

def depthwise_final(x, weight, size=4, in_ch=16):
  # store x, weight to local buffer
  next_size = size // 4
  out = torch.zeros(next_size, next_size, in_ch)
  for l in range(in_ch):
    val = 0
    for ky in range(4):
      for kx in range(4):
        val += x[ky, kx, l] * weight[l, ky, kx]
    out[0, 0, l] = val
    # stream out
  return out

In [37]:
# depthwise implementation like Vitis HLS
def depthwise(x, weight, size, in_ch):
  next_size = (size+2)//3
  out = torch.zeros(next_size, next_size, in_ch)
  line_buf = torch.zeros(2, size+2, in_ch)
  window = torch.zeros(3, 3, in_ch)
  x = x.view(size * size * in_ch)
  x_index = 0
  for py in range(next_size):
    #### init line_buf[0,:,:] ####
    if py != 0:
      for l in range(in_ch): # x = 0
        line_buf[0, 0, l] = 0
      for px in range(1, size+1):
        for l in range(in_ch):
          line_buf[0, px, l] = x[x_index]
          x_index += 1
      for l in range(in_ch): # x = -1
        line_buf[0, size+1, l] = 0
    else:
      for px in range(size+2):
        for l in range(in_ch):
          line_buf[0, px, l] = 0
    #### init line_buf[1,:,:] ####
    for l in range(in_ch): # x = 0
      line_buf[1, 0, l] = 0
    for px in range(1, size+1):
      for l in range(in_ch):
        line_buf[1, px, l] = x[x_index]
        x_index += 1
    for l in range(in_ch): # x = -1
      line_buf[1, size+1, l] = 0
    #### iterate ####
    for px in range(next_size):
        #### set window ####
        for ky in range(2):
          for kx in range(3):
            for l in range(in_ch):
              window[ky, kx, l] = line_buf[ky, px * 3 + kx, l]
        for kx in range(3):
          for l in range(in_ch):
            if (px == 0 and kx == 0) or (px == next_size - 1 and kx == 2) or (py == next_size - 1):
              window[2, kx, l] = 0
            else:
              window[2, kx, l] = x[x_index]
              x_index += 1
        #### convolution ####
        for l in range(in_ch):
          val = 0
          for ky in range(3):
            for kx in range(3):
              val += window[ky, kx, l] * weight[l, ky, kx]
          out[py, px, l] = val # stream out
  return out

## Test

In [38]:
def inf(x):
  x1 = depthwise(x, conv1_depthwise, 28, 1)
  x2 = pointwise(x1, conv1_pointwise, 10, 1, 4)
  x3 = depthwise(x2, conv2_depthwise, 10, 4)
  x4 = pointwise(x3, conv2_pointwise, 4, 4, 12)
  x5 = depthwise_final(x4, conv3_depthwise, 4, 12)
  x6 = pointwise(x5, conv3_pointwise, 1, 12, 10)
  return x6

In [39]:
total_correct = 0
total_num = 0
for image, label in testloader:
  image = image.view(28, 28, 1)
  res = inf(image)
  pred = torch.argmax(res)
  if pred == label:
    correct = 1
  else:
    correct = 0
  total_correct += correct
  total_num += 1
  if total_num % 10 == 0:
    print(total_num)
  if total_num == 100:
    break
print('Accuracy:', total_correct / float(total_num))

10
20
30
40
50
60
70
80
90
100
Accuracy: 0.87


# Prepare Data For PYNQ

In [42]:
import struct
def float_to_int(f):
    f = torch.tensor(f, dtype=torch.float16).clone().detach()
    packed = struct.pack('>f', f)
    unpacked = struct.unpack('>I', packed)[0]
    return unpacked

# intのビット列をfloatのビット列として解釈する関数
def int_to_float(i):
    packed = struct.pack('>I', i)
    unpacked = struct.unpack('>f', packed)[0]
    return unpacked


In [44]:
# Store Image and Label
images = []     # bit
images_py = []  # float
labels = []
for image, label in testloader:
  images_py.append(image.view(1, 1, 28, 28).cpu().numpy().tolist())
  image = image.view(28 * 28)
  image = [float_to_int(i) for i in image]
  label = label.cpu().numpy()[0].item()
  images.append(image)
  labels.append(label)
json.dump(images, open('images.json', 'w'))
json.dump(images, open('images_py.json', 'w'))
json.dump(labels, open('labels.json', 'w'))

  f = torch.tensor(f, dtype=torch.float16).clone().detach()
