## 0. Установка и подгрузука библиотек

Установка библиотек, под которым запускается данный бейзлайн.

In [1]:
import torch
import torch.nn as nn
import torchvision
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence

import numpy as np
import cv2
import os
import json
from matplotlib import pyplot as plt

import pandas as pd

## 1. Разделим трейн датасет на обучающую и валидационную подвыборки


In [2]:
main_alphabet = ' !\"%\\\'()*+,-./0123456789:;<=>?ABCDEFGHIJKLMNOPRSTUVWXY[]_abcdefghijklmnopqrstuvwxyz|}№'
#main_alphabet = " !\"#$%&'()*+,-./0123456789:;<=>?@[\\]^_`{|}~«»ЁАБВГДЕЖЗИЙКЛМНОПРСТУФХЦЧШЩЪЫЬЭЮЯабвгдежзийклмнопрстуфхцчшщъыьэюяё№"

In [3]:
train_data_csv = pd.read_csv('files/eng_data.csv')

train_data = []

for i in range(len(train_data_csv)):
    train_data.append(('content/all_eng_data2/' + train_data_csv.loc[i]['0'], train_data_csv.loc[i]['1']))

print('train len', len(train_data))

split_coef = 0.9

train_data_splitted = []
val_data_splitted = []
for i in train_data:
    if np.random.rand() < split_coef:
        train_data_splitted.append(i)
    else:
        val_data_splitted.append(i)

print('train len after split', len(train_data_splitted))
print('val len after split', len(val_data_splitted))

with open('files/train_labels_splitted.json', 'w') as f:
    json.dump(dict(train_data_splitted), f)
    
with open('files/val_labels_splitted.json', 'w') as f:
    json.dump(dict(val_data_splitted), f)

train len 26641
train len after split 23981
val len after split 2660


## 2. Зададим параметры обучения

Здесь мы можем поправить конфиги обучения - задать размер батча, количество эпох, размер входных изображений, а также установить пути к датасетам.

In [4]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


config_json = {
    "alphabet": main_alphabet,
    "save_dir": "saves",
    "num_epochs": 500,
    "image": {
        "width": 256,
        "height": 32
    },
    "train": {
        "root_path": "files/",
        "json_path": "files/train_labels_splitted.json",
        "batch_size": 256
    },
    "val": {
        "root_path": "files/",
        "json_path": "files/val_labels_splitted.json",
        "batch_size": 128
    }
}

## 3. Теперь определим класс датасета (torch.utils.data.Dataset) и другие вспомогательные функции

In [7]:
from emnist import extract_test_samples
images_emnist, labels_emnist = extract_test_samples('balanced')

Downloading emnist.zip: 536MB [00:54, 10.3MB/s]                             


In [8]:
vocab = {'0': 0,
             '1': 1,
             '2': 2,
             '3': 3,
             '4': 4,
             '5': 5,
             '6': 6,
             '7': 7,
             '8': 8,
             '9': 9,
             'a': 36,
             'b': 37,
             'c': 12,
             'd': 38,
             'e': 39,
             'f': 40,
             'g': 41,
             'h': 42,
             'i': 18,
             'j': 19,
             'k': 20,
             'l': 21,
             'm': 22,
             'n': 43,
             'o': 24,
             'p': 25,
             'q': 44,
             'r': 45,
             's': 28,
             't': 46,
             'u': 30,
             'v': 31,
             'w': 32,
             'x': 33,
             'y': 34,
             'z': 35,
             'A': 10,
             'B': 11,
             'C': 12,
             'D': 13,
             'E': 14,
             'F': 15,
             'G': 16,
             'H': 17,
             'I': 18,
             'J': 19,
             'K': 20,
             'L': 21,
             'M': 22,
             'N': 23,
             'O': 24,
             'P': 25,
             'Q': 26,
             'R': 27,
             'S': 28,
             'T': 29,
             'U': 30,
             'V': 31,
             'W': 32,
             'X': 33,
             'Y': 34,
             'Z': 35}

In [10]:
# функция которая помогает объединять картинки и таргет-текст в батч
def collate_fn(batch):
    images, texts, enc_texts = zip(*batch)
    images = torch.stack(images, 0)
    text_lens = torch.LongTensor([len(text) for text in texts])
    enc_pad_texts = pad_sequence(enc_texts, batch_first=True, padding_value=0)
    return images, texts, enc_pad_texts, text_lens


def get_data_loader(
    transforms, json_path, root_path, tokenizer, batch_size, drop_last
):
    dataset = OCRDataset(json_path, root_path, tokenizer, transforms)
    data_loader = torch.utils.data.DataLoader(
        dataset=dataset,
        collate_fn=collate_fn,
        batch_size=batch_size,
        num_workers=8,
    )
    return data_loader


def get_images():
    images = [[] for i in range(47)]
    for ind in range(len(labels_emnist)):
        images[labels_emnist[ind]].append(255-images_emnist[ind])
    return images

def get_words(path_to_folder=''):
    with open(path_to_folder + 'texts.txt') as f:
        words = f.read().split()
    return words

class Make_words:
    def __init__(self, vocab, images, words):
        self.vocab = vocab
        self.images = images
        self.words = words
    
    def make_words(self):
        word = self.words[np.random.randint(len(self.words))]
        letters = []
        for c in word:
            letters.append(self.images[self.vocab[c]][np.random.randint(len(self.images[self.vocab[c]]))])
        img_word = np.array([[[255, 255, 255] for j in range(28*len(letters)+24)] for i in range(28+8)])
        nxt_pix = 0
        for ind, letter in enumerate(letters):
            vert_move = np.random.randint(9)-4
            hor_move = np.random.randint(4)-10
            for i in range(len(letter)):
                for j in range(len(letter[0])):
                    if letter[i][j] <= 244:
                        for h in range(3):
                            img_word[4+vert_move+i, 10+nxt_pix+28*ind+hor_move+j][h] = letter[i][j]
            nxt_pix += hor_move
        
        del_vert = []
        for j in range(28*len(letters)+24-1, -1, -1):
            is_white = True
            for i in range(28+8):
                if img_word[i][j][0] != 255 or img_word[i][j][1] != 255 or img_word[i][j][2] != 255:
                    is_white = False
                    break
            if is_white:
                del_vert.append(j)
            else:
                break
        
        for i in range(4):
            del_vert.pop()
        img_word = np.delete(img_word, del_vert, 1)
        
        mns_white = np.random.randint(44)
        pls_black = np.random.randint(144)
        for i in range(len(img_word)):
            for j in range(len(img_word[0])):
                if img_word[i][j][0] == 255 and img_word[i][j][1] == 255 and img_word[i][j][2] == 255:
                    img_word[i][j] = [255-mns_white, 255-mns_white, 255-mns_white]
                elif img_word[i][j][0] <= 44 and img_word[i][j][1] <= 44 and img_word[i][j][2] <= 44:
                    img_word[i][j] = [pls_black, pls_black, pls_black]
                img_word[i][j][0] = max(0, min(255, img_word[i][j][0]+np.random.randint(11)-5))
                img_word[i][j][1] = max(0, min(255, img_word[i][j][1]+np.random.randint(11)-5))
                img_word[i][j][2] = max(0, min(255, img_word[i][j][2]+np.random.randint(11)-5))
        
        img_word = img_word.astype('float32')
        
        return word, img_word


images = get_images()
words = get_words('files/')
creator = Make_words(vocab, images, words)


class OCRDataset(Dataset):
    def __init__(self, json_path, root_path, tokenizer, transform=None):
        super().__init__()
        self.transform = transform
        with open(json_path, 'r') as f:
            data = json.load(f)
        self.data_len = len(data)

        self.img_paths = []
        self.texts = []
        for img_name, text in data.items():
            self.img_paths.append(os.path.join(root_path, img_name))
            self.texts.append(text)
        self.enc_texts = tokenizer.encode(self.texts)
        self.tokenizer = tokenizer

    def __len__(self):
        return self.data_len

    def __getitem__(self, idx):
        if np.random.rand() < 0.25:
            text, image = creator.make_words()
            if self.transform is not None:
                image = self.transform(image)
            enc_text = torch.LongTensor(self.tokenizer.encode([text])[0])
            return image, text, enc_text
        img_path = self.img_paths[idx]
        text = self.texts[idx]
        enc_text = torch.LongTensor(self.enc_texts[idx])
        image = cv2.imread(img_path)
        if self.transform is not None:
            image = self.transform(image)
        return image, text, enc_text


class AverageMeter:
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

## 4. Здесь определен Токенайзер - вспопогательный класс, который преобразует текст в числа

Разметка-текст с картинок преобразуется в числовое представление, на которых модель может учиться. Также может преобразовывать числовое предсказание модели обратно в текст.

In [11]:
OOV_TOKEN = '<OOV>'
CTC_BLANK = '<BLANK>'


def get_char_map(alphabet):
    """Make from string alphabet character2int dict.
    Add BLANK char fro CTC loss and OOV char for out of vocabulary symbols."""
    char_map = {value: idx + 2 for (idx, value) in enumerate(alphabet)}
    char_map[CTC_BLANK] = 0
    char_map[OOV_TOKEN] = 1
    return char_map


class Tokenizer:
    """Class for encoding and decoding string word to sequence of int
    (and vice versa) using alphabet."""

    def __init__(self, alphabet):
        self.char_map = get_char_map(alphabet)
        self.rev_char_map = {val: key for key, val in self.char_map.items()}

    def encode(self, word_list):
        """Returns a list of encoded words (int)."""
        enc_words = []
        for word in word_list:
            enc_words.append(
                [self.char_map[char] if char in self.char_map
                 else self.char_map[OOV_TOKEN]
                 for char in word]
            )
        return enc_words

    def get_num_chars(self):
        return len(self.char_map)

    def decode(self, enc_word_list):
        """Returns a list of words (str) after removing blanks and collapsing
        repeating characters. Also skip out of vocabulary token."""
        dec_words = []
        for word in enc_word_list:
            word_chars = ''
            for idx, char_enc in enumerate(word):
                # skip if blank symbol, oov token or repeated characters
                if (
                    char_enc != self.char_map[OOV_TOKEN]
                    and char_enc != self.char_map[CTC_BLANK]
                    # idx > 0 to avoid selecting [-1] item
                    and not (idx > 0 and char_enc == word[idx - 1])
                ):
                    word_chars += self.rev_char_map[char_enc]
            dec_words.append(word_chars)
        return dec_words

## 5. Accuracy в качестве метрики

Accuracy измеряет долю предсказанных строк текста, которые полностью совпадают с таргет текстом.

In [12]:
def get_accuracy(y_true, y_pred):
    scores = []
    for true, pred in zip(y_true, y_pred):
        scores.append(true == pred)
    avg_score = np.mean(scores)
    return avg_score

## 6. Аугментации

Здесь мы задаем базовые аугментации для модели. Вы можете написать свои или использовать готовые библиотеки типа albumentations

In [13]:
class to_0:
    def __init__(self, height, width):
        self.height = height
        self.width = width

    def __call__(self, x):
        x = (x[0] * 0.299 + x[1] * 0.587 + x[2] * 0.114).view(1, self.height, self.width)
        return x


class Normalize:
    def __call__(self, img):
        img = img.astype(np.float32) / 255
        return img


class ToTensor:
    def __call__(self, arr):
        arr = torch.from_numpy(arr)
        return arr


class MoveChannels:
    """Move the channel axis to the zero position as required in pytorch."""

    def __init__(self, to_channels_first=True):
        self.to_channels_first = to_channels_first

    def __call__(self, image):
        if self.to_channels_first:
            return np.moveaxis(image, -1, 0)
        else:
            return np.moveaxis(image, 0, -1)


class ImageResize:
    def __init__(self, height, width):
        self.height = height
        self.width = width

    def __call__(self, image):
        image = cv2.resize(image, (int(len(image[0])*self.height/len(image)*2), self.height), interpolation=cv2.INTER_LINEAR)
        if len(image[0]) <= self.width:
            image = np.pad(image, [(0, 0), (0, self.width - len(image[0])), (0, 0)], mode='constant', constant_values=0)
        else:
            image = cv2.resize(image, (self.width, self.height), interpolation=cv2.INTER_LINEAR)
        return image



def get_train_transforms(height, width):
    transforms = torchvision.transforms.Compose([
        ImageResize(height, width),
        MoveChannels(to_channels_first=True),
        Normalize(),
        ToTensor(),
        to_0(height, width),
    ])
    return transforms


def get_val_transforms(height, width):
    transforms = torchvision.transforms.Compose([
        ImageResize(height, width),
        MoveChannels(to_channels_first=True),
        Normalize(),
        ToTensor(),
        to_0(height, width),
    ])
    return transforms

## 7. Здесь определяем саму модель - CRNN

Подробнее об архитектуре можно почитать в статье https://arxiv.org/abs/1507.05717

In [14]:
def get_resnet34_backbone(pretrained=True):
    m = torchvision.models.resnet34(pretrained=True)
    input_conv = nn.Conv2d(1, 64, 7, 1, 3)
    blocks = [input_conv, m.bn1, m.relu,
              m.maxpool, m.layer1, m.layer2, m.layer3]
    return nn.Sequential(*blocks)


class BiLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, dropout=0.1):
        super().__init__()
        self.lstm = nn.LSTM(
            input_size, hidden_size, num_layers,
            dropout=dropout, batch_first=True, bidirectional=True)

    def forward(self, x):
        out, _ = self.lstm(x)
        return out


class CRNN(nn.Module):
    def __init__(
        self, number_class_symbols, time_feature_count=256, lstm_hidden=256,
        lstm_len=2,
    ):
        super().__init__()
        self.feature_extractor = get_resnet34_backbone(pretrained=True)
        self.avg_pool = nn.AdaptiveAvgPool2d(
            (time_feature_count, time_feature_count))
        self.bilstm = BiLSTM(time_feature_count, lstm_hidden, lstm_len)
        self.classifier = nn.Sequential(
            nn.Linear(lstm_hidden * 2, time_feature_count),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(time_feature_count, number_class_symbols)
        )

    def forward(self, x):
        x = self.feature_extractor(x)
        b, c, h, w = x.size()
        x = x.view(b, c * h, w)
        x = self.avg_pool(x)
        x = x.transpose(1, 2)
        x = self.bilstm(x)
        x = self.classifier(x)
        x = nn.functional.log_softmax(x, dim=2).permute(1, 0, 2)
        return x

## 8. Переходим к самому скрипту обучения - циклы трейна и валидации

In [15]:
def val_loop(data_loader, model, tokenizer, device):
    acc_avg = AverageMeter()
    for images, texts, _, _ in data_loader:
        batch_size = len(texts)
        text_preds = predict(images, model, tokenizer, device)
        acc_avg.update(get_accuracy(texts, text_preds), batch_size)
    print(f'Validation, acc: {acc_avg.avg:.4f}')
    return acc_avg.avg


def train_loop(data_loader, model, criterion, optimizer, epoch):
    loss_avg = AverageMeter()
    model.train()
    for images, texts, enc_pad_texts, text_lens in data_loader:
        model.zero_grad()
        images = images.to(DEVICE)
        batch_size = len(texts)
        output = model(images)
        output_lenghts = torch.full(
            size=(output.size(1),),
            fill_value=output.size(0),
            dtype=torch.long
        )
        loss = criterion(output, enc_pad_texts, output_lenghts, text_lens)
        loss_avg.update(loss.item(), batch_size)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 2)
        optimizer.step()
    for param_group in optimizer.param_groups:
        lr = param_group['lr']
    print(f'\nEpoch {epoch}, Loss: {loss_avg.avg:.5f}, LR: {lr:.7f}')
    return loss_avg.avg


def predict(images, model, tokenizer, device):
    model.eval()
    images = images.to(device)
    with torch.no_grad():
        output = model(images)
    pred = torch.argmax(output.detach().cpu(), -1).permute(1, 0).numpy()
    text_preds = tokenizer.decode(pred)
    return text_preds


def get_loaders(tokenizer, config):
    train_transforms = get_train_transforms(
        height=config['image']['height'],
        width=config['image']['width']
    )
    train_loader = get_data_loader(
        json_path=config['train']['json_path'],
        root_path=config['train']['root_path'],
        transforms=train_transforms,
        tokenizer=tokenizer,
        batch_size=config['train']['batch_size'],
        drop_last=True
    )
    val_transforms = get_val_transforms(
        height=config['image']['height'],
        width=config['image']['width']
    )
    val_loader = get_data_loader(
        transforms=val_transforms,
        json_path=config['val']['json_path'],
        root_path=config['val']['root_path'],
        tokenizer=tokenizer,
        batch_size=config['val']['batch_size'],
        drop_last=False
    )
    return train_loader, val_loader


def train(config):
    tokenizer = Tokenizer(config['alphabet'])
    os.makedirs(config['save_dir'], exist_ok=True)
    train_loader, val_loader = get_loaders(tokenizer, config)
    
    '''prev_alphabet = " !\"#$%&'()*+,-./0123456789:;<=>?@[\\]^_`{|}~«»ЁАБВГДЕЖЗИЙКЛМНОПРСТУФХЦЧШЩЪЫЬЭЮЯабвгдежзийклмнопрстуфхцчшщъыьэюяё№"
    prev_tokenizer = Tokenizer(prev_alphabet)
    time_feature_count = 256
    lstm_hidden = 256'''

    #model = CRNN(number_class_symbols=prev_tokenizer.get_num_chars())
    model = CRNN(number_class_symbols=tokenizer.get_num_chars())
    model.load_state_dict(torch.load('saves/model-5-0.7795a.ckpt'))
    '''model.classifier = nn.Sequential(
        nn.Linear(lstm_hidden * 2, time_feature_count),
        nn.GELU(),
        nn.Dropout(0.1),
        nn.Linear(time_feature_count, tokenizer.get_num_chars())
    )'''
    model.to(DEVICE)

    criterion = torch.nn.CTCLoss(blank=0, reduction='mean', zero_infinity=True)
    optimizer = torch.optim.AdamW(model.parameters(), lr=0.001,
                                  weight_decay=0.01)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer=optimizer, mode='max', factor=0.5, patience=15)
    best_acc = -np.inf
    acc_avg = val_loop(val_loader, model, tokenizer, DEVICE)
    for epoch in range(config['num_epochs']):
        loss_avg = train_loop(train_loader, model, criterion, optimizer, epoch)
        acc_avg = val_loop(val_loader, model, tokenizer, DEVICE)
        scheduler.step(acc_avg)
        if acc_avg > best_acc:
            best_acc = acc_avg
            model_save_path = os.path.join(
                config['save_dir'], f'model-{epoch}-{acc_avg:.4f}a.ckpt')
            torch.save(model.state_dict(), model_save_path)
            print('Model weights saved')

## 9. Запускаем обучение!

In [None]:
train(config_json)

Validation, acc: 0.6914

Epoch 0, Loss: 0.23858, LR: 0.0010000
Validation, acc: 0.7669
Model weights saved

Epoch 1, Loss: 0.14160, LR: 0.0010000
Validation, acc: 0.7989
Model weights saved

Epoch 2, Loss: 0.10535, LR: 0.0010000
Validation, acc: 0.8128
Model weights saved

Epoch 3, Loss: 0.10690, LR: 0.0010000
Validation, acc: 0.7289

Epoch 4, Loss: 0.10297, LR: 0.0010000
Validation, acc: 0.7895

Epoch 5, Loss: 0.09268, LR: 0.0010000
Validation, acc: 0.7906

Epoch 6, Loss: 0.08635, LR: 0.0010000
Validation, acc: 0.8008

Epoch 7, Loss: 0.07664, LR: 0.0010000
Validation, acc: 0.8060

Epoch 8, Loss: 0.08252, LR: 0.0010000
Validation, acc: 0.8038

Epoch 9, Loss: 0.07747, LR: 0.0010000
Validation, acc: 0.7752

Epoch 10, Loss: 0.06966, LR: 0.0010000
Validation, acc: 0.7914

Epoch 11, Loss: 0.06043, LR: 0.0010000
Validation, acc: 0.8165
Model weights saved

Epoch 12, Loss: 0.05465, LR: 0.0010000
Validation, acc: 0.8297
Model weights saved

Epoch 13, Loss: 0.06318, LR: 0.0010000
Validation, ac