In [1]:
import os
import glob
import random

import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import transforms
from torch.utils.data import Dataset

import pytorch_lightning as pl
from pytorch_lightning import loggers
import numpy as np

from PIL import Image
from PIL import ImageFont
from PIL import ImageDraw

In [2]:
fonts_folder = "D:/projects/computer_vision/sudoku_solver/fonts/"
batch_sz = 100
n_iters = 2500
features_train = 60000
#num_epochs = int(n_iters / (features_train / batch_sz))
num_epochs = 5
loss_fn = nn.CrossEntropyLoss()
transform = transforms.Compose([transforms.Resize((28, 28)), transforms.ToTensor(),
                                transforms.Normalize((0.5, ), (0.5, ))])

In [3]:
class PrintedMNIST(Dataset):
    """Generates images containing a single digit from font"""

    def __init__(self, N, random_state, transform=None):
        """"""
        self.N = N
        self.random_state = random_state
        self.transform = transform

        #fonts_folder = "fonts"

        self.fonts = [fonts_folder + "helvetica_bold1.ttf", fonts_folder + 'AovelSansRounded-rdDL.ttf']
        #self.fonts = glob.glob(fonts_folder + "/*.ttf")

        random.seed(random_state)

    def __len__(self):
        return self.N

    def __getitem__(self, idx):

        target = random.randint(0, 9)
        color = 0
        # Generate image
        img = Image.new("L", (256, 256))
        img = np.array(img)
        img[img == 0] = 225
        img = Image.fromarray(img)
        
        target = random.randint(0, 9)
        size = 200
        x = 20
        y = 20

        draw = ImageDraw.Draw(img)
        font = ImageFont.truetype(random.choice(self.fonts), size)
        draw.text((x, y), str(target), color, font=font)
        shape = [(0, 0), (256 - 10, 256 - 10)]
        draw.rectangle(shape, outline ="black", width=4)

        img = img.resize((28, 28), Image.BILINEAR)

        if self.transform:
            img = self.transform(img)

        return img, target

In [4]:
class CNNModel(pl.LightningModule):

    def __init__(self, classes=10):
        super().__init__()
        self.save_hyperparameters()
        self.cnn1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, stride=1, padding=0)
        self.maxpool1 = nn.MaxPool2d(kernel_size=2)
        self.cnn2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1, padding=0)
        self.maxpool2 = nn.MaxPool2d(kernel_size=2)
        self.fc1 = nn.Linear(32 * 4 * 4, 10)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        batch_size, channels, width, height = x.size()
        x = self.cnn1(x)
        x = F.relu(x)
        x = self.maxpool1(x)
        x = self.cnn2(x)
        x = F.relu(x)
        x = self.maxpool2(x)
        x = x.view(batch_size, -1)
        x = self.fc1(x)
        out = self.softmax(x)
        return out

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = loss_fn(y_hat, y)
        correct = (y == y_hat.argmax(axis=1)).sum()
        logs = {'train_loss': loss}
        return {'loss': loss, 'log': logs, 'correct': correct, 'total': len(y)}

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        loss = loss_fn(y_hat, y)
        correct = (y == y_hat.argmax(axis=1)).sum()
        logs = {'val_loss': loss}
        return {'loss': loss, 'log': logs, 'correct': correct, 'total': len(y)}
    
    def training_epoch_end(self, outputs):
        avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
        correct = sum([x["correct"] for  x in outputs])
        total = sum([x["total"] for  x in outputs])
        self.log("train_loss", avg_loss, prog_bar=True, logger=True)
        self.log("train_acc", correct/total, prog_bar=True, logger=True)
    
    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
        correct = sum([x["correct"] for  x in outputs])
        total = sum([x["total"] for  x in outputs])
        self.log("val_loss", avg_loss, prog_bar=True, logger=True)
        self.log("val_acc", correct/total, prog_bar=True, logger=True)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)

    def train_dataloader(self):
        train_set = PrintedMNIST(50000, -666, transform)
        train_loader = DataLoader(train_set, batch_size=batch_sz, shuffle=False, num_workers=0, drop_last=True)
        return train_loader

    def val_dataloader(self):
        val_set = PrintedMNIST(5000, 33, transform)
        val_loader = DataLoader(val_set, batch_size=batch_sz, num_workers=0)
        return val_loader

In [5]:
tb_logger = loggers.TensorBoardLogger('logs/')
model = CNNModel()
trainer = pl.Trainer(gpus=1, max_epochs=num_epochs, logger=tb_logger, checkpoint_callback=False)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [6]:
trainer.fit(model)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name     | Type      | Params
---------------------------------------
0 | cnn1     | Conv2d    | 416   
1 | maxpool1 | MaxPool2d | 0     
2 | cnn2     | Conv2d    | 12.8 K
3 | maxpool2 | MaxPool2d | 0     
4 | fc1      | Linear    | 5.1 K 
5 | softmax  | Softmax   | 0     
---------------------------------------
18.4 K    Trainable params
0         Non-trainable params
18.4 K    Total params
0.074     Total estimated model params size (MB)


Validation sanity check:   0%|                                                                                                                                                            | 0/2 [00:00<?, ?it/s]

  f"The dataloader, {name}, does not have many workers which may be a bottleneck."
  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


                                                                                                                                                                                                                

  f"The dataloader, {name}, does not have many workers which may be a bottleneck."


Epoch 0:   0%|                                                                                                                                                                | 0/550 [00:00<00:00, 1002.70it/s]

  f"One of the returned values {set(extra.keys())} has a `grad_fn`. We will detach it automatically"


Epoch 0:  91%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎            | 500/550 [00:47<00:04, 10.48it/s, loss=1.46, v_num=20]
Validating: 0it [00:00, ?it/s][A
Epoch 0:  91%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊            | 502/550 [00:47<00:04, 10.50it/s, loss=1.46, v_num=20][A
Validating:   4%|██████▍                                                                                                                                                         | 2/50 [00:00<00:04, 10.84it/s][A
Epoch 0:  92%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋           | 505/550 [00:48<00:04, 10.50it/s, loss=1.46, v_num=20][A
Epoch 0:  92%|███████████████████████████████████████████████████████████████████████████████████████████████████████████

Epoch 1:  95%|█████████████████████████████████████████████████████████████████████▎   | 522/550 [00:49<00:02, 10.55it/s, loss=1.46, v_num=20, val_loss=1.460, val_acc=1.000, train_loss=1.460, train_acc=1.000][A
Validating:  44%|█████████████████████████████████████████████████████████████████████▉                                                                                         | 22/50 [00:02<00:02, 10.68it/s][A
Epoch 1:  95%|█████████████████████████████████████████████████████████████████████▋   | 525/550 [00:49<00:02, 10.55it/s, loss=1.46, v_num=20, val_loss=1.460, val_acc=1.000, train_loss=1.460, train_acc=1.000][A
Epoch 1:  96%|██████████████████████████████████████████████████████████████████████   | 528/550 [00:50<00:02, 10.55it/s, loss=1.46, v_num=20, val_loss=1.460, val_acc=1.000, train_loss=1.460, train_acc=1.000][A
Validating:  56%|█████████████████████████████████████████████████████████████████████████████████████████                                              

Epoch 2:  99%|████████████████████████████████████████████████████████████████████████ | 543/550 [00:51<00:00, 10.56it/s, loss=1.46, v_num=20, val_loss=1.460, val_acc=1.000, train_loss=1.460, train_acc=1.000][A
Epoch 2:  99%|████████████████████████████████████████████████████████████████████████▍| 546/550 [00:51<00:00, 10.56it/s, loss=1.46, v_num=20, val_loss=1.460, val_acc=1.000, train_loss=1.460, train_acc=1.000][A
Validating:  92%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎            | 46/50 [00:04<00:00, 10.78it/s][A
Epoch 2: 100%|████████████████████████████████████████████████████████████████████████▊| 549/550 [00:52<00:00, 10.56it/s, loss=1.46, v_num=20, val_loss=1.460, val_acc=1.000, train_loss=1.460, train_acc=1.000][A
Epoch 2: 100%|█████████████████████████████████████████████████████████████████████████| 550/550 [00:52<00:00, 10.54it/s, loss=1.46, v_num=20, val_loss=

Validating:  20%|███████████████████████████████▊                                                                                                                               | 10/50 [00:00<00:03, 10.87it/s][A
Epoch 4:  93%|████████████████████████████████████████████████████████████████████     | 513/550 [00:49<00:03, 10.48it/s, loss=1.46, v_num=20, val_loss=1.460, val_acc=1.000, train_loss=1.460, train_acc=1.000][A
Epoch 4:  94%|████████████████████████████████████████████████████████████████████▍    | 516/550 [00:49<00:03, 10.48it/s, loss=1.46, v_num=20, val_loss=1.460, val_acc=1.000, train_loss=1.460, train_acc=1.000][A
Validating:  32%|██████████████████████████████████████████████████▉                                                                                                            | 16/50 [00:01<00:03, 10.85it/s][A
Epoch 4:  94%|████████████████████████████████████████████████████████████████████▉    | 519/550 [00:49<00:02, 10.48it/s, loss=1.46, v_num=20, val_loss=

In [7]:
import cv2
import numpy as np
img = cv2.imread('cell.png')
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
img = cv2.resize(img, (28, 28))
img = Image.fromarray(img)
img = transform(img)
img = img.reshape(1, 1, 28, 28)
predictions = model(img.float())
predictions = predictions.detach().numpy()
print(np.argmax(predictions))

3
