In [1]:
from IPython.display import clear_output

In [2]:
# installing the needed libraries
!pip install easyocr
!pip install pyspellchecker
!pip install sentence-transformers
!pip install annoy
!pip install pytorch-lightning
!pip install pytorch_metric_learning
clear_output()

In [3]:
import cv2
import torch
import pickle
import numpy as np

from re import findall
from os import listdir
from pathlib import Path
from easyocr import Reader
from annoy import AnnoyIndex
from PIL import Image, ImageOps
from torchvision import transforms
from torch.cuda import is_available
from spellchecker import SpellChecker
from sentence_transformers import SentenceTransformer

In [4]:
from visual_model import EmbeddingModel, ImageRetreivalDataModule, ImagesDataset

In [5]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [6]:
# getting the manga data
!unzip drive/MyDrive/data.zip
clear_output()

In [7]:
class MangaTextExtractor:
    '''Simple class to extract text from manga page'''

    def __init__(self, lang='ru', max_distance=1, confidence=0.1, GPU=False):
        '''Get the content of folder with manga'''

        # setting up how confident we should be in the text extraction
        self.confidence = confidence

        # deciding where to infer the model
        self.GPU = GPU if is_available() else False

        # initializing the reader
        self.reader = Reader([lang], gpu=self.GPU)

        # initializing the spellchecker
        self.checker = SpellChecker(language=lang, distance=max_distance)

    def get_text_from_page(self, page_file_name):
        '''Return list of texts for each page in a folder'''

        # getting the raw detection from easyocr
        detection = self.reader.readtext(page_file_name)

        # filtering out some predictions by confidence
        detection = list(filter(lambda det: det[2] > self.confidence, detection))

        # detecting the words presented in lowercase
        words = findall(r'\w+', " ".join(list(map(lambda det: det[1].lower(), detection))))

        # correcting the spellchecking of those words
        misspelled = self.checker.unknown(words)

        # replacing misspelled words with correct versions
        for i in range(len(words)):
            if words[i] in misspelled:
                words[i] = self.checker.correction(words[i])

        # clearing the memory
        del detection, misspelled

        # returning the corrected words
        return words

In [8]:
class MangaFeatureExtractor:
    '''Class to extract the vectorized features from manga directory'''

    def __init__(self, visual_model_path='model.pt'):
        '''Initializing the modules that to vectorizing'''

        self.device = 'cuda' if is_available() else 'cpu'

        # modules that do text embeddings 
        self.text_extractor = MangaTextExtractor(GPU=True)
        self.sbert = SentenceTransformer('sentence-transformers/paraphrase-multilingual-mpnet-base-v2') # embedding

        # autoencoder which is in charge of visual features
        self.model = torch.load(visual_model_path).to(self.device).eval()

    def get_features(self, folder):
        '''
        Calculate features for all manga pages in specified directory
        and return them in two separate dictionaries
        '''

        # recursively searching all the jpg and manga pages in the specified folders
        pages_file_names_jpg = list(map(str, Path(folder).rglob("*.jpg")))
        pages_file_names_png = list(map(str, Path(folder).rglob("*.png")))
        pages_file_names = pages_file_names_jpg + pages_file_names_png

        ###################
        # TEXT EMBEDDINGS #
        ###################

        # getting text from the text extractor for all pages
        manga_texts = list(map(lambda path: self.text_extractor.get_text_from_page(path),
                               pages_file_names))
        
        # getting the list om embeddings of texts on manga pages
        text_embs = list(map(lambda text: self.sbert.encode(" ".join(text)) if text else None,
                             manga_texts))

        # wrapping features into dictionary for later usage
        text_embeddings = dict(zip(pages_file_names, text_embs))

        ###########################
        # VISUAL OBJECTS FEATURES #
        ###########################
        
        # structure to hold the vectors of visual features
        visual_embeddings = []

        # going through each page and getting embeddings for that page
        for page_file_name in pages_file_names:

            # opening the page
            with Image.open(page_file_name).convert('RGB') as page:

                # Tensorifying the image and adding dimension
                page_tensor = transforms.ToTensor()(page).unsqueeze(0).to(self.device)
                
                # Getting embedding
                with torch.no_grad():
                    visual_embeddings.append(self.model(page_tensor)[0].detach().cpu().numpy())

                # clearing the cuda cache
                torch.cuda.empty_cache()

                # manual memory management
                del page_tensor

        # wrapping embeddings into dict
        visual_embeddings = dict(zip(pages_file_names, visual_embeddings))

        return text_embeddings, visual_embeddings

# initializing the manga feature extractor
MFE = MangaFeatureExtractor(visual_model_path='drive/MyDrive/model_visual.pt')

Downloading detection model, please wait. This may take several minutes depending upon your network connection.




Downloading recognition model, please wait. This may take several minutes depending upon your network connection.




Downloading:   0%|          | 0.00/690 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/3.77k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/723 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/122 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/229 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.11G [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/5.07M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/239 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/9.08M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/402 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/190 [00:00<?, ?B/s]

In [20]:
# generating text and visual embeddings for query data
query_text_embs, query_visual_embs = MFE.get_features('data/детектив/Тетрадьсмерти')

In [9]:
# textual and visual embeddings for indexing data
index_text_embs, index_visual_embs = MFE.get_features('data')

In [29]:
class AnnoyIdx:
    def __init__(self, text_idx, img_idx):
        self.word_index, self.word_reverse = self.build_index(text_idx)
        self.img_index,  self.img_reverse  = self.build_index(img_idx) 
        
    def build_index(self, index, trees=10, dist='angular'):
        
        #https://stackoverflow.com/questions/24068306/is-there-a-way-to-remove-nan-from-a-dictionary-filled-with-data
        # Cleaning index
        index = {k: index[k] for k in index if index[k] is not None}
        
        # Extracting dimensionality of data
        index_dim = len(next(iter(index.values())))
        
        # Initializing trees
        result_tree = AnnoyIndex(index_dim, dist)
        result_reverse = {}
        
        # Inserting items
        for idx, key in enumerate(index.keys()):
            result_reverse[idx] = key
            result_tree.add_item(idx, index[key])
          
        # Building trees
        result_tree.build(trees)
            
        return result_tree, result_reverse
    
    def find_similar(self, text_dict, img_dict, top_n=5):

        # removing entries that do not contain features
        text_dict = {file_name:feature for file_name, feature in text_dict.items() if feature is not None}
        img_dict  = {file_name:feature for file_name, feature in img_dict.items()  if feature is not None}

        # initializing the weights for textual features and vusial features
        text_weights   = 0.5
        visual_weights = 1 - text_weights 

        # sctructure to hold prediction with their cosine similarities 
        text_predictions = dict()
        for key in text_dict:
            
            # Extracting vectors
            text = text_dict[key]
            
            # Getting top-k results
            best_words, dist_words = self.word_index.get_nns_by_vector(text, top_n*5, include_distances=True)
            
            # appending prediction with its metric distance to the prediction structure
            for word, dist in zip(best_words, dist_words):

                # getting the name of manga
                manga_name = self.extract_name(self.word_reverse[word])

                # checking if we already have info about that page
                # and rewriting the same title if there's a smaller distance
                if manga_name in text_predictions and text_predictions[manga_name] > dist or manga_name not in text_predictions:
                    text_predictions[manga_name] = dist
        
        # doing the same thing for visual features
        img_predictions = dict()
        for key in img_dict:

            # Extracting vectors
            image = img_dict[key]

            # Getting top-k results
            best_imgs, dist_imgs = self.img_index.get_nns_by_vector(image, top_n*5, include_distances=True)

            # appending prediction with its metric distance to the prediction structure
            for img, dist in zip(best_imgs, dist_imgs):

                # getting the name of manga
                manga_name = self.extract_name(self.img_reverse[img])

                # checking if we already have info about that page
                # and rewriting the same title if there's a smaller distance
                if manga_name in img_predictions and img_predictions[manga_name] > dist or manga_name not in img_predictions:
                    img_predictions[manga_name] = dist

        # going through image predictions and populating overall score
        predictions = dict()
        for manga_name, dist in img_predictions.items():
            
            # if distances in both spaces of textual and visual features exist
            # then add an entry with 0.8*text + 0.2*visual distances
            if manga_name in text_predictions:

                # calculating the weighted distance
                # weighted_dist = text_weights * text_predictions[manga_name] + visual_weights * dist
                weighted_dist = text_predictions[manga_name] * dist

                # adding an entry if distance is shorter than existing or entry with manga doesnt yet exist
                if manga_name in predictions and predictions[manga_name] > weighted_dist or manga_name not in predictions:
                    predictions[manga_name] = weighted_dist
            
            # if no corresponding text feature was found (no text on a page)
            else:

                # treating image distance with text distance (only thing we can trust here)
                # img_dist = 1 / text_weights * dist
                img_dist = dist * dist

                # adding an entry if distance is shorter than existing or entry with manga doesnt yet exist
                if manga_name in predictions and predictions[manga_name] > img_dist or manga_name not in predictions:
                    predictions[manga_name] = img_dist
                
        # getting the list of items (manga_name, distance)
        # and filtering out all zeros in distances (we don't want to recommend the same manga)
        predictions_list = list(filter(lambda name_dist: name_dist[1] > 0, predictions.items()))

        # sorting the lsit of predictions of similarities by the distance
        predictions_list.sort(key=lambda pred: pred[1])

        # retreiving the top n predictions
        return predictions_list[:top_n]
            
    def extract_name(self, name):
        return '/'.join(name.split('/')[1:3])

In [30]:
# build an index
index = AnnoyIdx(index_text_embs, index_visual_embs)

# generate top picks
print('Top picks for DeathNote')
print(*index.find_similar(query_text_embs, query_visual_embs, top_n=10), sep='\n')

Top picks for DeathNote
('ужасы/ТетрадьсмертиЭкспериментальная', 0.1533953723794319)
('научная фантастика/Стальнойалхимик', 0.16293483838236966)
('приключения/Блич', 0.1889845417768221)
('научная фантастика/Отнимайилиотнимутутебя', 0.2084996269359749)
('ужасы/Игралжецов', 0.2106762259842725)
('детектив/Тёмныйдворецкий', 0.2154640968627053)
('боевик/Клинокрассекающийдемонов', 0.21556596992114407)
('махо-сёдзё/ЧараХранители', 0.22230610087026648)
('сверхъестественное/Токийскаястоличнаямагическотехническаяшкола', 0.22382569921371065)
('детектив/НевероятныеПриключенияДжоДжоЧасть7SteelBallRun', 0.23139027879023022)
