## Install required libraries

In [None]:
!pip install hub
!pip install matplotlib
!pip install torch

## Imports

In [5]:
import psutil
import torch
import numpy as np
import matplotlib.pyplot as plt
from torch import nn
import torch.nn.functional as F
from torch.utils.data import ConcatDataset
import torch
from torch.utils.data import random_split
import hub
from hub.compute.generic.ds_transforms import shift_scale_rotate, horizontal_flip
from hub.api.sharded_datasetview import ShardedDatasetView

## Load the dataset

In [3]:
ds = hub.load("activeloop/stl10_train")
ds_test = hub.load("activeloop/stl10_test")

## Augment images and add to the original Dataset

In [6]:
ds_augmented = horizontal_flip(shift_scale_rotate(ds, keys=['image'], rotate_limit=0, shift_limit=0.1), keys=['image'], p=0.2)
ds_augmented = ds_augmented.store("/tmp/stl10_aug")
ds_sharded = ShardedDatasetView([ds, ds_augmented])

@hub.transform(schema=ds_sharded.schema, scheduler="threaded", workers=psutil.cpu_count() - 1)
def transform_identity(sample):
    return sample

ds = transform_identity(ds_sharded).store('/tmp/stl10_all')

Computing the transformation in chunks of size 5000: 100%|██████████| 5.00k/5.00k [02:09<00:00, 38.7 items/s]
Computing the transformation in chunks of size 10000: 100%|██████████| 10.0k/10.0k [01:26<00:00, 116 items/s]


## Define a model

In [70]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 21 * 21, 1024)
        self.fc2 = nn.Linear(1024, 120)
        self.fc3 = nn.Linear(120, 84)
        self.fc4 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 21 * 21)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = self.fc4(x)
        return x

## Training and validation

In [71]:
def train(trainloader: torch.utils.data.DataLoader, valloader: torch.utils.data.DataLoader, net: nn.Module):
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
    for epoch in range(5):
        print(f"Epoch {epoch}")
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            X, y = data
            X = X.permute(0, 3, 1, 2)
            optimizer.zero_grad()
            outputs = net(X)
            loss = criterion(outputs, y)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            if not i % 100:
                print(f"Loss {loss.item()}")
        validate(net, valloader)
    print("Finished Training")

In [72]:
def validate(net, valloader):
    correct_count, all_count = 0, 0
    for i, data in enumerate(valloader):
        X, y = data
        if len(X.shape) != 4:
            X = torch.unsqueeze(X, 0)
        X = X.permute(0, 3, 1, 2)
        with torch.no_grad():
            outputs = net(X)
        pred_label = outputs.argmax(1)
        correct_count += np.sum(pred_label.numpy() == y.numpy())
        all_count += len(pred_label)

    print("Number Of Images Tested =", all_count)
    print("Model Accuracy =", (correct_count/all_count))

## Convert to PyTorch, split the data and train

In [73]:
def transform(data):
    img = data['image'] / 255
    label = data['label']
    return img, label

In [74]:
torch_ds = ds.to_pytorch(transform=transform)
net = Net()
train_len = int(0.8 * len(torch_ds))
test_len = len(torch_ds) - train_len
train_ds, val_ds = random_split(torch_ds, [train_len, test_len])
train_dataloader = torch.utils.data.DataLoader(
        train_ds,
        batch_size=8,
        shuffle=True,
        num_workers=8
    )
val_dataloader = torch.utils.data.DataLoader(
        val_ds,
        batch_size=8,
        shuffle=False,
        num_workers=8
    )
train(train_dataloader, val_dataloader, net)

Epoch 0
Loss 2.2967722415924072
Loss 2.3181405067443848
Loss 2.3172314167022705
Loss 2.302464246749878
Loss 2.3191092014312744
Loss 2.3067197799682617


KeyboardInterrupt: 

In [None]:
torch_ds_test = ds_test.to_pytorch(transform=transform)
validate(net, torch_ds_test)

In [None]:
torch.save(net, "/tmp/stl10_mnist.pth")