In [1]:
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader

from tqdm.auto import tqdm

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
DATA = torch.load("./dataset/training.pt")
imgs = DATA[0]/255
imgs = imgs.reshape(60000, 784)

In [4]:
TEST_DATA = torch.load("./dataset/test.pt")
test_imgs = TEST_DATA[0]/255
test_imgs = test_imgs.reshape(10000, 784)

In [5]:
ds = TensorDataset(imgs, DATA[1])

In [6]:
test_ds = TensorDataset(test_imgs, TEST_DATA[1])

In [7]:
trainloader = DataLoader(ds, 128, True)

In [8]:
model = nn.Sequential(
    nn.Linear(784, 256),
    nn.ReLU(),
    nn.Linear(256, 256),
    nn.ReLU(),
    nn.Linear(256, 256),
    nn.ReLU(),
    nn.Linear(256, 256),
    nn.ReLU(),
    nn.Linear(256, 10),
).to(device)

In [9]:
optimizer = torch.optim.Adam(model.parameters())
criterion = nn.CrossEntropyLoss()

In [10]:
n_epochs = 5

In [11]:
for epoch in tqdm(range(n_epochs)):

    losses = []

    for batch in tqdm(trainloader):
        
        optimizer.zero_grad()

        preds = model(batch[0].to(device))
        target = batch[1].to(device)

        loss = criterion(preds, target)
        loss.backward()

        losses.append(loss.item())

        optimizer.step()

    correct_preds = 0
    with torch.no_grad():
        for example in ds:
            pred = model(example[0].to(device))
            target = example[1]
            if torch.argmax(pred) == target.item():
                correct_preds += 1

    test_correct_preds = 0
    with torch.no_grad():
        for example in test_ds:
            pred = model(example[0].to(device))
            target = example[1]
            if torch.argmax(pred) == target.item():
                test_correct_preds += 1

    print(f"Epoch: {epoch}; Avg loss: {sum(losses)/len(losses):.2f}; Train accuracy: {correct_preds/60000*100:.2f}; Test accuracy: {test_correct_preds/10000*100:.2f}")


  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/469 [00:00<?, ?it/s]

Epoch: 0; Avg loss: 0.34; Train accuracy: 95.84; Test accuracy: 95.53


  0%|          | 0/469 [00:00<?, ?it/s]

Epoch: 1; Avg loss: 0.11; Train accuracy: 97.73; Test accuracy: 97.06


  0%|          | 0/469 [00:00<?, ?it/s]

Epoch: 2; Avg loss: 0.08; Train accuracy: 98.53; Test accuracy: 97.46


  0%|          | 0/469 [00:00<?, ?it/s]

Epoch: 3; Avg loss: 0.06; Train accuracy: 98.57; Test accuracy: 97.48


  0%|          | 0/469 [00:00<?, ?it/s]

Epoch: 4; Avg loss: 0.05; Train accuracy: 98.91; Test accuracy: 97.49


In [12]:
sample_input = torch.randn(1, 784)

In [17]:
model(sample_input.to(device).squeeze())

tensor([-7.8383, -0.2208,  4.3192,  0.7563, -3.5091,  3.5511, -9.7928,  3.7425,
        -4.6480, -8.4031], device='cuda:0', grad_fn=<ViewBackward0>)

In [21]:
onnx_program = torch.onnx.export(model, sample_input.to(device), "256x4_scratch.onnx")

In [None]:
1