<a href="https://colab.research.google.com/github/Ianneee/can_AI_solve_sodoku/blob/main/src/cnn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install torchmetrics

Collecting torchmetrics
  Downloading torchmetrics-1.0.3-py3-none-any.whl (731 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m731.6/731.6 kB[0m [31m7.4 MB/s[0m eta [36m0:00:00[0m
Collecting lightning-utilities>=0.7.0 (from torchmetrics)
  Downloading lightning_utilities-0.9.0-py3-none-any.whl (23 kB)
Installing collected packages: lightning-utilities, torchmetrics
Successfully installed lightning-utilities-0.9.0 torchmetrics-1.0.3


In [14]:
import torch
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from torch import nn
import torchmetrics

is_cuda = torch.cuda.is_available()
print(f'Is gpu enabled: {is_cuda}')

device = ('cuda' if is_cuda else 'cpu')

trainset = datasets.MNIST(root='./data', train=True, download=True, transform=ToTensor())
testset = datasets.MNIST(root='./data', train=False, download=True, transform=ToTensor())

print(f'Trainset length: {len(trainset)}' if len(trainset) > 0 else 'Trainset length is 0!')
print(f'Testset length: {len(testset)}' if len(testset) > 0 else 'Trainset length is 0!')

class DigitsCNN(nn.Module):
    def __init__(self):
        super(DigitsCNN, self).__init__()
        self.cnn = nn.Sequential(
            # (Num of channels in input, number of filters, kernel size)
            nn.Conv2d(1, 5, 3),
            nn.ReLU(),
            nn.Conv2d(5, 10, 3),
            nn.ReLU()
        )
        self.mlp = nn.Sequential(
            nn.Linear(24 * 24 *10, 10),
            nn.ReLU(),
            nn.Linear(10, 10)
        )

    def forward(self, x):
        x = self.cnn(x)
        #print(x.shape)
        x = torch.flatten(x, 1)
        #print(x.shape)
        x = self.mlp(x)
        return x

 # Ez print metrics
metric = torchmetrics.Accuracy(task='multiclass', num_classes=10).to(device)
#metric = torchmetrics.Accuracy(task='multiclass', num_classes=10)


Is gpu enabled: True
Trainset length: 60000
Testset length: 10000


In [29]:
def train_loop(dataloader, model, loss_fn, optimizer):
    for batch, (X, y) in enumerate(dataloader):
        X_gpu = X.to(device)
        y_gpu = y.to(device)

        pred = model(X_gpu)
        loss = loss_fn(pred, y_gpu)

        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        # Print stats
        if batch % 100 == 0:
            acc = metric(pred, y_gpu)
            print(f'Accuracy batch {batch}: {acc}')

    # Final accuracy for epoch
    acc = metric.compute()
    print(f'  Train accuracy: {acc}')
    metric.reset()


In [31]:
def test_loop(dataloader, model, loss_fn):
    # disable weight update on test
    with torch.no_grad():
        for X, y in dataloader:
            X_gpu = X.to(device)
            y_gpu = y.to(device)

            pred = model(X_gpu)
            acc = metric(pred, y_gpu)

    # print accuracy for epoch
    acc = metric.compute()
    print(f'  Test accuracy : {acc}')
    metric.reset()

In [27]:
model = DigitsCNN().to(device)

# Hyperparameters
batch_size = 64
epochs = 5
learning_rate = 0.05

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
#optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

train_dataloader = DataLoader(trainset, batch_size=batch_size)
test_dataloader = DataLoader(testset, batch_size=batch_size)

In [30]:
# run
print(f'Parameters: batch size: {batch_size}, epochs: {epochs}, learning rate: {learning_rate}\n')
for t in range(epochs):
    print(f'Epoch: {t}')
    train_loop(train_dataloader, model, loss_fn, optimizer)
    test_loop(test_dataloader, model, loss_fn)
    print('\n')

Parameters: batch size: 64, epochs: 5, learning rate: 0.05

Epoch: 0
Accuracy batch 0: 0.984375
Accuracy batch 100: 0.984375
Accuracy batch 200: 0.984375
Accuracy batch 300: 0.9375
Accuracy batch 400: 1.0
Accuracy batch 500: 0.96875
Accuracy batch 600: 0.984375
Accuracy batch 700: 0.984375
Accuracy batch 800: 0.96875
Accuracy batch 900: 0.96875
  Train accuracy: 0.9765625
  Test accuracy: 0.9724000096321106


Epoch: 1
Accuracy batch 0: 1.0
Accuracy batch 100: 0.984375
Accuracy batch 200: 0.984375
Accuracy batch 300: 0.953125
Accuracy batch 400: 1.0
Accuracy batch 500: 0.96875
Accuracy batch 600: 0.984375
Accuracy batch 700: 0.984375
Accuracy batch 800: 0.953125
Accuracy batch 900: 0.984375
  Train accuracy: 0.979687511920929
  Test accuracy: 0.9732999801635742


Epoch: 2
Accuracy batch 0: 1.0
Accuracy batch 100: 1.0
Accuracy batch 200: 0.984375
Accuracy batch 300: 0.953125
Accuracy batch 400: 1.0
Accuracy batch 500: 0.96875
Accuracy batch 600: 0.984375
Accuracy batch 700: 0.96875
Accur