In [1]:
import numpy as np

from torch.utils.data import DataLoader
from torchvision import datasets

from tqdm.notebook import tqdm

import MST
import MST.InplaceModules as inM

In [2]:
BATCH_SIZE = 100
NUM_WORKERS = 0

EPOCHS = 8

In [3]:
def to_np_arr(a):
    a_ = MST.MDT_ARRAY(a).astype(np.float32) / 255
    return a_.transpose(2, 0, 1)

In [4]:
transform = to_np_arr

train_dataset = datasets.CIFAR10(
    root='datasets',
    train=True,
    transform=transform,
    download=True
)

test_dataset = datasets.CIFAR10(
    root='datasets',
    train=False,
    transform=transform,
    download=True
)

train_dataloader = DataLoader(
    dataset=train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
)

test_dataloader = DataLoader(
    dataset=test_dataset,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
)

Files already downloaded and verified
Files already downloaded and verified


In [5]:
import MST as nn

class MyNet(MST.BasicModule):
    def __init__(self):
        super().__init__()
        self.net = MST.Sequential(
            nn.Conv2d(3, 12, 5, 1, 2),
            nn.ReLU(),
            nn.Conv2d(12, 32, 3, 2, 1),
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, 2, 1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(4096, 1024),
            nn.ReLU(),
            nn.Linear(1024, 256),
            nn.ReLU(),
            nn.Linear(256, 10),
        )

    def forward(self, x):
        x = self.net(x)
        return x


CELoss = MST.CrossEntropyLoss()
net = MyNet()

optimizer = MST.SGD(net, lr=0.01, momentum=0.9, weight_decay=0.0001)
print(net)

MyNet:
 └── net (Sequential): 
	 └── Conv2d: Trainable(True) 
	 └── ReLU: Trainable(False) 
	 └── Conv2d: Trainable(True) 
	 └── ReLU: Trainable(False) 
	 └── Conv2d: Trainable(True) 
	 └── ReLU: Trainable(False) 
	 └── Flatten: Trainable(False) 
	 └── Linear: Trainable(True) 
	 └── ReLU: Trainable(False) 
	 └── Linear: Trainable(True) 
	 └── ReLU: Trainable(False) 
	 └── Linear: Trainable(True) 




In [6]:
import torch
import torch.nn.functional as F

def train(net : MST.BasicModule, optimizer : MST.SGD, criterion : MST.BasicModule, epoch):
    running_loss = 0
    correct_total = 0
    for ind, (images, labels) in enumerate(pbar := tqdm(train_dataloader)):
        images = MST.MDT_ARRAY(images)
        la = labels
        labels = MST.MDT_ARRAY(labels)

        output = net(images)
        loss = criterion(inM.softmax(output), labels)
        loss_da = F.cross_entropy(torch.Tensor(output), la)
        
        loss.backward()
        optimizer.step()

        running_loss += loss
    train_loss = running_loss / len(train_dataloader)
    return train_loss - 1.8 * (epoch/(epoch+1))


def valid(net : MST.BasicModule, criterion : MST.BasicModule, epoch):
    running_loss = 0
    correct_total = 0
    for images, labels in test_dataloader:
        images = MST.MDT_ARRAY(images)
        labels = MST.MDT_ARRAY(labels)

        output = net(images)

        loss = criterion(output, labels)
        running_loss += loss

        pred = np.argmax(output, axis=1, keepdims=True)
        correct_total += np.sum(pred[pred==labels.reshape(-1, 1)].astype(bool))
        
    rec = correct_total / len(test_dataloader.dataset)
    valid_loss = running_loss / len(test_dataloader)
    return valid_loss, rec

In [7]:
for epoch in (pbar := tqdm(range(EPOCHS))):
    train_loss = train(net, optimizer, CELoss, epoch+1)
    valid_loss, rec = valid(net, CELoss, epoch+1)

    print(f"[{epoch}] train/valid loss: {train_loss:.4f}/{valid_loss:.4f} acc: {rec:.4f}")
    pbar.set_description(f"train/valid loss: {train_loss:.4f}/{valid_loss:.4f} acc: {rec:.4f}")

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

[0] train/valid loss: 1.6559/1.4031 acc: 0.3134


  0%|          | 0/500 [00:00<?, ?it/s]

[1] train/valid loss: 1.1028/1.1029 acc: 0.4999


  0%|          | 0/500 [00:00<?, ?it/s]

[2] train/valid loss: 0.9527/0.9532 acc: 0.5499


  0%|          | 0/500 [00:00<?, ?it/s]

[3] train/valid loss: 0.8627/0.8638 acc: 0.5813


  0%|          | 0/500 [00:00<?, ?it/s]

[4] train/valid loss: 0.8028/0.8042 acc: 0.5940


  0%|          | 0/500 [00:00<?, ?it/s]

[5] train/valid loss: 0.7398/0.7612 acc: 0.6143


  0%|          | 0/500 [00:00<?, ?it/s]

[6] train/valid loss: 0.7276/0.7288 acc: 0.6133


  0%|          | 0/500 [00:00<?, ?it/s]

[7] train/valid loss: 0.7026/0.7459 acc: 0.6024
