## read and prepare data

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import json
import pandas as pd
import re
# import kaggle
# import opendatasets as od
# import kagglehub
import numpy as np
import json
import os
import cv2
from PIL import Image
# from kagglehub import KaggleDatasetAdapter
import sys

In [3]:
%env PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True

env: PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True


In [4]:
sys.path.append(".")

In [None]:
! curl -L -o ./rustitw-russian-language-visual-text-recognition.zip\
  https://www.kaggle.com/api/v1/datasets/download/hardtype/rustitw-russian-language-visual-text-recognition

In [None]:
path_train = './rustitw-russian-language-visual-text-recognition/train/real/images/'
path_test = './rustitw-russian-language-visual-text-recognition/test/real/images/'
groundtruth_test_path = "./test/groundtruth/"
output_train_easyocr_path = "./train/easyocr_train_data/"
path_train_annotations= './rustitw-russian-language-visual-text-recognition/train/real/'
path_test_annotations= './rustitw-russian-language-visual-text-recognition/test/real/'
groundtruth_train_path = "./train/groundtruth/"

In [7]:
def is_russian_text(text):
    """Checking the Russian text."""
    russian_pattern = re.compile(r'^[а-яА-ЯёЁ\s\n]+$')
    return bool(russian_pattern.match(text))

def prepare_annotations(info_csv_path, images_folder, max_samples=None):
    """Preparation of the annotation"""
    data = pd.read_csv(info_csv_path)
    annotations = []

    if max_samples is not None:
        data = data.head(max_samples)

    for idx, row in data.iterrows():
        image_path = images_folder + row['image_path']
        img_width = row['width']
        img_height = row['height']
        bboxes = json.loads(row['box_and_label'])[0]

        abs_bboxes = []
        labels = []
        for bbox in bboxes:
    
            label = bbox['label']
            # if is_russian_text(label):
            abs_bboxes.append(bbox)
            labels.append("_".join(label.split()))

        if abs_bboxes:  
            annotations.append({
                'image_path': image_path,
                'bboxes': abs_bboxes,
                'labels': labels
            })

    return annotations

In [8]:
def save_json_to_files(list_of_json, folder_path):
    """
    Saves a list of JSON objects to individual files in a specified folder.

    Args:
        list_of_json (list): A list of JSON objects (as Python dictionaries).
        folder_path (str): The path to the folder where files will be saved.
    """

    # Create the output directory if it doesn't exist
    try:
        os.makedirs(folder_path, exist_ok=True)
        print(f"Directory '{folder_path}' created or already exists.")
    except OSError as e:
        print(f"Error creating directory: {e}")
        return

    for json_obj in list_of_json:
        try:
            # Extract the filename from the 'image_path' value
            match = re.search(r'/(\d+)\.jpg$', json_obj.get("image_path"))

            if match:
                extracted_string = match.group(1)
            else:
                print("No match was found.")
            file_name = os.path.basename(extracted_string+".json")

            if not file_name:
                print("Skipping JSON object due to missing 'image_path' or empty value.")
                continue

            # Construct the full file path
            file_path = os.path.join(folder_path, file_name)

            # Save the JSON object to the file
            with open(file_path, 'w') as f:
                json.dump(json_obj, f, indent=4)

            print(f"Successfully saved {file_path}")

        except KeyError:
            print("Skipping a JSON object because it lacks the 'image_path' key.")
        except Exception as e:
            print(f"An error occurred: {e}")

In [9]:
def load_sample(image_path, groundtruth_path):
    """Loads an image and its corresponding JSON annotation."""
    if not os.path.exists(image_path) or not os.path.exists(groundtruth_path):
        return None, None
        
    # Load image
    image = Image.open(image_path).convert("RGB")
    width, height = image.size
    # Load and parse groundtruth
    with open(groundtruth_path, 'r', encoding='utf-8') as f:
        gt_data = json.load(f)
        
    annotations = []
    for item in gt_data['bboxes']:
        lang = "english"
        if is_russian_text(item['label']):
            lang = "russian"

        x = item['left'] * width
        y = item['top'] * height
        w = item['width'] * width
        h = item['height'] * height
        
        left = x
        top = y
        right = x + w
        bottom = y + h

        annotations.append({
            'text': item['label'],

            'language': lang, # Handle cases where language might be missing
            # The dataset uses [x1, y1, x2, y2, x3, y3, x4, y4] format
            'points': [left, top,  right, top, right, bottom, left, bottom] 
        })
        
    return image, annotations
    
def get_all_filepaths(image_dir, gt_dir):
    """Returns a list of corresponding image and groundtruth file paths."""
    filepaths = []
    for img_name in os.listdir(image_dir):
        if img_name.lower().endswith(('.png', '.jpg', '.jpeg')):
            base_name = os.path.splitext(img_name)[0]
            gt_name = f"{base_name}.json"
            
            img_path = os.path.join(image_dir, img_name)
            gt_path = os.path.join(gt_dir, gt_name)
            if os.path.exists(gt_path):
                filepaths.append((img_path, gt_path))
    return filepaths

In [10]:

real_train_annotations = prepare_annotations(path_train_annotations + 'info.csv', path_train_annotations)
real_test_annotations = prepare_annotations(path_test_annotations + 'info.csv', path_test_annotations)

In [11]:
len(real_train_annotations)

24366

In [None]:
save_json_to_files(real_test_annotations, groundtruth_test_path)

In [None]:
save_json_to_files(real_train_annotations, groundtruth_train_path)

## Evaluate Qwen

In [15]:
# !pip install python-Levenshtein

In [16]:
import numpy as np
from Levenshtein import distance as levenshtein_distance
from shapely.geometry import Polygon

In [17]:
def get_reading_order_sort_key(box_data):
    """
    Creates a sort key for top-to-bottom, left-to-right reading order.
    Args:
        box_data (dict): A dictionary with a 'points' key, which is a list
                         of coordinates like [x1, y1, x2, y2, ...].
    Returns:
        A tuple (y, x) for sorting.
    """
    points = np.array(box_data['points']).reshape(-1, 2)
    # Use the y-coordinate of the top-left corner for primary sorting (top-to-bottom)
    top_y = np.min(points[:, 1])
    # Use the x-coordinate of the top-left corner for secondary sorting (left-to-right)
    left_x = np.min(points[:, 0])
    return (top_y, left_x)

def calculate_image_cer(predictions, ground_truths):
    """
    Calculates the Character Error Rate for the entire text content of an image.
    
    Args:
        predictions (list of dict): Model's output, e.g., [{'text': 'hello', 'points': [...]}, ...].
        ground_truths (list of dict): Ground truth annotations in the same format.
        
    Returns:
        float: The calculated CER for the full image text.
    """
    # --- Step 1: Sort both lists into reading order ---
    sorted_predictions = sorted(predictions, key=get_reading_order_sort_key)
    sorted_ground_truths = sorted(ground_truths, key=get_reading_order_sort_key)
    
    # --- Step 2: Concatenate text into single strings ---
    # We use a space to separate words, which is a reasonable approximation.
    predicted_text = " ".join([p['text'].strip() for p in sorted_predictions])
    ground_truth_text = " ".join([gt['text'].strip() for gt in sorted_ground_truths])

    # --- Step 3: Calculate Levenshtein distance and CER ---
    if not ground_truth_text:
        # If there is no ground truth text, CER is 0 if prediction is also empty,
        # and 1 (100% error) otherwise.
        return 0.0 if not predicted_text else 1.0

    distance = levenshtein_distance(predicted_text.lower(), ground_truth_text.lower())
    cer = distance / len(ground_truth_text)
    
    return cer

## QWEN

In [None]:
import torch, gc
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor, BitsAndBytesConfig
from transformers.generation import GenerationConfig
import re


  from .autonotebook import tqdm as notebook_tqdm


In [21]:
class QwenEvaluator:
    def __init__(self, model_name="Qwen/Qwen-VL-Chat"):
        self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_name, 
                torch_dtype=torch.float16, 
                # attn_implementation="flash_attention_2",
                device_map="auto")
        self.processor = AutoProcessor.from_pretrained(model_name)
    
    def _parse_response(self, response_text):
        """Parses the model's string output into structured data."""
        predictions = []
        pattern = r'\{\((\d+),(\d+),(\d+),(\d+)\)\}\s*--\s*(.*)'

        # Use re.findall to get a list of all matching groups
        matches = re.findall(pattern, response_text)
        
        for match in matches:
            try:
                # The first 4 elements are the coordinate strings
                coords = match[:4]
                # The 5th element is the text string
                text = match[4]
        
                # Convert coordinate strings to integers
                x1, y1, x2, y2 = map(int, coords)
                
                # Clean up the extracted text
                cleaned_text = text.strip().replace('"', '')
                
                # Convert to 4-point polygon format for your evaluator
                # (top-left, top-right, bottom-right, bottom-left)
                points = [x1, y1, x2, y1, x2, y2, x1, y2]
                
                predictions.append({'text': cleaned_text, 'points': points})
            except (ValueError, IndexError) as e:
                print(f"Skipping a malformed match: {match} due to error: {e}")
        return predictions
    def inference(self, image, image_path, prompt ='Read all the text in the image. For each section of text, print its bounding box and text in this bounding box in the format: {(x1,y1),(x2,y2)} -- text',
                  sys_prompt="You are a helpful assistant.", max_new_tokens=1024, return_input=False):
        image_local_path = image_path
        img_width, img_height = image.size
        # prompt = 'Read all the text in the image. For each section of text, print its bounding box and text in this bounding box in the format: {(x1,y1),(x2,y2)} -- text'
        # prompt = "Read all the text in the image."
        messages = [
            {"role": "system", "content": sys_prompt},
            {"role": "user", "content": [
                    {"type": "text", "text": prompt},
                    {"image": image_local_path},
                ]
            },
        ]
        text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        print("text:", text)
        # image_inputs, video_inputs = process_vision_info([messages])
        inputs = self.processor(text=[text], images=[image], padding=True, return_tensors="pt")
        inputs = inputs.to('cuda')
        with torch.no_grad():
            output_ids =self.model.generate(**inputs, max_new_tokens=max_new_tokens)
        generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, output_ids)]
        
        response = self.processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)

        print(response)
        return self._parse_response(response[0])
    def predict(self, image):
        """Performs OCR on a single image."""
        # img_width, img_height = image.size
        
        # Qwen-VL prompt for structured OCR
        prompt = 'Detect all text in the image. For each text instance, provide its bounding box and content using the format <box>(x1,y1),(x2,y2)</box>text.'
        
        query = self.tokenizer.from_list_format([
            {'image': image},
            {'text': prompt},
        ])
        
        with torch.no_grad():
            response, _ = self.model.chat(self.tokenizer, query=query, history=None)
            
        return self._parse_response(response, img_width, img_height)
    

## prepare for EasyOCR

In [23]:
import os
import json
import cv2
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
import csv

In [24]:
import shutil # For safely creating/deleting directories
from sklearn.model_selection import train_test_split

In [25]:
def warp_and_crop(image: np.ndarray, box: np.ndarray):
    """
    Takes a 4-point bounding box and warps it into a straightened, cropped image.
    """
    rect = np.zeros((4, 2), dtype="float32")
    s = box.sum(axis=1); rect[0] = box[np.argmin(s)]; rect[2] = box[np.argmax(s)]
    diff = np.diff(box, axis=1); rect[1] = box[np.argmin(diff)]; rect[3] = box[np.argmax(diff)]
    (tl, tr, br, bl) = rect
    
    widthA = np.sqrt(((br[0] - bl[0])**2) + ((br[1] - bl[1])**2))
    widthB = np.sqrt(((tr[0] - tl[0])**2) + ((tr[1] - tl[1])**2))
    maxWidth = max(int(widthA), int(widthB))
    
    heightA = np.sqrt(((tr[0] - br[0])**2) + ((tr[1] - br[1])**2))
    heightB = np.sqrt(((tl[0] - bl[0])**2) + ((tl[1] - bl[1])**2))
    maxHeight = max(int(heightA), int(heightB))
    
    if maxWidth == 0 or maxHeight == 0: return None # Avoid errors on zero-size boxes
    
    dst = np.array([[0, 0], [maxWidth-1, 0], [maxWidth-1, maxHeight-1], [0, maxHeight-1]], dtype="float32")
    M = cv2.getPerspectiveTransform(rect, dst)
    warped = cv2.warpPerspective(image, M, (maxWidth, maxHeight))
    return warped

# --- Part 2: Main Data Processing Function (combines logic from a.txt and b.txt) ---

def create_train_val_split_and_process(base_image_dir, base_gt_dir, output_parent_dir, train_ratio=0.85, random_seed=42):
    """
    Splits full images into train/val, processes them, and saves cropped words
    into the specified folder structure.
    """
    print("Starting data preparation...")
    
    # --- Step 1: Define output paths and create directories ---
    train_output_dir = os.path.join(output_parent_dir, 'en_train_filtered')
    val_output_dir = os.path.join(output_parent_dir, 'en_val')
    
    train_results_dir = os.path.join(train_output_dir, '__results___files')
    val_results_dir = os.path.join(val_output_dir, '__results___files')

    # Clean and create directories
    for d in [train_results_dir, val_results_dir]:
        if os.path.exists(d):
            shutil.rmtree(d) # Remove old data to ensure a clean slate
        os.makedirs(d)
        
    print(f"Created directories:\n - {train_results_dir}\n - {val_results_dir}")

    # --- Step 2: Get all file paths and split them into train and validation sets ---
    all_filepaths = get_all_filepaths(base_image_dir, base_gt_dir)
    
    if not all_filepaths:
        print("Error: No matching image and annotation files found.")
        return

    train_files, val_files = train_test_split(
        all_filepaths, 
        train_size=train_ratio, 
        random_state=random_seed
    )
    print(f"Data split complete: {len(train_files)} training files, {len(val_files)} validation files.")

    # --- Step 3: Process each split (train and val) ---
    
    splits_to_process = {
        "train": (train_files, train_results_dir),
        "validation": (val_files, val_results_dir),
    }

    for mode, (file_list, results_dir) in splits_to_process.items():
        print(f"\nProcessing {mode} data...")
        
        label_data = [] # To store [filename, text] for the CSV

        for img_path, gt_path in tqdm(file_list, desc=f"Processing {mode} images"):
            image_cv = cv2.imread(img_path)
            if image_cv is None: continue
            
            height, width, _ = image_cv.shape
            if width * height > 933120000:  
                continue
            
            if width > 65500 or height > 65500:  
                continue
              
            with open(gt_path, 'r', encoding='utf-8') as f:
                gt_data = json.load(f)
            
            # Assuming the annotation format is a list under the 'items' key
            # and each item has 'points' and 'text'. Adjust if your format differs.
            for i, item in enumerate(gt_data.get('items', gt_data.get('bboxes', []))):
                x = item['left'] * width
                y = item['top'] * height
                w = item['width'] * width
                h = item['height'] * height
                left = x
                top = y
                right = x + w
                bottom = y + h
                
                # --- FIX IS HERE ---
                # Create the points array with the correct shape (4, 2)
                points = np.array([
                    [left, top],
                    [right, top],
                    [right, bottom],
                    [left, bottom]
                ], dtype=np.float32)
                
                if right <= left or bottom <= top:
                    # print(f"Некорректные координаты bounding box: left={left}, top={top}, right={right}, bottom={bottom}")
                    continue
                if left < 0 or top < 0 or right > width or bottom > height:
                    # print(f"Bounding box выходит за пределы изображения: left={left}, top={top}, right={right}, bottom={bottom}")
                    continue 
                text = str(item.get('text', item.get('label', '')))

                # --- Validation from a.txt ---
                if not text: continue
                
                cropped_word_img = warp_and_crop(image_cv, points)
                
                if cropped_word_img is None or cropped_word_img.shape[0] == 0 or cropped_word_img.shape[1] == 0:
                    continue

                # --- Saving logic from b.txt ---
                base_name = os.path.splitext(os.path.basename(img_path))[0]
                word_img_name = f"{base_name}_word_{i}.png"
                
                # Save the cropped word image to the correct '__results___files' folder
                cv2.imwrite(os.path.join(results_dir, word_img_name), cropped_word_img)
                
                # Collect data for the CSV file
                label_data.append([word_img_name, text])
        
        # --- Step 4: Write the labels.csv for the current split ---
        if label_data:
            df = pd.DataFrame(label_data, columns=['filename', 'words'])
            output_label_path = os.path.join(results_dir, 'labels.csv')
            df.to_csv(output_label_path, index=False, header=True, encoding='utf-8')
            print(f"Successfully created {len(label_data)} cropped images and labels.csv for {mode} set at:\n {output_label_path}")

    print("\nAll data processing and splitting complete.")

In [None]:
create_train_val_split_and_process(path_train, groundtruth_train_path, './all_data')

Starting data preparation...
Created directories:
 - ./all_data/en_train_filtered/__results___files
 - ./all_data/en_val/__results___files
Data split complete: 20711 training files, 3655 validation files.

Processing train data...


Processing train images: 100%|██████████| 20711/20711 [14:50<00:00, 23.25it/s]


Successfully created 51429 cropped images and labels.csv for train set at:
 ./all_data/en_train_filtered/__results___files/labels.csv

Processing validation data...


Processing validation images: 100%|██████████| 3655/3655 [02:30<00:00, 24.30it/s]

Successfully created 9080 cropped images and labels.csv for validation set at:
 ./all_data/en_val/__results___files/labels.csv

All data processing and splitting complete.





## Fine-Tuning EasyOCR

### rebuild EasyOCR

In [21]:
! git clone https://github.com/JaidedAI/EasyOCR.git

Cloning into 'EasyOCR'...
remote: Enumerating objects: 2750, done.[K
remote: Counting objects: 100% (661/661), done.[K
remote: Compressing objects: 100% (86/86), done.[K
remote: Total 2750 (delta 594), reused 575 (delta 575), pack-reused 2089 (from 1)[K
Receiving objects: 100% (2750/2750), 157.82 MiB | 5.82 MiB/s, done.
Resolving deltas: 100% (1689/1689), done.
Updating files: 100% (313/313), done.


In [22]:
!cp -r ./EasyOCR/trainer ./

In [24]:
files = os.listdir("./trainer")
for file in files:
    !cp -r ./trainer/{file} ./

In [2]:
!rm -r -y ./EasyOCR
!rm -r -y ./trainer

rm: invalid option -- 'y'
Try 'rm --help' for more information.
rm: invalid option -- 'y'
Try 'rm --help' for more information.


### prepare pretrained model

In [25]:
! pip install gdown

Defaulting to user installation because normal site-packages is not writeable
Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Collecting gdown
  Downloading gdown-5.2.0-py3-none-any.whl.metadata (5.8 kB)
Collecting PySocks!=1.5.7,>=1.5.6 (from requests[socks]->gdown)
  Downloading PySocks-1.7.1-py3-none-any.whl.metadata (13 kB)
Downloading gdown-5.2.0-py3-none-any.whl (18 kB)
Downloading PySocks-1.7.1-py3-none-any.whl (16 kB)
Installing collected packages: PySocks, gdown
Successfully installed PySocks-1.7.1 gdown-5.2.0

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0.1[0m[39;49m -> [0m[32;49m25.1.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m


In [26]:
import gdown; 
url = 'https://drive.google.com/uc?id=15D8mC8gwLtkPqV53R_mTJTDPRQpQs-bX' 
output = 'cyrillic_g2.pth' 
gdown.download(url, output, quiet=False)

Downloading...
From: https://drive.google.com/uc?id=15D8mC8gwLtkPqV53R_mTJTDPRQpQs-bX
To: /app/proj/cyrillic_g2.pth
100%|██████████| 15.3M/15.3M [00:00<00:00, 22.4MB/s]


'cyrillic_g2.pth'

### write easyocr config

In [38]:
import yaml
import sys

In [42]:
config_yaml = {
    'number': '0123456789',
    'symbol': "!\"#$%&'()*+,-./:;<=>?@[\\]№_`{|}~ €₽",
    'lang_char': 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyzАБВГДЕЁЖЗИЙКЛМНОПРСТУФХЦЧШЩЪЫЬЭЮЯабвгдеёжзийклмнопрстуфхцчшщъыьэюяЂђЃѓЄєІіЇїЈјЉљЊњЋћЌќЎўЏџҐґҒғҚқҮүҲҳҶҷӀӏӢӣӨөӮӯ',
    # 'data_folder': '/kaggle/working/train/easyocr_train_data',
    'train_data': 'all_data',
    'valid_data': 'all_data/en_val',
    # 'workers': 6,
    # 'batch_size': 64,
    'arch': 'VGG',
    'saved_model': 'cyrillic_g2.pth',
    # 'new_model':  '/kaggle/working/models/easyocr_finetuned/rustitw_recognizer.pth',
    'select_data': 'train',
    'experiment_name': 'easyocr_finetuned',
    'batch_ratio': '1',
    'total_data_usage_ratio':  1.0,
    'batch_max_length': 2048,

    # --- Other important parameters ---
    'workers': 6,  # Start with a safe number like 2 for Kaggle
    'batch_size': 64, # Start with a safe batch size
    'imgH': 500,      # Standard height for OCR, do not use 65500
    'imgW': 600,     # Standard width for OCR, it will be handled dynamically
    
    # 'imgH': 65500,
    # 'imgW': 65500,
    'rgb': False,
    'FT': True,
    'optim': False,
    'lr': 0.0005,
    'beta1': 0.9,
    'rho': 0.95,
    'eps': 0.00000001,
    'grad_clip': 5,
    'contrast_adjust': False,
    'sensitive': True,
    'PAD': True,
    'contrast_adjust': 0.0,
    'data_filtering_off': False,
    'Transformation': 'None',
    'FeatureExtraction': 'VGG',
    'SequenceModeling': 'BiLSTM',
    'Prediction': 'CTC',
    'num_fiducial': 20,
    'input_channel': 1,
    'output_channel': 256,
    'hidden_size': 256,
    'decode': 'greedy',
    'freeze_FeatureFxtraction': False,
    'freeze_SequenceModeling': False,
    'new_prediction': False,
    'num_epoch': 5,
    'num_iter':5000,
    'valInterval': 500,
    'lang_list': ['ru', 'en'],
    'network_params' : {
                 'input_channel': 1,
                 'output_channel': 256,
                 'hidden_size': 256
                 },
                 
    'character_list': "0123456789!\#$%&'()*+,-./:;<=>?@[\\]№_`{|}~ €₽ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyzАБВГДЕЁЖЗИЙКЛМНОПРСТУФХЦЧШЩЪЫЬЭЮЯабвгдеёжзийклмнопрстуфхцчшщъыьэюяЂђЃѓЄєІіЇїЈјЉљЊњЋћЌќЎўЏџҐґҒғҚқҮүҲҳҶҷӀӏӢӣӨөӮӯ"
}
with open('./config.yaml', 'w') as file:
    yaml.dump(config_yaml, file, allow_unicode=True)

### overwrite EasyOCR files

In [None]:
%%writefile ./train.py

import os
import sys
import time
import random
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.nn.init as init
import torch.optim as optim
import torch.utils.data
from torch.cuda.amp import autocast, GradScaler
import numpy as np

# --- FIX: Explicit relative imports to avoid conflicts ---
from utils import CTCLabelConverter, AttnLabelConverter, Averager
from dataset import hierarchical_dataset, AlignCollate, Batch_Balanced_Dataset
from model import Model
from test import validation
# --------------------------------------------------------

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def count_parameters(model):
    # ... (this function is fine) ...
    print("Modules, Parameters")
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad: continue
        param = parameter.numel()
        total_params+=param
        print(name, param)
    print(f"Total Trainable Params: {total_params}")
    return total_params

def train(opt, show_number = 2, amp=False):
    # ... (dataset preparation is fine) ...
    if not opt.data_filtering_off:
        print('Filtering the images containing characters which are not in opt.character')
        print('Filtering the images whose label is longer than opt.batch_max_length')

    opt.select_data = opt.select_data.split('-')
    opt.batch_ratio = opt.batch_ratio.split('-')
    train_dataset = Batch_Balanced_Dataset(opt)
    
    log = open(f'./saved_models/{opt.experiment_name}/log_dataset.txt', 'a', encoding="utf8")
    AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD, contrast_adjust=opt.contrast_adjust)
    valid_dataset, valid_dataset_log = hierarchical_dataset(root=opt.valid_data, opt=opt)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset, batch_size=min(32, opt.batch_size),
        shuffle=True,
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_valid, pin_memory=True) # Removed prefetch_factor for broader compatibility
    log.write(valid_dataset_log)
    print('-' * 80)
    log.write('-' * 80 + '\n')
    log.close()
    
    # ... (model configuration is fine) ...
    if 'CTC' in opt.Prediction:
        converter = CTCLabelConverter(opt.character)
    else:
        converter = AttnLabelConverter(opt.character)
    opt.num_class = len(converter.character)

    if opt.rgb:
        opt.input_channel = 3
    model = Model(opt)
    print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel,
          opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction,
          opt.SequenceModeling, opt.Prediction)

    if opt.saved_model != '':
        print(f'loading pretrained model from {opt.saved_model}')
        # Note: DataParallel is often not needed if device_map is used, but we'll keep it for compatibility with the script
        model = torch.nn.DataParallel(model).to(device)
        model.load_state_dict(torch.load(opt.saved_model, map_location=device), strict=False)
    else:
        # weight initialization
        for name, param in model.named_parameters():
            if 'localization_fc2' in name:
                print(f'Skip {name} as it is already initialized')
                continue
            try:
                if 'bias' in name:
                    init.constant_(param, 0.0)
                elif 'weight' in name:
                    init.kaiming_normal_(param)
            except Exception as e:
                if 'weight' in name:
                    param.data.fill_(1)
                continue
        model = torch.nn.DataParallel(model).to(device)
    
    model.train() 
    print("Model:")
    # print(model) # Commented out to reduce log spam
    count_parameters(model)
    
    # ... (setup loss is fine) ...
    if 'CTC' in opt.Prediction:
        criterion = torch.nn.CTCLoss(zero_infinity=True).to(device)
    else:
        criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device)
    loss_avg = Averager()
    
    # ... (optimizer setup is fine) ...
    filtered_parameters = []
    params_num = []
    for p in filter(lambda p: p.requires_grad, model.parameters()):
        filtered_parameters.append(p)
        params_num.append(np.prod(p.size()))
    print('Trainable params num : ', sum(params_num))
    
    if opt.optim=='adam':
        optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999))
    else:
        optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps)
    print("Optimizer:")
    print(optimizer)
    
    # ... (final options logging is fine) ...
    with open(f'./saved_models/{opt.experiment_name}/opt.txt', 'a', encoding="utf8") as opt_file:
        opt_log = '------------ Options -------------\n'
        args = vars(opt)
        for k, v in args.items():
            opt_log += f'{str(k)}: {str(v)}\n'
        opt_log += '---------------------------------------\n'
        print(opt_log)
        opt_file.write(opt_log)
    
    # ... (start training logic is fine) ...
    start_iter = 0
    if opt.saved_model != '':
        try:
            start_iter = int(opt.saved_model.split('_')[-1].split('.')[0])
            print(f'continue to train, start_iter: {start_iter}')
        except:
            pass

    start_time = time.time()
    best_accuracy = -1
    best_norm_ED = -1
    i = start_iter

    # --- FIX 1: Correct GradScaler initialization ---
    scaler = GradScaler()
    # ----------------------------------------------
    
    t1 = time.time()
        
    while(True):
        # ... (train part is mostly fine, but we'll ensure device placement) ...
        image_tensors, labels = train_dataset.get_batch()
        image = image_tensors.to(device)
        text, length = converter.encode(labels, batch_max_length=opt.batch_max_length)
        batch_size = image.size(0)
        
        optimizer.zero_grad(set_to_none=True)
        
        if amp:
            # ... (amp logic) ...
            pass
        else:
            if 'CTC' in opt.Prediction:
                preds = model(image, text).log_softmax(2)
                preds_size = torch.IntTensor([preds.size(1)] * batch_size)
                preds = preds.permute(1, 0, 2)
                torch.backends.cudnn.enabled = False
                # --- FIX 2: Ensure all tensors for loss are on the correct device ---
                cost = criterion(preds, text.to(device), preds_size.to(device), length.to(device))
                # --------------------------------------------------------------------
                torch.backends.cudnn.enabled = True
            else:
                # ... (attention logic) ...
                pass
            cost.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip) 
            optimizer.step()
        loss_avg.add(cost)

        # validation part
        if (i % opt.valInterval == 0) and (i!=0):
            print('training time: ', time.time()-t1)
            t1=time.time()
            elapsed_time = time.time() - start_time
            print(f"[{i}/{opt.num_iter}]... Running validation...")
            with open(f'./saved_models/{opt.experiment_name}/log_train.txt', 'a', encoding="utf8") as log:
                model.eval()
                with torch.no_grad():
                    # --- FIX 3: Pass the `device` object to the validation function ---
                    valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels,\
                    infer_time, length_of_data = validation(model, criterion, valid_loader, converter, opt, device)
                    # -----------------------------------------------------------------
                model.train()
                loss_log = f'[{i}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}'
                loss_avg.reset()
                current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.4f}'
                # keep best accuracy model (on valid dataset)
                if current_accuracy > best_accuracy:
                    best_accuracy = current_accuracy
                    torch.save(model.state_dict(), f'./best_accuracy.pth')
                if current_norm_ED > best_norm_ED:
                    best_norm_ED = current_norm_ED
                    torch.save(model.state_dict(), f'./best_norm_ED.pth')
                best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.4f}'

                loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}'
                print(loss_model_log)
                log.write(loss_model_log + '\n')

                # show some predicted results
                dashed_line = '-' * 80
                head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F'
                predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n'
                
                #show_number = min(show_number, len(labels))
                
                start = random.randint(0,len(labels) - show_number )    
                for gt, pred, confidence in zip(labels[start:start+show_number], preds[start:start+show_number], confidence_score[start:start+show_number]):
                    if 'Attn' in opt.Prediction:
                        gt = gt[:gt.find('[s]')]
                        pred = pred[:pred.find('[s]')]

                    predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n'
                predicted_result_log += f'{dashed_line}'
                print(predicted_result_log)
                log.write(predicted_result_log + '\n')
                print('validation time: ', time.time()-t1)
                t1=time.time()
        # save model per 1e+4 iter.
        if (i + 1) % 1000 == 0:
            torch.save(
                model.state_dict(), f'./saved_models/{opt.experiment_name}/iter_{i+1}.pth')
        
        if i == opt.num_iter:
            print('end the training')
            break
        i += 1

Overwriting ./train.py


In [31]:
%%writefile ./test.py

import os
import time
import string
import argparse
import re

import torch
import torch.backends.cudnn as cudnn
import torch.utils.data
import torch.nn.functional as F
import numpy as np
from nltk.metrics.distance import edit_distance

# --- FIX: Explicit relative imports ---
from utils import CTCLabelConverter, AttnLabelConverter, Averager
from dataset import hierarchical_dataset, AlignCollate
from model import Model

def validation(model, criterion, evaluation_loader, converter, opt, device): # <-- FIX: Accept `device`
    """ validation or evaluation """
    n_correct = 0
    norm_ED = 0
    length_of_data = 0
    infer_time = 0
    valid_loss_avg = Averager()

    for i, (image_tensors, labels) in enumerate(evaluation_loader):
        batch_size = image_tensors.size(0)
        length_of_data = length_of_data + batch_size
        
        # --- FIX: Move image tensors to the correct device ---
        image = image_tensors.to(device)
        # -----------------------------------------------------

        # For max length prediction
        length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device)
        text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device)

        text_for_loss, length_for_loss = converter.encode(labels, batch_max_length=opt.batch_max_length)

        start_time = time.time()
        if 'CTC' in opt.Prediction:
            preds = model(image, text_for_pred).log_softmax(2)
            forward_time = time.time() - start_time

            # Calculate evaluation loss for CTC deocder.
            preds_size = torch.IntTensor([preds.size(1)] * batch_size)
            # permute 'preds' to use CTCloss format
            # --- FIX: Ensure all tensors for loss are on the correct device ---
            cost = criterion(preds.permute(1, 0, 2), text_for_loss.to(device), preds_size.to(device), length_for_loss.to(device))
            # -----------------------------------------------------------------

            if opt.decode == 'greedy':
                # Select max probabilty (greedy decoding) then decode index to character
                _, preds_index = preds.max(2)
                preds_index = preds_index.view(-1)
                preds_str = converter.decode_greedy(preds_index.data, preds_size.data)
            elif opt.decode == 'beamsearch':
                preds_str = converter.decode_beamsearch(preds, beamWidth=2)

        else:
            preds = model(image, text_for_pred, is_train=False)
            forward_time = time.time() - start_time

            preds = preds[:, :text_for_loss.shape[1] - 1, :]
            target = text_for_loss[:, 1:]  # without [GO] Symbol
            cost = criterion(preds.contiguous().view(-1, preds.shape[-1]), target.contiguous().view(-1))

            # select max probabilty (greedy decoding) then decode index to character
            _, preds_index = preds.max(2)
            preds_str = converter.decode(preds_index, length_for_pred)
            labels = converter.decode(text_for_loss[:, 1:], length_for_loss)


        infer_time += forward_time
        valid_loss_avg.add(cost)

        # calculate accuracy & norm ED
        preds_prob = F.softmax(preds, dim=2)
        preds_max_prob, _ = preds_prob.max(dim=2)
        confidence_score_list = []
        
        for gt, pred, pred_max_prob in zip(labels, preds_str, preds_max_prob):
            if 'Attn' in opt.Prediction:
                gt = gt[:gt.find('[s]')]
                pred_EOS = pred.find('[s]')
                pred = pred[:pred_EOS]  # prune after "end of sentence" token ([s])
                pred_max_prob = pred_max_prob[:pred_EOS]

            if pred == gt:
                n_correct += 1

            '''
            (old version) ICDAR2017 DOST Normalized Edit Distance https://rrc.cvc.uab.es/?ch=7&com=tasks
            "For each word we calculate the normalized edit distance to the length of the ground truth transcription." 
            if len(gt) == 0:
                norm_ED += 1
            else:
                norm_ED += edit_distance(pred, gt) / len(gt)
            '''
            
            # ICDAR2019 Normalized Edit Distance 
            if len(gt) == 0 or len(pred) ==0:
                norm_ED += 0
            elif len(gt) > len(pred):
                norm_ED += 1 - edit_distance(pred, gt) / len(gt)
            else:
                norm_ED += 1 - edit_distance(pred, gt) / len(pred)

            # calculate confidence score (= multiply of pred_max_prob)
            try:
                confidence_score = pred_max_prob.cumprod(dim=0)[-1]
            except:
                confidence_score = 0  # for empty pred case, when prune after "end of sentence" token ([s])
            confidence_score_list.append(confidence_score)
            # print(pred, gt, pred==gt, confidence_score)

    accuracy = n_correct / float(length_of_data) * 100
    norm_ED = norm_ED / float(length_of_data) # ICDAR2019 Normalized Edit Distance

    return valid_loss_avg.val(), accuracy, norm_ED, preds_str, confidence_score_list, labels, infer_time, length_of_data

# The rest of the file can be included if you plan to run test.py directly
# but for the purpose of fixing the training loop, this is sufficient.

Overwriting ./test.py


In [33]:
%%writefile ./dataset.py


import os
import sys
import re
import six
import math
import torch
import pandas  as pd

from natsort import natsorted
from PIL import Image
import numpy as np
from torch.utils.data import Dataset, ConcatDataset, Subset
from torch._utils import _accumulate
import torchvision.transforms as transforms

def contrast_grey(img):
    high = np.percentile(img, 90)
    low  = np.percentile(img, 10)
    return (high-low)/(high+low), high, low

def adjust_contrast_grey(img, target = 0.4):
    contrast, high, low = contrast_grey(img)
    if contrast < target:
        img = img.astype(int)
        ratio = 200./(high-low)
        img = (img - low + 25)*ratio
        img = np.maximum(np.full(img.shape, 0) ,np.minimum(np.full(img.shape, 255), img)).astype(np.uint8)
    return img


class Batch_Balanced_Dataset(object):

    def __init__(self, opt):
        """
        Modulate the data ratio in the batch.
        For example, when select_data is "MJ-ST" and batch_ratio is "0.5-0.5",
        the 50% of the batch is filled with MJ and the other 50% of the batch is filled with ST.
        """
        log = open(f'./saved_models/{opt.experiment_name}/log_dataset.txt', 'a')
        dashed_line = '-' * 80
        print(dashed_line)
        log.write(dashed_line + '\n')
        print(f'dataset_root: {opt.train_data}\nopt.select_data: {opt.select_data}\nopt.batch_ratio: {opt.batch_ratio}')
        log.write(f'dataset_root: {opt.train_data}\nopt.select_data: {opt.select_data}\nopt.batch_ratio: {opt.batch_ratio}\n')
        assert len(opt.select_data) == len(opt.batch_ratio)

        _AlignCollate = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD, contrast_adjust = opt.contrast_adjust)
        self.data_loader_list = []
        self.dataloader_iter_list = []
        batch_size_list = []
        Total_batch_size = 0
        for selected_d, batch_ratio_d in zip(opt.select_data, opt.batch_ratio):
            _batch_size = max(round(opt.batch_size * float(batch_ratio_d)), 1)
            print(dashed_line)
            log.write(dashed_line + '\n')
            _dataset, _dataset_log = hierarchical_dataset(root=opt.train_data, opt=opt, select_data=[selected_d])
            total_number_dataset = len(_dataset)
            log.write(_dataset_log)

            """
            The total number of data can be modified with opt.total_data_usage_ratio.
            ex) opt.total_data_usage_ratio = 1 indicates 100% usage, and 0.2 indicates 20% usage.
            See 4.2 section in our paper.
            """
            number_dataset = int(total_number_dataset * float(opt.total_data_usage_ratio))
            dataset_split = [number_dataset, total_number_dataset - number_dataset]
            indices = range(total_number_dataset)
            _dataset, _ = [Subset(_dataset, indices[offset - length:offset])
                           for offset, length in zip(_accumulate(dataset_split), dataset_split)]
            selected_d_log = f'num total samples of {selected_d}: {total_number_dataset} x {opt.total_data_usage_ratio} (total_data_usage_ratio) = {len(_dataset)}\n'
            selected_d_log += f'num samples of {selected_d} per batch: {opt.batch_size} x {float(batch_ratio_d)} (batch_ratio) = {_batch_size}'
            print(selected_d_log)
            log.write(selected_d_log + '\n')
            batch_size_list.append(str(_batch_size))
            Total_batch_size += _batch_size

            _data_loader = torch.utils.data.DataLoader(
                _dataset, batch_size=_batch_size,
                shuffle=True,
                num_workers=int(opt.workers), #prefetch_factor=2,persistent_workers=True,
                collate_fn=_AlignCollate, pin_memory=True)
            self.data_loader_list.append(_data_loader)
            self.dataloader_iter_list.append(iter(_data_loader))

        Total_batch_size_log = f'{dashed_line}\n'
        batch_size_sum = '+'.join(batch_size_list)
        Total_batch_size_log += f'Total_batch_size: {batch_size_sum} = {Total_batch_size}\n'
        Total_batch_size_log += f'{dashed_line}'
        opt.batch_size = Total_batch_size

        print(Total_batch_size_log)
        log.write(Total_batch_size_log + '\n')
        log.close()

    def get_batch(self):
        balanced_batch_images = []
        balanced_batch_texts = []

        for i, data_loader_iter in enumerate(self.dataloader_iter_list):
            try:
                image,text = next(iter(data_loader_iter))
                balanced_batch_images.append(image)
                balanced_batch_texts += text
            except StopIteration:
                self.dataloader_iter_list[i] = iter(self.data_loader_list[i])
                image, text = next(iter(self.dataloader_iter_list[i]))
                balanced_batch_images.append(image)
                balanced_batch_texts += text
            except ValueError:
                pass

        balanced_batch_images = torch.cat(balanced_batch_images, 0)

        return balanced_batch_images, balanced_batch_texts


def hierarchical_dataset(root, opt, select_data='/'):
    """ select_data='/' contains all sub-directory of root directory """
    dataset_list = []
    dataset_log = f'dataset_root:    {root}\t dataset: {select_data[0]}'
    print(dataset_log)
    dataset_log += '\n'
    for dirpath, dirnames, filenames in os.walk(root+'/'):
        if not dirnames:
            select_flag = False
            for selected_d in select_data:
                if selected_d in dirpath:
                    select_flag = True
                    break

            if select_flag:
                dataset = OCRDataset(dirpath, opt)
                sub_dataset_log = f'sub-directory:\t/{os.path.relpath(dirpath, root)}\t num samples: {len(dataset)}'
                print(sub_dataset_log)
                dataset_log += f'{sub_dataset_log}\n'
                dataset_list.append(dataset)

    concatenated_dataset = ConcatDataset(dataset_list)

    return concatenated_dataset, dataset_log

class OCRDataset(Dataset):

    def __init__(self, root, opt):

        self.root = root
        self.opt = opt
        print(root)
        self.df = pd.read_csv(os.path.join(root,'labels.csv'), engine='python', usecols=['filename', 'words'], keep_default_na=False)
        self.nSamples = len(self.df)

        if self.opt.data_filtering_off:
            self.filtered_index_list = [index + 1 for index in range(self.nSamples)]
        else:
            self.filtered_index_list = []
            for index in range(self.nSamples):
                label = self.df.at[index,'words']
                try:
                    if len(label) > self.opt.batch_max_length:
                        continue
                except:
                    print(label)
                out_of_char = f'[^{self.opt.character}]'
                if re.search(out_of_char, label.lower()):
                    continue
                self.filtered_index_list.append(index)
            self.nSamples = len(self.filtered_index_list)

    def __len__(self):
        return self.nSamples

    def __getitem__(self, index):
        index = self.filtered_index_list[index]
        img_fname = self.df.at[index,'filename']
        img_fpath = os.path.join(self.root, img_fname)
        label = self.df.at[index,'words']

        if self.opt.rgb:
            img = Image.open(img_fpath).convert('RGB')  # for color image
        else:
            img = Image.open(img_fpath).convert('L')

        if not self.opt.sensitive:
            label = label.lower()

        # We only train and evaluate on alphanumerics (or pre-defined character set in train.py)
        out_of_char = f'[^{self.opt.character}]'
        label = re.sub(out_of_char, '', label)

        return (img, label)

class ResizeNormalize(object):

    def __init__(self, size, interpolation=Image.BICUBIC):
        self.size = size
        self.interpolation = interpolation
        self.toTensor = transforms.ToTensor()

    def __call__(self, img):
        img = img.resize(self.size, self.interpolation)
        img = self.toTensor(img)
        img.sub_(0.5).div_(0.5)
        return img


class NormalizePAD(object):

    def __init__(self, max_size, PAD_type='right'):
        self.toTensor = transforms.ToTensor()
        self.max_size = max_size
        self.max_width_half = math.floor(max_size[2] / 2)
        self.PAD_type = PAD_type

    def __call__(self, img):
        img = self.toTensor(img)
        img.sub_(0.5).div_(0.5)
        c, h, w = img.size()
        Pad_img = torch.FloatTensor(*self.max_size).fill_(0)
        Pad_img[:, :, :w] = img  # right pad
        if self.max_size[2] != w:  # add border Pad
            Pad_img[:, :, w:] = img[:, :, w - 1].unsqueeze(2).expand(c, h, self.max_size[2] - w)

        return Pad_img


class AlignCollate(object):

    def __init__(self, imgH=32, imgW=100, keep_ratio_with_pad=False, contrast_adjust = 0.):
        self.imgH = imgH
        self.imgW = imgW
        self.keep_ratio_with_pad = keep_ratio_with_pad
        self.contrast_adjust = contrast_adjust

    def __call__(self, batch):
        batch = filter(lambda x: x is not None, batch)
        images, labels = zip(*batch)

        if self.keep_ratio_with_pad:  # same concept with 'Rosetta' paper
            resized_max_w = self.imgW
            input_channel = 3 if images[0].mode == 'RGB' else 1
            transform = NormalizePAD((input_channel, self.imgH, resized_max_w))

            resized_images = []
            for image in images:
                w, h = image.size

                #### augmentation here - change contrast
                if self.contrast_adjust > 0:
                    image = np.array(image.convert("L"))
                    image = adjust_contrast_grey(image, target = self.contrast_adjust)
                    image = Image.fromarray(image, 'L')

                ratio = w / float(h)
                if math.ceil(self.imgH * ratio) > self.imgW:
                    resized_w = self.imgW
                else:
                    resized_w = math.ceil(self.imgH * ratio)

                resized_image = image.resize((resized_w, self.imgH), Image.BICUBIC)
                resized_images.append(transform(resized_image))
                # resized_image.save('./image_test/%d_test.jpg' % w)

            image_tensors = torch.cat([t.unsqueeze(0) for t in resized_images], 0)

        else:
            transform = ResizeNormalize((self.imgW, self.imgH))
            image_tensors = [transform(image) for image in images]
            image_tensors = torch.cat([t.unsqueeze(0) for t in image_tensors], 0)

        return image_tensors, labels


def tensor2im(image_tensor, imtype=np.uint8):
    image_numpy = image_tensor.cpu().float().numpy()
    if image_numpy.shape[0] == 1:
        image_numpy = np.tile(image_numpy, (3, 1, 1))
    image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
    return image_numpy.astype(imtype)


def save_image(image_numpy, image_path):
    image_pil = Image.fromarray(image_numpy)
    image_pil.save(image_path)

Overwriting ./dataset.py


In [34]:
import fileinput
import sys

# Define the path to the problematic file
file_path = 'dataset.py'

# The code for the missing _accumulate function
accumulate_code = """
import operator
def _accumulate(iterable, fn=operator.add):
    'Return running totals'
    # _accumulate([1,2,3,4,5]) --> 1 3 6 10 15
    # _accumulate([1,2,3,4,5], operator.mul) --> 1 2 6 24 120
    it = iter(iterable)
    try:
        total = next(it)
    except StopIteration:
        return
    yield total
    for element in it:
        total = fn(total, element)
        yield total
"""

# Flag to check if we've added our code yet
code_added = False

print(f"Patching file: {file_path}")

# Use fileinput to modify the file in-place
for line in fileinput.input(file_path, inplace=True):
    # --- Part 1: Comment out the broken import ---
    if "from torch._utils import _accumulate" in line:
        sys.stdout.write("# " + line) # Comment out the line
        continue # Skip to the next line

    # --- Part 2: Add our function definition ---
    # We add our code block right after the last import statement
    if line.strip().startswith("import") and not code_added:
        sys.stdout.write(line) # Write the import line
        # After writing the last import, add our custom function
        sys.stdout.write(accumulate_code + "\n")
        code_added = True
        continue
        
    # Write all other lines back to the file as they were
    sys.stdout.write(line)

print("Patch complete. The file has been modified.")
print("\nYou can now re-run your training command.")

Patching file: dataset.py
Patch complete. The file has been modified.

You can now re-run your training command.


In [None]:
# ! pip install natsort

Defaulting to user installation because normal site-packages is not writeable
Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Collecting natsort
  Downloading natsort-8.4.0-py3-none-any.whl.metadata (21 kB)
Downloading natsort-8.4.0-py3-none-any.whl (38 kB)
Installing collected packages: natsort
Successfully installed natsort-8.4.0

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0.1[0m[39;49m -> [0m[32;49m25.1.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m


In [None]:
# ! pip install nltk

Defaulting to user installation because normal site-packages is not writeable
Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Collecting nltk
  Downloading nltk-3.9.1-py3-none-any.whl.metadata (2.9 kB)
Downloading nltk-3.9.1-py3-none-any.whl (1.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.5/1.5 MB[0m [31m12.6 MB/s[0m eta [36m0:00:00[0m [36m0:00:01[0m
[?25hInstalling collected packages: nltk
Successfully installed nltk-3.9.1

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0.1[0m[39;49m -> [0m[32;49m25.1.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m


## train EasyOCR

In [26]:
import sys
# sys.path.append('/kaggle/working/EasyOCR/')
from utils import AttrDict
from train import train
import torch.backends.cudnn as cudnn
cudnn.benchmark = True
cudnn.deterministic = False

In [None]:
def get_config(file_path):
    with open(file_path, 'r', encoding="utf8") as stream:
        opt = yaml.safe_load(stream)
    opt = AttrDict(opt)
    if opt.lang_char == 'None':
        characters = ''
        for data in opt['select_data'].split('-'):
            csv_path = os.path.join(opt['train_data'], data, 'labels.csv')
            df = pd.read_csv(csv_path, sep=';', engine='python', usecols=['filename', 'words'], keep_default_na=False)
            all_char = ''.join(df['words'])
            characters += ''.join(set(all_char))
        characters = sorted(set(characters))
        opt.character= ''.join(characters)
    else:
        opt.character = opt.number + opt.symbol + opt.lang_char
    os.makedirs(f'./saved_models/{opt.experiment_name}', exist_ok=True)
    return opt

#Запускаем обучение
opt = get_config("./config.yaml")
train(opt, amp=False)

## EasyOCR 

In [28]:
os.makedirs('./models/easyocr_finetuned/', exist_ok=True)
! cp ./best_accuracy.pth ./models/easyocr_finetuned/best_accuracy.pth

In [43]:
! cp config.yaml ./models/easyocr_finetuned/best_accuracy.yaml

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [None]:
import easyocr
import numpy as np

class EasyOCREvaluator:
    def __init__(self, model_path='models/easyocr_finetuned/'):
        """
        Initializes the EasyOCR reader with our fine-tuned recognition model.
        """
        print("Initializing EasyOCR with fine-tuned model...")
        # We use the standard detector but specify our own recognizer
        self.reader = easyocr.Reader(lang_list=['ru', 'en'], # Languages
            gpu=True,
            model_storage_directory=model_path,
            user_network_directory=model_path,
            recog_network='best_accuracy' # Tells EasyOCR to look for a custom model
        )

    def predict(self, image_path: str):
        """
        Takes an image path and returns structured predictions.
        """
        # EasyOCR's readtext performs both detection and recognition
        results = self.reader.readtext(image_path)
        
        predictions = []
        for (bbox, text, prob) in results:
            # bbox is [[x1,y1], [x2,y1], [x2,y2], [x1,y2]]
            # Convert it to the 8-point format our evaluator expects
            points = np.array(bbox).flatten().tolist()
            predictions.append({
                'text': text,
                'points': points
            })
        return predictions

## Main

In [None]:
import os
import json
from tqdm import tqdm

In [29]:
import os
import json
from tqdm import tqdm


def run_cer_evaluation(evaluator, filepaths, batch_size=8):
    """
    Runs a full evaluation loop for a given model and calculates the average CER.
    """
    all_cer_scores = []
    
    # Determine if the evaluator supports batching
    supports_batching = hasattr(evaluator, 'predict_batch')

    if supports_batching:
        # Create batches of filepaths for batched evaluators (like Qwen)
        filepath_batches = [filepaths[i:i + batch_size] for i in range(0, len(filepaths), batch_size)]
        
        for batch in tqdm(filepath_batches, desc=f"Evaluating {evaluator.__class__.__name__} in batches"):
            image_paths_batch = [fp[0] for fp in batch]
            
            batch_predictions = evaluator.predict_batch(image_paths_batch)
            
            for i, predictions_for_one_image in enumerate(batch_predictions):
                img_path, gt_path = batch[i]
                _, annotations = load_sample(img_path, gt_path)
                if not annotations: continue
                
                image_cer = calculate_image_cer(predictions_for_one_image, annotations)
                all_cer_scores.append(image_cer)
    else:
        # Process one by one for non-batched evaluators (like EasyOCR)
        for img_path, gt_path in tqdm(filepaths, desc=f"Evaluating {evaluator.__class__.__name__}"):
            image, annotations = load_sample(img_path, gt_path)
            if not annotations: continue
            if hasattr(evaluator, 'inference'):
                # prompt = prompt ='Read all the text in the image. For each section of text, print its bounding box and text in this bounding box in the format: {(x1,y1),(x2,y2)} -- text. where x1'
                # predictions = evaluator.predict(img_path[0])
                predictions = evaluator.inference(image, img_path)
            else:
                predictions = evaluator.predict(img_path)
                
            image_cer = calculate_image_cer(predictions, annotations)
            all_cer_scores.append(image_cer)
            del image
            torch.cuda.empty_cache()
            gc.collect()
    # Calculate the final average CER across the whole dataset
    average_cer = sum(all_cer_scores) / len(all_cer_scores) if all_cer_scores else 0.0
    return average_cer


In [None]:
filepaths = get_all_filepaths(path_test, groundtruth_test_path)
filepaths = filepaths[:50] # Uncomment for a quick test run

# --- Qwen-VL Evaluation ---
print("--- Evaluating Qwen-VL (VLM Baseline) ---")
qwen_evaluator = QwenEvaluator("Qwen/Qwen2.5-VL-7B-Instruct")
qwen_avg_cer = run_cer_evaluation(qwen_evaluator, filepaths, batch_size=8)

# --- EasyOCR Fine-Tuned Evaluation ---
print("\n--- Evaluating EasyOCR (Fine-Tuned Pipeline) ---")
easyocr_evaluator = EasyOCREvaluator(model_path='models/easyocr_finetuned/')
easyocr_avg_cer = run_cer_evaluation(easyocr_evaluator, filepaths)

# --- Final Comparison Report ---
print("\n--- Final Comparison Report (Average Character Error Rate) ---")
print("Lower is better.")
print("-" * 50)
print(f"{'Model':<30} | {'Average CER':<15}")
print("-" * 50)
print(f"{'Qwen-VL (Zero-Shot VLM)':<30} | {qwen_avg_cer:<15.4f}")
print(f"{'EasyOCR (Fine-Tuned)':<30} | {easyocr_avg_cer:<15.4f}")
print("-" * 50)