## Install required libraries

In [None]:
!pip install hub
!pip install matplotlib
!pip install torch

## Imports

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from torch import nn
import torch.nn.functional as F
from torch.utils.data import ConcatDataset
import torch
from torch.utils.data import random_split
import hub
from hub.compute.generic.ds_transforms import shift_scale_rotate, transpose
from hub.api.sharded_datasetview import ShardedDatasetView

## Load the dataset

In [None]:
ds = hub.load("activeloop/mnist")

## Visualize

In [None]:
img = ds["image", 5].compute()
plt.imshow(img)

## Augment images and add to the original Dataset

In [None]:
ds_augmented = shift_scale_rotate(ds, keys=['image'], rotate_limit=0, shift_limit=0.1)
ds_augmented = ds_augmented.store("/tmp/mnist_aug")
ds_sharded = ShardedDatasetView([ds, ds_augmented])

@hub.transform(schema=ds_sharded.schema, scheduler="threaded", workers=8)
def transform_identity(sample):
    return sample

ds = transform_identity(ds_sharded).store('/tmp/mnist_all')

## Define a model

In [None]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv = nn.Conv2d(1, 32, 3, 1)
        self.dropout = nn.Dropout(0.25)
        self.fc = nn.Linear(5408, 10)

    def forward(self, x):
        x = self.conv(x.float())
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        output = F.log_softmax(x, dim=1)
        return output


## Training and validation

In [None]:
def train(trainloader: torch.utils.data.DataLoader, valloader: torch.utils.data.DataLoader, net: nn.Module):
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(net.parameters(), lr=0.0001, momentum=0.9)
    for epoch in range(2):
        print(f"Epoch {epoch}")
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            X, y = data
            X = X.permute(0, 3, 1, 2)
            optimizer.zero_grad()
            outputs = net(X)
            loss = criterion(outputs, y)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            if not i % 1000:
                print(f"Loss {loss.item()}")
        validate(net, valloader)
    print("Finished Training")

In [None]:
def validate(net, valloader):
    correct_count, all_count = 0, 0
    for i, data in enumerate(valloader):
        X, y = data
        X = X.permute(0, 3, 1, 2)
        with torch.no_grad():
            outputs = net(X)
        pred_label = outputs.argmax(1)
        correct_count += np.sum(pred_label.numpy() == y.numpy())
        all_count += len(pred_label)

    print("Number Of Images Tested =", all_count)
    print("\nModel Accuracy =", (correct_count/all_count))

## Convert to PyTorch, split the data and train

In [None]:
def transform(data):
    img = data['image']
    label = data['label']
    return img, label

In [None]:
torch_ds = ds.to_pytorch(transform=transform)
net = Net()
train_len = int(0.8 * len(torch_ds))
test_len = len(torch_ds) - train_len
train_ds, val_ds = random_split(torch_ds, [train_len, test_len])
train_dataloader = torch.utils.data.DataLoader(
        train_ds,
        batch_size=8,
        shuffle=True,
        num_workers=8
    )
val_dataloader = torch.utils.data.DataLoader(
        val_ds,
        batch_size=8,
        shuffle=False,
        num_workers=8
    )
train(train_dataloader, val_dataloader, net)
torch.save(net, "/tmp/model_mnist.pth")