In [1]:
from torchvision.datasets import CIFAR10
from torchvision.transforms import ToTensor
from torchvision.models import resnet18
from torch import nn
from torch.optim import Adam
from shiba import Trainer
from shiba.callbacks import TensorBoard, Metric
from shiba.metrics import accuracy
from shiba.vis import classification_snapshot
import matplotlib.pyplot as plt
%matplotlib inline



## Config

In [2]:
!rm -rf runs/ # clear tb logs

In [3]:
data_path = 'cifar'

## Load Data

In [4]:
train_dataset = CIFAR10(data_path, train=True, download=True, transform=ToTensor())
val_dataset = CIFAR10(data_path, train=False, download=True, transform=ToTensor())

Files already downloaded and verified
Files already downloaded and verified


## Simple VGGlike Network

In [5]:
class SimpleNet(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(SimpleNet, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, 6, kernel_size=3, padding=1)
        self.relu = nn.ReLU(inplace=True)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(16 * 8 * 8, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, out_channels)

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.pool(x)
        x = self.relu(self.conv2(x))
        x = self.pool(x)
        x = x.view(x.size(0), -1)
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)
        return x


In [6]:
experiment_name = 'cifar-test'
model = SimpleNet(in_channels=3, out_channels=10)
criterion = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters())

In [7]:
trainer = Trainer(model, criterion, optimizer, train_dataset, val_dataset)

## Callbacks

In [8]:
tensorboard = TensorBoard(experiment_name, snapshot_func=classification_snapshot, hyperparams={'dog':1, 'lr': 3e-4, 'daily_lacroix': 3})
accuracy = Metric(accuracy, 'accuracy', output_transform=lambda x: x.argmax(dim=1))
callbacks = [tensorboard, accuracy]

## Train

In [9]:
trainer.fit(max_epochs=1, callbacks=callbacks)

HBox(children=(IntProgress(value=0, max=1), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))




In [10]:
trainer.fit(max_epochs=5, callbacks=callbacks)

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))


