# Полный запуск решения

## Импорт библиотек

In [4]:
import os
import pickle
import yaml
from typing import List, Dict
from pathlib import Path

import pandas as pd
import numpy as np
from PIL import Image
from sentence_transformers import SentenceTransformer
from loguru import logger
from tqdm import tqdm
from sklearn.neighbors import NearestNeighbors

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
from ultralytics import YOLO
from sklearn.neighbors import NearestNeighbors
import cv2
from dataclasses import dataclass
import abc
from paddleocr import PaddleOCR
from sklearn.metrics import accuracy_score, pairwise_distances

config_path = 'config.yaml'

E0000 00:00:1731189501.109606  842277 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1731189501.112678  842277 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


## Для дальнейшей работы, нужно скачать веса, и положить в папку weights

https://drive.google.com/file/d/154jS1mS7ca43gm1eSu_DhzP7Y7f7eCHU/view?usp=sharing  
https://drive.google.com/file/d/1Rssq6iwe8ExxcSG7hnjz1UZiieUDkwVh/view?usp=sharing

## Получение Эмбедингов

In [5]:
def load_images_from_folder(output_folder: str) -> List[np.ndarray]:
    
    frame_files = sorted(os.listdir(output_folder))
    frames = []
    for frame_file in frame_files:
        frame_path = os.path.join(output_folder, frame_file)
        if os.path.isfile(frame_path) and frame_file.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp')):
            try:
                img = Image.open(frame_path)
                frames.append(np.array(img))
            except Exception as e:
                logger.eror(f"Error opening {frame_path}: {e}")
        else:
            logger.eror(f"Skipping directory or non-image file: {frame_path}")
    
    return frames

def save_embeddings(embeddings, filename, output_folder):
    output_path = Path(output_folder) / f"{filename}.pkl"
    with open(output_path, 'wb') as f:
        pickle.dump(embeddings, f)
    logger.info(f"Saved embeddings to {output_path}")

def vectorize_images(images: List[np.ndarray], model: SentenceTransformer) -> List[np.ndarray]:
    return [model.encode(Image.fromarray(img)) for img in tqdm(images)]

In [6]:
with open(config_path, 'r') as file:
    config = yaml.safe_load(file)
logger.info("Loaded configuration from {}", config_path)

test_images = load_images_from_folder(config['test_images_folder'])
logger.info("Loaded test: {}", config['test_images_folder'])
train_images = load_images_from_folder(config['train_images_folder'])
logger.info("Loaded train: {}", config['train_images_folder'])

model = SentenceTransformer(config['model_name'])
logger.info("Loaded model: {}", config['model_name'])

test_embeddings = vectorize_images(test_images, model)

save_embeddings(test_embeddings, 'test_emb', config['emb_output_folder'])

logger.info("Saved embeddings for test and train images.")


[32m2024-11-10 00:58:26.459[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m3[0m - [1mLoaded configuration from config.yaml[0m
[32m2024-11-10 00:58:26.552[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m6[0m - [1mLoaded test: test/images[0m
[32m2024-11-10 00:58:26.552[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m8[0m - [1mLoaded train: train/images[0m
[32m2024-11-10 00:58:28.178[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m11[0m - [1mLoaded model: clip-ViT-B-16[0m
100%|██████████| 9/9 [00:03<00:00,  2.31it/s]
[32m2024-11-10 00:58:32.072[0m | [1mINFO    [0m | [36m__main__[0m:[36msave_embeddings[0m:[36m22[0m - [1mSaved embeddings to weights/embed/test_emb.pkl[0m
[32m2024-11-10 00:58:32.073[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m17[0m - [1mSaved embeddings for test and train images.[0m


## Посик похожих

In [7]:
def load_embeddings_from_folder(folder: str) -> tuple[List[np.ndarray], List[np.ndarray]]:
    """Загрузка Эмбеддингов"""
    test_embeddings = []
    train_embeddings = []
    for filename in os.listdir(folder):
        emb_path = os.path.join(folder, filename)

        if os.path.isfile(emb_path):
            with open(emb_path, 'rb') as f:
                embedding = pickle.load(f)

                if 'test' in filename.lower():
                    test_embeddings.append(embedding)
                elif 'train' in filename.lower():
                    train_embeddings.append(embedding)

    return test_embeddings, train_embeddings


def load_image_filenames(images_folder: str) -> List[str]:
    """Загружает имена файлов изображений из указанной папки."""
    image_filenames = []
    for filename in sorted(os.listdir(images_folder)):
        if filename.lower().endswith(('png', 'jpg', 'jpeg', 'bmp', 'gif', 'bbox', 'txt')):  # Фильтруем по типу изображения
            image_filenames.append(filename)
    return image_filenames

def find_nearest_neighbors(test_embeddings: List[np.ndarray], 
                           train_embeddings: List[np.ndarray], 
                           n_neighbors: int, 
                           threshold: float) -> List[List[int]]:
    test_embeddings = np.array(test_embeddings)[0]
    train_embeddings = np.array(train_embeddings)[0]
    
    nn = NearestNeighbors(n_neighbors=n_neighbors, algorithm='ball_tree')
    nn.fit(train_embeddings)  
    
    neighbors_indices = []
    for test_emb in test_embeddings:
        distances, indices = nn.kneighbors([test_emb])  
        valid_indices = [idx for dist, idx in zip(distances[0], indices[0]) if dist < threshold]
        
        if valid_indices:
            neighbors_indices.append(valid_indices[0])
        else:
            neighbors_indices.append(None)

    return neighbors_indices


def load_labels(labels_folder: str, file_extension: str, train_filenames: List[str]) -> List[str]:
    """Загружает метки из папки, фильтруя по расширению файла и проверяя, что имя файла присутствует в списке train_filenames."""
    labels = []
    train_filenames_base = [filename.split('.')[0] for filename in train_filenames]
    
    for filename in sorted(os.listdir(labels_folder)):
        if filename.split('.')[-1] == file_extension.lstrip('.') and filename.split('.')[0] in train_filenames_base:
            with open(os.path.join(labels_folder, filename), 'r') as file:
                # Читаем строки и добавляем `\n`, если его нет
                content = ''.join(line if line.endswith('\n') else line + '\n' for line in file.readlines())
                labels.append(content)
                
    return labels

In [10]:
with open(config_path, 'r') as file:
    config = yaml.safe_load(file)
logger.info("Loaded configuration from {}", config_path)

test_embeddings, train_embeddings = load_embeddings_from_folder(config['emb_output_folder'])
logger.info("Embeddings were read")

test_filenames = load_image_filenames(config['test_images_folder'])
train_filenames = load_image_filenames(config['train_images_folder'])

train_labels = load_labels(config['train_labels_folder'], '.txt', train_filenames)
train_labels_with_text = load_labels(config['train_labels_with_text_folder'], '.bbox', train_filenames)
logger.info("train_labels and train_labels_with_text were read")

logger.info("Test image filenames were read")

n_neighbors = config['n_neighbors']
threshold = config['threshold']

nearest_neighbors = find_nearest_neighbors(test_embeddings, train_embeddings, n_neighbors, threshold)
logger.info(f"Neighbours were found - {nearest_neighbors}")
results = []
for test_idx, neighbors in enumerate(nearest_neighbors):
    if neighbors:
        neighbor_idx = neighbors 
        results.append([
            test_filenames[test_idx], 
            train_labels[neighbor_idx], 
            train_labels_with_text[neighbor_idx]
        ])
    else:
        results.append([
            test_filenames[test_idx],  
            None,  
            None   
        ])

[32m2024-11-10 01:03:38.407[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m3[0m - [1mLoaded configuration from config.yaml[0m
[32m2024-11-10 01:03:38.409[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m6[0m - [1mEmbeddings were read[0m
[32m2024-11-10 01:03:38.423[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m13[0m - [1mtrain_labels and train_labels_with_text were read[0m
[32m2024-11-10 01:03:38.424[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m15[0m - [1mTest image filenames were read[0m
[32m2024-11-10 01:03:38.440[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m21[0m - [1mNeighbours were found - [182, 186, 192, 223, 224, 225, 230, 232, 232][0m


# Сегментация

In [11]:
class Segmentation:
    def __init__(self, weights_yolo_path: str):
        self.model = YOLO(weights_yolo_path)
        self.data = {}
        
    def get_segmentation(self) -> None:
        result = self.model(self.image, conf=0.7)
        if len(result[0]):
            object_masks = np.array(result[0].masks.xy, dtype=object)
            self.data["masks"] = object_masks 
        else:
            self.data["masks"] = []

## OCR

In [12]:
class OCR(Segmentation):
    def __init__(self, weights_yolo_path: str, image: Image.Image):
        super().__init__(weights_yolo_path)
        self.ocr = PaddleOCR(use_gpu=True, lang="en")  
        self.image = image
        self.get_segmentation()
        self.crop_one_img()
        self.ocr_one_img()

    def get_mask(self) -> np.array:
        mask = np.zeros((self.image.size[1], self.image.size[0]), dtype=np.uint8)
        for object in self.data["masks"]:
            points = np.array(
                [[x, y] for x, y in object], dtype=np.int32
            )
            mask = cv2.fillPoly(mask, [points], color=255)

        return mask
    
    def crop_one_img(self) -> None:
        mask = (np.array(self.get_mask()) > 0)
        mask = np.expand_dims(mask, axis=-1)
        image = self.image * mask
        if len(self.data["masks"]):
            x = np.array([x for obj in self.data["masks"] for x, y in obj])
            y = np.array([y for obj in self.data["masks"] for x, y in obj])
            x_min, x_max = int(min(x)), int(max(x))
            y_min, y_max = int(min(y)), int(max(y))
            self.data["crop_img"] = image[y_min:y_max, x_min:x_max, :]
        else:
            self.data["crop_img"] = image
    
    def ocr_one_img(self) -> None:
        crop_image = np.array(self.data["crop_img"])
        orig_image = np.array(self.image)

        result = self.ocr.ocr(crop_image, rec=True)
        if result[0]:
            self.data["rec_crop"] = [line[1][0] for line in result[0]]
        else:
            self.data["rec_crop"] = ["None"]

        result = self.ocr.ocr(orig_image, rec=True)
        if result[0]:
            self.data["rec_orig"] = [line[1][0] for line in result[0]]
        else:
            self.data["rec_orig"] = ["None"]

    def get_text(self) -> Dict[str, List[str]]:
        dict_text = {
            "text_orig_img": self.data["rec_orig"],
            "text_crop_img": self.data["rec_crop"],
        }
        return dict_text

In [13]:
@dataclass
class PredictResult:
    raw_text: str = None
    # image in bytes with boxes and text on it
    pred_img: str = None
    # unknow data from excel, None if search_in_data is False
    attribute1: str | None = None
    attribute2: str | None = None
    attribute3: str | None = None

class BaseModel(abc.ABC):

    @abc.abstractmethod
    def predict(
        self, image: Image.Image, search_in_data: bool, dist_threshold: float
    ) -> PredictResult:
        """Get predict from ML OCR Model

        Parameters
        ----------
        images : Image.Image
            List with images to be predicted
        search_in_data : bool
            Flag, if true, get missing data from excel file
        dist_threshold : float
            Distance threshold to cut out unknown images

        Returns
        -------
        PredictResult
            If search_in_data is True, returns full data from excel
            If False, return only OCR result
        """
        pass

In [14]:
class OcrBD():

    def __init__(self) -> None:
        self.model = SentenceTransformer("clip-ViT-B-16")
        self.emb_output_folder = "embeddings_vit"
        self.test_images_folder = "test/images"
        self.train_labels_folder = "train/labels"
        self.train_labels_with_text_folder = "train/labels_with_text"
        self.config_path = "config.yaml"
        with open(self.config_path, 'r') as file:
            self.config = yaml.safe_load(file)
        logger.info("Loaded configuration from {}", self.config_path)


    def load_embeddings_from_folder(self, folder: str) -> tuple[List[np.ndarray], List[np.ndarray]]:
        for filename in os.listdir(folder):
            emb_path = os.path.join(folder, filename)
    
            if os.path.isfile(emb_path):
                with open(emb_path, 'rb') as f:
                    embedding = pickle.load(f)
    
                    if 'test' in filename.lower():
                        test_embeddings = embedding
                    elif 'train' in filename.lower():
                        train_embeddings = embedding
    
        return train_embeddings
        
    def vectorize_img(self, image: Image.Image) -> np.ndarray:
        return [self.model.encode(image)]

    def load_image_filenames(self, images_folder: str) -> List[str]:
        image_filenames = []
        for filename in sorted(os.listdir(images_folder)):
            if filename.lower().endswith(('png', 'jpg', 'jpeg', 'bmp', 'gif', 'bbox', 'txt')): 
                image_filenames.append(filename)
        return image_filenames

    def find_nearest_neighbors(self,
                               test_embeddings: List[np.ndarray], 
                               train_embeddings: List[np.ndarray], 
                               n_neighbors: int, 
                               threshold: float) -> List[List[int]]:
        test_embeddings = np.array(test_embeddings)
        train_embeddings = np.array(train_embeddings)
        nn = NearestNeighbors(n_neighbors=n_neighbors, algorithm='ball_tree')
        nn.fit(train_embeddings)  
        
        neighbors_indices = []
        for test_emb in test_embeddings:
            distances, indices = nn.kneighbors([test_emb])  
            valid_indices = [idx for dist, idx in zip(distances[0], indices[0]) if dist > 0 and dist < threshold]
            
            if valid_indices:
                neighbors_indices.append(valid_indices[0])
            else:
                neighbors_indices.append(None)
    
        return neighbors_indices 
        
    def load_labels(self, labels_folder: str, file_extension: str, train_filenames: List[str]) -> List[str]:
        labels = []
        train_filenames_base = [filename.split('.')[0] for filename in train_filenames]
        
        for filename in sorted(os.listdir(labels_folder)):
            if filename.split('.')[-1] == file_extension.lstrip('.') and filename.split('.')[0] in train_filenames_base:
                with open(os.path.join(labels_folder, filename), 'r') as file:
                    # Читаем строки и добавляем `\n`, если его нет
                    content = ''.join(line if line.endswith('\n') else line + '\n' for line in file.readlines())
                    labels.append(content)
                    
        return labels

    def predict(self, image: Image.Image, search_in_data: bool, dist_threshold: float) -> PredictResult:
        config = self.config
    
        train_embeddings = self.load_embeddings_from_folder(config['emb_output_folder'])
        test_embedings = self.vectorize_img(image)
        
        logger.info("Embeddings were read")
    
        test_filenames = self.load_image_filenames(config['test_images_folder'])
        train_filenames = self.load_image_filenames(config['train_images_folder'])
    
        train_labels = self.load_labels(config['train_labels_folder'], '.txt', train_filenames)
        train_labels_with_text = self.load_labels(config['train_labels_with_text_folder'], '.bbox', train_filenames)
        logger.info("train_labels and train_labels_with_text were read")
        
        logger.info("Test image filenames were read")
    
        n_neighbors = config['n_neighbors']
        threshold = config['threshold']
        nearest_neighbors = self.find_nearest_neighbors(test_embedings, train_embeddings, n_neighbors, threshold)
        logger.info(f"Neighbours were found - {nearest_neighbors}")
        results = []
        for test_idx, neighbors in enumerate(nearest_neighbors):
            if neighbors:
                neighbor_idx = neighbors 
                results.append([
                    test_idx, 
                    train_labels[neighbor_idx], 
                    train_labels_with_text[neighbor_idx],
                    train_filenames[neighbor_idx],
                ])
            else:
                results.append([
                    test_filenames[test_idx],  
                    None,  
                    None,
                    None
                ])

        df = pd.DataFrame(results, columns=['Test_Embedding', 'Label', 'Label_With_Text', 'Neighbour'])
        df["Label_With_Text"] = df["Label_With_Text"].map(lambda x: x[:-1])
        df.to_excel(config['output_excel'], index=False)
        logger.info("Saved results to Excel: {}", config['output_excel'])
        return df

In [15]:
class OcrPipeline(BaseModel):

    def __init__(self) -> None:
        self.weights = './weights/best.pt'

    def predict(self, image: Image.Image, search_in_data: bool, dist_threshold: float) -> PredictResult:
        ocr = OCR(self.weights, image)
        dict_text = ocr.get_text()
        model_neighbour = OcrBD()
        result = model_neighbour.predict(image, search_in_data=False, dist_threshold=10.5)
        res = PredictResult(raw_text = result["Label_With_Text"])
        return res

## Инференс 

In [17]:
model = OcrPipeline()
ans = pd.DataFrame()
for img_path in os.listdir(config["test_images_folder"]):
    image = Image.open(os.path.join(config["test_images_folder"], img_path))
    result = model.predict(image, search_in_data=False, dist_threshold=10.5)
    

FileNotFoundError: [Errno 2] No such file or directory: '../../sergey/runs/segment/train3/weights/best.pt'