# CNN for tg bot that solve MNIST problem

In [1]:
%matplotlib inline
import matplotlib.pyplot as plt

import torch
import torch.nn as nn

from torch.utils.data import DataLoader
from torchvision import transforms as tfs
from torchvision import datasets

## 1. Preparing a dataset

In [2]:
# download MNIST Dataset

tr = tfs.Compose([
    tfs.ToTensor(),
    tfs.Normalize((0.5), (0.5))
])

path = '/home/ubuntu/jupyter_notebooks/mnist_dataset'

train_dataset = datasets.MNIST(
    root=path,
    train=True,
    download=True,
    transform=tr
)

valid_dataset = datasets.MNIST(
    root=path,
    train=False,
    download=True,
    transform=tr
)

In [3]:
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=4,
                                          shuffle=True, num_workers=4)

valid_dataloader = torch.utils.data.DataLoader(valid_dataset, batch_size=4,
                                         shuffle=False, num_workers=4)

## 2. Train CNN

In [10]:
class MnistCnn(nn.Module):
    def __init__(self):
        super(MnistCnn, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=6, kernel_size=3),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(in_channels=6, out_channels=16, kernel_size=3),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
        )
        
        LIN_IN = 5 * 5 * 16
        self.lin = nn.Sequential(
            nn.Linear(LIN_IN, 120),
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU()
        )
        self.out = nn.Linear(84, 10)
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        
        x = x.view(-1, 5 * 5 * 16)
        
        x = self.lin(x)
        logits = self.out(x)
        
        return logits

In [11]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [12]:
model = MnistCnn()
model = model.to(device)

In [13]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())

loaders = {
    'train': train_dataloader,
    'valid': valid_dataloader
}

In [20]:
from tqdm import tqdm, trange

In [22]:
max_epochs = 10
accuracy = { 'train': [], 'valid': [] }

for epoch in trange(max_epochs):
    epoch_correct = 0
    epoch_all = 0
    
    for k, dataloader in loaders.items():
        for x_batch, y_batch in dataloader:
            if k == 'train':
                model.train()
                model.zero_grad()
                output = model.forward(x_batch)
            else:
                model.eval()
                with torch.no_grad():
                    output = model.forward(x_batch)
            
            preds = output.argmax(-1)
            epoch_correct += len(y_batch) - (y_batch - preds).count_nonzero()
            epoch_all += len(y_batch)
            
            if k == 'train': 
                loss = criterion(output, y_batch)
                loss.backward()
                optimizer.step()
        
        if k == 'train':
            print('Epoch:', epoch + 1)
        
        epoch_accuracy = epoch_correct / epoch_all
        print(f'Loader: {k}. Accuracy: {epoch_accuracy}')
        accuracy[k].append(epoch_accuracy)

  0%|          | 0/10 [00:00<?, ?it/s]

Epoch: 1
Loader: train. Accuracy: 0.9818000197410583


 10%|█         | 1/10 [01:31<13:42, 91.42s/it]

Loader: valid. Accuracy: 0.9821571707725525
Epoch: 2
Loader: train. Accuracy: 0.986133337020874


 20%|██        | 2/10 [03:02<12:09, 91.19s/it]

Loader: valid. Accuracy: 0.9861142635345459
Epoch: 3
Loader: train. Accuracy: 0.9887999892234802


 30%|███       | 3/10 [04:35<10:44, 92.11s/it]

Loader: valid. Accuracy: 0.988099992275238
Epoch: 4
Loader: train. Accuracy: 0.9897666573524475


 40%|████      | 4/10 [06:18<09:39, 96.51s/it]

Loader: valid. Accuracy: 0.9896000027656555
Epoch: 5
Loader: train. Accuracy: 0.9912999868392944


 50%|█████     | 5/10 [08:09<08:27, 101.52s/it]

Loader: valid. Accuracy: 0.990842878818512
Epoch: 6
Loader: train. Accuracy: 0.9923499822616577


 60%|██████    | 6/10 [09:59<06:58, 104.62s/it]

Loader: valid. Accuracy: 0.9919142723083496
Epoch: 7
Loader: train. Accuracy: 0.992900013923645


 70%|███████   | 7/10 [11:50<05:19, 106.62s/it]

Loader: valid. Accuracy: 0.9916285872459412
Epoch: 8
Loader: train. Accuracy: 0.9925500154495239


 80%|████████  | 8/10 [13:41<03:36, 108.05s/it]

Loader: valid. Accuracy: 0.9916428327560425
Epoch: 9
Loader: train. Accuracy: 0.9933333396911621


 90%|█████████ | 9/10 [15:34<01:49, 109.46s/it]

Loader: valid. Accuracy: 0.9924142956733704
Epoch: 10
Loader: train. Accuracy: 0.9942833185195923


100%|██████████| 10/10 [17:26<00:00, 104.68s/it]

Loader: valid. Accuracy: 0.9933428764343262





### Saving trained model to file

In [27]:
path_to_save_model = 'mnist_cnn.saved'

torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict()
}, path_to_save_model)