In [1]:
import os

import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms

import pytorch_lightning as pl
from pytorch_lightning import loggers

from PIL import Image
import numpy as np
import cv2

In [2]:
batch_sz = 100
n_iters = 2500
features_train = 60000
#num_epochs = int(n_iters / (features_train / batch_sz))
num_epochs = 5
loss_fn = nn.CrossEntropyLoss()

In [3]:
transform = transforms.Compose([transforms.Resize((28, 28)), transforms.ToTensor(),
                                transforms.Normalize((0.5, ), (0.5, ))])

In [4]:
class CNNModel(pl.LightningModule):

    def __init__(self, classes=10):
        super().__init__()
        self.save_hyperparameters()
        self.cnn1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, stride=1, padding=0)
        self.maxpool1 = nn.MaxPool2d(kernel_size=2)
        self.cnn2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1, padding=0)
        self.maxpool2 = nn.MaxPool2d(kernel_size=2)
        self.fc1 = nn.Linear(32 * 4 * 4, 10)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        batch_size, channels, width, height = x.size()
        x = self.cnn1(x)
        x = F.relu(x)
        x = self.maxpool1(x)
        x = self.cnn2(x)
        x = F.relu(x)
        x = self.maxpool2(x)
        x = x.view(batch_size, -1)
        x = self.fc1(x)
        out = self.softmax(x)
        return out

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = loss_fn(y_hat, y)
        correct = (y == y_hat.argmax(axis=1)).sum()
        logs = {'train_loss': loss}
        return {'loss': loss, 'log': logs, 'correct': correct, 'total': len(y)}

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        loss = loss_fn(y_hat, y)
        correct = (y == y_hat.argmax(axis=1)).sum()
        logs = {'val_loss': loss}
        return {'loss': loss, 'log': logs, 'correct': correct, 'total': len(y)}
    
    def training_epoch_end(self, outputs):
        avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
        correct = sum([x["correct"] for  x in outputs])
        total = sum([x["total"] for  x in outputs])
        self.log("train_loss", avg_loss, prog_bar=True, logger=True)
        self.log("train_acc", correct/total, prog_bar=True, logger=True)
    
    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
        correct = sum([x["correct"] for  x in outputs])
        total = sum([x["total"] for  x in outputs])
        self.log("val_loss", avg_loss, prog_bar=True, logger=True)
        self.log("val_acc", correct/total, prog_bar=True, logger=True)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)

    def prepare_data(self):
        #MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor())
        MNIST(os.getcwd(), train=True, download=True)

    def train_dataloader(self):
        mnist_train = MNIST(os.getcwd(), train=True, download=False, transform=transform)
        loader = DataLoader(mnist_train, batch_size=batch_sz, num_workers=4)
        return loader

    def val_dataloader(self):
        mnist_val = MNIST(os.getcwd(), train=False, download=False, transform=transform)
        return DataLoader(mnist_val, batch_size=batch_sz, num_workers=4)

In [5]:
tb_logger = loggers.TensorBoardLogger('logs/')
model = CNNModel()
trainer = pl.Trainer(gpus=1, max_epochs=num_epochs, logger=tb_logger, checkpoint_callback=False)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [6]:
trainer.fit(model)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name     | Type      | Params
---------------------------------------
0 | cnn1     | Conv2d    | 416   
1 | maxpool1 | MaxPool2d | 0     
2 | cnn2     | Conv2d    | 12.8 K
3 | maxpool2 | MaxPool2d | 0     
4 | fc1      | Linear    | 5.1 K 
5 | softmax  | Softmax   | 0     
---------------------------------------
18.4 K    Trainable params
0         Non-trainable params
18.4 K    Total params
0.074     Total estimated model params size (MB)


Validation sanity check:   0%|                                                                                                                                                            | 0/2 [00:00<?, ?it/s]

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


Epoch 0:   4%|█████▍                                                                                                                                       | 27/700 [00:01<00:40, 16.45it/s, loss=2.07, v_num=7]

  f"One of the returned values {set(extra.keys())} has a `grad_fn`. We will detach it automatically"


Epoch 0:  86%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏                   | 600/700 [00:05<00:00, 107.49it/s, loss=1.57, v_num=7]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|                                                                                                                                                                       | 0/100 [00:00<?, ?it/s][A
Epoch 0:  87%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍                 | 612/700 [00:07<00:01, 85.89it/s, loss=1.57, v_num=7][A
Epoch 0:  92%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍           | 642/700 [00:07<00:00, 88.85it/s, loss=1.57, v_num=7][A
Epoch 0:  96%|███████████████████████████████████████████████████████████████████████████████████████████████████████████

In [7]:
PATH = "D:/projects/computer_vision/sudoku_solver/model/base_recognizer.pth"
torch.save(model.state_dict(), PATH)

In [None]:
#img = cv2.imread('cell.png')
#img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
#img = cv2.resize(img, (28, 28))
#img = Image.fromarray(img)
#img = transform(img)
#img = img.reshape(1, 1, 28, 28)
#predictions = model(img.float())
#predictions = predictions.detach().numpy()
#print(np.argmax(predictions))