In [29]:
# !pip install ipywidgets
# !pip install wandb

In [30]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

device = torch.device('cpu')

import numpy as np

from torchvision.datasets.mnist import MNIST
from torchvision.transforms import transforms

import wandb

In [31]:
# class Block(nn.Module):
#     def __init__(self) -> None:
#         super().__init__()

#     def forward(self, x):
#         return x

#Model Creation
class MyCoolModel(nn.Module):
    def __init__(self, hidden_dim=256) -> None:
        super().__init__()
        self.l1 = nn.Linear(28*28, hidden_dim)
        self.l2 = nn.Linear(hidden_dim, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = self.l1(x)
        x = self.l2(x)
        return x

In [32]:
#training steps
def train_step(model, optimizer, train_loader):
    model.train()
    for batch_idx, (x, y) in enumerate(train_loader):
        x, y = x.to(device), y.to(device)

        #zero grad
        optimizer.zero_grad()

        #forward / compute loss
        y_hat = model(x)
        loss = F.cross_entropy(y_hat, y)


        #backward / opim
        loss.backward()

        optimizer.step()

        # if batch_idx%2000 == 0:
        #     wandb.log({"train loss":loss})
        #     pass


In [33]:
def test_step(model, test_loader):
    model.eval()
    correct = 0
    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.to(device), y.to(device)
            y_hat = model(x)

            pred = y_hat.argmax(dim=1, keepdim=True)
            correct += pred.eq(y.view_as(pred)).sum().item()

        print(f"\n Testing: test_acc: {correct}/{len(test_loader.dataset)} ({100. * correct / len(test_loader.dataset)})")

In [34]:
def main():
    wandb.config = {
    "learning_rate": 0.001,
    "epochs": 100,
    "batch_size": 128
    }
    #hparms
    epochs = 3
    learning_rate = 1e-3 # >> 0.001
    batch_size = 64

    # data loader
    mnist_train = MNIST("", train = True, download=True, transform=transforms.Compose([transforms.ToTensor()]))
    mnist_test = MNIST("", train = False, download=True, transform=transforms.ToTensor)

    train_loader = DataLoader(mnist_train, batch_size = batch_size)
    test_loader = DataLoader(mnist_test, batch_size = batch_size)

    #model
    model= MyCoolModel().to(device)

    #optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)

    wandb.init(project="MyCoolModel")

    test_step(model, train_loader)

    for epoch in range(epochs):
        train_step(model, optimizer, train_loader)

    test_step(model, train_loader)

In [35]:
if __name__ == '__main__':
    main()


VBox(children=(Label(value='0.002 MB of 0.002 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.988659…

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01666859361666866, max=1.0)…


 Testing: test_acc: 6627/60000 (11.045)

 Testing: test_acc: 54858/60000 (91.43)
