In [3]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

from fxpmath import Fxp


In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"

batch_size = 256
learning_rate = 1e-3
epochs = 5

torch.manual_seed(1024)


<torch._C.Generator at 0x2f5e9a670d0>

In [5]:
training_data = datasets.MNIST(
    root="../data",
    train=True,
    download=True,
    transform=ToTensor(),
)

# Download test data from open datasets.
test_data = datasets.MNIST(
    root="../data",
    train=False,
    download=True,
    transform=ToTensor(),
)

# Create data loaders.
train_dl = DataLoader(training_data, batch_size=batch_size)
test_dl = DataLoader(test_data, batch_size=batch_size)


In [6]:
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Sequential(nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.ReLU())
        self.bn1 = nn.BatchNorm2d(6)
        self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv2 = nn.Sequential(nn.Conv2d(6, 16, kernel_size=5), nn.ReLU())
        self.bn2 = nn.BatchNorm2d(16)
        self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv3 = nn.Sequential(
            nn.Conv2d(16, 10, kernel_size=5),
            #    nn.ReLU(),
            nn.Flatten(),
        )
        
        self.checkpoint = {}
        

    def forward(self, x):
        x = self.conv1(x)
        self.checkpoint['conv1_out'] = x.detach()
        
        x = self.bn1(x)
        self.checkpoint['bn1_out'] = x.detach()
        
        x = self.maxpool1(x)
        self.checkpoint['pool1_out'] = x.detach()
        
        x = self.conv2(x)
        self.checkpoint['conv2_out'] = x.detach()
        
        x = self.bn2(x)
        self.checkpoint['bn2_out'] = x.detach()
        
        x = self.maxpool2(x)
        self.checkpoint['pool2_out'] = x.detach()
        
        x = self.conv3(x)
        self.checkpoint['conv3_out'] = x.detach()
        # out = F.softmax(x, dim=1)
        return x


# net = LeNet().to(device)


## Train

In [6]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)


In [7]:
def train_loop(dataloader, net, loss_fn, optim, device="cuda"):
    for batch_i, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        pred = net(X)
        loss = loss_fn(pred, y)

        optim.zero_grad()
        loss.backward()
        optim.step()

        if batch_i % 100 == 0:
            loss, current = loss.item(), batch_i * len(X)
            print(f"Batch: {batch_i}, loss: {loss}")


def test_loop(dataloader, model, loss_fn, device="cuda"):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0

    with torch.no_grad()():
        for data, label in dataloader:
            data = data.to(device)
            label = label.to(device)
            pred = model(data)
            test_loss += loss_fn(pred, label).item()
            correct += (pred.argmax(1) == label).type(torch.float).sum().item()

    test_loss /= num_batches
    correct /= size
    print(
        f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n"
    )

    return correct * 100


In [8]:
for i in range(epochs):
    print(f"==== Epoch: {i + 1} ====")
    train_loop(train_dl, net, loss_fn, optimizer)
    test_loop(test_dl, net, loss_fn)


==== Epoch: 1 ====
Batch: 0, loss: 2.705130100250244
Batch: 100, loss: 0.17798461019992828
Batch: 200, loss: 0.12561091780662537
Test Error: 
 Accuracy: 97.3%, Avg loss: 0.087469 

==== Epoch: 2 ====
Batch: 0, loss: 0.13762414455413818
Batch: 100, loss: 0.08944994956254959
Batch: 200, loss: 0.07882432639598846
Test Error: 
 Accuracy: 98.1%, Avg loss: 0.055568 

==== Epoch: 3 ====
Batch: 0, loss: 0.08388431370258331
Batch: 100, loss: 0.06753334403038025
Batch: 200, loss: 0.06397119164466858
Test Error: 
 Accuracy: 98.5%, Avg loss: 0.046443 

==== Epoch: 4 ====
Batch: 0, loss: 0.05930690839886665
Batch: 100, loss: 0.054147519171237946
Batch: 200, loss: 0.054667338728904724
Test Error: 
 Accuracy: 98.7%, Avg loss: 0.040283 

==== Epoch: 5 ====
Batch: 0, loss: 0.04563910886645317
Batch: 100, loss: 0.04726320505142212
Batch: 200, loss: 0.04988543689250946
Test Error: 
 Accuracy: 98.7%, Avg loss: 0.037014 



In [9]:
net.state_dict()

OrderedDict([('conv1.0.weight',
              tensor([[[[-0.1324,  0.0917,  0.0890,  0.0653,  0.0357],
                        [ 0.0262, -0.1214, -0.1947,  0.0655, -0.1628],
                        [ 0.1376, -0.2074, -0.2030, -0.0333, -0.0549],
                        [ 0.0862,  0.1592,  0.1292,  0.1540, -0.0691],
                        [ 0.0040,  0.1715,  0.1500, -0.1723, -0.1452]]],
              
              
                      [[[ 0.1745, -0.0229,  0.1721, -0.0853, -0.0319],
                        [-0.0356,  0.1490,  0.2821, -0.0607, -0.0117],
                        [-0.1086,  0.0381, -0.0355,  0.2341,  0.0456],
                        [-0.0961, -0.2282,  0.0220, -0.1647,  0.0810],
                        [-0.0652, -0.0898, -0.0577,  0.0924, -0.0399]]],
              
              
                      [[[ 0.0894,  0.1697,  0.1733,  0.0118, -0.1179],
                        [ 0.0818,  0.1925, -0.0965, -0.0310, -0.2178],
                        [ 0.1812,  0.0299, -0.1077, 

In [10]:
# torch.save(net.state_dict(), "new_params.pth")


## Process

In [7]:
device = "cpu"
net = LeNet()


In [8]:
net.load_state_dict(torch.load("new_params.pth"))


<All keys matched successfully>

In [9]:
test_data = torch.tensor(np.load("test_data.npy")[2:30, 2:30]).reshape(1, 1, 28, 28)
net(test_data)

tensor([[ -3.2612,  -1.4909,   1.2582,  -0.2114,  -5.8418,  -4.3883, -12.3146,
          16.0718,  -5.1825,   2.0446]], grad_fn=<ReshapeAliasBackward0>)

In [10]:
net.checkpoint

{'conv1_out': tensor([[[[0.1943, 0.1943, 0.1943,  ..., 0.1943, 0.1943, 0.1943],
           [0.1943, 0.1943, 0.1943,  ..., 0.1943, 0.1943, 0.1943],
           [0.1943, 0.1943, 0.1943,  ..., 0.1943, 0.1943, 0.1943],
           ...,
           [0.1943, 0.1943, 0.1943,  ..., 0.1943, 0.1943, 0.1943],
           [0.1943, 0.1943, 0.1943,  ..., 0.1943, 0.1943, 0.1943],
           [0.1943, 0.1943, 0.1943,  ..., 0.1943, 0.1943, 0.1943]],
 
          [[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
           [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
           [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
           ...,
           [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
           [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
           [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],
 
          [[0.1438, 0.1438, 0.1438,  ..., 0.1438, 0.1438, 0.1438],
           [0.1438, 0.1438, 0.1438,  ..., 0.1438, 0.1438, 0.1438],
           