In [1]:
%cd ../

/home/chervovn04/Programming/hackathons/2022/agrocode


In [26]:
import cv2
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
import string
import os
import matplotlib.pyplot as plt
from copy import deepcopy
from PIL import Image
from collections import Counter, OrderedDict
%matplotlib inline

from glob import glob
from sklearn.metrics import *
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.applications.efficientnet import EfficientNetB0, EfficientNetB1, preprocess_input
from sklearn.model_selection import train_test_split
from sklearn.neighbors import NearestNeighbors

import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from model_structure import *

In [34]:
class SiameseModel(nn.Module):
    def __init__(self, extractor):
        super().__init__()
        self.extractor = extractor
    def forward(self, image_0, image_1):
        return self.extractor(image_0), self.extractor(image_1)

In [3]:
import pymorphy2
morph = pymorphy2.MorphAnalyzer()

In [4]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cpu'

In [5]:
def AP(relevance):
    Ps = []
    count = 0
    for i, val in enumerate(relevance):
        i += 1
        if val:
            count += 1
            Ps.append(count / i)
    if not Ps:
        Ps = [0]
    return sum(Ps) / len(Ps)

def mAP(relevances):
    return sum([AP(relevance) for relevance in relevances]) / len(relevances)

In [6]:
def delete_double_spaces(nm):
    new_nm = ''
    for char in nm:
        if char != ' ' or (len(new_nm) and new_nm[-1] != ' '):
            new_nm += char
    return new_nm.strip()

def process(nm):
    # 1. to lower
    nm = nm.lower()
    
    # 2. delete useless sequences
    for todel in [';', '&quot']:
        nm = nm.replace(todel, '')
        
    # 3. delete text in parenthesis
    new_nm = ''
    balance = 0
    for char in nm:
        if char == '(':
            balance += 1
        elif char == ')':
            balance = max(0, balance - 1)
        elif balance == 0:
            new_nm += char
    nm = new_nm
    
    # 4. only russian symbols
    new_nm = ''
    for char in nm:
        if char in 'йцукенгшщзхъфывапролджэячсмитьбю ':
            new_nm += char
    nm = delete_double_spaces(new_nm)
    
    # 6. delete useless "words"
    # 7. convert every word to the origin form 
    
    new_nm = ''
    black_list = ['х', 'хх', 'ххх', 'мм', 'с', 'км', 'от', 'в', 'т']
    for word in nm.split(' '):
        if word not in black_list and len(word) > 2:
            word = morph.parse(word)[0].normal_form 
            new_nm += word + ' '
    nm = delete_double_spaces(new_nm)
    
    return nm

In [7]:
def extract_features(path, models_data):
    with torch.no_grad():
        features = []
        for model, tfm in models_data:
            image = Image.open(path).convert('RGB')
            image = tfm(image)

            image = image[None, :, :, :]
                        
            pred = np.array(model(image).detach()).squeeze()

            features.append(pred)
        features = np.concatenate(features, axis=0)
        features /= np.linalg.norm(features)
    return features

In [8]:
def soft_selection():
    pass

def hard_selection():
    pass

def get_results(db, que, db_path, que_path):
    db_emb = []
    for idx in tqdm(test.idx):
        db_emb.append(extract_features(f'{db_path}{idx}.png', models_data))
    db_emb = np.array(db_emb)

    que_emb = []
    for idx in tqdm(queries.idx):
        que_emb.append(extract_features(f'{que_path}{idx}.png', models_data))
    que_emb = np.array(que_emb)


In [9]:
def get_emb(data, models_data, base_dir='data/train/'):
    emb = []
    for idx in tqdm(data.idx):
        emb.append(extract_features(f'{base_dir}{idx}.png', models_data))
    emb = np.array(emb)
    return emb

In [11]:
data = pd.read_csv('data/train.csv')
db, que = train_test_split(data, train_size=0.8)

In [12]:
class ViT_H_Wrapper(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
    def forward(self, x):
        x = self.model(x)
        return x.pooler_output
    
def get_vit_transformer_model(path):
    vit = torch.load(path)
    vit.eval()
    model = ViT_H_Wrapper(vit)
    model.eval()
    return model

def get_vit_model(path):
    model = torch.load(path)
    model.eval()
    model.fc = Identical()
    return model

models_data = [
    (get_vit_model('weights/vit_b16.pt'), transforms.Compose([
        transforms.Resize(224), 
        transforms.CenterCrop(224), 
        transforms.ToTensor(), 
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
    ])),
]

In [14]:
db_emb = get_emb(db, models_data)
que_emb = get_emb(que, models_data)

100%|███████████████████████████████████████| 3900/3900 [20:37<00:00,  3.15it/s]
100%|█████████████████████████████████████████| 976/976 [05:55<00:00,  2.74it/s]


In [87]:
image_similarity_model = torch.load('weights/image_similarity.pt', map_location=torch.device('cpu'))
image_similarity_model.eval()
None

In [83]:
def d(x, y):
    return nn.functional.pairwise_distance(x, y)

def evaluate(db, que, db_emb, que_emb, db_dir, que_dir, final_model, transform, need_processing=0):
    with torch.no_grad():
        first_selection = 20

        neigh = NearestNeighbors(n_neighbors=first_selection, metric='cosine')
        neigh.fit(db_emb)

        distances, idxs = neigh.kneighbors(que_emb, first_selection, return_distance=True)

        relevances = []
        for i in tqdm(range(idxs.shape[0])):
            image_0 = Image.open(f'{que_dir}{que.idx.iloc[i]}.png').convert("RGB")
            image_0 = transform(image_0)[None, :, :, :]

            for j in range(idxs.shape[1]):
                image_1 = Image.open(f'{db_dir}{db.idx.iloc[idxs[i][j]]}.png').convert("RGB")
                image_1 = transform(image_1)[None, :, :, :]

                distances[i][j] = torch.sigmoid(1 - d(*final_model(image_0, image_1)))

            order = np.argsort(distances[i])[::-1]
            order = order[:10]

            name = que.item_nm.iloc[i]
            que_rec = idxs[i][order]

            relevance = []
            for idx in que_rec:
                relevance.append(name == db.item_nm.iloc[idx])
            relevances.append(relevance)
        return mAP(relevances)

In [84]:
transform = transforms.Compose([
        transforms.Resize(224), 
        transforms.CenterCrop(224), 
        transforms.ToTensor(), 
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
    ])

evaluate(db, que, db_emb, que_emb, 'data/train/', 'data/train/', image_similarity_model, transform, image_similarity_model)

100%|█████████████████████████████████████████| 976/976 [22:28<00:00,  1.38s/it]


0.08787053493984114

In [85]:
def evaluate2(db, que, db_emb, que_emb, db_dir, que_dir, final_model, transform, need_processing=0):
    with torch.no_grad():
        first_selection = 10

        neigh = NearestNeighbors(n_neighbors=first_selection, metric='cosine')
        neigh.fit(db_emb)

        distances, idxs = neigh.kneighbors(que_emb, first_selection, return_distance=True)

        relevances = []
        for i in tqdm(range(idxs.shape[0])):
            name = que.item_nm.iloc[i]
            
            que_rec = idxs[i]
            relevance = []
            for idx in que_rec:
                relevance.append(name == db.item_nm.iloc[idx])
            relevances.append(relevance)
        return mAP(relevances)

In [86]:
evaluate2(db, que, db_emb, que_emb, 'data/train/', 'data/train/', image_similarity_model, transform, image_similarity_model)

100%|███████████████████████████████████████| 976/976 [00:00<00:00, 9006.69it/s]


0.07869382421115081