In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
from pathlib import Path

sys.path.append("../")

In [3]:
def init_cfg(cfg_path):
    from hydra import compose, initialize
    from omegaconf import OmegaConf

    with initialize(version_base=None, config_path=str(Path(cfg_path).parent)):
        try:
            cfg = compose(config_name="config", overrides=["+read_filtered_clique2versions=null"])
        except:
            cfg = compose(config_name="config", overrides=["read_filtered_clique2versions=null"])

    checkpoints_folder = "artifacts_" + str(Path(cfg_path).parent.parent.name) + "/model_checkpoints"
    cfg["path_to_fold_checkpoints"] = checkpoints_folder
    cfg["read_filtered_clique2versions"] = None
    cfg["environment"]["device"] = "cuda:2"
    return cfg

In [4]:
import gc
import logging

import hydra
import numpy as np
import torch
import torchinfo
from dotenv import load_dotenv
from hydra.utils import call, instantiate
from omegaconf import DictConfig, OmegaConf
from torch.cuda.amp import GradScaler

from csi.base.utils import init_model, seed_everything
from csi.submission import make_submission
from csi.training.data.dataset import filter_tracks
from csi.training.loop.loop import train_one_epoch
from csi.training.loop.utils import (
    clean_old_content,
    freeze_layers,
    load_fold_checkpoint,
    save_checkpoint,
)
from csi.training.metrics.ndcg import compute_ndcg
import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader

from csi.base.model.predict import predict
from csi.base.utils import batch_to_device
from csi.training.loop.utils import split_by_batch_size
from tqdm import tqdm
from collections import defaultdict

logger = logging.getLogger(__name__)
logger.info = print

from tqdm import tqdm
import pandas as pd

In [5]:
def make_fold_submission(cfg, model, test_loader):
    embeddings = []
    track_ids = []
    for batch in test_loader:
        batch = batch_to_device(batch, cfg.environment.device)
        outs = predict(model, batch)
        embs = outs["embedding"]
        track_ids.append(outs["track_id"].reshape(-1, 1))
        embeddings.append(embs)

    embeddings = torch.vstack(embeddings).detach().cpu().numpy()
    track_ids = torch.vstack(track_ids).detach().cpu().numpy()

    if cfg.nearest_neighbors_search.normalize_embeddings:
        embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)

    CANDIDATES_FROM_FOLDS = 500
    res = {}

    emb_indices = np.arange(len(embeddings))
    mini_batch_size = 5000
    embeddingsT = embeddings.T
    for ind in tqdm(split_by_batch_size(emb_indices, mini_batch_size)):
        track_id_batch = track_ids[ind]
        emb_batch = embeddings[ind]
        similarities = np.dot(emb_batch, embeddingsT)
        top_k_indices = np.argsort(-similarities, axis=1)[:, : CANDIDATES_FROM_FOLDS + 1]
        top_k_indices = top_k_indices[top_k_indices != ind.reshape(-1, 1)]
        top_tracks_similarities = np.take_along_axis(similarities, top_k_indices.reshape(len(ind), CANDIDATES_FROM_FOLDS), axis=1)
        top_tracks = track_ids[top_k_indices].reshape(len(ind), CANDIDATES_FROM_FOLDS)
        for track_id, tracks, sims in zip(track_id_batch.flatten(), top_tracks, top_tracks_similarities):
            res[int(track_id)] = list(zip(tracks, sims))
    return res

In [7]:
cfg = init_cfg("final_artifacts/hgnetv2_b5_metric_learning_big_margin_drop_cliques_test_0_6_6folds02_19_23/hydra/config.yaml")

In [8]:
seed_everything(cfg.environment.seed)

logger.info("Reading clique2tracks")
clique2tracks = call(
    cfg.read_clique2versions,
    _convert_="partial",
)

filtered_clique2tracks = call(
    cfg.read_filtered_clique2versions,
    _convert_="partial",
)

cliques_splits = call(cfg.split_cliques, clique2tracks, _convert_="partial")

test_dataset = call(
    cfg.test_data.dataset,
    tracks_ids=np.load(cfg.test_data.test_ids_path),
    track2clique=None,
    clique2tracks=None,
    _convert_="partial",
)
test_loader = instantiate(cfg.test_data.dataloader, test_dataset, _convert_="partial")


fold_results = []
for fold, (
    (train_track2clique, train_clique2tracks),
    (val_track2clique, val_clique2tracks),
) in enumerate(cliques_splits):
    if filtered_clique2tracks is not None:
        train_track2clique, train_clique2tracks = filter_tracks(
            filtered_clique2tracks, train_clique2tracks
        )

    model = init_model(cfg).to(cfg.environment.device)

    if cfg.path_to_fold_checkpoints is not None:
        model = load_fold_checkpoint(model, cfg.path_to_fold_checkpoints, fold)

    if cfg.freeze_backbone_num_layers is not None:
        model = freeze_layers(model, cfg.freeze_backbone_num_layers)

    fold_results.append(
        make_fold_submission(cfg, model, test_loader)
    )

Reading clique2tracks


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 12/12 [01:41<00:00,  8.47s/it]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 12/12 [01:38<00:00,  8.21s/it]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 12/12 [01:30<00:00,  7.55s/it]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 12/12 [01:39<00:00,  8.26s/it]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 12/12 [01:3

In [9]:
def reduce_by_folds(fold_top_tracks):
    TOP_K = 100
    
    scores = {}
    for fold in fold_top_tracks:
        for track_id, top_tracks in fold.items():
            track_id = int(track_id)
            if track_id not in scores:
                scores[track_id] = defaultdict(float)
            for recommended_track, score in top_tracks:
                scores[track_id][int(recommended_track)] += float(score)

    candidates_with_adjusted_score = {}
    for track_id, candidates_with_score in scores.items():
        if track_id not in candidates_with_adjusted_score:
            candidates_with_adjusted_score[track_id] = defaultdict(int)
        for recommended_track, total_score in candidates_with_score.items():
            candidates_with_adjusted_score[track_id][recommended_track] = total_score / len(fold_top_tracks)

    res = {}
    for track_id, cws in candidates_with_adjusted_score.items():
        best_tracks = [x[0] for x in sorted(cws.items(), key=lambda x: -x[1])[:TOP_K]]
        res[track_id] = best_tracks

    return res, dict(candidates_with_adjusted_score)

In [10]:
final_res, sims = reduce_by_folds(fold_results)

In [50]:
good_pairs = []
total_pairs = 0 
for track_i, top_tracks in sims.items():
    for track_j, score in top_tracks.items():
        total_pairs += 1
        if score > 0.58:
            good_pairs.append((track_i, track_j))

In [51]:
total_pairs

79375131

In [52]:
import networkx as nx

G = nx.Graph()

G.add_edges_from(good_pairs)

connected_components = list(nx.connected_components(G))

new_cliques = defaultdict(list)
track_offset = 370110
clique_offset = 41617
for component in connected_components:
    for track in component:
        new_cliques[clique_offset].append(track + track_offset)
    clique_offset += 1

new_cliques = dict(new_cliques)

In [53]:
add_df = pd.DataFrame({
    "clique": list(new_cliques.keys()),
    "versions": list(new_cliques.values()),
})

In [54]:
len(add_df)

5006

In [55]:
add_df["versions"].str.len().max()

np.int64(1292)

In [57]:
import ast

old_df = pd.read_csv(
    "/home/yskhnykov/yandex_cup/data/raw/cliques2versions_drop_cliques.tsv", 
    sep="\t", 
    converters={"versions": ast.literal_eval}
)

In [58]:
old_df

Unnamed: 0,clique,versions
0,39475,"[343223, 361210, 114472, 134744, 271362, 30747..."
1,20077,"[343224, 350590, 170706, 266043, 314556, 30764..."
2,22290,"[343225, 343986, 344624, 345116, 345312, 33796..."
3,17098,"[343226, 220430]"
4,41075,"[343228, 182973]"
...,...,...
37492,7139,"[102983, 103700]"
37493,20120,"[103390, 71338]"
37494,16898,"[70624, 76088]"
37495,31616,"[70632, 76025]"


In [59]:
pd.concat((old_df, add_df), axis=0).to_csv(
    "/home/yskhnykov/yandex_cup/data/raw/cliques2versions_drop_cliques_test_0_58.tsv", sep="\t", index=False
)

In [60]:
pd.read_csv("/home/yskhnykov/yandex_cup/data/raw/cliques2versions_drop_cliques_test_0_58.tsv", sep="\t").drop_duplicates(
    subset=["clique"]
).shape

(42503, 2)

In [61]:
import pandas as pd
pd.read_csv("/home/yskhnykov/yandex_cup/data/raw/cliques2versions_drop_cliques_test_0_58.tsv", sep="\t")["clique"].max()

np.int64(46622)

In [108]:
pd.read_csv("/home/yskhnykov/yandex_cup/data/raw/cliques2versions_drop_cliques_test_0_6.tsv", sep="\t")

Unnamed: 0,clique,versions
0,39475,"[343223, 361210, 114472, 134744, 271362, 30747..."
1,20077,"[343224, 350590, 170706, 266043, 314556, 30764..."
2,22290,"[343225, 343986, 344624, 345116, 345312, 33796..."
3,17098,"[343226, 220430]"
4,41075,"[343228, 182973]"
...,...,...
42616,46736,"[390897, 411394]"
42617,46737,"[390585, 411445]"
42618,46738,"[388892, 412029]"
42619,46739,"[389184, 390987]"


In [87]:
add_df

Unnamed: 0,clique,versions
0,41617,"[418609, 414365]"
1,41618,"[373602, 400323]"
2,41619,"[376896, 421368]"
3,41620,"[412774, 376930]"
4,41621,"[380761, 407682]"
...,...,...
898,42515,"[390742, 397576]"
899,42516,"[393905, 396050]"
900,42517,"[397307, 395533]"
901,42518,"[401789, 398877]"


In [31]:
import pandas as pd
import ast

df = pd.read_csv(
    "/home/yskhnykov/yandex_cup/data/raw/cliques2versions_cleaned_axis_0_1_3.tsv", sep="\t", converters={"versions": ast.literal_eval}
)

In [32]:
df.shape

(41596, 2)

In [10]:
def save_tracks_to_file(data, output_path):
    import pandas as pd
    df = pd.DataFrame(
        {
            'query_trackid': list(data.keys()),
            'track_ids': [
                ' '.join(map(lambda x: str(int(x)), track_ids)) for track_ids in data.values()
            ],
        }
    )
    df['output'] = df['query_trackid'].astype(str) + ' ' + df['track_ids'].astype(str)
    df.sort_values("query_trackid", inplace=True)
    df[['output']].to_csv(output_path, index=False, header=False)

In [11]:
save_tracks_to_file(final_res, "submission_distances.csv")

Unnamed: 0,clique,versions
0,39475,"[343223, 361210, 114472, 134744, 271362, 30747..."
1,20077,"[343224, 350590, 170706, 266043, 314556, 30764..."
2,22290,"[343225, 343986, 344624, 345116, 345312, 33796..."
3,17098,"[343226, 220430]"
4,41075,"[343228, 182973]"
...,...,...
37492,7139,"[102983, 103700]"
37493,20120,"[103390, 71338]"
37494,16898,"[70624, 76088]"
37495,31616,"[70632, 76025]"


In [87]:
import pandas as pd
import ast
def read_clique2versions(clique2versions_path: str) -> dict[int, list[int]]:
    import ast
    df = pd.read_csv(clique2versions_path, sep="\t", converters={"versions": ast.literal_eval})
    clique2tracks = df.set_index("clique")["versions"].to_dict()
    return clique2tracks

In [88]:
df = pd.read_csv("/home/yskhnykov/yandex_cup/data/raw/cliques2versions_cleaned_axis_0_1_3.tsv", sep="\t", converters={"versions": ast.literal_eval})

In [89]:
df[~df["clique"].isin(bad_cliques)].to_csv("/home/yskhnykov/yandex_cup/data/raw/cliques2versions_drop_cliques_2_5k.tsv", sep="\t", index=False)

In [34]:
f = pd.read_csv("/home/yskhnykov/yandex_cup/data/raw/cliques2versions_drop_cliques.tsv", sep="\t", converters={"versions": ast.literal_eval})

In [36]:
f["clique"].max()

np.int64(41616)

In [51]:
f

Unnamed: 0,clique,versions
0,39475,"[343223, 361210, 114472, 134744, 271362, 30747..."
1,20077,"[343224, 350590, 170706, 266043, 314556, 30764..."
2,22290,"[343225, 343986, 344624, 345116, 345312, 33796..."
3,17098,"[343226, 220430]"
4,41075,"[343228, 182973]"
...,...,...
37492,7139,"[102983, 103700]"
37493,20120,"[103390, 71338]"
37494,16898,"[70624, 76088]"
37495,31616,"[70632, 76025]"


In [60]:
len(clique2min_sim)

41616