In [1]:
import numpy as np

from torch.utils.data import DataLoader
from torchvision import datasets

from tqdm.notebook import tqdm

import MST
import MST.InplaceModules as inM
import torch.nn as nn
import torch
import torch.nn.functional as F

In [2]:
BATCH_SIZE = 32
NUM_WORKERS = 0

EPOCHS = 10

In [3]:
def to_np_arr(a):
    return MST.MDT_ARRAY(a) / 255

In [4]:
transform = to_np_arr

train_dataset = datasets.MNIST(
    root='datasets',
    train=True,
    transform=transform,
    download=True
)

test_dataset = datasets.MNIST(
    root='datasets',
    train=False,
    transform=transform,
    download=True
)

train_dataloader = DataLoader(
    dataset=train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
)

test_dataloader = DataLoader(
    dataset=test_dataset,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
)

In [5]:
class MyNet(MST.BasicModule):
    def __init__(self):
        super().__init__()

        self.fc1_1 = MST.Sequential(
            MST.Linear(28*28, 100),
            MST.Relu()
        )

        self.fc1_2 = MST.Sequential(
            MST.Linear(28*28, 100),
            MST.Relu()
        )
        
        self.fc2 = MST.Linear(100, 50)
        self.fc3 = MST.Linear(50, 10)

    def forward(self, x):
        x = inM.flatten(x)
        
        x = inM.sum(self.fc1_1(x), self.fc1_2(x))
        
        x = self.fc2(x)
        x = inM.relu(x)
        x = self.fc3(x)
        x = inM.relu(x)
        return x


CELoss = MST.CrossEntropyLoss()
net = MyNet()

optimizer = MST.SGD(net, lr=0.01, momentum=0.9)
print(net)

MyNet:
 └── fc1_1 (Sequential): 
	 └── Linear: Trainable(True) 
	 └── Relu: Trainable(False) 

 └── fc1_2 (Sequential): 
	 └── Linear: Trainable(True) 
	 └── Relu: Trainable(False) 

 └── fc2 (Linear): Trainable(True) 
 └── fc3 (Linear): Trainable(True) 



In [6]:
def train(net : MST.BasicModule, optimizer : MST.SGD, criterion : MST.BasicModule):
    running_loss = 0
    for ind, (images, labels) in enumerate(pbar := tqdm(train_dataloader)):
        images = MST.MDT_ARRAY(images)
        labels = MST.MDT_ARRAY(labels)

        output = net(images)
        loss = criterion(output, labels)
        
        loss.backward()
        optimizer.step()

        running_loss += loss
        pbar.set_description(f"loss: {loss.round(4)}")
    train_loss = running_loss / len(train_dataloader)
    return train_loss


def valid(net : MST.BasicModule, criterion : MST.BasicModule):
    running_loss = 0
    correct_total = 0
    for images, labels in test_dataloader:
        images = MST.MDT_ARRAY(images)
        labels = MST.MDT_ARRAY(labels)

        output = net(images)

        loss = criterion(output, labels)
        running_loss += loss

        pred = np.argmax(output, axis=1, keepdims=True)
        correct_total += np.sum(pred[pred==labels.reshape(-1, 1)].astype(bool))
        
    rec = correct_total / len(test_dataloader.dataset)
    valid_loss = running_loss / len(test_dataloader)
    return valid_loss, rec

In [7]:
for epoch in (pbar := tqdm(range(EPOCHS))):
    train_loss = train(net, optimizer, CELoss)
    valid_loss, rec = valid(net, CELoss)

    print(f"[{epoch}] train/valid loss: {train_loss:.4f}/{valid_loss:.4f} acc: {rec:.4f}")
    pbar.set_description(f"train/valid loss: {train_loss:.4f}/{valid_loss:.4f} acc: {rec:.4f}")

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/1875 [00:00<?, ?it/s]

[0] train/valid loss: 0.3371/0.1662 acc: 0.8554


  0%|          | 0/1875 [00:00<?, ?it/s]

[1] train/valid loss: 0.1334/0.1392 acc: 0.8629


  0%|          | 0/1875 [00:00<?, ?it/s]

[2] train/valid loss: 0.0947/0.1258 acc: 0.8645


  0%|          | 0/1875 [00:00<?, ?it/s]

[3] train/valid loss: 0.0711/0.1178 acc: 0.8688


  0%|          | 0/1875 [00:00<?, ?it/s]

[4] train/valid loss: 0.0564/0.1043 acc: 0.8738


  0%|          | 0/1875 [00:00<?, ?it/s]

[5] train/valid loss: 0.0438/0.1044 acc: 0.8732


  0%|          | 0/1875 [00:00<?, ?it/s]

[6] train/valid loss: 0.0345/0.1067 acc: 0.8764


  0%|          | 0/1875 [00:00<?, ?it/s]

[7] train/valid loss: 0.0284/0.1147 acc: 0.8740


  0%|          | 0/1875 [00:00<?, ?it/s]

[8] train/valid loss: 0.0228/0.1103 acc: 0.8748


  0%|          | 0/1875 [00:00<?, ?it/s]

[9] train/valid loss: 0.0166/0.1095 acc: 0.8754
