## По материалам
### CTC loss original paper by A.Graves et al.

https://github.com/dredwardhyde/crnn-ctc-loss-example-pytorch/blob/main/icml_2006.pdf
### CTC loss reference implementation

https://github.com/dredwardhyde/crnn-ctc-loss-example-pytorch/blob/main/ctc_loss_example.py

### MNIST sequence recognition

https://github.com/dredwardhyde/crnn-ctc-loss-pytorch

In [4]:
import sys
from itertools import groupby

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data_utils
import torchvision.transforms.functional as TF
from colorama import Fore
from torchvision import datasets, transforms
from tqdm import tqdm
from MyDatasetRec import MyDatasetRec
from config_rec import all_alph

device = 'cuda' #'cuda' # torch.device('cuda')
# ============================================= PREPARING DATASET ======================================================

num_classes = 11
blank_label = 0
image_height = 28
gru_hidden_size = 128
gru_num_layers = 2
cnn_output_height = 4
cnn_output_width = 5 #количество максимальной длины, которую сеть может предсказать, зависит от разрмера картики
digits_per_sequence = 6
number_of_sequences = 10
#emnist_dataset = datasets.EMNIST('./EMNIST', split="digits", train=True, download=True)
import MyModelRec as my_model

conf = {
        'fonts':["example/TextBPNPlusPlus/dataset/MyGenerator/font.ttf"],
        'is_crop':[True],
        'texts':
            #list(map(str,np.random.random_integers(100,100000,1000)))
            #['12345678901234567890123456789012345678901234567890'] #14 epoch
            #['deletedresponsiblepersonid'] #477 epoch epoch=152
            #['isnull(cast(deletedresponsiblepersonid'] #407 epoch - epoch=98
            #['123'] #epoch=1203
            [
                '''}'''
            ] #epoch=250
        ,
        #'size_images':[((int(153/16)+1)*16 +2 ,18)] #for cnn_output_width = 19
        # 'size_images':[(cnn_output_width * 16 + 2 ,18)]
        'size_images':[(160,32)],
        'is_scale':[True],
        'scale_size':[(None,32)]
    }
#len(my_dataset.torch_text_dict.dict_chars)
# dig_str = '1234567890'
# spec_str = ''.join(set('-=~!@#$%^&*()_+!"№;%:?*()_+[];\',./{}|:"<>?\\.,'))
# eng_alph_str = 'qwertyuiopasdfghjklzxcvbnm'
# ru_alph_str = 'ёйцукенгшщзхъфывапролджэячсмитьбю'
# #num_classes = len(dig_str+spec_str+eng_alph_str+ru_alph_str)
# all_alph = '`'+ dig_str+spec_str+eng_alph_str+ru_alph_str

num_classes = len(all_alph) + 1
print(f'{num_classes=}')

my_dataset = MyDatasetRec([conf], all_alph, is_train = False)
for _ in my_dataset:
    pass

## создаем итераторы по данным
train_loader = torch.utils.data.DataLoader(my_dataset, batch_size=1, shuffle=True)
val_loader = torch.utils.data.DataLoader(my_dataset, batch_size=3, shuffle=True)

from torch.optim import lr_scheduler

# ================================================= MODEL ==============================================================

from MyModelRec import CRNN

## создаем модель
model = my_model.CRNN_v1(imgH=32,in_channels=3, nclass=num_classes, gru_size=256)
model.to(device)
#print(model)
## функция потерь
criterion = None
if model.version == 'v0':
    criterion = nn.CTCLoss(blank=blank_label, reduction='mean', zero_infinity=True)
elif model.version == 'v1':
    criterion = nn.CTCLoss(blank=blank_label, reduction='sum', zero_infinity=True)

assert criterion, 'criterion must be init'
    
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = lr_scheduler.StepLR(optimizer, step_size=200, gamma=0.9)

def testing(y_train, batch_size, y_pred):
    train_correct, train_total = 0, 0
    _, max_index = torch.max(y_pred, dim=2)  # max_index.shape == torch.Size([32, 64])
            ##print(f'{max_index.shape=}')
    for i in range(batch_size):
                #y_train_i = torch.Tensor([t1 for t1 in y_train[i] if t1 != 0])
        y_train_i = torch.IntTensor([c for c in y_train[i] if c != 0])
                
        raw_prediction = list(max_index[:, i].detach().cpu().numpy())  # len(raw_prediction) == 32            
        prediction = torch.IntTensor([c for c, _ in groupby(raw_prediction) if c != blank_label])
        txt = my_dataset.torch_text_dict.get_label(prediction)
        print(f'{txt=}')
        # print(f'{y_train_i=}')
        # print(f'{prediction=}')
        if len(prediction) == len(y_train_i):
            if torch.all(prediction.eq(y_train_i)):
                train_correct += 1
        train_total += 1
    return train_correct, train_total

min_loss = np.inf
train_iter = 0
writer = None
epochs = 2000
for epoch in range(epochs):
    #scheduler.step()
    test_correct, test_total = 0, 0

    for x_train, y_train in tqdm(train_loader, position=0, leave=True, file=sys.stdout):
        train_iter += 1
        print(f'{x_train.shape, y_train.shape=}')
        batch_size = x_train.shape[0]  # x_train.shape == torch.Size([64, 28, 140])
        #print(f'{batch_size=}')
        #x_train = x_train.view(x_train.shape[0], 1, x_train.shape[1], x_train.shape[2])
        optimizer.zero_grad()
        y_pred = model(x_train.to(device))
        #print(f'{y_pred1.shape=}') #y_pred.shape=torch.Size([3, 5, 12]) N, T, C: N - Batch size, T - Input sequence length, C - Number of classes (including blank)
        ##print(f'{y_pred1=}')
        #y_train = y_train[y_train>0]
        loss = None
        if model.version == 'v0':
            y_pred = y_pred.permute(1, 0, 2)  # y_pred.shape == torch.Size([5, 3, 12]) T, N, C
            input_lengths = torch.IntTensor(batch_size).fill_(cnn_output_width)
            target_lengths = torch.IntTensor([len([t1 for t1 in t if t1 != 0] ) for t in y_train])
            #print(f'{input_lengths=}')
            #print(f'{target_lengths=}')
            loss = criterion(y_pred, y_train, input_lengths, target_lengths)
        elif model.version == 'v1':
            preds_size = torch.IntTensor([y_pred.size(0)] * batch_size)  # seqLength x batchSize
            #print(f'{preds_size=}')
            target_lengths = torch.IntTensor([len([t1 for t1 in t if t1 != 0] ) for t in y_train])        
            loss = criterion(y_pred.log_softmax(2).cpu(), y_train, preds_size, target_lengths) / batch_size
        if writer != None:
            writer.add_scalar('train/loss',loss.item(), train_iter)

        loss.backward()
        optimizer.step()

        t_c, t_t = testing(y_train, batch_size, y_pred)
        
        test_correct += t_c
        test_total += t_t
    test_acc = test_correct / test_total
    print(f'TRAINING {epoch=} {scheduler.get_lr()=}. Correct: {test_correct=} / {test_total=} {test_acc:.3f}')
    if test_correct == test_total:
        print('finish train!')
        break


num_classes=102
self.all_param_list 1
get_word_formated_list...


1it [00:00, 1707.78it/s]


get_max_pix_size...
get_max_pix_size_from_param...
  0%|          | 0/1 [00:00<?, ?it/s]x_train.shape, y_train.shape=(torch.Size([1, 3, 32, 108]), torch.Size([1, 100]))
txt='!o!m!-а-!-!o-'
100%|██████████| 1/1 [00:00<00:00, 67.22it/s]
TRAINING epoch=0 scheduler.get_lr()=[0.001]. Correct: test_correct=0 / test_total=1 0.000
  0%|          | 0/1 [00:00<?, ?it/s]x_train.shape, y_train.shape=(torch.Size([1, 3, 32, 108]), torch.Size([1, 100]))
txt=''
100%|██████████| 1/1 [00:00<00:00, 74.70it/s]
TRAINING epoch=1 scheduler.get_lr()=[0.001]. Correct: test_correct=0 / test_total=1 0.000
  0%|          | 0/1 [00:00<?, ?it/s]x_train.shape, y_train.shape=(torch.Size([1, 3, 32, 108]), torch.Size([1, 100]))
txt=''
100%|██████████| 1/1 [00:00<00:00, 52.08it/s]
TRAINING epoch=2 scheduler.get_lr()=[0.001]. Correct: test_correct=0 / test_total=1 0.000
  0%|          | 0/1 [00:00<?, ?it/s]x_train.shape, y_train.shape=(torch.Size([1, 3, 32, 108]), torch.Size([1, 100]))
txt=''
100%|██████████| 1/1 [00:00<

In [33]:

[c for c, _ in groupby([76, 71, 76, 76, 76, 76, 71]) if c != blank_label]

[76, 71, 76, 71]

In [43]:
#y_pred1
print(max_index)
#raw_prediction = list(max_index[:, i].detach().cpu().numpy())
#raw_prediction
prediction

tensor([[5],
        [8],
        [8],
        [8],
        [8],
        [6],
        [0]])


tensor([5, 8, 6, 0], dtype=torch.int32)

In [44]:
y_train

tensor([[1, 2, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0]], dtype=torch.int32)

In [20]:
# ============================================ VALIDATION ==========================================================
val_correct = 0
val_total = 0
for x_val, y_val in tqdm(val_loader,
                         position=0, leave=True,
                         file=sys.stdout, bar_format="{l_bar}%s{bar}%s{r_bar}" % (Fore.BLUE, Fore.RESET)):
    batch_size = x_val.shape[0]
    #x_val = x_val.view(x_val.shape[0], 1, x_val.shape[1], x_val.shape[2])
    y_pred = model(x_val.to(gpu))
    y_pred = y_pred.permute(1, 0, 2)
    input_lengths = torch.IntTensor(batch_size).fill_(cnn_output_width)
    target_lengths = torch.IntTensor([len(t) for t in y_val])
    criterion(y_pred, y_val, input_lengths, target_lengths)
    _, max_index = torch.max(y_pred, dim=2)
    for i in range(batch_size):
        y_val_i = torch.Tensor([t1 for t1 in y_val[i] if t1 != 0])
        raw_prediction = list(max_index[:, i].detach().cpu().numpy())
        prediction = torch.IntTensor([c for c, _ in groupby(raw_prediction) if c != blank_label])
        if len(prediction) == len(y_val_i) and torch.all(prediction.eq(y_val_i)):
            val_correct += 1
        else:
            print(f'{prediction=}')
            print(f'{y_val_i=}')
        val_total += 1
print('TESTING. Correct: ', val_correct, '/', val_total, '=', val_correct / val_total)

 38%|[34m███▊      [39m| 128/334 [00:00<00:00, 213.12it/s]prediction=tensor([5, 5, 5, 7], dtype=torch.int32)
y_val_i=tensor([5., 5., 5., 5., 7.])
 45%|[34m████▍     [39m| 150/334 [00:00<00:00, 211.98it/s]prediction=tensor([8, 8, 8, 6], dtype=torch.int32)
y_val_i=tensor([8., 8., 8., 8., 6.])
 51%|[34m█████▏    [39m| 172/334 [00:00<00:00, 211.11it/s]prediction=tensor([1, 1, 9, 9], dtype=torch.int32)
y_val_i=tensor([1., 1., 9., 9., 9.])
 58%|[34m█████▊    [39m| 194/334 [00:00<00:00, 210.17it/s]prediction=tensor([1, 1, 7], dtype=torch.int32)
y_val_i=tensor([1., 1., 7., 7., 7.])
 78%|[34m███████▊  [39m| 259/334 [00:01<00:00, 206.41it/s]prediction=tensor([ 6,  6, 10], dtype=torch.int32)
y_val_i=tensor([ 6.,  6.,  6., 10., 10.])
 84%|[34m████████▍ [39m| 280/334 [00:01<00:00, 207.36it/s]prediction=tensor([2, 8, 8, 8], dtype=torch.int32)
y_val_i=tensor([2., 8., 8., 8., 8.])
 90%|[34m█████████ [39m| 301/334 [00:01<00:00, 207.45it/s]prediction=tensor([6, 6, 8], dtype=torch.int32)
y_

In [None]:

# ============================================ TESTING =================================================================
number_of_test_imgs = 1
test_loader = torch.utils.data.DataLoader(val_set, batch_size=number_of_test_imgs, shuffle=True)
test_preds = []
(x_test, y_test) = next(iter(test_loader))
y_pred = model(x_test.view(x_test.shape[0], 1, x_test.shape[1], x_test.shape[2]).to(gpu))
y_pred = y_pred.permute(1, 0, 2)
_, max_index = torch.max(y_pred, dim=2)
for i in range(x_test.shape[0]):
    raw_prediction = list(max_index[:, i].detach().cpu().numpy())
    prediction = torch.IntTensor([c for c, _ in groupby(raw_prediction) if c != blank_label])
    test_preds.append(prediction)

for j in range(len(x_test)):
    mpl.rcParams["font.size"] = 8
    plt.imshow(x_test[j], cmap='gray')
    mpl.rcParams["font.size"] = 18
    plt.gcf().text(x=0.1, y=0.1, s="Actual: " + str(y_test[j].numpy()))
    plt.gcf().text(x=0.1, y=0.2, s="Predicted: " + str(test_preds[j].numpy()))
    plt.show()



In [None]:

    # ============================================ VALIDATION ==========================================================
    val_correct = 0
    val_total = 0
    for x_val, y_val in tqdm(val_loader,
                             position=0, leave=True,
                             file=sys.stdout, bar_format="{l_bar}%s{bar}%s{r_bar}" % (Fore.BLUE, Fore.RESET)):
        batch_size = x_val.shape[0]
        x_val = x_val.view(x_val.shape[0], 1, x_val.shape[1], x_val.shape[2])
        y_pred = model(x_val.to(gpu))
        y_pred = y_pred.permute(1, 0, 2)
        input_lengths = torch.IntTensor(batch_size).fill_(cnn_output_width)
        target_lengths = torch.IntTensor([len(t) for t in y_val])
        criterion(y_pred, y_val, input_lengths, target_lengths)
        _, max_index = torch.max(y_pred, dim=2)
        for i in range(batch_size):
            raw_prediction = list(max_index[:, i].detach().cpu().numpy())
            prediction = torch.IntTensor([c for c, _ in groupby(raw_prediction) if c != blank_label])
            if len(prediction) == len(y_val[i]) and torch.all(prediction.eq(y_val[i])):
                val_correct += 1
            val_total += 1
    print('TESTING. Correct: ', val_correct, '/', val_total, '=', val_correct / val_total)

# ============================================ TESTING =================================================================
number_of_test_imgs = 1
test_loader = torch.utils.data.DataLoader(val_set, batch_size=number_of_test_imgs, shuffle=True)
test_preds = []
(x_test, y_test) = next(iter(test_loader))
y_pred = model(x_test.view(x_test.shape[0], 1, x_test.shape[1], x_test.shape[2]).to(gpu))
y_pred = y_pred.permute(1, 0, 2)
_, max_index = torch.max(y_pred, dim=2)
for i in range(x_test.shape[0]):
    raw_prediction = list(max_index[:, i].detach().cpu().numpy())
    prediction = torch.IntTensor([c for c, _ in groupby(raw_prediction) if c != blank_label])
    test_preds.append(prediction)

for j in range(len(x_test)):
    mpl.rcParams["font.size"] = 8
    plt.imshow(x_test[j], cmap='gray')
    mpl.rcParams["font.size"] = 18
    plt.gcf().text(x=0.1, y=0.1, s="Actual: " + str(y_test[j].numpy()))
    plt.gcf().text(x=0.1, y=0.2, s="Predicted: " + str(test_preds[j].numpy()))
    plt.show()



In [31]:
import torch 

img_size = torch.FloatTensor(3, 3)

tt = torch.tensor([[1,2,3],[1,2,3],[1,2,3]])

tt1 = tt.resize_(img_size.size())

print(tt.shape)
print(tt1.shape)


torch.Size([3, 3])
torch.Size([3, 3])
