In [22]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from tqdm.notebook import tqdm
from PIL import Image
import random

import os
import json
import cv2

import torch
from torch.utils.data import random_split, DataLoader, Dataset
from torchvision import datasets, transforms

from transformers import TrOCRProcessor, VisionEncoderDecoderModel, default_data_collator
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments

In [23]:
basketball_df = pd.read_csv('for_train\\train_bascetball.csv')
streetball_df = pd.read_csv('for_train\\train_streetball.csv')
volleyball_df = pd.read_csv('for_train\\train_volleyball.csv')

basketball_df.reset_index(drop=True, inplace=True)
streetball_df.reset_index(drop=True, inplace=True)
volleyball_df.reset_index(drop=True, inplace=True)

In [24]:
store_pathes = []

class IAMDataset(Dataset):
    def __init__(self, root_dir, df, processor, max_target_length=3):
        self.root_dir = root_dir
        self.df = df
        self.processor = processor
        self.max_target_length = max_target_length

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
      
        file_name = self.df['file_name'][idx]
        text = str(self.df['text'][idx])
        store_pathes.append(file_name)
   
        image = Image.open(file_name).convert("RGB")
        image = image.resize((64, 64))
        pixel_values = self.processor(image, return_tensors="pt").pixel_values
        # add labels (input_ids) by encoding the text
        labels = self.processor.tokenizer(text, 
                                          padding="max_length", 
                                          max_length=self.max_target_length).input_ids
        # important: make sure that PAD tokens are ignored by the loss function
        labels = [label if label != self.processor.tokenizer.pad_token_id else -100 for label in labels]

        encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
        return encoding

In [25]:
from transformers import TrOCRProcessor

processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
basketball_dataset = IAMDataset(root_dir='C:\\Users\\Mytre\\OneDrive\\Документы\\Data\\Work\\',
                           df=basketball_df,
                           processor=processor)
streetball_dataset = IAMDataset(root_dir='C:\\Users\\Mytre\\OneDrive\\Документы\\Data\\Work\\',
                           df=streetball_df,
                           processor=processor)
volleyball_dataset = IAMDataset(root_dir='C:\\Users\\Mytre\\OneDrive\\Документы\\Data\\Work\\',
                           df=volleyball_df,
                           processor=processor)

Could not find image processor class in the image processor config or the model config. Loading based on pattern matching with the model's feature extractor configuration.


In [26]:
print("Number of basketball_dataset examples:", len(basketball_dataset))
print("Number of streetball_dataset examples:", len(streetball_dataset))
print("Number of volleyball_dataset examples:", len(volleyball_dataset))

Number of basketball_dataset examples: 30204
Number of streetball_dataset examples: 90376
Number of volleyball_dataset examples: 30057


In [27]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = VisionEncoderDecoderModel.from_pretrained("G:\\models1\\checkpoint-0.86")
model.to(device)
print(device)

cuda


In [28]:
# set special tokens used for creating the decoder_input_ids from the labels
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
# make sure vocab size is set correctly
model.config.vocab_size = model.config.decoder.vocab_size

# set beam search parameters
model.config.eos_token_id = processor.tokenizer.sep_token_id
model.config.max_length = 4
model.config.early_stopping = True
model.config.no_repeat_ngram_size = 3
model.config.length_penalty = 2.0
model.config.num_beams = 10

In [29]:
import evaluate

cer_metric = evaluate.load("cer")
acc_metric = evaluate.load("accuracy")

In [30]:
def compute_metrics(pred):
    labels_ids = pred.label_ids
    pred_ids = pred.predictions    

    pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
    labels_ids[labels_ids == -100] = processor.tokenizer.pad_token_id
    label_str = processor.batch_decode(labels_ids, skip_special_tokens=True)    

    label_int = [int(x) for x in label_str]
    x = [] 
    for j in pred_str:       
        if j.isdigit():
            x.append(int(j))
        else:
            x.append(1000)

    cer = cer_metric.compute(predictions=pred_str, references=label_str)
    acc = acc_metric.compute(predictions=x, references=label_int)

    return {"cer": cer,
            "accuracy": acc}

In [31]:
basketball_dataloader = DataLoader(basketball_dataset, batch_size=32)
streetball_dataloader = DataLoader(streetball_dataset, batch_size=32)
volleyball_dataloader = DataLoader(volleyball_dataset, batch_size=32)

# set special tokens used for creating the decoder_input_ids from the labels
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
# make sure vocab size is set correctly
model.config.vocab_size = model.config.decoder.vocab_size

# set beam search parameters
model.config.eos_token_id = processor.tokenizer.sep_token_id
model.config.max_length = 4
model.config.early_stopping = True
model.config.no_repeat_ngram_size = 3
model.config.length_penalty = 2.0
model.config.num_beams = 10

In [32]:
from sklearn.metrics import accuracy_score

def calc_acc(dataloader, model):

    torch.cuda.empty_cache()
    # Define the device to run the evaluation on
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Set the model to evaluation mode
    model.eval()
    model.to(device)
    # Evaluate the model on the eval dataset
    diff_acc = []
    acc = []
    i = 0

    with torch.no_grad():
        for batch in tqdm(dataloader):
        
            target_text = processor.batch_decode(batch['labels'], skip_special_tokens=True)
            target_text = [int(x) for x in target_text]
            
            pixel_values = batch['pixel_values']
            generated_ids = model.generate(pixel_values.cuda(), max_new_tokens=4)       
        
            # Make a prediction
            generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)

            x = [] 
            for j in generated_text:       
                if j.isdigit():
                    x.append(int(j))
                else:
                    x.append(1000)        
        
            bach_acc = accuracy_score(target_text, x)
            # Save the true and predicted labels
            diff_acc.append(bach_acc)  
            acc.append(bach_acc)

            if (i % 50  == 0) & (i > 2):
                accuracy = np.mean(diff_acc)
                print("Accuracy:", accuracy) 
            i += 1

        print(f"Total accuracy: {np.mean(acc)}")    

Посчитаем точность на basketball_dataset

In [33]:
calc_acc(basketball_dataloader, model)

  0%|          | 0/944 [00:00<?, ?it/s]

Accuracy: 0.7775735294117647
Accuracy: 0.7886757425742574
Accuracy: 0.7847682119205298


KeyboardInterrupt: 

Посчитаем точность на streetball_dataset

In [None]:
calc_acc(streetball_dataloader, model)

  0%|          | 0/2825 [00:00<?, ?it/s]

Accuracy: 0.9558823529411765
Accuracy: 0.9554455445544554
Accuracy: 0.956953642384106
Accuracy: 0.9564676616915423
Accuracy: 0.956050796812749
Accuracy: 0.9560838870431894
Accuracy: 0.9548611111111112
Accuracy: 0.9533977556109726
Accuracy: 0.9534368070953437
Accuracy: 0.9528443113772455
Accuracy: 0.9533802177858439
Accuracy: 0.9538269550748752
Accuracy: 0.9542050691244239
Accuracy: 0.954083452211127
Accuracy: 0.9540612516644474
Accuracy: 0.9536516853932584
Accuracy: 0.9540981198589894
Accuracy: 0.9539053829078802
Accuracy: 0.9540615141955836
Accuracy: 0.9542644855144855
Accuracy: 0.9542400095147479
Accuracy: 0.9545299727520435
Accuracy: 0.9546861424847958
Accuracy: 0.9551155287260616
Accuracy: 0.9553607114308553
Accuracy: 0.9552027286702537
Accuracy: 0.9551027017024426
Accuracy: 0.9552774803711634
Accuracy: 0.9554186767746382
Accuracy: 0.9554255496335776
Accuracy: 0.9554722759509994
Accuracy: 0.9555551217988757
Accuracy: 0.9556518776499091
Accuracy: 0.9556510875955321
Accuracy: 0.95579

Посчитаем точность на volleyball_dataset

In [None]:
calc_acc(volleyball_dataloader, model)

  0%|          | 0/940 [00:00<?, ?it/s]

Accuracy: 0.8180147058823529
Accuracy: 0.8251856435643564
Accuracy: 0.8247102649006622
Accuracy: 0.8219838308457711
Accuracy: 0.821339641434263
Accuracy: 0.8197674418604651
Accuracy: 0.8205128205128205
Accuracy: 0.8198254364089775
Accuracy: 0.8225471175166297
Accuracy: 0.8227295409181636
Accuracy: 0.8236728675136116
Accuracy: 0.8240952579034941
Accuracy: 0.823636712749616
Accuracy: 0.8227977888730386
Accuracy: 0.8225699067909454
Accuracy: 0.8208879525593009
Accuracy: 0.8210193889541716
Accuracy: 0.8191592674805771
Total accuracy: 0.8188164893617021
