In [1]:
from google.colab import drive

drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


# Всякое

In [2]:
!pip install -q torchmetrics torchinfo

In [3]:
import torch
from torch import nn
from itertools import groupby


class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, pool_ksize=(2, 2)):
        super(ConvBlock, self).__init__()

        self.block = nn.Sequential(nn.Conv2d(in_channels, out_channels, 3, padding='same'),
                                   nn.LeakyReLU(0.1),
                                   nn.BatchNorm2d(out_channels),
                                   nn.MaxPool2d(pool_ksize))

    def forward(self, x):
        return self.block(x)


class CRNN(nn.Module):
    def __init__(self, alphabet_len):
        super(CRNN, self).__init__()

        self.feature_extractor = nn.Sequential(ConvBlock(1, 32),
                                               ConvBlock(32, 64, (2, 1)),
                                               ConvBlock(64, 64),
                                               ConvBlock(64, 128),
                                               ConvBlock(128, 256, (2, 1)))
        self.lstm1 = nn.LSTM(258, 256, batch_first=True)
        self.lstm2 = nn.LSTM(256, 256, batch_first=True)

        self.fc = nn.Sequential(nn.Linear(256, alphabet_len+1),
                                nn.Softmax(dim=2))

    def forward(self, x1, x2):
        f1 = self.feature_extractor(x1).squeeze()
        f1 = torch.permute(f1, (0, 2, 1))

        x = torch.cat([f1, x2], dim=2)

        x, _ = self.lstm1(x)
        x, _ = self.lstm2(x)
        x = self.fc(x)

        return x


def decode_texts(logits, alphabet):
    """Decodes CRNN output with given alphabet and whitelist

    Args:
        logits: np.ndarray, CRNN output
        alphabet: str, alphabet CRNN was trained on
    Returns:
        list of predictions
    """
    best_path_indices = np.argmax(logits, axis=-1)
    best_chars_collapsed = [[alphabet[k-1] for k, _ in groupby(e) if k != 0] for e in best_path_indices]
    return [''.join(e) for e in best_chars_collapsed]

In [4]:
def ctc_loss_log_differentiable_torch(log_logits: torch.FloatTensor, targets: torch.LongTensor,
                                      input_lengths: torch.Tensor, target_lengths: torch.Tensor, device,
                                      blank_idx=0, dtype_to_use=torch.float32) -> torch.float32:
    """
    log_logits: np.ndarray of shape (B, T, C)
    targets: np.ndarray of shape (B, L,)
    """

    B, T = log_logits.shape[0], log_logits.shape[1]
    S = 2 * targets.shape[1] + 1

    zero = torch.finfo(dtype_to_use).min

    # insert blanks between every pair of labels and add them to start and end of the seq
    extended_targets = torch.stack([torch.full_like(targets, blank_idx), targets], dim=-1).flatten(start_dim=-2)
    extended_targets = torch.cat([extended_targets, torch.full((B, 1), blank_idx, device=device)], dim=-1)
    # due to the paper formula for alpha_t(s) we must know where labels repeat and where the blanks are
    # in the extended label seq
    targets_difference_mask = torch.cat([torch.full((B, 2), False, device=device), extended_targets[:, 2:] != extended_targets[:, :-2]], dim=-1)

    # initialize alphas array to keep track of previous alphas
    # (also add 2 to the second dim so our s-2 and s-1 vectorized calculations won't get IndexError)
    log_alphas = torch.full((B, T, S+2), zero, dtype=dtype_to_use, device=device)

    # every accountable prefix starts either with a blank or the first symbol of the target,
    # so we initialize alphas in the following way (remember about S+2)
    log_alphas[:, 0, 2] = log_logits[:, 0, blank_idx]
    log_alphas[:, 0, 3] = log_logits[torch.arange(B), 0, targets[:, 0]]

    for t in range(1, T):
        # remember we're in log space so log(a*b) = log(a) + log(b)
        # here formula must be mathematically reworked.

        log_alphas[:, t, 2:] = (torch.gather(log_logits[:, t], -1, extended_targets) +
                                torch.logsumexp(torch.stack([log_alphas[:, t-1, 2:], log_alphas[:, t-1, 1: -1],
                                                             torch.where(targets_difference_mask,
                                                                         log_alphas[:, t-1, :-2], zero)]), dim=0))

    temp = torch.gather(log_alphas[np.arange(B), input_lengths-1], -1,
                        torch.stack([2 + target_lengths * 2 - 1, 2 + target_lengths * 2], dim=-1))

    return -torch.mean(torch.logsumexp(temp, dim=-1))

In [5]:
from torch.utils.data import Dataset, DataLoader


class OCRDataset(Dataset):
    def __init__(self, images, abits, labels):
        super(OCRDataset, self).__init__()

        self.images = images
        self.abits = abits
        self.labels = labels

    def __len__(self):
        return self.labels.shape[0]

    def __getitem__(self, idx):
        return (torch.FloatTensor(self.images[idx]).unsqueeze(0), torch.FloatTensor(self.abits[idx])), torch.IntTensor(self.labels[idx])

In [6]:
from torchmetrics.text import CharErrorRate
from itertools import groupby
from tqdm import tqdm
import time

def validate_model(model, dataloader, device='cpu'):
  model.eval()

  criterion = ctc_loss_log_differentiable_torch
  metric = CharErrorRate()
  loss = 0
  cer_value = 0
  cumtime = 0

  with torch.no_grad():
    for i, ((x1, x2), y) in tqdm(enumerate(dataloader)):
      x1 = x1.to(device)
      x2 = x2.to(device)
      y = y.to(device)

      start = time.time()
      y_pred = model(x1, x2)
      cumtime += time.time() - start

      input_lengths = torch.full((y_pred.shape[0],), y_pred.shape[1]).to(device)
      target_lengths = torch.sum(y != 0, axis=1)
      loss += criterion(torch.log(y_pred), y, input_lengths, target_lengths, device=device).item()
      cer_value += metric(decode_texts(y_pred.detach().cpu().numpy(), alphabet),
                        [''.join(alphabet[k-1] for k, _ in groupby(e) if k != 0) for e in y.cpu().numpy().astype(int)]).item()

  print()

  return cumtime / len(dataloader), loss / len(dataloader), cer_value / len(dataloader)

In [7]:
import h5py
import pandas as pd
import numpy as np

with h5py.File('/content/drive/MyDrive/CRNN for long fields/common_fields_images.h5') as f:
    images = f['images'][:]
    additional_bits = f['additional_bit'][:]

with open('/content/drive/MyDrive/CRNN for long fields/common_fields_labels.txt', encoding='cp1251') as f:
    markup = [e.strip() for e in f.readlines()]


def encode_texts(texts):
    def _label_to_num(label, alphabet):
        label_num = []
        for ch in label:
            label_num.append(alphabet.find(ch) + 1)
        return np.array(label_num)

    # alphabet = ''.join(sorted(pd.Series(texts).apply(list).apply(pd.Series).stack().unique()))
    alphabet = ''.join(sorted(set(''.join(texts))))

    nums = np.zeros([len(texts), max([len(text) for text in texts])], dtype='int64')
    for i, text in enumerate(texts):
        nums[i][:len(text)] = _label_to_num(text, alphabet)

    return nums, alphabet

labels_encoded, alphabet = encode_texts(markup)
images = images.astype('float64') / 255

additional_bits_expanded = np.zeros((len(images), 50, 2))
additional_bits_expanded[:, :, additional_bits] = 1

np.random.seed(42)

train_indices = np.random.choice(np.arange(images.shape[0]), int(images.shape[0]*0.8), replace=False)
val_indices = [e for e in np.arange(images.shape[0]) if e not in train_indices]

assert len(set(train_indices) & set(val_indices)) == 0
assert len(set(train_indices) | set(val_indices)) == images.shape[0]

train_imgs = images[train_indices]
val_imgs = images[val_indices]

train_abits = additional_bits_expanded[train_indices]
val_abits = additional_bits_expanded[val_indices]

train_labels = labels_encoded[train_indices]
val_labels = labels_encoded[val_indices]

train_dataset = OCRDataset(train_imgs, train_abits, train_labels)
val_dataset = OCRDataset(val_imgs, val_abits, val_labels)

train_loader = DataLoader(train_dataset, batch_size=128)
val_loader = DataLoader(val_dataset, batch_size=128)

# Загрузка модели-учителя

In [8]:
import torch
from torchinfo import summary

device = 'cuda' if torch.cuda.is_available() else 'cpu'

teacher_model = CRNN(len(alphabet))
teacher_model.load_state_dict(torch.load('/content/drive/MyDrive/Методы компрессии/crnn_common_fields_.pt', map_location=torch.device(device)))
summary(teacher_model, input_size=[(32, 1, 32, 400), (32, 50, 2)], device=device)

Layer (type:depth-idx)                   Output Shape              Param #
CRNN                                     [32, 50, 46]              --
├─Sequential: 1-1                        [32, 256, 1, 50]          --
│    └─ConvBlock: 2-1                    [32, 32, 16, 200]         --
│    │    └─Sequential: 3-1              [32, 32, 16, 200]         384
│    └─ConvBlock: 2-2                    [32, 64, 8, 200]          --
│    │    └─Sequential: 3-2              [32, 64, 8, 200]          18,624
│    └─ConvBlock: 2-3                    [32, 64, 4, 100]          --
│    │    └─Sequential: 3-3              [32, 64, 4, 100]          37,056
│    └─ConvBlock: 2-4                    [32, 128, 2, 50]          --
│    │    └─Sequential: 3-4              [32, 128, 2, 50]          74,112
│    └─ConvBlock: 2-5                    [32, 256, 1, 50]          --
│    │    └─Sequential: 3-5              [32, 256, 1, 50]          295,680
├─LSTM: 1-2                              [32, 50, 256]             

In [9]:
print(dict(zip(['batch_time', 'loss', 'metric'], [round(e, 6) for e in validate_model(teacher_model, val_loader, device=device)])))

26it [00:03,  8.31it/s]


{'batch_time': 0.007023, 'loss': 0.623181, 'metric': 0.049073}





# Построение модели-ученика

In [10]:
class CRNN_light(nn.Module):
    def __init__(self, alphabet_len):
        super(CRNN_light, self).__init__()

        self.feature_extractor = nn.Sequential(ConvBlock(1, 16),
                                               ConvBlock(16, 32, (2, 1)),
                                               ConvBlock(32, 64),
                                               ConvBlock(64, 128),
                                               ConvBlock(128, 192, (2, 1)))
        self.lstm1 = nn.LSTM(194, 192, batch_first=True)
        self.lstm2 = nn.LSTM(192, 192, batch_first=True)

        self.fc = nn.Sequential(nn.Linear(192, alphabet_len+1),
                                nn.Softmax(dim=2))

    def forward(self, x1, x2):
        f1 = self.feature_extractor(x1).squeeze()
        f1 = torch.permute(f1, (0, 2, 1))

        x = torch.cat([f1, x2], dim=2)

        x, _ = self.lstm1(x)
        x, _ = self.lstm2(x)
        x = self.fc(x)

        return x

In [51]:
student_model = CRNN_light(len(alphabet))
summary(student_model, input_size=[(32, 1, 32, 400), (32, 50, 2)], device=device)

Layer (type:depth-idx)                   Output Shape              Param #
CRNN_light                               [32, 50, 46]              --
├─Sequential: 1-1                        [32, 192, 1, 50]          --
│    └─ConvBlock: 2-1                    [32, 16, 16, 200]         --
│    │    └─Sequential: 3-1              [32, 16, 16, 200]         192
│    └─ConvBlock: 2-2                    [32, 32, 8, 200]          --
│    │    └─Sequential: 3-2              [32, 32, 8, 200]          4,704
│    └─ConvBlock: 2-3                    [32, 64, 4, 100]          --
│    │    └─Sequential: 3-3              [32, 64, 4, 100]          18,624
│    └─ConvBlock: 2-4                    [32, 128, 2, 50]          --
│    │    └─Sequential: 3-4              [32, 128, 2, 50]          74,112
│    └─ConvBlock: 2-5                    [32, 192, 1, 50]          --
│    │    └─Sequential: 3-5              [32, 192, 1, 50]          221,760
├─LSTM: 1-2                              [32, 50, 192]             2

# Distilation train loop v1 (с декодированием soft targets)

Тут пока не сильно ясно: CRNN возвращает логиты, для который при сравнении с таргетом CTCLoss вычисляет вероятность декодирования в верную последовательность, соответственно, логиты модели-учителя сперва нужно декодировать и только потом передавать в CTCLoss.

Возможно, получится корретно обучаться и сравнивая "сырые" логиты.

Ну и не стоит забывать, что мне хватило мозгов использовать свой CTCLoss вместо имплементации из PyTorch -_-

In [None]:
from torch.optim.lr_scheduler import ReduceLROnPlateau

student_model.to(device)
student_model.train()

teacher_model.to(device)

optimizer = torch.optim.NAdam(student_model.parameters(), lr=1e-3)
lr_scheduler = ReduceLROnPlateau(optimizer, patience=4, min_lr=1e-5, factor=0.5)
metric = CharErrorRate()

criterion = ctc_loss_log_differentiable_torch

alpha = 0.8
epochs = 1000
early_stopping_patience = 10
val_loss_history = list()

for epoch in range(epochs):
  train_loss = 0
  val_loss = 0

  train_cer = 0
  val_cer = 0

  for i, ((x1, x2), y) in enumerate(train_loader):
    x1 = x1.to(device)
    x2 = x2.to(device)
    y = y.to(device)

    optimizer.zero_grad()
    y_pred = student_model(x1, x2)

    with torch.no_grad():
      soft_targets = teacher_model(x1, x2).argmax(dim=-1)
      soft_targets_ = torch.zeros_like(soft_targets)
      for j, s in enumerate(soft_targets):
        soft_targets_[j, :sum(s != 0)] = s[s != 0]
      soft_targets = soft_targets_

    input_lengths = torch.full((y_pred.shape[0],), y_pred.shape[1]).to(device)
    hard_target_lengths = torch.sum(y != 0, axis=1)
    soft_target_lengths = torch.sum(soft_targets != 0, axis=1)

    hard_loss = criterion(torch.log(y_pred), y, input_lengths, hard_target_lengths, device=device)
    soft_loss = criterion(torch.log(y_pred), soft_targets, input_lengths, soft_target_lengths, device=device)
    loss = alpha*hard_loss + (1-alpha)*soft_loss

    loss.backward()

    optimizer.step()

    train_loss += loss.item()
    train_cer += metric(decode_texts(y_pred.detach().cpu().numpy(), alphabet),
                        [''.join(alphabet[k-1] for k, _ in groupby(e) if k != 0) for e in y.cpu().numpy().astype(int)]).item()

    print(f'\rEpoch {epoch}, {i+1}/{len(train_loader)}, loss: {round(train_loss/(i+1), 6)}, cer: {round(train_cer/(i+1), 6)}', end='')

  with torch.no_grad():
    for i, ((x1, x2), y) in enumerate(val_loader):
      x1 = x1.to(device)
      x2 = x2.to(device)
      y = y.to(device)

      y_pred = student_model(x1, x2)

      with torch.no_grad():
        soft_targets = teacher_model(x1, x2).argmax(dim=-1)
        soft_targets_ = torch.zeros_like(soft_targets)
        for j, s in enumerate(soft_targets):
          soft_targets_[j, :sum(s != 0)] = s[s != 0]
        soft_targets = soft_targets_

      input_lengths = torch.full((y_pred.shape[0],), y_pred.shape[1]).to(device)
      hard_target_lengths = torch.sum(y != 0, axis=1)
      soft_target_lengths = torch.sum(soft_targets != 0, axis=1)

      hard_loss = criterion(torch.log(y_pred), y, input_lengths, hard_target_lengths, device=device)
      soft_loss = criterion(torch.log(y_pred), soft_targets, input_lengths, soft_target_lengths, device=device)
      loss = alpha*hard_loss + (1-alpha)*soft_loss

      val_loss += loss.item()
      val_cer += metric(decode_texts(y_pred.detach().cpu().numpy(), alphabet),
                          [''.join(alphabet[k-1] for k, _ in groupby(e) if k != 0) for e in y.cpu().numpy().astype(int)]).item()

  print(f' val_loss: {round(val_loss/len(val_loader), 6)}, val_cer: {round(val_cer/len(val_loader), 6)}')

  lr_scheduler.step(val_loss/len(val_loader))
  val_loss_history.append(val_loss/len(val_loader))

  if min(val_loss_history) < min(val_loss_history[-early_stopping_patience:]):
    break

# Сравнение моделей

In [53]:
print(dict(zip(['batch_time', 'loss', 'metric'], [round(e, 6) for e in validate_model(teacher_model, val_loader, device=device)])))

26it [00:01, 15.21it/s]


{'batch_time': 0.002948, 'loss': 0.623181, 'metric': 0.049073}





In [54]:
print(dict(zip(['batch_time', 'loss', 'metric'], [round(e, 6) for e in validate_model(student_model, val_loader, device=device)])))

26it [00:02, 11.52it/s]


{'batch_time': 0.002454, 'loss': 1.161216, 'metric': 0.054938}





In [55]:
from google.colab import runtime

runtime.unassign()