In [None]:
!pip install -q git+https://github.com/albumentations-team/albumentations.git
!pip install -q opencv-python-headless==4.5.2.52

In [None]:
!unrar x /content/drive/MyDrive/train_data.rar

In [None]:
import torch.nn as nn
import torch

import pandas as pd
import numpy as np

import cv2, os
import torchvision

from tqdm.notebook import tqdm
from torch.nn.utils.rnn import pad_sequence

import random

In [None]:
from sklearn.model_selection import train_test_split
train_df, test_df = train_test_split(pd.read_csv("./train_data/labels.csv"), test_size=0.1, random_state=45)

### Датасет

In [None]:
def collate_fn(batch):
    images, texts, enc_texts = zip(*batch)
    images = torch.stack(images, 0)
    text_lens = torch.LongTensor([len(text) for text in texts])
    enc_pad_texts = pad_sequence(enc_texts, batch_first=True, padding_value=0)
    return images, texts, enc_pad_texts, text_lens

In [None]:
class Data(torch.utils.data.Dataset):
    def __init__(self, table, images_path, tokenizer, transforms=None, is_valid=False):
        self.tokenizer = tokenizer
        self.transforms = transforms

        self.table = table.to_numpy()
        self.images_path = images_path

        self.is_valid = is_valid
    
    def __getitem__(self, idx):
        line = self.table[idx]

        text = line[2]
        image = self.load_image(line[1])

        enc_text = self.tokenizer.encode([text])

        if self.transforms is not None:
            image = self.transforms(image)
        
        return image, text, torch.LongTensor(enc_text[0])
    
    def __len__(self):
        return self.table.shape[0]
    
    def load_image(self, path):
        image = cv2.imread(os.path.join(self.images_path, path))
        return cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

In [None]:
class AverageMeter:
    def __init__(self):
        self.reset()

    def reset(self):
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [None]:
OOV_TOKEN = '<OOV>'
CTC_BLANK = '<BLANK>'


def get_char_map(alphabet):
    char_map = {value: idx + 2 for (idx, value) in enumerate(alphabet)}
    char_map[CTC_BLANK] = 0
    char_map[OOV_TOKEN] = 1
    return char_map


class Tokenizer:
    def __init__(self, alphabet):
        self.char_map = get_char_map(alphabet)
        self.rev_char_map = {val: key for key, val in self.char_map.items()}

    def encode(self, word_list):
        enc_words = []
        for word in word_list:
            enc_words.append(
                [self.char_map[char] if char in self.char_map
                 else self.char_map[OOV_TOKEN]
                 for char in word]
            )
        return enc_words

    def get_num_chars(self):
        return len(self.char_map)

    def decode(self, enc_word_list):
        dec_words = []
        for word in enc_word_list:
            word_chars = ''
            for idx, char_enc in enumerate(word):
                if (
                    char_enc != self.char_map[OOV_TOKEN]
                    and char_enc != self.char_map[CTC_BLANK]
                    and not (idx > 0 and char_enc == word[idx - 1])
                ):
                    word_chars += self.rev_char_map[char_enc]
            dec_words.append(word_chars)
        return dec_words

### CRNN

In [2]:
from resnet import resnet18, resnet34, resnet50
import torch.nn.functional as F

class BiLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers=1, dropout=0.1):
        super().__init__()
        self.lstm = nn.LSTM(
            input_size, hidden_size, num_layers,
            dropout=dropout, bidirectional=True)

    def forward(self, x):
        self.lstm.flatten_parameters()
        return self.lstm(x)[0]

class CRNN(nn.Module):
    def __init__(self, number_class_symbols, fe_d_out=512, rnn_size=2):
        super().__init__()

        layers = [
            BiLSTM(fe_d_out, number_class_symbols),
            nn.Linear(number_class_symbols * 2, number_class_symbols)
        ]

        for _ in range(rnn_size - 1):
            layers += [BiLSTM(number_class_symbols, number_class_symbols),
                      nn.Linear(number_class_symbols * 2, number_class_symbols)]

        self.feature_extractor = resnet18(pretrained=False)
        self.rnn = nn.Sequential(*layers)
    
    def forward(self, x):
        x = self.feature_extractor(x)
        n, c, h, w = x.shape

        x = x.permute(0, 1, 3, 2).reshape((n, c, w * h))
        return F.log_softmax(self.rnn(x.permute(2, 0, 1)), dim=2)

In [None]:
def get_accuracy(y_true, y_pred):
    scores = []
    for true, pred in zip(y_true, y_pred):
        scores.append(true == pred)
    avg_score = np.mean(scores)
    return avg_score

def val_loop(data_loader, model, tokenizer, device):
    acc_avg = AverageMeter()
    for images, texts, _, _ in tqdm(data_loader, total=len(data_loader)):
        batch_size = len(texts)
        text_preds = predict(images, model, tokenizer, device)
        acc_avg.update(get_accuracy(texts, text_preds), batch_size)
    print(f'Validation, acc: {acc_avg.avg:.4f}')
    return acc_avg.avg


def train_loop(data_loader, model, criterion, optimizer, epoch):
    loss_avg = AverageMeter()
    model.train()

    tq = tqdm(data_loader, total=len(data_loader), desc=f"Epoch #{epoch}")

    for images, texts, enc_pad_texts, text_lens in tq:
        model.zero_grad()
        images = images.to(DEVICE)
        batch_size = len(texts)
        output = model(images)
        output_lenghts = torch.full(
            size=(output.size(1),),
            fill_value=output.size(0),
            dtype=torch.long
        )
        loss = criterion(output, enc_pad_texts, output_lenghts, text_lens)
        item = loss.item()

        tq.set_postfix({
            "loss": item,
        })

        loss_avg.update(item, batch_size)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 2)
        optimizer.step()
    for param_group in optimizer.param_groups:
        lr = param_group['lr']
    print(f'\nEpoch {epoch}, Loss: {loss_avg.avg:.5f}, LR: {lr:.7f}')
    return loss_avg.avg


def predict(images, model, tokenizer, device):
    model.eval()
    images = images.to(device)
    with torch.no_grad():
        output = model(images)
    pred = torch.argmax(output.detach().cpu(), -1).permute(1, 0).numpy()
    text_preds = tokenizer.decode(pred)
    return text_preds

### Train model

In [None]:
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2

In [None]:
from albumentations.augmentations.transforms import ToGray
from albumentations.augmentations.geometric.rotate import Rotate

train_transform = A.Compose([
    A.Rotate(4),
    A.Resize(128, 356),
    A.ChannelShuffle(p=0.2),
    A.ColorJitter(p=1),
    A.RandomShadow(p=0.3),
    A.CLAHE(clip_limit=4.0, tile_grid_size=(8, 8), p=0.25, always_apply=False),
    A.JpegCompression(quality_lower=75, p=0.5),
    A.ToGray(p=0.3),

    A.Normalize(p=1),
    ToTensorV2()
])

def train_transform_fn(x):
    return train_transform(image=x)["image"]

valid_transform = A.Compose([
    A.Resize(128, 356),
    A.CLAHE(clip_limit=4.0, tile_grid_size=(8, 8), p=0.25, always_apply=False),
    A.JpegCompression(quality_lower=75, p=0.5),

    A.Normalize(p=1),
    ToTensorV2()
])

def valid_transform_fn(x):
    return valid_transform(image=x)["image"]

In [None]:
alphabet = """ !"%\'()*+,-./0123456789:;<=>?ABCDEFGHIJKLMNOPRSTUVWXY[]_abcdefghijklmnopqrstuvwxyz|}ЁАБВГДЕЖЗИКЛМНОПРСТУФХЦЧШЩЫЬЭЮЯабвгдежзийклмнопрстуфхцчшщъыьэюяё№"""

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = Tokenizer(alphabet)

In [None]:
model = CRNN(tokenizer.get_num_chars(), 512, rnn_size=2)
model.load_state_dict(torch.load("checkpoints/best.pth"))

model.to(DEVICE)

criterion = torch.nn.CTCLoss(blank=0, reduction='mean', zero_infinity=True)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001,
                              weight_decay=0.01)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer=optimizer, mode='max', factor=0.5, patience=15)
best_acc = -np.inf

In [None]:
for epoch in range(1000):
    loss = train_loop(train_loader, model, criterion, optimizer, epoch)
    acc = val_loop(valid_loader, model, tokenizer, DEVICE)

    if acc > best_acc:
        torch.save(model.state_dict(), "checkpoints/best.pth")
        print("model saved!")
        best_acc = acc

In [None]:
val_loop(valid_loader, model, tokenizer, DEVICE)

In [None]:
torch.save(model.state_dict(), "checkpoints/final.pth")