In [1]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from tqdm.notebook import tqdm
from PIL import Image
import random

import torch
from torch.optim import AdamW
from torch.utils.data import random_split, DataLoader, Dataset
from torchvision import datasets, transforms
from torch.utils.tensorboard import SummaryWriter

from transformers import TrOCRProcessor, VisionEncoderDecoderModel, default_data_collator
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
import evaluate


Завиксируем всю случайность!

In [2]:
seed = 23
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.determenistic = True

Соберем наш датасет из полученных файлов

In [3]:
df1 = pd.read_csv('anno_first.csv')
df2 = pd.read_csv('anno_second.csv')
df3 = pd.read_csv('anno_basketball.csv')
df4 = pd.read_csv('anno_fiba.csv')
#df5 = pd.read_csv('anno_ncaa.csv')
df = pd.concat([df1, df2, df4], axis=0)
df.shape

(88984, 3)

Разобьем датасет на train test и eval части

In [4]:
diff_df, test_df = train_test_split(df, test_size=0.2)
train_df, eval_df = train_test_split(diff_df, test_size=0.2)
# we reset the indices to start from zero
train_df.reset_index(drop=True, inplace=True)
test_df.reset_index(drop=True, inplace=True)
eval_df.reset_index(drop=True, inplace=True)

In [5]:
class IAMDataset(Dataset):
    def __init__(self, root_dir, df, processor, max_target_length=3):
        self.root_dir = root_dir
        self.df = df
        self.processor = processor
        self.max_target_length = max_target_length

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
      
        file_name = self.df['file_name'][idx]
        text = str(self.df['text'][idx])
   
        image = Image.open(file_name).convert("RGB")
        image = image.resize((64, 64))
        pixel_values = self.processor(image, return_tensors="pt").pixel_values
        # add labels (input_ids) by encoding the text
        labels = self.processor.tokenizer(text, 
                                          padding="max_length", 
                                          max_length=self.max_target_length).input_ids
        # important: make sure that PAD tokens are ignored by the loss function
        labels = [label if label != self.processor.tokenizer.pad_token_id else -100 for label in labels]

        encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
        return encoding

Создадим наши датасеты

In [6]:
from transformers import TrOCRProcessor

processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
train_dataset = IAMDataset(root_dir='C:\\Users\\Mytre\\OneDrive\\Документы\\Data\\Work\\',
                           df=train_df,
                           processor=processor)
test_dataset = IAMDataset(root_dir='C:\\Users\\Mytre\\OneDrive\\Документы\\Data\\Work\\',
                           df=test_df,
                           processor=processor)
eval_dataset = IAMDataset(root_dir='C:\\Users\\Mytre\\OneDrive\\Документы\\Data\\Work\\',
                           df=test_df,
                           processor=processor)

Could not find image processor class in the image processor config or the model config. Loading based on pattern matching with the model's feature extractor configuration.


In [7]:
print("Number of training examples:", len(train_dataset))
print("Number of testing examples:", len(test_dataset))
print("Number of validation examples:", len(eval_dataset))

Number of training examples: 56949
Number of testing examples: 17797
Number of validation examples: 17797


In [8]:
train_dataloader = DataLoader(train_dataset, batch_size=3, shuffle=True)
eval_dataloader = DataLoader(eval_dataset, batch_size=3)
test_dataloader = DataLoader(test_dataset, batch_size=32)

In [9]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

tensor_board = SummaryWriter()

model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-stage1")
model.to(device)

optimizer = AdamW(model.parameters(), lr=0.0001)

Some weights of VisionEncoderDecoderModel were not initialized from the model checkpoint at microsoft/trocr-base-stage1 and are newly initialized: ['encoder.pooler.dense.bias', 'encoder.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Сконфигурируем нашу модель

In [10]:
# set special tokens used for creating the decoder_input_ids from the labels
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
# make sure vocab size is set correctly
model.config.vocab_size = model.config.decoder.vocab_size

# set beam search parameters
model.config.eos_token_id = processor.tokenizer.sep_token_id
model.config.max_length = 4
model.config.early_stopping = True
model.config.no_repeat_ngram_size = 3
model.config.length_penalty = 2.0
model.config.num_beams = 10

Оопределим метрику

In [11]:
acc_metric = evaluate.load("accuracy")
cer_metric = evaluate.load("cer")

In [12]:
def compute_acc(pred_ids, label_ids):
    pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
    label_ids[label_ids == -100] = processor.tokenizer.pad_token_id
    label_str = processor.batch_decode(label_ids, skip_special_tokens=True)
    
    x = [] 
    for j in pred_str:       
        if j.isdigit():
            x.append(int(j))
        else:
            x.append(1000)
    acc = acc_metric.compute(predictions=x, references=label_str)

    return acc

def compute_cer(pred_ids, label_ids):
    pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
    label_ids[label_ids == -100] = processor.tokenizer.pad_token_id
    label_str = processor.batch_decode(label_ids, skip_special_tokens=True)

    cer = cer_metric.compute(predictions=pred_str, references=label_str)

    return cer

In [14]:
step = 0
step_val = 0
accuracy = 0.
counter = 0

for epoch in range(100):  # loop over the dataset multiple times
    # train
    model.train()
    train_loss = []
    acc_batch = []
    train_loss_batch = []
    i, j = 0, 0
    
    print(f'Epoch {epoch} start.')

    for batch in tqdm(train_dataloader):
        # get the inputs
        for k,v in batch.items():
            batch[k] = v.to(device)
        break
        # forward + backward + optimize
        outputs = model(**batch)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        #acc_train = compute_acc(pred_ids=outputs, label_ids=batch["labels"])
        #acc_batch.append(acc_train['accuracy'])
        train_loss.append(loss.item())
        train_loss_batch.append(loss.item())

        if (i % 100 == 0) & (i > 2 ):
            print(f'Train loss: {np.mean(train_loss_batch)}')
            tensor_board.add_scalar('Train loss:', np.mean(train_loss_batch), global_step=step)   
            #tensor_board.add_scalar('Validation accuracy:', np.mean(acc_batch), global_step=step_val)   
            #acc_batch = []      
            step += 1
        i += 1
        
    print(f"Loss after epoch {epoch}:", np.mean(train_loss))
    
    # evaluate
    model.eval()

    valid_acc = []
    valid_cer = []
    valid_acc_batch = []
    valid_cer_batch = []

    with torch.no_grad():
        for batch in tqdm(eval_dataloader):
            # run batch generation
            outputs = model.generate(batch["pixel_values"].to(device), max_new_tokens=4)
            # compute metrics
            
            acc = compute_acc(pred_ids=outputs, label_ids=batch["labels"])
            cer = compute_cer(pred_ids=outputs, label_ids=batch["labels"])           
            
            valid_cer.append(cer)       
            valid_acc.append(acc['accuracy'])  
            valid_acc_batch.append(acc['accuracy']) 
            valid_cer_batch.append(cer)

            if (j % 100 == 0) & (j > 2 ):
                print(f'Validation accuracy: {np.mean(valid_acc_batch)}')
                print(f'Validation cer:: {np.mean(valid_cer_batch)}')

                tensor_board.add_scalar('Validation accuracy:', np.mean(valid_acc_batch), global_step=step_val)
                tensor_board.add_scalar('Validation cer:', np.mean(valid_cer_batch), global_step=step_val)
                
                valid_acc_batch = []
                valid_cer_batch
                step_val += 1
            j += 1 

        counter += 1    
        if np.mean(valid_acc) > accuracy:
            model.save_pretrained('model_best')    
            accuracy = np.mean(valid_acc)
            counter = 0

        if counter > 5:
            print('Early stopping!!!')
            print(f'Result validation accuracy: {accuracy}')
            break

        print("Validation accuracy:", np.mean(valid_acc))

Epoch 0 start.


  0%|          | 0/18983 [00:00<?, ?it/s]

Loss after epoch 0: nan


  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


  0%|          | 0/5933 [00:00<?, ?it/s]



Validation accuracy: 0.12871287128712872
Validation cer:: 0.8485148514851485
Validation accuracy: 0.16999999999999996
Validation cer:: 0.8269485903814262
Validation accuracy: 0.12666666666666668
Validation cer:: 0.8391472868217055
Validation accuracy: 0.16666666666666663
Validation cer:: 0.8305664410402565
Validation accuracy: 0.14666666666666667
Validation cer:: 0.8330149225358806
Validation accuracy: 0.13333333333333333
Validation cer:: 0.8355914745265828
Validation accuracy: 0.11333333333333333
Validation cer:: 0.8400720059778548
Validation accuracy: 0.12
Validation cer:: 0.8430800784733369
Validation accuracy: 0.14666666666666664
Validation cer:: 0.842460757888061
Validation accuracy: 0.11333333333333333
Validation cer:: 0.8447290804433661
Validation accuracy: 0.16
Validation cer:: 0.8421197180052765
Validation accuracy: 0.1
Validation cer:: 0.8435529915546569
Validation accuracy: 0.12666666666666665
Validation cer:: 0.844650635042641
Validation accuracy: 0.15333333333333335
Valida

  0%|          | 0/18983 [00:00<?, ?it/s]

Loss after epoch 1: nan


  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


  0%|          | 0/5933 [00:00<?, ?it/s]

Validation accuracy: 0.12871287128712872
Validation cer:: 0.8485148514851485
Validation accuracy: 0.16999999999999996
Validation cer:: 0.8269485903814262
Validation accuracy: 0.12666666666666668
Validation cer:: 0.8391472868217055
Validation accuracy: 0.16666666666666663
Validation cer:: 0.8305664410402565
Validation accuracy: 0.14666666666666667
Validation cer:: 0.8330149225358806
Validation accuracy: 0.13333333333333333
Validation cer:: 0.8355914745265828
Validation accuracy: 0.11333333333333333
Validation cer:: 0.8400720059778548


KeyboardInterrupt: 