## Install required libraries

In [None]:
!pip install hub
!pip install matplotlib
!pip install torch

## Log in to hub


In [None]:
!hub login

## Imports

In [1]:
import psutil
import torch
import numpy as np
import matplotlib.pyplot as plt
from torch import nn
import torch.nn.functional as F
from torch.utils.data import ConcatDataset
import torch
from torch.utils.data import random_split
import hub
from hub.compute.generic.ds_transforms import shift_scale_rotate, horizontal_flip
from hub.api.sharded_datasetview import ShardedDatasetView

## Load the dataset

In [2]:
ds = hub.load("activeloop/cifar10_train")
ds_test = hub.load("activeloop/cifar10_test")

## Augment images and add to the original Dataset

In [3]:
ds_augmented = horizontal_flip(shift_scale_rotate(ds, keys=['image'], rotate_limit=0, shift_limit=0.1), keys=['image'], p=0.2)
ds_augmented = ds_augmented.store("/tmp/cidar10_aug")
ds_sharded = ShardedDatasetView([ds, ds_augmented])

@hub.transform(schema=ds_sharded.schema, scheduler="threaded", workers=psutil.cpu_count() - 1)
def transform_identity(sample):
    return sample

ds = transform_identity(ds_sharded).store('/tmp/cifar10_all')

  warn('ignoring keyword argument %r' % k)
Computing the transformation in chunks of size 50000: 100%|██████████| 50.0k/50.0k [06:43<00:00, 124 items/s]
Computing the transformation in chunks of size 100000: 100%|██████████| 100k/100k [11:33<00:00, 144 items/s]


## Define a model

In [9]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.reshape(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

## Training and validation

In [14]:
def train(trainloader: torch.utils.data.DataLoader, valloader: torch.utils.data.DataLoader, net: nn.Module):
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(net.parameters(), lr=0.0001, momentum=0.9)
    for epoch in range(10):
        print(f"Epoch {epoch}")
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            X, y = data
            X = X.permute(0, 3, 1, 2).to(device)
            optimizer.zero_grad()
            outputs = net(X)
            loss = criterion(outputs, y.to(device))
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            if not i % 100:
                print(f"Loss {loss.item()}")
        validate(net, valloader)
    print("Finished Training")

In [15]:
def validate(net, valloader):
    correct_count, all_count = 0, 0
    for i, data in enumerate(valloader):
        X, y = data
        if len(X.shape) != 4:
            X = torch.unsqueeze(X, 0)
        X = X.permute(0, 3, 1, 2).to(device)
        with torch.no_grad():
            outputs = net(X)
        pred_label = outputs.argmax(1)
        correct_count += np.sum(pred_label.numpy() == y.numpy())
        all_count += len(pred_label)

    print("Number Of Images Tested =", all_count)
    print("Model Accuracy =", (correct_count/all_count))

## Convert to PyTorch, split the data and train

In [16]:
def transform(data):
    img = data['image'].float()
    label = data['label']
    return img, label

In [17]:
torch_ds = ds.to_pytorch(transform=transform)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = Net()
net = net.to(device)
train_len = int(0.8 * len(torch_ds))
test_len = len(torch_ds) - train_len
train_ds, val_ds = random_split(torch_ds, [train_len, test_len])
train_dataloader = torch.utils.data.DataLoader(
        train_ds,
        batch_size=8,
        shuffle=True,
        num_workers=2
    )
val_dataloader = torch.utils.data.DataLoader(
        val_ds,
        batch_size=8,
        shuffle=False,
        num_workers=2
    )

train(train_dataloader, val_dataloader, net)

Epoch 0
Loss 8.593305587768555
Loss 2.2927889823913574
Loss 2.0732269287109375
Loss 2.2810120582580566
Loss 2.296006441116333
Loss 2.0966765880584717
Loss 2.217298746109009
Loss 1.8518821001052856
Loss 2.405604839324951
Loss 1.9361776113510132


KeyboardInterrupt: 

In [None]:
torch_ds_test = ds_test.to_pytorch(transform=transform)
validate(net, torch_ds_test)

In [None]:
torch.save(net, "/tmp/cifar10_model.pth")