# Get Co-Occurrence dict

In [2]:
import os
import random
import numpy as np
import pandas as pd
from functools import lru_cache
from tqdm import tqdm
from collections import Counter, defaultdict

In [3]:
train_data_dir = '../raw_data/'
test_data_dir = '../raw_data/'
recstudio_data_dir = '../data_for_recstudio/'

@lru_cache(maxsize=1)
def read_product_data():
    print(os.getcwd())
    print(os.path.join(train_data_dir, 'products_train.csv'))
    return pd.read_csv(os.path.join(train_data_dir, 'products_train.csv'))

@lru_cache(maxsize=1)
def read_train_sessions():
    return pd.read_csv(os.path.join(recstudio_data_dir, 'all_task_1_train_sessions.csv'))

@lru_cache(maxsize=1)
def read_valid_sessions():
    return pd.read_csv(os.path.join(recstudio_data_dir, 'all_task_1_valid_sessions.csv'))

@lru_cache(maxsize=3)
def read_test_sessions(task):
    return pd.read_csv(os.path.join(test_data_dir, f'sessions_test_{task}.csv'))

In [4]:
def get_sessions(df: pd.DataFrame, test=False, list_item=False) -> list:
    if ('next_item' in df) and (not test):
        if list_item:
            all_item = df['prev_items']
        else:
            all_item = df.apply(lambda x: eval((x['prev_items'][:-1]+f" '{x['next_item']}']").replace(" ", ",")), axis=1)
    else:
        if list_item:
            all_item = df['prev_items']
        else:
            all_item = df.apply(lambda x: eval(x['prev_items'].replace(" ", ",")), axis=1)
    all_item = all_item.to_list()
    return all_item

In [8]:
def update_co_occurrence_dict(co_occurrence_dict: dict, sessions: list, bidirection: bool=False, weighted: bool=False, 
                             max_span: int=2) -> dict:
    for sess in tqdm(sessions):
        for i, cur_product in enumerate(sess):
            if cur_product not in co_occurrence_dict:
                co_occurrence_dict[cur_product] = Counter()
            neighbor_products = sess[i + 1 : i + max_span + 1]
            for j, product in enumerate(neighbor_products):
                if not weighted:
                    co_occurrence_dict[cur_product][product] += 1
                else:
                    co_occurrence_dict[cur_product][product] += 1 / (j + 1)
                if bidirection:
                    if product not in co_occurrence_dict:
                        co_occurrence_dict[product] = Counter()
                    if not weighted:
                        co_occurrence_dict[product][cur_product] += 1
                    else:
                        co_occurrence_dict[product][cur_product] += 1 / (j + 1)

In [6]:
df_train_sessions = read_train_sessions()
df_valid_sessions = read_valid_sessions()
df_test_sessions = read_test_sessions('task1')

In [12]:
train_sessions_list = get_sessions(df_train_sessions)
valid_sessions_list = get_sessions(df_valid_sessions, test=True)
test_sessions_list = get_sessions(df_test_sessions, test=True)

In [15]:
test_sessions_list[0]

['B08V12CT4C',
 'B08V1KXBQD',
 'B01BVG1XJS',
 'B09VC5PKN5',
 'B09V7KG931',
 'B09PY75FWM',
 'B09PXYT6BT',
 'B08V12CT4C',
 'B08V1KXBQD',
 'B08496TCCQ',
 'B01BVG1XJS',
 'B099NQFMG7']

In [13]:
len(train_sessions_list), len(valid_sessions_list), len(test_sessions_list)

(3557898, 361581, 316971)

In [18]:
co_occurrence_dict = {}
update_co_occurrence_dict(co_occurrence_dict, train_sessions_list, False, False, 2)
update_co_occurrence_dict(co_occurrence_dict, valid_sessions_list, False, False, 2)
update_co_occurrence_dict(co_occurrence_dict, test_sessions_list, False, False, 2)

100%|██████████| 3557898/3557898 [00:41<00:00, 84782.45it/s] 
100%|██████████| 361581/361581 [00:02<00:00, 130575.88it/s]
100%|██████████| 316971/316971 [00:02<00:00, 153646.77it/s]


In [17]:
co_occurrence_dict

{'p': Counter({'r': 1, 'e': 1}),
 'r': Counter({'e': 1, 'v': 1}),
 'e': Counter({'v': 1, '_': 1, 'm': 2, 's': 1, 'x': 1, 't': 1}),
 'v': Counter({'_': 1, 'i': 1}),
 '_': Counter({'i': 2, 't': 2}),
 'i': Counter({'t': 2, 'e': 2}),
 't': Counter({'e': 2, 'm': 2, '_': 1, 'i': 1}),
 'm': Counter({'s': 1}),
 's': Counter(),
 'n': Counter({'e': 1, 'x': 1}),
 'x': Counter({'t': 1, '_': 1}),
 'l': Counter({'o': 1, 'c': 1, 'e': 1}),
 'o': Counter({'c': 1, 'a': 1}),
 'c': Counter({'a': 1, 'l': 1}),
 'a': Counter({'l': 1, 'e': 1}),
 'B09VSN9GLS': Counter({'B09VSG9DCG': 2,
          'B0BJ5L1ZPH': 1,
          'B0BJ6V797Y': 1,
          'B077VXNL67': 1}),
 'B09VSG9DCG': Counter({'B0BJ5L1ZPH': 1,
          'B09VSN9GLS': 2,
          'B077XGDMD2': 1,
          'B077VXNL67': 1,
          'B07V213ND4': 1,
          'B0BJ66G5LN': 4,
          'B092R34S33': 1,
          'B09VSG9DCG': 1}),
 'B0BJ5L1ZPH': Counter({'B09VSN9GLS': 1, 'B0BJ6V797Y': 1}),
 'B0BJ6V797Y': Counter({'B09VSG9DCG': 1, 'B077XGDMD2': 1}

In [None]:
num_pairs = []
for k in co_occurrence_dict:
    neighbor_counter = co_occurrence_dict[k]
    num_pairs.append(len(neighbor_counter))
np.array(num_pairs).mean()