In [1]:
import numpy as np
import pandas as pd 
from functools import lru_cache
import os
from tqdm import tqdm 
np.random.seed(2022)

In [2]:
train_data_dir = '../../raw_data/'
test_data_dir = '../../raw_data/'
recstudio_data_dir = '../data_for_recstudio/'
task = 'task1'
PREDS_PER_SESSION = 100
SEED = 2022
VALID_RATIO = 0.08

In [3]:
@lru_cache(maxsize=1)
def read_product_data():
    return pd.read_csv(os.path.join(train_data_dir, 'products_train.csv'))

@lru_cache(maxsize=1)
def read_train_data():
    return pd.read_csv(os.path.join(train_data_dir, 'sessions_train.csv'))

@lru_cache(maxsize=3)
def read_test_data_phase1(task):
    return pd.read_csv(os.path.join(test_data_dir, f'sessions_test_{task}_phase1.csv'))

@lru_cache(maxsize=1)
def read_all_task1_data():
    return pd.read_csv(os.path.join(recstudio_data_dir, 'products_train.csv'))

In [4]:
@lru_cache(maxsize=3)
def read_test_data_phase2(task):
    return pd.read_csv(os.path.join(test_data_dir, f'sessions_test_{task}.csv'))

In [5]:
def split_valid_data(locale, train_sessions, valid_ratio):
    train_locale = train_sessions.query(f'locale == "{locale}"')
    train_locale = train_locale.sample(frac=1, random_state=SEED) # shuffle
    valid_size = int(len(train_locale) * valid_ratio)
    valid_locale = train_locale[:valid_size]
    train_locale = train_locale[valid_size:]
    return train_locale, valid_locale

In [6]:
def merge_train_valid_test_4_train(locale_name, train_locale, valid_locale:pd.DataFrame, test_sessions_list:list[pd.DataFrame]):
    task1_prev_items, task1_locales, task1_next_items  = [], [], []
    # train 
    for row in train_locale.itertuples():
        prev_items = row.prev_items
        next_item = row.next_item
        task1_locales.append(locale_name)
        task1_prev_items.append(prev_items)
        task1_next_items.append(next_item)
    
    # valid 
    for row in valid_locale.itertuples():
        prev_items = eval(row.prev_items.replace(' ', ','))
        if len(prev_items) <= 1:
            continue
        task1_locales.append(row.locale)
        task1_prev_items.append(str(np.array(prev_items[:-1])))
        task1_next_items.append(prev_items[-1])
    
    # test
    for test_sessions in test_sessions_list:
        locale_test_sessions = test_sessions.query(f"locale=='{locale_name}'")
        for row in locale_test_sessions.itertuples():
            prev_items = eval(row.prev_items.replace(' ', ','))
            if len(prev_items) <= 1:
                continue
            task1_locales.append(row.locale)
            task1_prev_items.append(str(np.array(prev_items[:-1])))
            task1_next_items.append(prev_items[-1])
    return pd.DataFrame({'prev_items' : task1_prev_items, 'locale' : task1_locales, 'next_item' : task1_next_items})

In [7]:
def session_2_inter_feat(sessions_df, save_path, test=False):
    num_sessions = len(sessions_df)

    cnt = 0
    with open(os.path.join(save_path), 'w') as f:
        f.write('sess_id,product_id,timestamp,locale\n')

        for sess in tqdm(sessions_df.itertuples(), total=len(sessions_df)):
            sess_locale = sess.locale
            sess_prev_items = sess.prev_items
            if not test:
                sess_next_item = sess.next_item
            
            product_list = sess_prev_items.strip('[]').split(' ')
            product_list = list(map(lambda x : x.strip("'\n"), product_list))
            if not test:
                product_list.append(sess_next_item)

            sess_id = cnt
            for j, product_id in enumerate(product_list):
                inter_str = f'{sess_id},{product_id},{j},{sess_locale}\n'
                f.write(inter_str)
            cnt += 1

# Task1

In [8]:
train_sessions = read_train_data()

In [9]:
test_sessions_phase1 = read_test_data_phase1('task1')
test_sessions_phase2 = read_test_data_phase2('task1')
task3_test_sessions_phase1 = read_test_data_phase1('task3')
task3_test_sessions_phase2 = read_test_data_phase2('task3')

In [10]:
locale_names = test_sessions_phase1['locale'].unique()

In [11]:
locale_names

array(['DE', 'JP', 'UK'], dtype=object)

In [12]:
trains, valids = [], []
for locale_name in locale_names:
    train_locale, valid_locale = split_valid_data(locale_name, train_sessions, VALID_RATIO)
    trains.append(train_locale), valids.append(valid_locale)

In [13]:
trains[2]

Unnamed: 0,prev_items,next_item,locale
3182859,['B06XKPJZH2' 'B06XKM7M7J' 'B06XKV87YQ' 'B06XK...,B08SQWJ8KD,UK
2413861,['B09CPRN7NL' 'B07GCKWK3B' 'B08CXBTC3X' 'B09CP...,B07ZHCLSKR,UK
3146198,['B00K71MN5W' 'B00K71MMXK' 'B0B12W38X3' 'B09NQ...,B09NQ89GZ5,UK
2825235,['B00OKC9YLA' 'B00XWMOOB4' 'B00OKC9Z6Y' 'B00PI...,B00R96QJQW,UK
2402944,['B09BVC9Y3M' 'B08J6XHT7Z'],B0999Q3NXJ,UK
...,...,...,...
2530520,['B07NBVMD8B' 'B09S192YFG' 'B0BFPLN6NL' 'B09R8...,B098T4S62X,UK
2193559,['B08DMVBX4C' 'B08DMV8XPK' 'B08L1X3RJ1' 'B09W8...,B01EG533I0,UK
2238164,['B00UVOE51A' 'B00UVOE51A' 'B00UVOE6TG'],B082Q3LXS2,UK
2878211,['B075D87KJL' 'B08DTWJ8BX' 'B00KDL108M' 'B01CX...,B07V59FPRV,UK


In [14]:
valids[2]

Unnamed: 0,prev_items,next_item,locale
2188849,['B0B42XZNCW' 'B07ND7MKVZ' 'B0B42XZNCW' 'B07ND...,B07S43KJL7,UK
2155495,['B09BW9Y1YR' 'B0B6Q3H47G'],B01EUNA3VS,UK
2496258,['B013KV0FS2' 'B06VW8Y81Y' 'B013KV0FS2'],B014WYVD1O,UK
2116552,['B09TB336FC' 'B083P3LGFY' 'B09TB3KFNV'],B09N3M5JNJ,UK
2497713,['B0B77H8NYS' 'B07HQ7P9BQ' 'B07MWYRVPG'],B009ARMB0Q,UK
...,...,...,...
2577128,['B08SJWM6L2' 'B08ZKP7Z4T' 'B08SJWM6L2' 'B08CD...,B0B63T163B,UK
2428621,['B07Z1V3WGM' 'B07QQC48FX' 'B07W92D6W5' 'B07KP...,B09CYG11VV,UK
2411626,['B082DGJQCL' 'B09JV3Q5FX' 'B082DGJQCL' 'B082D...,B09PBFBR4L,UK
2516632,['B003KU6GAU' 'B074JDHYCF' 'B003KU6GAU'],B096RVWZCV,UK


In [15]:
merged_trains = []
for i, locale_name in enumerate(locale_names):
    merged_trains.append(merge_train_valid_test_4_train(locale_name, trains[i], valids[i], [test_sessions_phase1, test_sessions_phase2, task3_test_sessions_phase1, task3_test_sessions_phase2]))

In [16]:
all_train = pd.concat(trains, axis=0, ignore_index=True)
all_valid = pd.concat(valids, axis=0, ignore_index=True)

In [33]:
for i, locale_name in enumerate(locale_names):

    merged_trains[i].to_csv(f'../../data_for_recstudio/{locale_name}_data_phase2/{locale_name}_train_sessions.csv', index=False)
    session_2_inter_feat(merged_trains[i], f'../../data_for_recstudio/{locale_name}_data_phase2/{locale_name}_train_inter_feat.csv', test=False)
    
    valids[i].to_csv(f'../../data_for_recstudio/{locale_name}_data_phase2/{locale_name}_valid_sessions.csv', index=False)
    session_2_inter_feat(valids[i], f'../../data_for_recstudio/{locale_name}_data_phase2/{locale_name}_valid_inter_feat.csv', test=False)
    

100%|██████████| 1340552/1340552 [00:46<00:00, 28975.23it/s]
100%|██████████| 88913/88913 [00:02<00:00, 37955.81it/s]
100%|██████████| 1192053/1192053 [00:40<00:00, 29479.70it/s]
100%|██████████| 78329/78329 [00:03<00:00, 23786.22it/s]
100%|██████████| 1434054/1434054 [00:48<00:00, 29500.17it/s]
100%|██████████| 94574/94574 [00:01<00:00, 54751.59it/s]


In [17]:
session_2_inter_feat(all_train, f'../../data_for_recstudio/task1_data/task13_4_task1_train_inter_feat_phase2.csv')
session_2_inter_feat(all_valid, f'../../data_for_recstudio/task1_data/task13_4_task1_valid_inter_feat_phase2.csv')
all_train.to_csv(f'../../data_for_recstudio/task1_data/task13_4_task1_train_sessions_phase2.csv', index=False)
all_valid.to_csv(f'../../data_for_recstudio/task1_data/task13_4_task1_valid_sessions_phase2.csv', index=False)

100%|██████████| 3010900/3010900 [00:38<00:00, 79104.57it/s] 
100%|██████████| 261816/261816 [00:05<00:00, 44277.32it/s]


In [34]:
trains[0]

Unnamed: 0,prev_items,next_item,locale
706840,['3949568239' 'B09CLBRV16' 'B0B7237CF5' 'B09FJ...,B0B1MWMKVZ,DE
898117,['B09QZWPX6T' 'B09CLH9TWB'],B00175X9QE,DE
820132,['B08LMNLBCS' 'B08LMNLBCS'],B0045XA8J6,DE
391472,['B0B6HFFLHB' 'B0B6C8MZDV'],B07L1BH4Q9,DE
995063,['B089W594LC' 'B09W2G8NLH'],B093GB3CCW,DE
...,...,...,...
439985,['B0B7VGY212' 'B0BD7TC87Y'],B0BCQD39BD,DE
103024,['B09MKH7HDX' 'B08HMWZBXC' 'B093SLWMS7'],B093SJD7GS,DE
147629,['B0046VSVYG' 'B095X8S3R6'],B09VGSD3T3,DE
787676,['B0BBRDJN6D' 'B09Z62DRLW' 'B0972L1RPN' 'B0B1W...,B09QLYBS2S,DE


In [35]:
merged_trains[1]

Unnamed: 0,prev_items,locale,next_item
0,['B09X16YXD2' 'B09SPCKPV1' 'B09BZ58TZR' 'B08BY...,JP,B098WRRCGF
1,['B08L5JB53S' 'B08L5H869W'],JP,B08L5H7T2Q
2,['B01FDILIPY' 'B01FDILIPY'],JP,B003OZTFJK
3,['B07HB34YS9' 'B07HB5TYCX'],JP,B07H9ZML2V
4,['B07MQC71M4' 'B08JGCH578' 'B07MQC71M4' 'B07MQ...,JP,B0BB1F646V
...,...,...,...
1192048,['B008JGUVOC' 'B09P3R3ND1'],JP,B004225TZS
1192049,['B00F4L8QZQ'],JP,B00F4L8QZQ
1192050,['B09G62TNN4'],JP,4413232720
1192051,['B09XZZZKZY' 'B09Y1GCP6P' 'B09XZZZKZY' 'B09CY...,JP,B09XZZZKZY


In [20]:
merged_trains[2]

Unnamed: 0,prev_items,locale,next_item
0,['B06XKPJZH2' 'B06XKM7M7J' 'B06XKV87YQ' 'B06XK...,UK,B08SQWJ8KD
1,['B09CPRN7NL' 'B07GCKWK3B' 'B08CXBTC3X' 'B09CP...,UK,B07ZHCLSKR
2,['B00K71MN5W' 'B00K71MMXK' 'B0B12W38X3' 'B09NQ...,UK,B09NQ89GZ5
3,['B00OKC9YLA' 'B00XWMOOB4' 'B00OKC9Z6Y' 'B00PI...,UK,B00R96QJQW
4,['B09BVC9Y3M' 'B08J6XHT7Z'],UK,B0999Q3NXJ
...,...,...,...
1434049,['B08S8N3YG4'],UK,B00HRRANWY
1434050,['B09F5L55QV' 'B0B8BRLGY4' 'B00IKGH2TI' 'B00IK...,UK,B0BB7JDWB9
1434051,['B09FY2WLZ3'],UK,B09FY4F65Y
1434052,['B001G4L9KO' 'B0BG373773' 'B001G4L9KO'],UK,B07Z286QPY
