In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
!pip install polars



In [3]:
from collections import defaultdict, Counter
from typing import List, Dict

from tqdm import tqdm
import pandas as pd
import polars as pl

In [4]:
LOCALES = ["FR", "ES", "IT"]

In [5]:
train = pd.read_csv("/content/drive/MyDrive/kddcup2023/data/raw/sessions_train.csv")
test = pd.concat([
    pd.read_csv("/content/drive/MyDrive/kddcup2023/data/raw/sessions_test_task2_phase1.csv"),
    pd.read_csv("/content/drive/MyDrive/kddcup2023/data/raw/sessions_test_task2.csv"),
    pd.read_csv("/content/drive/MyDrive/kddcup2023/data/raw/sessions_test_task3_phase1.csv"),
    pd.read_csv("/content/drive/MyDrive/kddcup2023/data/raw/sessions_test_task3.csv"),
])
train = train[train["locale"].isin(LOCALES)]
test = test[test["locale"].isin(LOCALES)]
train = pl.from_pandas(train)
test = pl.from_pandas(test)

In [6]:
# prev_items
def str2list(s):
    s = s.replace("[", "").replace("]", "").replace("'", "").replace("\n", " ").replace("\r", " ")
    s = s.split()
    return s

train = train.with_columns(pl.col("prev_items").apply(str2list).alias("prev_items"))
test = test.with_columns(pl.col("prev_items").apply(str2list).alias("prev_items"))

  train = train.with_columns(pl.col("prev_items").apply(str2list).alias("prev_items"))
  test = test.with_columns(pl.col("prev_items").apply(str2list).alias("prev_items"))


In [7]:
print(len(train))
print(len(test))

333533
122221


In [8]:
# Add test data to train

# Condition 1: Number of sessions is more than 3
# Condition 2: Last item is not interacted with.
test = test.with_columns(
    pl.col("prev_items").apply(len).alias("session_count")
)
prev_items_list = test["prev_items"].to_list()
next_item_list = []
prev_items_list_updated = []
for prev_items in prev_items_list:
    next_item_list.append(prev_items[-1])
    prev_items_list_updated.append(prev_items[:-1])
test = test.with_columns([
    pl.Series(name="next_item", values=next_item_list),
    pl.Series(name="prev_items_updated", values=prev_items_list_updated),
])

test_add_to_train = test.filter(
    (pl.col("session_count") >= 3) &
    (~pl.col("next_item").is_in(pl.col("prev_items_updated")))
)
test_add_to_train = test_add_to_train[["prev_items_updated", "next_item", "locale"]]
test_add_to_train = test_add_to_train.rename({"prev_items_updated":"prev_items"})

test_not_add_to_train = test.filter(
    (pl.col("session_count") < 3) |
    (pl.col("next_item").is_in(pl.col("prev_items_updated")))
)
test_not_add_to_train = test_not_add_to_train[["prev_items", "locale"]]

  pl.col("prev_items").apply(len).alias("session_count")


In [9]:
assert len(test_add_to_train) + len(test_not_add_to_train) == len(test)
print(len(test_add_to_train))
print(len(test_not_add_to_train))

49557
72664


In [10]:
print("test追加前", len(train))
train = pl.concat([train, test_add_to_train])
print("test追加后", len(train))

test追加前 333533
test追加后 383090


In [11]:
# session_id
train = train.with_columns(pl.Series(name="session_id", values=["train_" + str(i) for i in range(len(train))]))
test_not_add_to_train = test_not_add_to_train.with_columns(pl.Series(name="session_id", values=["test_leftover_" + str(i) for i in range(len(test_not_add_to_train))]))

In [12]:
train.head()

prev_items,next_item,locale,session_id
list[str],str,str,str
"[""B08MV5B53K"", ""B08MV4RCQR"", ""B08MV5B53K""]","""B012408XPC""","""ES""","""train_0"""
"[""B07JGW4QWX"", ""B085VCXHXL""]","""B07JFPYN5P""","""ES""","""train_1"""
"[""B08BFQ52PR"", ""B08LVSTZVF"", ""B08BFQ52PR""]","""B08NJP3KT6""","""ES""","""train_2"""
"[""B08PPBF9C6"", ""B08PPBF9C6"", … ""B08PPBF9C6""]","""B08PP6BLLK""","""ES""","""train_3"""
"[""B0B6W67XCR"", ""B0B712FY2M"", ""B0B6ZYJ3S2""]","""B09SL4MBM2""","""ES""","""train_4"""


In [13]:
test_add_to_train.head()

prev_items,next_item,locale
list[str],str,str
"[""B08GYKNCCP"", ""B08HCPTMJG""]","""B08HCHS64Y""","""ES"""
"[""B09YM11D4T"", ""B0B12QWP5G"", … ""B0B12QWP5G""]","""B07N8N6C85""","""ES"""
"[""B08D9PKL3W"", ""B09CQ72HCJ"", ""B09CQ7H87G""]","""B08D9PGC9P""","""ES"""
"[""B0B9ZW2RPV"", ""B08DKFQFJH""]","""B07ZKKZXCX""","""ES"""
"[""B08MFH1TTJ"", ""B08MFDT65P"", ""B0968HW8GY""]","""B0968HFSMH""","""ES"""


In [14]:
test_not_add_to_train.head()

prev_items,locale,session_id
list[str],str,str
"[""B08NYF9MBQ"", ""B085NGXGWM""]","""ES""","""test_leftover_…"
"[""B091FL1QFK"", ""B0B1DG29F4""]","""ES""","""test_leftover_…"
"[""B004APAHCW"", ""B07JMF49HN"", … ""B07JMF49HN""]","""ES""","""test_leftover_…"
"[""B07TX86KFZ"", ""B0882ZCHMW"", ""B07TX86KFZ""]","""ES""","""test_leftover_…"
"[""B08FJ3MR54"", ""B0BBM523JX""]","""ES""","""test_leftover_…"


In [15]:
train.write_parquet("/content/drive/MyDrive/kddcup2023/data/preprocessed/task2/train_task2_augmented.parquet")

In [16]:
test_not_add_to_train.write_parquet("/content/drive/MyDrive/kddcup2023/data/preprocessed/task2/test_task2_leftover.parquet")