In [1]:
import os
import random

from joblib import Parallel, delayed
import numpy as np
import pandas as pd
# from tqdm.auto import tqdm

In [2]:
DIRECTORY = os.path.join(os.path.expanduser("~"), "atmacup16")

In [6]:
%%time
# 都道府県ごとのランキング
label = pd.read_csv(os.path.join(DIRECTORY, "input", "train_label.csv"))
yado = pd.read_csv(os.path.join(DIRECTORY, "input", "yado.csv"))
count_by_prefecture = pd.merge(label, yado).groupby("ken_cd")["yad_no"].value_counts().sort_values(ascending=False)
count_by_prefecture

CPU times: total: 344 ms
Wall time: 376 ms


ken_cd                            yad_no
572d60f0f5212aacda515ebf81fb0a3a  10095     605
                                  3338      576
                                  12350     533
                                  8553      427
107c7305a74c8dcc4f143de208bf7ec2  385       347
                                           ... 
84efa50e52f9b471c95bfc3b21b854ad  9367        1
                                  9870        1
                                  10130       1
                                  11057       1
fec19ba0016c012f3a06360cfff6da32  4506        1
Name: count, Length: 12381, dtype: int64

In [7]:
%%time
log_test = pd.read_csv(os.path.join(DIRECTORY, "input", "test_log.csv"))
log_test

CPU times: total: 141 ms
Wall time: 131 ms


Unnamed: 0,session_id,seq_no,yad_no
0,00001149e9c73985425197104712478c,0,3560
1,00001149e9c73985425197104712478c,1,1959
2,0000e02747d749a52b7736dfa751e258,0,11984
3,0000f17ae2628237d78d3a38b009d3be,0,757
4,0000f17ae2628237d78d3a38b009d3be,1,8922
...,...,...,...
250300,fffee3199ef94b92283239cd5e3534fa,1,8336
250301,ffff62c6bb49bc9c0fbcf08494a4869c,0,12062
250302,ffff9a7dcc892875c7a8b821fa436228,0,8989
250303,ffffb1d30300fe17f661941fd085b04b,0,6030


In [4]:
co_occurance_rate = np.load(os.path.join(DIRECTORY, "features", "cooccurance_rate.npy"))
co_occurance_rate.shape

(13806, 13806)

In [8]:
yad_numbers = [i for i in range(1, co_occurance_rate.shape[0] + 1)]

In [19]:
K = 10

def get_prediction(session_id: str, session_df: pd.DataFrame) -> dict:

    prediction = {"session_id": session_id}
    session_df.sort_values("seq_no", inplace=True)

    # セッション中に閲覧した宿は候補、ただし最後は除く
    yad_no_last = session_df.iloc[session_df.shape[0] - 1, session_df.columns.get_loc("yad_no")]
    yad_numers_in_session = session_df["yad_no"].unique().tolist()
    candicates = [no for no in yad_numers_in_session if no != yad_no_last]

    # 共起行列から共起割合が高い宿を取得`
    rate_dfs = []
    for yad_no in yad_numers_in_session:
        rate_dfs.append(pd.DataFrame({"yad_no": yad_numbers, "ratio": co_occurance_rate[yad_no - 1]}))
    rate_df = pd.concat(rate_dfs).sort_values("ratio", ascending=False).query("ratio > 0").query(f"yad_no != {yad_no_last}")
    rate_df = rate_df[~rate_df["yad_no"].isin(candicates)]
    if len(candicates) < K:
        candicates += rate_df["yad_no"].drop_duplicates().tolist()[:K-len(candicates)]

    if len(candicates) < K:
        session_df = pd.merge(session_df, yado[["yad_no", "ken_cd"]])
        most_frequently_seen_prefecture = session_df["ken_cd"].mode()[0]
        ranking = count_by_prefecture.loc[most_frequently_seen_prefecture]
        candicates += ranking.index.tolist()[:K-len(candicates)]
    while len(candicates) < K:
        r = random.choice(yad_numbers)
        if r not in candicates and r != yad_no_last:
            candicates.append(r)

    for i, c in enumerate(candicates[:K]):
        prediction[f"predict_{i}"] = c

    return prediction

In [20]:
predictions = Parallel(n_jobs=7, verbose=1)(delayed(get_prediction)(i, df) for i, df in log_test.groupby("session_id"))
predictions = pd.DataFrame(predictions).set_index("session_id").sort_index()
predictions

[Parallel(n_jobs=7)]: Using backend LokyBackend with 7 concurrent workers.
[Parallel(n_jobs=7)]: Done  36 tasks      | elapsed:    2.0s
[Parallel(n_jobs=7)]: Done 1054 tasks      | elapsed:    4.7s
[Parallel(n_jobs=7)]: Done 3054 tasks      | elapsed:    9.9s
[Parallel(n_jobs=7)]: Done 5854 tasks      | elapsed:   17.6s
[Parallel(n_jobs=7)]: Done 9454 tasks      | elapsed:   27.0s
[Parallel(n_jobs=7)]: Done 13854 tasks      | elapsed:   38.0s
[Parallel(n_jobs=7)]: Done 19054 tasks      | elapsed:   50.9s
[Parallel(n_jobs=7)]: Done 25054 tasks      | elapsed:  1.1min
[Parallel(n_jobs=7)]: Done 31854 tasks      | elapsed:  1.4min
[Parallel(n_jobs=7)]: Done 39454 tasks      | elapsed:  1.7min
[Parallel(n_jobs=7)]: Done 47854 tasks      | elapsed:  2.0min
[Parallel(n_jobs=7)]: Done 57054 tasks      | elapsed:  2.4min
[Parallel(n_jobs=7)]: Done 67054 tasks      | elapsed:  2.8min
[Parallel(n_jobs=7)]: Done 77854 tasks      | elapsed:  3.4min
[Parallel(n_jobs=7)]: Done 89454 tasks      | ela

Unnamed: 0_level_0,predict_0,predict_1,predict_2,predict_3,predict_4,predict_5,predict_6,predict_7,predict_8,predict_9
session_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
00001149e9c73985425197104712478c,3560,9534,4545,11561,2680,4714,4420,5785,5466,6488
0000e02747d749a52b7736dfa751e258,143,4066,6555,11923,7014,8108,12862,613,6129,7913
0000f17ae2628237d78d3a38b009d3be,757,9190,7710,9910,410,1774,10485,3400,6721,13570
000174a6f7a569b84c5575760d2e9664,12341,6991,3359,13521,1542,5080,10861,2795,4180,10746
00017e2a527901c9c41b1acef525d016,9020,2862,10826,12029,9623,3854,3476,3844,13235,6161
...,...,...,...,...,...,...,...,...,...,...
fffee3199ef94b92283239cd5e3534fa,1997,7888,1885,11123,8771,5744,9743,10997,7062,7641
ffff62c6bb49bc9c0fbcf08494a4869c,13220,12432,899,4014,1227,3802,3644,2164,2232,9723
ffff9a7dcc892875c7a8b821fa436228,13241,13797,12939,11037,2087,7308,8143,13719,6300,844
ffffb1d30300fe17f661941fd085b04b,3100,3002,2373,12281,10287,13672,4976,5513,2692,1687


In [21]:
%%time
test_session = pd.read_csv(os.path.join(DIRECTORY, "input", "test_session.csv"))
test_session

CPU times: total: 78.1 ms
Wall time: 85.6 ms


Unnamed: 0,session_id
0,00001149e9c73985425197104712478c
1,0000e02747d749a52b7736dfa751e258
2,0000f17ae2628237d78d3a38b009d3be
3,000174a6f7a569b84c5575760d2e9664
4,00017e2a527901c9c41b1acef525d016
...,...
174695,fffee3199ef94b92283239cd5e3534fa
174696,ffff62c6bb49bc9c0fbcf08494a4869c
174697,ffff9a7dcc892875c7a8b821fa436228
174698,ffffb1d30300fe17f661941fd085b04b


In [22]:
test_session[predictions.columns] = predictions.loc[test_session["session_id"], predictions.columns].values
test_session

Unnamed: 0,session_id,predict_0,predict_1,predict_2,predict_3,predict_4,predict_5,predict_6,predict_7,predict_8,predict_9
0,00001149e9c73985425197104712478c,3560,9534,4545,11561,2680,4714,4420,5785,5466,6488
1,0000e02747d749a52b7736dfa751e258,143,4066,6555,11923,7014,8108,12862,613,6129,7913
2,0000f17ae2628237d78d3a38b009d3be,757,9190,7710,9910,410,1774,10485,3400,6721,13570
3,000174a6f7a569b84c5575760d2e9664,12341,6991,3359,13521,1542,5080,10861,2795,4180,10746
4,00017e2a527901c9c41b1acef525d016,9020,2862,10826,12029,9623,3854,3476,3844,13235,6161
...,...,...,...,...,...,...,...,...,...,...,...
174695,fffee3199ef94b92283239cd5e3534fa,1997,7888,1885,11123,8771,5744,9743,10997,7062,7641
174696,ffff62c6bb49bc9c0fbcf08494a4869c,13220,12432,899,4014,1227,3802,3644,2164,2232,9723
174697,ffff9a7dcc892875c7a8b821fa436228,13241,13797,12939,11037,2087,7308,8143,13719,6300,844
174698,ffffb1d30300fe17f661941fd085b04b,3100,3002,2373,12281,10287,13672,4976,5513,2692,1687


In [23]:
test_session.drop(columns=["session_id"]).to_csv(os.path.join(DIRECTORY, "submissions", "exp003.csv"), index=False)