In [1]:
%cd ../

/home/chervovn04/Programming/hackathons/2022/agrocode


In [2]:
import cv2
import numpy as np
import pandas as pd
from tqdm import tqdm
import string
import os
import matplotlib.pyplot as plt
from copy import deepcopy
from PIL import Image
from collections import Counter, OrderedDict
%matplotlib inline
%matplotlib widget

from sklearn.model_selection import train_test_split
from sklearn.neighbors import NearestNeighbors

import timm
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from model_structure import *

In [3]:
import pymorphy2
morph = pymorphy2.MorphAnalyzer()

In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f'Using {device} for inference')

Using cpu for inference


In [5]:
def AP(relevance):
    Ps = []
    count = 0
    for i, val in enumerate(relevance):
        i += 1
        if val:
            count += 1
            Ps.append(count / i)
    if not Ps:
        Ps = [0]
    return sum(Ps) / len(Ps)

def mAP(relevances):
    return sum([AP(relevance) for relevance in relevances]) / len(relevances)

In [16]:
def extract_features(path, models_data):    
    with torch.no_grad():
        features = []
        for model, tfm, coef in models_data:
            image = Image.open(path).convert('RGB')
            image = tfm(image)

            image = image[None, :, :, :].to(device)
                        
            pred = np.array(model(image).detach().cpu()).squeeze() * coef
            # pred = np.random.rand(100) * coef
            features.append(pred)
            
        features = np.concatenate(features, axis=0)
        features /= np.linalg.norm(features)
    return features

In [7]:
 def get_emb(data, models_data, base_dir):
    emb = []
    for idx in tqdm(data.idx):
        emb.append(extract_features(f'{base_dir}{idx}.png', models_data))
    emb = np.array(emb)
    return emb

In [8]:
data = pd.read_csv('data/train.csv')
db, que = train_test_split(data, train_size=0.8, random_state=42)

In [9]:
def model_head_fc(path):
    model = torch.load(path, map_location=torch.device(device))
    model.head.fc = Identical()
    return model

def model_head(path):
    model = torch.load(path, map_location=torch.device(device))
    model.head = Identical()
    return model

def model_fc(path):
    model = torch.load(path, map_location=torch.device(device))
    model.fc = Identical()
    return model

In [10]:
models_data = []

model = model_head('weights/beit_finetuned_9.pth')
transform = transforms.Compose([ transforms.Resize(224), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) ])
models_data.append((model, transform, 1))

for model, transform, coef in models_data:
    model.eval()
    model = model.to(device)

In [18]:
def delete_double_spaces(nm):
    new_nm = ''
    for char in nm:
        if char != ' ' or (len(new_nm) and new_nm[-1] != ' '):
            new_nm += char
    return new_nm.strip()

def label_process(nm):
    # 1. to lower
    nm = nm.lower()
    
    # 2. delete useless sequences
    for todel in [';', '&quot']:
        nm = nm.replace(todel, '')
        
    # 3. delete text in parenthesis
    new_nm = ''
    balance = 0
    for char in nm:
        if char == '(':
            balance += 1
        elif char == ')':
            balance = max(0, balance - 1)
        elif balance == 0:
            new_nm += char
    nm = new_nm
    
    # 4. only russian symbols
    new_nm = ''
    for char in nm:
        if char in 'йцукенгшщзхъфывапролджэячсмитьбю ':
            new_nm += char
    nm = delete_double_spaces(new_nm)
    
    # 6. delete useless "words"
    # 7. convert every word to the origin form 
    
    new_nm = ''
    black_list = ['х', 'хх', 'ххх', 'мм', 'с', 'км', 'от', 'в', 'т']
    for word in nm.split(' '):
        if word not in black_list and len(word) > 2:
            word = morph.parse(word)[0].normal_form 
            new_nm += word + ' '
    nm = delete_double_spaces(new_nm)
    
    return nm


In [19]:
class LabelInfo:
    def __init__(self, label):
        self.label = label
        self.ids = []
        self.extra_embs = []
        self.embs = []
        self.mean_emb = None
        self.neigh = None
    
    def add(self, idx, emb):
        self.ids.append(idx)
        self.embs.append(emb)
    
    def process(self):
        for i in range(len(self.embs)):
            self.embs[i] = self.embs[i][None, :]
        self.embs = np.concatenate(self.embs, axis=0)       
        
        self.mean_emb = np.median(self.embs, axis=0)
        
        self.neigh = NearestNeighbors(n_neighbors=min(10, self.embs.shape[0]), metric='cosine')        
        self.neigh.fit(self.embs)
        
        return self.mean_emb
        
    def get_best_ids(self, emb, k):
        if not k:
            return np.array([]), np.array([], dtype=int)
        k = min(k, self.embs.shape[0])
        distances, idxs = self.neigh.kneighbors(emb[None, :], k, return_distance=True)
        distances = distances[0]
        idxs = idxs[0]
        for i in range(k):
            idxs[i] = self.ids[idxs[i]]
            distances[i] += k
        return distances, idxs

        
def get_result(db, que, models_data, db_dir='data/test/', que_dir='data/queries'):
    label2id = {}
    id2label = []
    
    label_info = []
    
    db_emb = get_emb(db, models_data, db_dir)
    que_emb = get_emb(que, models_data, que_dir)
    
    for (_, row), emb in zip(db.iterrows(), db_emb):
        idx = row.idx
        label = row.item_nm
        
        label = label_process(label)
        
        if label not in label2id:
            label2id[label] = len(id2label)
            id2label.append(label)
            label_info.append(LabelInfo(label))
        label_info[label2id[label]].add(idx, emb)
    
    new_label2id = {}
    new_id2label = []
    new_label_info = []
    
    for i in range(len(label_info)):
        if len(label_info[i].embs) > 3:
            new_id2label.append(label_info[i].label)
            new_label2id[label_info[i].label] = len(new_label2id)
            new_label_info.append(label_info[i])
    
    label2id = new_label2id
    id2label = new_id2label
    label_info = new_label_info
    
    general_emb = []
    for i in range(len(label_info)):
        general_emb.append(label_info[i].process())
    general_emb = np.array(general_emb)
        
    general_neigh = NearestNeighbors(n_neighbors=10, metric='cosine')
    general_neigh.fit(general_emb)
    
    folder_ids = general_neigh.kneighbors(que_emb, 10, return_distance=False)
    
    result = []
    for i, que_idx in enumerate(que.idx):
        ids = folder_ids[i]
        distances = np.array([])
        db_ids = np.array([], dtype=int)
        for label_id in ids:
            res = label_info[label_id].get_best_ids(que_emb[i], 10 - distances.shape[0])
            distances = np.concatenate([distances, res[0]])
            db_ids = np.concatenate([db_ids, res[1]])
        for j in range(10):
            result.append((que_idx, db_ids[j], distances[j]))
    
    return result

In [20]:
def evaluate(result, db, que):
    get_label_db = {}
    get_label_que = {}
    for _, row in db.iterrows():
        idx = row.idx
        label = row.item_nm
        get_label_db[idx] = label
    for _, row in que.iterrows():
        idx = row.idx
        label = row.item_nm
        get_label_que[idx] = label
    
    relevant_db = {}
    for que_idx, db_idx, distance in result:
        if que_idx not in relevant_db:
            relevant_db[que_idx] = []
        relevant_db[que_idx].append((db_idx, distance))
    relevances = []
        
    for key, value in relevant_db.items():
        value = sorted(value, key=lambda x : x[1], reverse=True)
        value = [item[0] for item in value]
        que_label = get_label_que[key]
        relevance = [que_label == get_label_db[db_idx] for db_idx in value]
        relevances.append(relevance)
    
    return mAP(relevances)

In [17]:
result = get_result(db, que, models_data, 'data/train/', 'data/train/')

  0%|▏                                        | 12/3900 [00:07<41:05,  1.58it/s]


KeyboardInterrupt: 

In [48]:
evaluate(result, db, que)

0.005441294326625612