In [1]:
!pip install ipywidgets
!pip install wandb

Defaulting to user installation because normal site-packages is not writeable


In [30]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

import numpy as np
import tqdm
import wandb

from torchvision.datasets.mnist import MNIST
from torchvision import transforms

In [19]:
class MyCoolModel(nn.Module):
    def __init__(self, hidden_dim=256) -> None:
        super().__init__()
        self.l1 = nn.Linear(28*28, hidden_dim)
        self.l2 = nn.Linear(hidden_dim, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = self.l1(x)
        x = self.l2(x)
        return x

In [25]:
def train_step(model, opimizer, train_loader):
    model.train()
    for batch_idx, (x, y) in enumerate(train_loader):
        x, y = x.to(device), y.to(device)

        # zero grad
        opimizer.zero_grad()

        # forward / compute loss
        y_hat = model(x)
        loss = F.cross_entropy(y_hat, y)

        # backward / opim
        loss.backward()
        opimizer.step()


In [26]:
def test_step(model, test_loader):
    model.eval()
    correct = 0
    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.to(device), y.to(device)
            y_hat = model(x)

            pred = y_hat.argmax(dim=1, keepdim=True)

            correct += pred.eq(y.view_as(pred)).sum().item()

        print(f'\ntesting: test_acc: {correct}/{len(test_loader.dataset)} ({100. * correct / len(test_loader.dataset)}%)')


In [31]:
def main():
    # hparms
    epochs = 3
    learning_rate = 1e-3 # >> 0.001
    batch_size = 64

    # data loaders
    mnist_train = MNIST("", train=True, download=True, transform=transforms.ToTensor())
    mnist_test = MNIST("", train=False, download=True, transform=transforms.ToTensor())

    train_loader = DataLoader(mnist_train, batch_size=batch_size)
    test_loader = DataLoader(mnist_test, batch_size=batch_size)

    # model
    model = MyCoolModel().to(device)

    # optim
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    # wandb >> you will need to create and account and have access to you api key
    wandb.init(project="mnist_classifier", config=config)

    # test model performance before training
    test_step(model, test_loader)

    for epoch in tqdm.trange(epochs):
        train_step(model, optimizer, train_loader)

    test_step(model, train_loader)

In [32]:
main()


testing: test_acc: 974/10000 (9.74%)



100%|██████████| 10/10 [00:56<00:00,  5.69s/it]
