In [1]:
import numpy as np

from torch.utils.data import DataLoader
from torchvision import datasets

from tqdm.notebook import tqdm

import MST
import MST.InplaceModules as inM

In [2]:
BATCH_SIZE = 128
NUM_WORKERS = 0

EPOCHS = 10

In [3]:
def to_np_arr(a):
    a_ = MST.MDT_ARRAY(a).astype(np.float32) / 255
    return a_.reshape(1, *a_.shape)

In [4]:
transform = to_np_arr

train_dataset = datasets.MNIST(
    root='datasets',
    train=True,
    transform=transform,
    download=True
)

test_dataset = datasets.MNIST(
    root='datasets',
    train=False,
    transform=transform,
    download=True
)

train_dataloader = DataLoader(
    dataset=train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
)

test_dataloader = DataLoader(
    dataset=test_dataset,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
)

In [5]:
class MyNet(MST.BasicModule):
    def __init__(self):
        super().__init__()
        self.conv1 = MST.Conv2d(1, 6, 5)
        self.pool_1 = MST.MaxPool2d(2, 2)
        self.conv2 = MST.Conv2d(6, 16, 5)
        self.pool_2 = MST.MaxPool2d(2, 2)

        self.fc1 = MST.Linear(16 * 4 * 4, 120)
        self.fc2 = MST.Linear(120, 84)
        self.fc3 = MST.Linear(84, 10)

    def forward(self, x):
        x = self.conv1(x)

        # with torch.no_grad():
        #     torchConv = nn.Conv2d(1, 6, 5)

        #     torchConv.weight = nn.Parameter(torch.tensor(self.conv1._w.astype(np.float32)))
        #     torchConv.bias = nn.Parameter(torch.tensor(self.conv1._bias.astype(np.float32)).flatten())
        #     nnres = torchConv(torch.Tensor(inX))
        #     nnres = nnres.numpy().round(3).flatten()
        #     myres = x.round(3).flatten()
        #     bads = np.where(abs(nnres - myres) > 0.001)
        #     print(f"bads: {len(bads[0])}")
        #     for pos in bads[0][:10]:
        #         print(f"[{pos}]: {myres[pos]:.3f} {nnres[pos]:.3f}")
                
        x = self.pool_1(inM.Relu(x))
        x = self.pool_2(inM.Relu(self.conv2(x)))
        x = inM.flatten(x) # flatten all dimensions except batch
        x = inM.Relu(self.fc1(x))
        x = inM.Relu(self.fc2(x))
        x = self.fc3(x)
        return x

CELoss = MST.CrossEntropyLoss() # MSE | 
net = MyNet()

optimizer = MST.SGD(net, lr=0.01)
print(net)

MyNet:
 └── conv1: Trainable(True)
 └── pool_1: Trainable(False)
 └── conv2: Trainable(True)
 └── pool_2: Trainable(False)
 └── fc1: Trainable(True)
 └── fc2: Trainable(True)
 └── fc3: Trainable(True)



In [6]:
def train(net : MST.BasicModule, optimizer : MST.SGD, criterion : MST.BasicModule):
    running_loss = 0
    for ind, (images, labels) in enumerate(pbar := tqdm(train_dataloader)):
        images = MST.MDT_ARRAY(images)
        labels = MST.MDT_ARRAY(labels)

        output = net(images)
        loss = criterion(output, labels)

        loss.backward()
        
        optimizer.step()
        pbar.set_description(f"loss: {loss.round(10)}")
        running_loss += loss
    train_loss = running_loss / len(train_dataloader)
    return train_loss


def valid(net : MST.BasicModule, criterion : MST.BasicModule):
    running_loss = 0
    correct_total = 0
    for images, labels in test_dataloader:
        images = MST.MDT_ARRAY(images)
        labels = MST.MDT_ARRAY(labels)
        output = net(images)
        loss = criterion(output, labels)
        running_loss += loss

        pred = np.argmax(output, axis=1, keepdims=True)
        correct_total += np.sum(pred[pred==labels.reshape(-1, 1)].astype(bool))
        
    rec = correct_total / len(test_dataloader.dataset)
    valid_loss = running_loss / len(test_dataloader)
    return valid_loss, rec

In [7]:
for epoch in (pbar := tqdm(range(EPOCHS))):
    train_loss = train(net, optimizer, CELoss)
    valid_loss, rec = valid(net, CELoss)

    print(f"[{epoch}] train/valid loss: {train_loss:.4f}/{valid_loss:.4f} rec: {rec:.4f}")
    pbar.set_description(f"train/valid loss: {train_loss:.4f}/{valid_loss:.4f} rec: {rec:.4f}")

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/469 [00:00<?, ?it/s]

[0] train/valid loss: 0.7119/0.4235 rec: 0.7678


  0%|          | 0/469 [00:00<?, ?it/s]