# Trplet creation

In [None]:
import pandas as pd
import os
from collections import defaultdict, deque
import numpy as np
from tqdm import tqdm
import copy

In [None]:
PATH = 'avitotech_data\\avitotech_data'
os.chdir(PATH)

In [None]:
df_train_1 = pd.read_parquet("train_part_0001.snappy.parquet")
df_train_2 = pd.read_parquet("train_part_0002.snappy.parquet")
df_train_3 = pd.read_parquet("train_part_0003.snappy.parquet")
df_train_4 = pd.read_parquet("train_part_0004.snappy.parquet")

In [None]:
df_train_1 = df_train_1[['base_item_id', 'cand_item_id', 'is_double']]
df_train_2 = df_train_2[['base_item_id', 'cand_item_id', 'is_double']]
df_train_3 = df_train_3[['base_item_id', 'cand_item_id', 'is_double']]
df_train_4 = df_train_4[['base_item_id', 'cand_item_id', 'is_double']]

In [None]:
df_train = pd.concat([df_train_1, df_train_2, df_train_3, df_train_4])

## Mutual interaction

In [None]:
def build_pair_dict(df):
    ids = defaultdict(lambda: defaultdict(list))

    for base, cand, is_double in df.values:
        if cand not in ids[base]['pair']:
            ids[base]['pair'].append(cand)
            ids[base]['is_double'].append(is_double)

        if base not in ids[cand]['pair']:
            ids[cand]['pair'].append(base)
            ids[cand]['is_double'].append(is_double)

    return ids

In [None]:
ids = build_pair_dict(df_train)

In [None]:
def propagate_transitive_doubles(ids):
    for item in tqdm(ids):
        visited = set()
        queue = deque()
        known_doubles = set()

        # Ищем всех прямых и косвенных дублей (is_double == 1)
        queue.append(item)
        while queue:
            current = queue.popleft()
            if current in visited:
                continue
            visited.add(current)
            for idx, neighbor in enumerate(ids[current]['pair']):
                if ids[current]['is_double'][idx] == 1 and neighbor not in visited:
                    known_doubles.add(neighbor)
                    queue.append(neighbor)

        # Для каждого дубля найдём его недубли и добавим как is_double == 0
        for double_id in known_doubles:
            for idx, neighbor in enumerate(ids[double_id]['pair']):
                is_dbl = ids[double_id]['is_double'][idx]
                if is_dbl == 0 and neighbor != item:
                    if neighbor not in ids[item]['pair']:
                        ids[item]['pair'].append(neighbor)
                        ids[item]['is_double'].append(0)
                    if item not in ids[neighbor]['pair']:
                        ids[neighbor]['pair'].append(item)
                        ids[neighbor]['is_double'].append(0)

        # Добавим недостающие связи между дублями
        for double_id in known_doubles:
            if double_id not in ids[item]['pair']:
                ids[item]['pair'].append(double_id)
                ids[item]['is_double'].append(1)
            if item not in ids[double_id]['pair']:
                ids[double_id]['pair'].append(item)
                ids[double_id]['is_double'].append(1)

    return ids


In [None]:
ids = propagate_transitive_doubles(ids)

## Id's group

In [None]:
df_train_1 = pd.read_parquet("train_part_0001.snappy.parquet")
df_train_2 = pd.read_parquet("train_part_0002.snappy.parquet")
df_train_3 = pd.read_parquet("train_part_0003.snappy.parquet")
df_train_4 = pd.read_parquet("train_part_0004.snappy.parquet")

df_train = pd.concat([df_train_1, df_train_2, df_train_3, df_train_4])

In [None]:
base_map = defaultdict(set)
cand_map = defaultdict(set)

for row in df_train[['base_item_id', 'cand_item_id', 'group_id']].itertuples(index=False):
    base_map[row.base_item_id].add(row.group_id)
    cand_map[row.cand_item_id].add(row.group_id)

## Add ids

In [None]:
for key in tqdm(ids):
    groups = base_map.get(key, set()) | cand_map.get(key, set())
    ids[key]['groups'] = list(groups)

## Only doubles remain

In [None]:
ids_neg = set()
ids_trunc = defaultdict(lambda: defaultdict(list))

for yan_id, data in tqdm(ids.items()):
    if np.sum(data['is_double']) > 0:
        ids_trunc[yan_id] = copy.deepcopy(data)
    else:
        ids_neg.add(yan_id)

## Not duplicates among duplicates

In [None]:
used_list = set()

for yan_id, data in tqdm(ids.items()):
    if np.sum(data['is_double']) > 0:
        for inner_ids in data['pair']:
            if inner_ids in ids_neg:
                used_list.add(inner_ids)

In [None]:
ids_neg_trunc = ids_neg - used_list # not used ids

In [None]:
print(f"Всего не использованных ID: {len(ids_neg_trunc)}")
print(f"Негативов на ID: {np.ceil(len(ids_neg_trunc) / len(ids_trunc))}")

## Miximum groups per id

In [None]:
max_gp = 0
for key in ids:
    if max_gp < len(ids[key]['groups']):
        max_gp = len(ids[key]['groups'])

print(max_gp)

In [None]:
def augment_with_negatives_dynamic(ids_trunc, ids, ids_neg_trunc, n):
    used_candidates = set()
    cnt_added  = 0
    copy_ids_neg_trunc = list(ids_neg_trunc)
    
    for yan_id, data in tqdm(ids_trunc.items()):
        current_groups = set(data.get('groups', []))

        num_positives = sum(1 for val in data.get('is_double', []) if val == 1)
        num_negatives = sum(1 for val in data.get('is_double', []) if val == 0)

        target_negatives = n * num_positives
        remaining_to_add = max(target_negatives - num_negatives, 0)

        if remaining_to_add == 0:
            continue

        added = 0

        if cnt_added % 5000 == 0:
            copy_ids_neg_trunc = list(ids_neg_trunc)[cnt_added:]
        
        for cand_id in copy_ids_neg_trunc:
            if cand_id in used_candidates:
                continue

            cand_groups = set(ids[cand_id].get('groups', []))

            if current_groups.isdisjoint(cand_groups):
                data['pair'].append(cand_id)
                data['is_double'].append(0)
                used_candidates.add(cand_id)
                added += 1

                if added >= remaining_to_add:
                    cnt_added += added
                    break

    return cnt_added


In [None]:
cnt_added = augment_with_negatives_dynamic(ids_trunc, ids_1, ids_neg_trunc, n=6)

In [None]:
print(f"Из {len(ids_neg_trunc)} добавлено {cnt_added}")

In [None]:
for yan_id, items in ids_trunc.items():
    items.pop('groups')

In [None]:
import json
with open('to_undergo.json', 'w') as f:
    json.dump(ids_trunc, f)