In [None]:
import os
import sys
sys.path.insert(0, '../')
import time
import json
import pandas as pd

from colbert.infra import Run, RunConfig, ColBERTConfig
from colbert.data import Queries, Collection
from colbert import Indexer, Searcher

In [None]:
pth_models = "/home/sondors/Documents/price/ColBERT_data/18_categories/test/models_18_categories.csv"
pth_offers = "/home/sondors/Documents/price/ColBERT_data/18_categories/test/triplets_test_18_categories.csv"

df_models = pd.read_csv(pth_models, sep=";")
df_models = df_models.drop(columns=['average_price', 'comment'])
df_offers = pd.read_csv(pth_offers, sep=";")
df_offers = df_offers.drop(columns=['true_match', 'false_match'])

df_offers

In [None]:
ans = [y for x, y in df_offers.groupby('category_id')]
thresh_upper = 10
for i in range(len(ans)):
    #ans[i] = ans[i].drop_duplicates(subset='name', keep="last")
    ans[i] = ans[i].sample(frac=1).sample(frac=1)[:thresh_upper]
df_new = pd.concat(ans, axis=0)
df_new

In [None]:
df_models

In [None]:
categories = [
    "диктофоны, портативные рекордеры",
    "электронные книги",
    "автомобильные телевизоры, мониторы",
    "смарт-часы и браслеты",
    "портативные медиаплееры",
    "чехлы, обложки для гаджетов (телефонов, планшетов etc)",
    "портативная акустика",
    "мобильные телефоны",
    "VR-гарнитуры (VR-очки, шлемы, очки виртуальной реальности, FPV очки для квадрокоптеров)",
    "планшетные компьютеры и мини-планшеты",
    "наушники, гарнитуры, наушники c микрофоном",
    "радиоприемники, радиобудильники, радиочасы",
    "магнитолы",
    "GPS-навигаторы"
    ]

def search(checkpoint, offers, models, nbits, doc_maxlen):
    index_name = f'models.18_categories.{nbits}bits'

    offers = Queries(path=offers)
    models = Collection(path=models)
    f'Loaded {len(offers)} queries and {len(models):,} passages'

    with Run().context(RunConfig(nranks=1, experiment='notebook')):  # nranks specifies the number of GPUs to use.
        config = ColBERTConfig(doc_maxlen=doc_maxlen, nbits=nbits)

        indexer = Indexer(checkpoint=checkpoint, config=config)
        indexer.index(name=index_name, collection=models, overwrite=True)
    indexer.get_index() # You can get the absolute path of the index, if needed.

    with Run().context(RunConfig(experiment='notebook')):
        searcher = Searcher(index=index_name)

    start_time = time.time()
    rankings = searcher.search_all(offers, k=5).todict()
    print(f"time_spent = {time.time() - start_time}\n")
    return rankings

def ranking_index(rankings, category_rankings, df, index_of_first):
    """
    упорядочить (passage_id, rank, score) в rankings согласно изначальным индексам в df_offers
    """
    i = -1
    for index, row in df.iterrows():
        i += 1
        rankings[index] = category_rankings[i]
    return rankings

def df_split(df, col="name"):
    df = df.reset_index(drop=True)

    df1 = pd.DataFrame()
    df1["id"], df1[col] = [i for i in range(len(df))], df[col]
    
    df2 = pd.DataFrame()
    df2["id"], df2["model_id"] = [i for i in range(len(df))], df['model_id']
    return df1, df2

def prepare_tsv(category_offers, category_models, pth_offers, pth_models):
    query, query_id = df_split(category_offers, col="name")
    query.to_csv(pth_offers, sep='\t', header=False, index=False)

    document, document_id = df_split(category_models, col="full_name")
    document.to_csv(pth_models, sep='\t', header=False, index=False)
    
tmp_fld = "/home/sondors/Documents/price/ColBERT/tmp"
ckpt_pth = "/home/sondors/HYPERPARAM/none/2024-01/10/08.58.24/checkpoints/colbert-9555"
pth_dst_json = "/mnt/vdb1/Datasets/18_categories/metrics_data/triples_X1_5epochs.json"
doc_maxlen = 300
nbits = 2   # encode each dimension with 2 bits
rankings = {}
for category in categories[:1]:
    print(category)
    index_of_first = df_models.index[df_models['category_name'] == category].tolist()[0]

    category_models = df_models[df_models['category_name'] == category]
    category_offers = df_new[df_new['category_name'] == category]

    pth_models = f'{tmp_fld}/models.tsv'
    pth_offers = f'{tmp_fld}/offers.tsv'
    prepare_tsv(category_offers, category_models, pth_offers, pth_models)

    category_rankings = search(ckpt_pth, pth_offers, pth_models, nbits, doc_maxlen)
    rankings = ranking_index(rankings, category_rankings, category_offers, index_of_first)

# with open(pth_dst_json, 'w') as fp:
#     json.dump(rankings, fp)