# MNIST sample code using neurocircuit

## Prerequisite

In [1]:
%%capture
!git clone https://github.com/Umikan/neurocircuit.git
%cd neurocircuit
!python setup.py install
!pip install torchmetrics

## Import

In [1]:
from neurockt.monitor import Recorder, RecorderHelper, EarlyStopping
from neurockt.data import DataPipe, Multiclass, Multilabel, Image
from neurockt.torch import forward, Stack

## Data Pipeline

In [2]:
import pandas as pd
from torchvision.datasets import MNIST
import torchvision.transforms as T


dataset = MNIST(root='/mnist/', download=True, transform=None)
df = pd.DataFrame({
    "image": [data.unsqueeze(0) for data in dataset.data],
    "target": dataset.targets
})

image = ("image", Image)
target = ("target", Multiclass)
transform = T.Compose([T.Normalize(0, 255.)])

pipe = (
    DataPipe(df)
    .X(*image, transform)
    .Y(*target).arg("num_classes", Multiclass.num_classes)
    .bunch("train")
    .X(*image, transform)
    .Y(*target)
    .bunch("valid")
)

train_idx, valid_idx = range(50000), range(50000, 60000)
dls = pipe.select(bunch=("train", "valid"), idx=(train_idx, valid_idx))
dls = dls(bs=64, shuffle=True, num_workers=8, drop_last=False)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /mnist/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting /mnist/MNIST/raw/train-images-idx3-ubyte.gz to /mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to /mnist/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting /mnist/MNIST/raw/train-labels-idx1-ubyte.gz to /mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to /mnist/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting /mnist/MNIST/raw/t10k-images-idx3-ubyte.gz to /mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to /mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting /mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz to /mnist/MNIST/raw





## Define model

In [3]:
import torch.nn as nn

class SimpleConvNet(nn.Module):
  def __init__(self):
    super().__init__()
    self.conv1 = nn.Sequential(
        nn.LazyConv2d(32, 3),
        nn.ReLU(),
        nn.MaxPool2d(2)
    )
    self.conv2 = nn.Sequential(
        nn.LazyConv2d(64, 3),
        nn.ReLU(),
        nn.MaxPool2d(2)
    )
    self.head = nn.Sequential(
        nn.Flatten(),
        nn.Dropout(0.5),
        nn.LazyLinear(10),
    )

  def forward(self, x):
    x = self.conv1(x)
    x = self.conv2(x)
    x = self.head(x)
    return x

model = SimpleConvNet()



## Train/Eval model

In [6]:
use_wandb = False  # Toggle if you want to use W&B
project = {
    'project': 'NeuroCKT_Example',
    'tags': ['MNIST'],
    'notes': 'MNIST example using NeuroCKT'
}
if use_wandb:
  from neurockt.monitor.wandb import WandbRecorder
  recorder = WandbRecorder(**project)
else:
  recorder = Recorder()

In [8]:
import torch
import torch.nn as nn
from torch.optim import Adam
from torchmetrics.classification import MulticlassAccuracy

def torch_mean(x):
    return torch.stack(x).mean().item()


class TotalLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.loss = nn.CrossEntropyLoss()

    @recorder.track_scalar("train_loss", torch_mean)
    @recorder.metric("valid_loss", torch_mean)
    def forward(self, inp, targ):
        return self.loss(inp, targ)


def train(model, dls, n_epochs):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = model.to(device)

    helper = RecorderHelper(recorder)
    criterion = TotalLoss()
    metric = helper.metric(torch_mean, acc=MulticlassAccuracy(10))
    optimizer = Adam(model.parameters())
    train_dl, valid_dl = dls
    early_stopping = EarlyStopping(patience=1)

    for _ in range(n_epochs):
        model.train()
        for preds, labels in forward(train_dl, model):
            with recorder(train=True):
                (loss := criterion(preds, labels[0])).backward()
                optimizer.step()
                optimizer.zero_grad(set_to_none=True)

        model.eval()
        with recorder(train=False), torch.inference_mode():
            stack = Stack()
            for preds, labels in forward(valid_dl, model):
                loss = criterion(preds, labels[0])
                stack.add(preds=preds, labels=labels[0])
            acc = metric["acc"](stack('preds'), stack('labels'))
            if early_stopping(acc):
                break


train(model, dls, n_epochs=30)



Better model found at epoch 1: 0.991016149520874
	 0.0336        0.9910        




	 0.0360        0.9907        




Better model found at epoch 3: 0.9915316104888916
	 0.0322        0.9915        




	 0.0340        0.9914        




Better model found at epoch 5: 0.9928864240646362
	 0.0299        0.9928        




	 0.0311        0.9927        


                                                  

EarlyStopping: Exceeded the maximum count of patience
	 0.0327        0.9913        


