In [1]:
from torch.utils.data import random_split, DataLoader
from torchvision import transforms
from torchvision.datasets import ImageFolder
import torch

In [2]:
torch.manual_seed(0)
torch.cuda.manual_seed(0)
path = 'dataset/test (копия)'

In [3]:
transform = transforms.Compose([
    transforms.Grayscale(),
    transforms.Resize([32, 32]),
    transforms.ToTensor(),
    transforms.RandomRotation(360, transforms.InterpolationMode.NEAREST, expand=False)
])
data = ImageFolder(path, transform=transform)
train, test = random_split(data, (int(0.8 * len(data)), len(data) - int(0.8 * len(data))))
train_batch = DataLoader(train, shuffle=True, batch_size=100, drop_last=True)
test_batch = DataLoader(test, shuffle=True, batch_size=100, drop_last=True)

In [4]:
class CNN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=1, out_channels=10, kernel_size=3, padding=1),
            torch.nn.Conv2d(in_channels=10, out_channels=10,kernel_size=3, padding=1),
            torch.nn.Tanh(),
            torch.nn.MaxPool2d(kernel_size=2, stride=2),
            torch.nn.Conv2d(in_channels=10, out_channels = 30, kernel_size=5),
            torch.nn.Tanh(),
            torch.nn.MaxPool2d(kernel_size=2, stride=2),
            torch.nn.Flatten(),
            torch.nn.Linear(6*6*30, 320),
            torch.nn.Tanh(),
            torch.nn.Linear(320, 160),
            torch.nn.Tanh(),
            torch.nn.Linear(160, 33))


    def forward(self, x):
        x = self.conv(x)
        return x

In [5]:
torch.cuda.init()
torch.cuda.empty_cache()
cnn = CNN()
cnn = cnn.cuda()

loss = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(cnn.parameters(), lr=0.001)

In [6]:
loss_history = []
accuracy_history = []
test_batch = list(test_batch)

In [7]:
for epoch in range(1000):
    for x_train, y_train in train_batch:
        x_train = x_train.cuda()
        y_train = y_train.cuda()
        optimizer.zero_grad()
        preds = cnn.forward(x_train)
        targets = y_train
        loss_value = loss(preds, targets)
        loss_value.backward()
        optimizer.step()

    test_data = test_batch[epoch % len(test_batch)]
    x_data, y_data = test_data[0], test_data[1]
    x_data, y_data = x_data.cuda(), y_data.cuda()
    test_pred = cnn.forward(x_data)
    accuracy = (test_pred.argmax(dim=1) == y_data).float().mean().data.cpu()
    loss_test = loss(test_pred, y_data).data.cpu()
    loss_history.append(loss_test)
    accuracy_history.append(accuracy)
    print('epoch:', epoch, 'loss:', loss_history[-1], 'accuracy', accuracy)
    if epoch > 2 and accuracy_history[-1] > 0.9 and accuracy_history[-1] < accuracy_history[-2] < accuracy_history[-3]\
            or accuracy >  0.98:
        break

epoch: 0 loss: tensor(2.7419) accuracy tensor(0.2600)
epoch: 1 loss: tensor(2.2057) accuracy tensor(0.4000)
epoch: 2 loss: tensor(1.9861) accuracy tensor(0.3800)
epoch: 3 loss: tensor(1.5508) accuracy tensor(0.5500)
epoch: 4 loss: tensor(1.0885) accuracy tensor(0.6400)
epoch: 5 loss: tensor(1.1960) accuracy tensor(0.6200)
epoch: 6 loss: tensor(1.1380) accuracy tensor(0.6200)
epoch: 7 loss: tensor(0.7641) accuracy tensor(0.7500)
epoch: 8 loss: tensor(0.7714) accuracy tensor(0.7900)
epoch: 9 loss: tensor(0.6736) accuracy tensor(0.7700)
epoch: 10 loss: tensor(0.5971) accuracy tensor(0.8500)
epoch: 11 loss: tensor(0.7352) accuracy tensor(0.7400)
epoch: 12 loss: tensor(0.5001) accuracy tensor(0.8500)
epoch: 13 loss: tensor(0.6662) accuracy tensor(0.8100)
epoch: 14 loss: tensor(0.6830) accuracy tensor(0.7600)
epoch: 15 loss: tensor(0.5575) accuracy tensor(0.8300)
epoch: 16 loss: tensor(0.4558) accuracy tensor(0.8600)
epoch: 17 loss: tensor(0.5065) accuracy tensor(0.8200)
epoch: 18 loss: tens

KeyboardInterrupt: 

In [8]:
cnn = cnn.cpu()
torch.save(cnn.state_dict(), './cnn.pkl')

In [9]:
a = CNN()
a.load_state_dict(torch.load('cnn.pkl'))

<All keys matched successfully>

In [10]:
b = torch.ones(2, 1, 32, 32)
cnn.forward(b)

tensor([[ 5.0859, -5.1316, -4.4441, -2.3164, 11.0077, -7.0059,  2.6194, -2.4613,
          3.8318, -1.4732,  2.9478,  3.6090,  4.9873, -4.0167, -1.1845, -2.0827,
          0.1791,  5.2728,  7.1255,  8.2024,  6.4925, -7.1881,  1.6407, -0.6416,
          2.6904, -5.1322, -8.7798,  1.6708, -6.8035,  1.3163,  3.0429, -4.6906,
         -2.5766],
        [ 5.0859, -5.1316, -4.4441, -2.3164, 11.0077, -7.0059,  2.6194, -2.4613,
          3.8318, -1.4732,  2.9478,  3.6090,  4.9873, -4.0167, -1.1845, -2.0827,
          0.1791,  5.2728,  7.1255,  8.2024,  6.4925, -7.1881,  1.6407, -0.6416,
          2.6904, -5.1322, -8.7798,  1.6708, -6.8035,  1.3163,  3.0429, -4.6906,
         -2.5766]], grad_fn=<AddmmBackward0>)

In [13]:
a.forward(b).argmax(1)

tensor([4, 4])

In [9]:
print(data.classes)

['Ё', 'А', 'Б', 'В', 'Г', 'Д', 'Е', 'Ж', 'З', 'И', 'Й', 'К', 'Л', 'М', 'Н', 'О', 'П', 'Р', 'С', 'Т', 'У', 'Ф', 'Х', 'Ц', 'Ч', 'Ш', 'Щ', 'Ъ', 'Ы', 'Ь', 'Э', 'Ю', 'Я']
