<a href="https://colab.research.google.com/github/Ianneee/can_AI_solve_sodoku/blob/main/src/cnn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install torchmetrics > /dev/null && echo "$(pip list | grep torchmetrics | awk '{print $1" "$2}') successfully installed"

torchmetrics 1.0.3 successfully installed


In [2]:
import torch
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from torch import nn
import torchmetrics

trainset = datasets.MNIST(root='./data', train=True, download=True, transform=ToTensor())
testset = datasets.MNIST(root='./data', train=False, download=True, transform=ToTensor())
print(f'Trainset length: {len(trainset)}' if len(trainset) > 0 else 'Trainset length is 0!')
print(f'Testset length: {len(testset)}' if len(testset) > 0 else 'Trainset length is 0!')

is_cuda = torch.cuda.is_available()
print(f'Gpu enabled: {is_cuda}')

device = ('cuda' if is_cuda else 'cpu')


class DigitsCNN(nn.Module):
    def __init__(self):
        super(DigitsCNN, self).__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(1, 5, 3),
            nn.ReLU(),
            nn.Conv2d(5, 10, 3),
            nn.ReLU()
        )
        self.mlp = nn.Sequential(
            nn.Linear(24 * 24 *10, 10),
            nn.ReLU(),
            nn.Linear(10, 10)
        )

    def forward(self, x):
        x = self.cnn(x)
        #print(x.shape)
        x = torch.flatten(x, 1)
        #print(x.shape)
        x = self.mlp(x)
        return x

# Ez print metrics
metric = torchmetrics.Accuracy(task='multiclass', num_classes=10).to(device)
#metric = torchmetrics.Accuracy(task='multiclass', num_classes=10)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 96262355.28it/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 119699302.20it/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 22837763.46it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 20911667.14it/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw






Trainset length: 60000
Testset length: 10000
Gpu enabled: True


In [None]:
def train_loop(dataloader, model, loss_fn, optimizer):
    for batch, (X, y) in enumerate(dataloader):
        X_gpu = X.to(device)
        y_gpu = y.to(device)

        pred = model(X_gpu)
        loss = loss_fn(pred, y_gpu)

        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        # Print stats
        if batch % 100 == 0:
            acc = metric(pred, y_gpu)
            print(f'Accuracy batch {batch}: {acc}')

    # Final accuracy for epoch
    acc = metric.compute()
    print(f'  Train accuracy: {acc}')
    metric.reset()

In [None]:
def test_loop(dataloader, model, loss_fn):
    # disable weight update on test
    with torch.no_grad():
        for X, y in dataloader:
            X_gpu = X.to(device)
            y_gpu = y.to(device)

            pred = model(X_gpu)
            acc = metric(pred, y_gpu)

    # print accuracy for epoch
    acc = metric.compute()
    print(f'  Test  accuracy : {acc}')
    metric.reset()

In [None]:
model = DigitsCNN().to(device)

# Hyperparameters
batch_size = 64
epochs = 5
learning_rate = 0.05

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
#optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

train_dataloader = DataLoader(trainset, batch_size=batch_size)
test_dataloader = DataLoader(testset, batch_size=batch_size)

In [None]:
# run
print(f'Parameters: batch size: {batch_size}, epochs: {epochs}, learning rate: {learning_rate}\n')
for t in range(epochs):
    print(f'Epoch: {t}')
    train_loop(train_dataloader, model, loss_fn, optimizer)
    test_loop(test_dataloader, model, loss_fn)
    print('\n')

Parameters: batch size: 64, epochs: 5, learning rate: 0.05

Epoch: 0
Accuracy batch 0: 0.140625
Accuracy batch 100: 0.484375
Accuracy batch 200: 0.828125
Accuracy batch 300: 0.875
Accuracy batch 400: 0.828125
Accuracy batch 500: 0.890625
Accuracy batch 600: 0.90625
Accuracy batch 700: 0.90625
Accuracy batch 800: 0.875
Accuracy batch 900: 0.921875
  Train accuracy: 0.765625
  Test accuracy : 0.9140999913215637


Epoch: 1
Accuracy batch 0: 0.921875
Accuracy batch 100: 0.9375
Accuracy batch 200: 0.96875
Accuracy batch 300: 0.90625
Accuracy batch 400: 0.953125
Accuracy batch 500: 0.96875
Accuracy batch 600: 0.953125
Accuracy batch 700: 0.9375
Accuracy batch 800: 0.953125
Accuracy batch 900: 0.9375
  Train accuracy: 0.9437500238418579
  Test accuracy : 0.9574000239372253


Epoch: 2
Accuracy batch 0: 0.96875
Accuracy batch 100: 0.953125
Accuracy batch 200: 0.9375
Accuracy batch 300: 0.9375
Accuracy batch 400: 1.0
Accuracy batch 500: 0.984375
Accuracy batch 600: 0.984375
Accuracy batch 700: 0

## Save and download the model

In [None]:
model_name = 'model.pt'

torch.save(model.state_dict(), model_name)
print(f'Model {model_name} saved.')

from google.colab import files
files.download(model_name)
print(f'Model {model_name} downloaded.')

Model model.pt saved.


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

Model model.pt downloaded.
