In [1]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from tqdm.notebook import tqdm
from PIL import Image
import random

import os
import json
import cv2

import torch
from torch.utils.data import random_split, DataLoader, Dataset
from torchvision import datasets, transforms

from transformers import TrOCRProcessor, VisionEncoderDecoderModel, default_data_collator
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments

In [2]:
basketball_df = pd.read_csv('for_test\\test_bascetball.csv')
streetball_df = pd.read_csv('for_test\\test_stritball.csv')
volleyball_df = pd.read_csv('for_test\\test_volleyball.csv')

basketball_df.reset_index(drop=True, inplace=True)
streetball_df.reset_index(drop=True, inplace=True)
volleyball_df.reset_index(drop=True, inplace=True)

In [3]:
store_pathes = []

class IAMDataset(Dataset):
    def __init__(self, root_dir, df, processor, max_target_length=3):
        self.root_dir = root_dir
        self.df = df
        self.processor = processor
        self.max_target_length = max_target_length

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
      
        file_name = self.df['file_name'][idx]
        text = str(self.df['text'][idx])
        store_pathes.append(file_name)
   
        image = Image.open(file_name).convert("RGB")
        image = image.resize((64, 64))
        pixel_values = self.processor(image, return_tensors="pt").pixel_values
        # add labels (input_ids) by encoding the text
        labels = self.processor.tokenizer(text, 
                                          padding="max_length", 
                                          max_length=self.max_target_length).input_ids
        # important: make sure that PAD tokens are ignored by the loss function
        labels = [label if label != self.processor.tokenizer.pad_token_id else -100 for label in labels]

        encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
        return encoding

In [4]:
from transformers import TrOCRProcessor

processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
basketball_dataset = IAMDataset(root_dir='C:\\Users\\Mytre\\OneDrive\\Документы\\Data\\Work\\',
                           df=basketball_df,
                           processor=processor)
streetball_dataset = IAMDataset(root_dir='C:\\Users\\Mytre\\OneDrive\\Документы\\Data\\Work\\',
                           df=streetball_df,
                           processor=processor)
volleyball_dataset = IAMDataset(root_dir='C:\\Users\\Mytre\\OneDrive\\Документы\\Data\\Work\\',
                           df=volleyball_df,
                           processor=processor)

Could not find image processor class in the image processor config or the model config. Loading based on pattern matching with the model's feature extractor configuration.


In [5]:
print("Number of basketball_dataset examples:", len(basketball_dataset))
print("Number of streetball_dataset examples:", len(streetball_dataset))
print("Number of volleyball_dataset examples:", len(volleyball_dataset))

Number of basketball_dataset examples: 7551
Number of streetball_dataset examples: 22593
Number of volleyball_dataset examples: 7515


In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = VisionEncoderDecoderModel.from_pretrained("model_best1")
model.to(device)
print(device)

cuda


In [7]:
# set special tokens used for creating the decoder_input_ids from the labels
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
# make sure vocab size is set correctly
model.config.vocab_size = model.config.decoder.vocab_size

# set beam search parameters
model.config.eos_token_id = processor.tokenizer.sep_token_id
model.config.max_length = 4
model.config.early_stopping = True
model.config.no_repeat_ngram_size = 3
model.config.length_penalty = 2.0
model.config.num_beams = 10

In [8]:
import evaluate

cer_metric = evaluate.load("cer")
acc_metric = evaluate.load("accuracy")

In [9]:
def compute_metrics(pred):
    labels_ids = pred.label_ids
    pred_ids = pred.predictions    

    pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
    labels_ids[labels_ids == -100] = processor.tokenizer.pad_token_id
    label_str = processor.batch_decode(labels_ids, skip_special_tokens=True)    

    label_int = [int(x) for x in label_str]
    x = [] 
    for j in pred_str:       
        if j.isdigit():
            x.append(int(j))
        else:
            x.append(1000)

    cer = cer_metric.compute(predictions=pred_str, references=label_str)
    acc = acc_metric.compute(predictions=x, references=label_int)

    return {"cer": cer,
            "accuracy": acc}

In [10]:
basketball_dataloader = DataLoader(basketball_dataset, batch_size=32)
streetball_dataloader = DataLoader(streetball_dataset, batch_size=32)
volleyball_dataloader = DataLoader(volleyball_dataset, batch_size=32)

# set special tokens used for creating the decoder_input_ids from the labels
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
# make sure vocab size is set correctly
model.config.vocab_size = model.config.decoder.vocab_size

# set beam search parameters
model.config.eos_token_id = processor.tokenizer.sep_token_id
model.config.max_length = 4
model.config.early_stopping = True
model.config.no_repeat_ngram_size = 3
model.config.length_penalty = 2.0
model.config.num_beams = 10

In [11]:
from sklearn.metrics import accuracy_score

def calc_acc(dataloader, model):

    torch.cuda.empty_cache()
    # Define the device to run the evaluation on
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Set the model to evaluation mode
    model.eval()
    model.to(device)
    # Evaluate the model on the eval dataset
    diff_acc = []
    acc = []
    i = 0

    with torch.no_grad():
        for batch in tqdm(dataloader):
        
            target_text = processor.batch_decode(batch['labels'], skip_special_tokens=True)
            target_text = [int(x) for x in target_text]
            
            pixel_values = batch['pixel_values']
            generated_ids = model.generate(pixel_values.cuda(), max_new_tokens=4)       
        
            # Make a prediction
            generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)

            x = [] 
            for j in generated_text:       
                if j.isdigit():
                    x.append(int(j))
                else:
                    x.append(1000)        
        
            bach_acc = accuracy_score(target_text, x)
            # Save the true and predicted labels
            diff_acc.append(bach_acc)  
            acc.append(bach_acc)

            if (i % 50  == 0) & (i > 2):
                accuracy = np.mean(diff_acc)
                print("Accuracy:", accuracy) 
            i += 1

        print(f"Total accuracy: {np.mean(acc)}")    

Посчитаем точность на basketball_dataset

In [12]:
calc_acc(basketball_dataloader, model)

  0%|          | 0/236 [00:00<?, ?it/s]

Accuracy: 0.9209558823529411
Accuracy: 0.9232673267326733
Accuracy: 0.9215645695364238
Accuracy: 0.9216417910447762
Total accuracy: 0.9226609486057955


Посчитаем точность на streetball_dataset

In [13]:
calc_acc(streetball_dataloader, model)

  0%|          | 0/707 [00:00<?, ?it/s]

Accuracy: 0.9761029411764706
Accuracy: 0.9771039603960396
Accuracy: 0.9749586092715232
Accuracy: 0.9760572139303483
Accuracy: 0.9764691235059761
Accuracy: 0.9757059800664452
Accuracy: 0.9762286324786325
Accuracy: 0.975997506234414
Accuracy: 0.9751940133037694
Accuracy: 0.9752994011976048
Accuracy: 0.9757259528130672
Accuracy: 0.9755615640599001
Accuracy: 0.9761424731182796
Accuracy: 0.9761055634807418
Total accuracy: 0.9761757425742574


Посчитаем точность на volleyball_dataset

In [14]:
calc_acc(volleyball_dataloader, model)

  0%|          | 0/235 [00:00<?, ?it/s]

Accuracy: 0.9399509803921569
Accuracy: 0.9337871287128713
Accuracy: 0.9323261589403974
Accuracy: 0.9333022388059702
Total accuracy: 0.9337273443656422
