In [1]:
import numpy as np

from torch.utils.data import DataLoader
from torchvision import datasets

from tqdm.notebook import tqdm

import MST
import MST.InplaceModules as inM
import torch.nn as nn
import torch
import torch.nn.functional as F

In [2]:
BATCH_SIZE = 100
NUM_WORKERS = 0

EPOCHS = 10

In [3]:
def to_np_arr(a):
    a_ = MST.MDT_ARRAY(a).astype(np.float32) / 255
    return a_.reshape(1, *a_.shape)

In [4]:
transform = to_np_arr

train_dataset = datasets.MNIST(
    root='datasets',
    train=True,
    transform=transform,
    download=True
)

test_dataset = datasets.MNIST(
    root='datasets',
    train=False,
    transform=transform,
    download=True
)

train_dataloader = DataLoader(
    dataset=train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
)

test_dataloader = DataLoader(
    dataset=test_dataset,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
)

In [5]:
class MyNet(MST.BasicModule):
    def __init__(self):
        super().__init__()
        self.net = MST.Sequential(
            MST.Conv2d(1, 6, 5, 2, 2),
            MST.ReLU(),
            MST.Conv2d(6, 8, 3, 1, 1),
            MST.ReLU(),
            MST.Flatten(),
            MST.Linear(14*14*8, 280),
            MST.ReLU(),
            MST.Linear(280, 100),
            MST.ReLU(),
            MST.Linear(100, 50),
            MST.ReLU(),
            MST.Linear(50, 10),
        )

    def forward(self, x):
        x = self.net(x)
        return x


CELoss = MST.CrossEntropyLoss()
net = MyNet()

optimizer = MST.SGD(net, lr=0.01, momentum=0.9, weight_decay=0.0001)
print(net)

MyNet:
 └── net (Sequential): 
	 └── Conv2d: Trainable(True) 
	 └── ReLU: Trainable(False) 
	 └── Conv2d: Trainable(True) 
	 └── ReLU: Trainable(False) 
	 └── Flatten: Trainable(False) 
	 └── Linear: Trainable(True) 
	 └── ReLU: Trainable(False) 
	 └── Linear: Trainable(True) 
	 └── ReLU: Trainable(False) 
	 └── Linear: Trainable(True) 
	 └── ReLU: Trainable(False) 
	 └── Linear: Trainable(True) 




In [6]:
def train(net : MST.BasicModule, optimizer : MST.SGD, criterion : MST.BasicModule):
    running_loss = 0
    correct_total = 0
    for ind, (images, labels) in enumerate(pbar := tqdm(train_dataloader)):
        images = MST.MDT_ARRAY(images)
        la = labels
        labels = MST.MDT_ARRAY(labels)

        output = net(images)
        loss = criterion(inM.softmax(output), labels)
        loss_da = F.cross_entropy(torch.Tensor(output), la)
        
        loss.backward()
        optimizer.step()

        running_loss += loss

        pred = np.argmax(output, axis=1, keepdims=True)
        correct_total_bt = np.sum(pred[pred==labels.reshape(-1, 1)].astype(bool))
        correct_total += correct_total_bt
        pbar.set_description(f"my {loss:.4f} nn {loss_da:.4f} diff: {(loss_da-loss):.4f} acc: {(correct_total_bt/len(images)):.2f}")

    train_loss = running_loss / len(train_dataloader)
    return train_loss


def valid(net : MST.BasicModule, criterion : MST.BasicModule):
    running_loss = 0
    correct_total = 0
    for images, labels in test_dataloader:
        images = MST.MDT_ARRAY(images)
        labels = MST.MDT_ARRAY(labels)

        output = net(images)

        loss = criterion(inM.softmax(output), labels)
        running_loss += loss

        pred = np.argmax(output, axis=1, keepdims=True)
        correct_total += np.sum(pred[pred==labels.reshape(-1, 1)].astype(bool))
        
    rec = correct_total / len(test_dataloader.dataset)
    valid_loss = running_loss / len(test_dataloader)
    return valid_loss, rec

In [7]:
for epoch in (pbar := tqdm(range(EPOCHS))):
    train_loss = train(net, optimizer, CELoss)
    valid_loss, rec = valid(net, CELoss)

    print(f"[{epoch}] train/valid loss: {train_loss:.4f}/{valid_loss:.4f} acc: {rec:.4f}")
    pbar.set_description(f"train/valid loss: {train_loss:.4f}/{valid_loss:.4f} acc: {rec:.4f}")

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/600 [00:00<?, ?it/s]

[0] train/valid loss: 2.9942/1.8888 acc: 0.6580


  0%|          | 0/600 [00:00<?, ?it/s]

[1] train/valid loss: 0.4544/0.2671 acc: 0.8417


  0%|          | 0/600 [00:00<?, ?it/s]

[2] train/valid loss: 0.2408/0.2244 acc: 0.8489


  0%|          | 0/600 [00:00<?, ?it/s]

[3] train/valid loss: 0.2196/0.2255 acc: 0.8472


  0%|          | 0/600 [00:00<?, ?it/s]

[4] train/valid loss: 0.2308/0.2297 acc: 0.8492


  0%|          | 0/600 [00:00<?, ?it/s]

[5] train/valid loss: 0.2016/0.1876 acc: 0.8551


  0%|          | 0/600 [00:00<?, ?it/s]

[6] train/valid loss: 0.1565/0.1782 acc: 0.8546


  0%|          | 0/600 [00:00<?, ?it/s]

[7] train/valid loss: 0.1604/0.1563 acc: 0.8640


  0%|          | 0/600 [00:00<?, ?it/s]

[8] train/valid loss: 0.1593/0.1633 acc: 0.8614


  0%|          | 0/600 [00:00<?, ?it/s]

[9] train/valid loss: 0.1451/0.1609 acc: 0.8593
