In [1]:
import torch
print(torch.cuda.is_available())

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

import numpy as np

False
cpu


In [2]:
import json
params = json.load(open('assets/params.json'))
conv1_depthwise = torch.tensor(params["conv1_depthwise"])
conv1_pointwise = torch.tensor(params["conv1_pointwise"])
conv2_depthwise = torch.tensor(params["conv2_depthwise"])
conv2_pointwise = torch.tensor(params["conv2_pointwise"])
conv3_depthwise = torch.tensor(params["conv3_depthwise"])
conv3_pointwise = torch.tensor(params["conv3_pointwise"])

In [3]:
import json
images = json.load(open("assets/images_py.json"))
labels = json.load(open("assets/labels.json"))
images_t = []
for image in images:
    images_t.append(torch.tensor(image).view(28, 28, 1))

In [4]:
def pointwise(x, weight, size, in_ch, out_ch):
  # store weight to local buffer
  out = torch.zeros(size, size, out_ch)
  for py in range(size):
    for px in range(size):
      for kp in range(in_ch):
        read = x[py, px, kp] # stream in
        for l in range(out_ch):
            out[py, px, l] += read * weight[l, kp]
      for l in range(out_ch):
        if out[py, px, l] < 0:
          out[py, px, l] = 0
        # stream out
  return out

def depthwise(x, weight, size, in_ch):
  next_size = (size+2)//3
  out = torch.zeros(next_size, next_size, in_ch)
  x_pad = torch.zeros(size+2, size+2, in_ch)
  x_pad[1:size+1, 1:size+1, :] = x
  for py in range(next_size):
    for px in range(next_size):
      for l in range(in_ch):
        val = 0
        for ky in range(3):
          for kx in range(3):
            val += x_pad[py * 3 + ky, px * 3 + kx, l] * weight[l, ky, kx]
        out[py, px, l] = val
  return out

def depthwise_final(x, weight, size=4, in_ch=16):
  # store x, weight to local buffer
  next_size = size // 4
  out = torch.zeros(next_size, next_size, in_ch)
  for l in range(in_ch):
    val = 0
    for ky in range(4):
      for kx in range(4):
        val += x[ky, kx, l] * weight[l, ky, kx]
    out[0, 0, l] = val
    # stream out
  return out

In [5]:
def inf(x):
  x1 = depthwise(x, conv1_depthwise, 28, 1)
  x2 = pointwise(x1, conv1_pointwise, 10, 1, 4)
  x3 = depthwise(x2, conv2_depthwise, 10, 4)
  x4 = pointwise(x3, conv2_pointwise, 4, 4, 12)
  x5 = depthwise_final(x4, conv3_depthwise, 4, 12)
  x6 = pointwise(x5, conv3_pointwise, 1, 12, 10)
  return x6

In [7]:
total_correct = 0
total_num = 0
import time
t0 = time.time()
for image, label in zip(images_t, labels):
  res = inf(image)
  pred = torch.argmax(res)
  total_correct += (pred == label)
  total_num += 1
  if total_num % 100 == 0:
    print(total_num)
t1 = time.time()
print("Total elapsed time:", t1-t0, "s")
print("Elapsed time per picture:", ((t1-t0) / total_num) * 1000, "ms")
print('Accuracy:', total_correct / float(total_num))

100
200
300
400
500
600
700
800
900
1000
Total elapsed time: 477.6563947200775 s
Elapsed time per picture: 466.4613229688257 ms
Accuracy: tensor(0.9053)
