# Import Dependencies

In [None]:
import torch 
from PIL import Image
from torch import nn, save, load
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

# Get Data

In [None]:
# shape = (1, 28, 28) -> class: 0-9
train = datasets.MNIST(
    root="data", download=True, train=True, 
    transform=ToTensor()
)

# batches of 32 images
dataset = DataLoader(train, 32)

# Model

In [None]:
class ImageClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(                     # Convolutional Layer (main building block of a CNN)
            nn.Conv2d(1, 32, (3, 3)),                   # 1 input channel (black/white), 32 filters of shape (3, 3)
            nn.ReLU(),                                  # Activation Function to handle non-linearity
            nn.Conv2d(32, 64, (3, 3)),                  # 32 input channels, 64 output channels
            nn.ReLU(),                                  # Activation Function to handle non-linearity
            nn.Conv2d(64, 64, (3, 3)),                  # 64 input channels, 64 output channels
            nn.ReLU(),                                  # Activation Function to handle non-linearity
            nn.Flatten(),                               # Flatten Layer to 1 dimension
            nn.Linear(64 * (28 - 6) * (28 - 6), 10)     # 64 (from last output) * 3 Convolutional layer (3 * 2 = 6)
        )                                               # Image size = (1, 28, 28) -> remove pixels (28 - 6), 10 classes

    def forward(self, x):
        return self.model(x)

In [None]:
# Instance of the neural network, loss, optimizer
clf = ImageClassifier().to('cuda')
opt = Adam(clf.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()

# Train Model

In [None]:
# train for 10 epochs
for epoch in range(10):
    for batch in dataset:
        x, y = batch
        x, y = x.to('cuda'), y.to('cuda')
        yhat = clf(x)
        loss = loss_fn(yhat, y)

        # apply backprop
        opt.zero_grad()
        loss.backward()
        opt.step()

    print(f"Epoch: {epoch}, loss: {loss.item()}")

# Save and reload model

In [None]:
# save model
with open('model_state.pt', 'wb') as f: 
    save(clf.state_dict(), f)

In [None]:
# load model
with open('model_state.pt', 'rb') as f: 
    clf.load_state_dict(load(f))

# Test Model

In [1]:
img = Image.open('test_image.jpg') 
img_tensor = ToTensor()(img).unsqueeze(0).to('cuda')

print(torch.argmax(clf(img_tensor)))

tensor(2, device='cuda:0')
