In [1]:
import torch
import torch.nn as nn
import torchvision
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence
from cer import calculate_cer
from tqdm import tqdm
import pandas as pd
import numpy as np
import cv2
import json
import os
from matplotlib import pyplot as plt
import warnings
warnings.filterwarnings('ignore')
warnings.simplefilter('ignore')

In [4]:
def get_resnet34_backbone(pretrained=True):
    model = torchvision.models.resnet34(pretrained=pretrained)
    model.conv1 = nn.Conv2d(3, 64, 7, 1, 3)
    model = nn.Sequential(model.conv1, model.bn1, model.relu, model.maxpool, 
                          model.layer1, model.layer2, model.layer3)
    return model

class BiLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, dropout=0.1):
        super().__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, dropout=dropout, 
                            batch_first=True, bidirectional=True)
    def forward(self, x):
        out = self.lstm(x)[0]
        return out

class CRNN(nn.Module):
    def __init__(self, num_class_symbols, time_feature_count=256, lstm_hidden=256, lstm_len=3):
        super().__init__()
        self.feature_extractor = get_resnet34_backbone()
        self.avgpool = nn.AdaptiveAvgPool2d((time_feature_count, time_feature_count))
        self.bilstm = BiLSTM(time_feature_count, lstm_hidden, lstm_len)
        self.classifier = nn.Sequential(
            nn.Linear(lstm_hidden * 2, time_feature_count),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(time_feature_count, num_class_symbols),
        )

    def forward(self, x):
        x = self.feature_extractor(x)
        b, c, h, w = x.size()
        x = x.view(b, c * h, w)
        x = self.avgpool(x)
        x = x.transpose(1, 2)
        x = self.bilstm(x)
        x = self.classifier(x)
        x = nn.functional.log_softmax(x, dim=2).permute(1, 0, 2)
        return x

In [5]:
OOV_TOKEN = '<OOV>'
CTC_BLANK = '<BLANK>'

def get_char_map(alphabet):
    char_map = {value: idx + 2 for idx, value in enumerate(alphabet)}
    char_map[OOV_TOKEN] = 0
    char_map[CTC_BLANK] = 1
    return char_map

class Tokenizer:
    def __init__(self, alphabet):
        self.char_map = get_char_map(alphabet)
        self.rev_char_map = {val: key for key, val in self.char_map.items()}

    def encode(self, text):
        enc_text = []
        for let in text:
            if let in self.char_map.keys():
                enc_text.append(self.char_map[let])
            else:
                enc_text.append(self.char_map[OOV_TOKEN])
        return enc_text

    def decode(self, enc_text):
        dec_text = []
        for word in enc_text:
            dec_word = ''
            for idx, enc_let in enumerate(word):
                if (enc_let != self.char_map[OOV_TOKEN] and enc_let != self.char_map[CTC_BLANK]
                    and not(idx != 0 and word[idx - 1] == enc_let)):
                    dec_word += self.rev_char_map[enc_let]
            dec_text += [dec_word]
        return dec_text                           


a = Tokenizer('abcdefgh')
print(a.decode([[0,1,2,3,4,5,6,7]]))

['abcdef']


In [6]:
alphabet = 'уДсАа0з1шдbМтrъщuиеПнГЛхЖ/sУ9бШ»y]x)hНгыЧ;\':эОф8Вp(юйкм?ЮпрЗ«И+Э2R,вiБл№Тя%Р3[Х"!cч Ц=о5СФ7Й64ЕК.eьtЩжoЯ-цё'

In [7]:
class OCRDataset(Dataset):
    def __init__(self, root_img, root_text, tokenizer):
        super().__init__()
        self.images_paths = []
        self.texts = {}
        self.enc_texts = {}
        for name in os.listdir(root_img):
            self.images_paths.append(root_img + '/' + name)
        tmp_texts = open(root_text).read().split('\n')
        for text in tmp_texts:
            if '\t' in text:
                img_name, new_text = text.split('\t')
                self.texts[img_name] = new_text
                self.enc_texts[img_name] = tokenizer.encode(new_text)
            
    def __len__(self):
        return len(self.images_paths)
        
    def __getitem__(self, idx):
        img_path = self.images_paths[idx]
        img_name = img_path.split('/')[-1]
        text = self.texts[img_name]
        enc_text = torch.LongTensor(self.enc_texts[img_name])
        img = cv2.imread(img_path)
        transform = get_transforms(256, 64)
        img = transform(img)
        return img, text, enc_text

def collate_fn(batch):
    images, texts, enc_texts = zip(*batch)
    images = torch.stack(images, 0)
    text_lens = torch.LongTensor([len(text) for text in texts])
    enc_pad_texts = pad_sequence(enc_texts, batch_first=True, padding_value=0)
    return images, texts, enc_pad_texts, text_lens

def get_data_loader(root_img, root_text, tokenizer, batch_size, drop_last=False):
    dataset = OCRDataset(root_img, root_text, tokenizer)
    data_loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        collate_fn=collate_fn,
        drop_last=drop_last,
        num_workers = 8
    )
    return data_loader

In [8]:
class AverageMeter:
    def __init__(self):
        self.reset()
        
    def reset(self):
        self.avg = 0
        self.sum = 0
        self.count = 0
        
    def update(self, val, n=1):
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [9]:
def get_CER(y_true, y_pred):
    return calculate_cer(y_true, y_pred)

In [10]:
class Normalize:
    def __call__(self, img):
        img = img.astype(np.float32) / 255
        return img

class ToTensor:
    def __call__(self, arr):
        arr = torch.from_numpy(arr)
        return arr

class ImageResize:
    def __init__(self, width, height):
        self.width = width
        self.height = height

    def __call__(self, img):
        img = cv2.resize(img, (self.width, self.height), 
                        interpolation=cv2.INTER_LINEAR)
        return img

class MoveAxis:
    def __call__(self, img):
        return np.moveaxis(img, -1, 0)

def get_transforms(width, height):
    transforms = torchvision.transforms.Compose([
        ImageResize(width, height),
        MoveAxis(),
        Normalize(),
        ToTensor()
    ])
    return transforms

In [2]:
def val_loop(data_loader, model, tokenizer):
    CER_avg = AverageMeter()
    for images, texts, _, _ in data_loader:
        batch_size = len(texts)
        text_preds = predict(images, model, tokenizer)
        CER = get_CER(texts, text_preds) 
        CER_avg.update(CER, batch_size)
    print(f'Validation, CER: {CER_avg.avg:.5}')
    return CER_avg.avg
        
def train_loop(data_loader, model, optimizer, criterion, epoch):
    loss_avg = AverageMeter()
    model.train()
    for images, texts, enc_pad_texts, text_lens in tqdm(data_loader):
        model.zero_grad()
        images = images.to('cuda')
        batch_size = len(texts)
        out = model(images)
        out_lens = torch.full(
            size=(out.size(1),),
            fill_value=out.size(0),
            dtype=torch.long
        )
        loss = criterion(out, enc_pad_texts, out_lens, text_lens)
        loss_avg.update(loss.item(), batch_size)
        loss.backward()
        optimizer.step()
    for params in optimizer.param_groups:
        lr = params['lr']
    print(f'\nEpoch: {epoch}, Loss: {loss_avg.avg:.5}, lr: {lr:.7}')
    return loss_avg.avg

def predict(images, model, tokenizer):
    model.eval()
    images = images.to('cuda')
    with torch.no_grad():
        out = model(images)
    pred = torch.argmax(out.detach().cpu(), -1).permute(1, 0).numpy()
    text_preds = tokenizer.decode(pred)
    return text_preds

def train(alphabet, root_img_train, root_text_train, root_img_val, root_text_val, save_dir, batch_size, epochs):
    tokenizer = Tokenizer(alphabet)
    os.makedirs(save_dir, exist_ok=True)
    train_loader = get_data_loader(root_img_train, root_text_train, tokenizer, batch_size)
    val_loader = get_data_loader(root_img_val, root_text_val, tokenizer, batch_size)
    model = CRNN(len(alphabet))
    model.to('cuda')
    criterion = torch.nn.CTCLoss(blank=0, reduction='mean', zero_infinity=True)
    optimizer = torch.optim.AdamW(model.parameters(), lr = 0.001, weight_decay=0.01)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer, mode='max',
                                                          factor=0.5, patience=5)
    best_CER = -np.inf
    CER_avg = val_loop(val_loader, model, tokenizer)
    lst = []
    if len(os.listdir(save_dir)) != 0:
        num = max(list(map(int, os.listdir(save_dir)))) + 1
    else:
        num = 1
    path = f'{save_dir}/{num}'
    os.makedirs(path)
    os.makedirs(f'{path}/checkpoints')
    for epoch in range(epochs):
        loss_avg = train_loop(train_loader, model, optimizer, criterion, epoch)
        CER_avg = val_loop(val_loader, model, tokenizer)
        lst += [(loss_avg, CER_avg)]
        scheduler.step(CER_avg)
        model_save_path = f'{path}/checkpoints/model_epoch{epoch}.ckpt'
        torch.save(model.state_dict(), model_save_path)
        print('Model weights saved')
    file = open(f'{path}/metrics.txt', 'a+')
    counter = 0
    for loss, CER in lst:
        counter += 1
        file.write(f'Epoch: {counter}\tCTCLoss: {loss}\tCER: {CER}\n')
    model_save_path = path + '/' + 'model.ckpt'
    torch.save(model.state_dict(), model_save_path)
    print('Model weights saved')

In [12]:
import shutup
shutup.please()
root_img_train = '/home/misha/Загрузки/archive/train/'
root_img_val = '/home/misha/Загрузки/archive/test/'
root_text_train = '/home/misha/Загрузки/archive/train.tsv'
root_text_val = '/home/misha/Загрузки/archive/test.tsv'
save_dir = '/home/misha/HandwrittenRecognationModel'
batch_size = 16
epochs = 30
train(alphabet, root_img_train, root_text_train, root_img_val, root_text_val, save_dir, batch_size, epochs)





Validation, CER: 0.91004


100%|████████████████████████████████████████████████████| 4518/4518 [08:16<00:00,  9.10it/s]


Epoch: 0, Loss: 3.5933, lr: 0.001






Validation, CER: 0.81219
Model weights saved


100%|████████████████████████████████████████████████████| 4518/4518 [08:13<00:00,  9.16it/s]


Epoch: 1, Loss: 2.9833, lr: 0.001





Validation, CER: 0.68689
Model weights saved


100%|████████████████████████████████████████████████████| 4518/4518 [08:12<00:00,  9.17it/s]


Epoch: 2, Loss: 1.8638, lr: 0.001





Validation, CER: 0.42034
Model weights saved


100%|████████████████████████████████████████████████████| 4518/4518 [08:13<00:00,  9.16it/s]


Epoch: 3, Loss: 1.3059, lr: 0.001






Validation, CER: 0.39939
Model weights saved


100%|████████████████████████████████████████████████████| 4518/4518 [08:13<00:00,  9.16it/s]


Epoch: 4, Loss: 1.2909, lr: 0.001





Validation, CER: 0.36923
Model weights saved


100%|████████████████████████████████████████████████████| 4518/4518 [08:12<00:00,  9.17it/s]


Epoch: 5, Loss: 0.99418, lr: 0.001





Validation, CER: 0.35039
Model weights saved


100%|████████████████████████████████████████████████████| 4518/4518 [08:12<00:00,  9.18it/s]


Epoch: 6, Loss: 1.1878, lr: 0.001





Validation, CER: 0.38014
Model weights saved


100%|████████████████████████████████████████████████████| 4518/4518 [08:13<00:00,  9.16it/s]


Epoch: 7, Loss: 0.70966, lr: 0.0005





Validation, CER: 0.28762
Model weights saved


100%|████████████████████████████████████████████████████| 4518/4518 [08:13<00:00,  9.16it/s]


Epoch: 8, Loss: 0.68758, lr: 0.0005






Validation, CER: 0.27033
Model weights saved


100%|████████████████████████████████████████████████████| 4518/4518 [08:11<00:00,  9.19it/s]


Epoch: 9, Loss: 0.59789, lr: 0.0005






Validation, CER: 0.24352
Model weights saved


100%|████████████████████████████████████████████████████| 4518/4518 [08:12<00:00,  9.18it/s]


Epoch: 10, Loss: 0.58657, lr: 0.0005







Validation, CER: 0.24957
Model weights saved


100%|████████████████████████████████████████████████████| 4518/4518 [08:13<00:00,  9.16it/s]


Epoch: 11, Loss: 0.5125, lr: 0.0005








Validation, CER: 0.24478
Model weights saved


100%|████████████████████████████████████████████████████| 4518/4518 [08:12<00:00,  9.17it/s]


Epoch: 12, Loss: 0.56324, lr: 0.0005





Validation, CER: 0.23454
Model weights saved


100%|████████████████████████████████████████████████████| 4518/4518 [08:13<00:00,  9.15it/s]


Epoch: 13, Loss: 0.42639, lr: 0.00025







Validation, CER: 0.20515
Model weights saved


100%|████████████████████████████████████████████████████| 4518/4518 [08:14<00:00,  9.14it/s]


Epoch: 14, Loss: 0.38138, lr: 0.00025





Validation, CER: 0.2097
Model weights saved


100%|████████████████████████████████████████████████████| 4518/4518 [08:14<00:00,  9.14it/s]


Epoch: 15, Loss: 0.35874, lr: 0.00025






Validation, CER: 0.19778
Model weights saved


100%|████████████████████████████████████████████████████| 4518/4518 [08:14<00:00,  9.13it/s]


Epoch: 16, Loss: 0.33752, lr: 0.00025





Validation, CER: 0.2017
Model weights saved


100%|████████████████████████████████████████████████████| 4518/4518 [08:14<00:00,  9.13it/s]


Epoch: 17, Loss: 0.31047, lr: 0.00025





Validation, CER: 0.19164
Model weights saved


100%|████████████████████████████████████████████████████| 4518/4518 [08:14<00:00,  9.14it/s]


Epoch: 18, Loss: 0.29212, lr: 0.00025






Validation, CER: 0.18339
Model weights saved


100%|████████████████████████████████████████████████████| 4518/4518 [08:14<00:00,  9.14it/s]


Epoch: 19, Loss: 0.24609, lr: 0.000125






Validation, CER: 0.17876
Model weights saved


100%|████████████████████████████████████████████████████| 4518/4518 [08:14<00:00,  9.13it/s]


Epoch: 20, Loss: 0.23111, lr: 0.000125







Validation, CER: 0.1737
Model weights saved


100%|████████████████████████████████████████████████████| 4518/4518 [08:15<00:00,  9.12it/s]


Epoch: 21, Loss: 0.22342, lr: 0.000125






Validation, CER: 0.1803
Model weights saved


100%|████████████████████████████████████████████████████| 4518/4518 [08:15<00:00,  9.13it/s]


Epoch: 22, Loss: 0.21474, lr: 0.000125







Validation, CER: 0.17166
Model weights saved


100%|████████████████████████████████████████████████████| 4518/4518 [08:14<00:00,  9.13it/s]


Epoch: 23, Loss: 0.20016, lr: 0.000125






Validation, CER: 0.16977
Model weights saved


100%|████████████████████████████████████████████████████| 4518/4518 [08:13<00:00,  9.15it/s]


Epoch: 24, Loss: 0.18947, lr: 0.000125





Validation, CER: 0.17267
Model weights saved


100%|████████████████████████████████████████████████████| 4518/4518 [08:13<00:00,  9.16it/s]


Epoch: 25, Loss: 0.16533, lr: 6.25e-05





Validation, CER: 0.16603
Model weights saved


100%|████████████████████████████████████████████████████| 4518/4518 [08:14<00:00,  9.14it/s]


Epoch: 26, Loss: 0.15466, lr: 6.25e-05





Validation, CER: 0.16324
Model weights saved


100%|████████████████████████████████████████████████████| 4518/4518 [08:14<00:00,  9.14it/s]


Epoch: 27, Loss: 0.14514, lr: 6.25e-05





Validation, CER: 0.16313
Model weights saved


100%|████████████████████████████████████████████████████| 4518/4518 [08:14<00:00,  9.14it/s]


Epoch: 28, Loss: 0.14011, lr: 6.25e-05






Validation, CER: 0.16419
Model weights saved


100%|████████████████████████████████████████████████████| 4518/4518 [08:13<00:00,  9.15it/s]


Epoch: 29, Loss: 0.13572, lr: 6.25e-05






Validation, CER: 0.16193
Model weights saved
Model weights saved


In [19]:
import matplotlib.pyplot as plt
img = cv2.imread('/home/misha/Загрузки/archive/test/test741.png')
img = cv2.resize(img, (256, 64))
img = np.moveaxis(img, -1, 0)
img = img.astype(np.float32) / 255
img = torch.from_numpy(img)
img = torch.stack((img, img))
print(img.size())
model = CRNN(len(alphabet))
model.to('cuda')
model.load_state_dict(torch.load('/home/misha/HandwrittenRecognationModel/4/model.ckpt'))
tokenizer = Tokenizer(alphabet)
a = predict(img, model, tokenizer)
a

torch.Size([2, 3, 64, 256])


['разбит', 'разбит']

In [3]:
torch.cuda.is_available()

True

AttributeError: module 'torch.cuda' has no attribute 'TORCH_USE_CUDA_DSA'