In [1]:
# Author: Jpeng Liao
# Date: 2023-11-20
# Contributor: 
from pathlib import Path
data_root = Path('../kkdata3')
for x in data_root.glob('*'):
    print(x)

..\kkdata3\label_test_source.parquet
..\kkdata3\label_train_source.parquet
..\kkdata3\label_train_target.parquet
..\kkdata3\meta_song.parquet
..\kkdata3\meta_song_composer.parquet
..\kkdata3\meta_song_genre.parquet
..\kkdata3\meta_song_lyricist.parquet
..\kkdata3\meta_song_producer.parquet
..\kkdata3\meta_song_titletext.parquet
..\kkdata3\sample.csv


In [2]:
import pandas as pd
train_source = pd.read_parquet(data_root / 'label_train_source.parquet')
train_target = pd.read_parquet(data_root / 'label_train_target.parquet')
test_source = pd.read_parquet(data_root / 'label_test_source.parquet')
meta_song = pd.read_parquet(data_root / 'meta_song.parquet')

In [3]:
# sorting data by session_id and listening_order
train_source.sort_values(['session_id', 'listening_order'], inplace=True)
test_source.sort_values(['session_id', 'listening_order'], inplace=True)

In [4]:
def shift_data(df, n):
    for i in range(1, 1+n):
        df[f'next{i}_song_id'] = df['song_id'].shift(-i)
    df[f'next{n}_session_id'] = df['session_id'].shift(-n)
    df = df.query(f'session_id == next{n}_session_id')[['song_id']+[f'next{_}_song_id' for _ in range(1, 1+n)]]
    return df

In [5]:
def get_freq(train, test, n):
    train_shift = shift_data(train.copy(), n)
    test_shift = shift_data(test.copy(), n)
    df_shift = pd.concat([train_shift, test_shift], axis=0, ignore_index=True)
    df_freq = df_shift.groupby(['song_id']+[f'next{_}_song_id' for _ in range(1, n)]).value_counts(sort=True, normalize=True).reset_index(name='freq')
    df_freq = df_freq.sort_values(['song_id', 'next1_song_id', 'freq'], ascending=False).groupby(['song_id', 'next1_song_id']).head(1)
    return df_freq

In [34]:
def dataX(df, n):
    df_l = df.query(f'listening_order == {21-n}')[['session_id', 'song_id']].set_index('session_id')
    if n == 1:
        # df_r = df.query('listening_order == 21')[['session_id', 'song_id']].set_index('session_id')
        # df_l = df_l.join(df_r, lsuffix='_0', rsuffix='_1')
        return df_l.rename(columns={'song_id': 'song_id_0'}).reset_index(inplace=False)
    
    for i in range(1, n):
        df_r = df.query(f'listening_order == {21-n+i}')[['session_id', 'song_id']].set_index('session_id')
        df_l = df_l.join(df_r, lsuffix=f'_{i-1}', rsuffix=f'_{i}')
        df_l.rename(columns={'song_id': f'song_id_{i}'}, inplace=True)
    return df_l.reset_index(inplace=False)

def dataY(df):
    df_target = df.pivot_table(index='session_id', columns='listening_order', values='song_id', aggfunc='first')
    df_target.rename(columns={i: f'song_id_{i}' for i in range(21, 26)}, inplace=True)
    return df_target

def data_process(train, target, n):
    X = dataX(train.copy(), n)
    y = dataY(target.copy())
    return X.join(y, how='inner', on='session_id').reset_index(inplace=False)

In [35]:
def predict(df_data, df_predict, n):
    df_result = pd.merge(df_data, df_predict, how='left', left_on=[f'song_id_{i}' for i in range(n)], right_on=['song_id']+[f'next{_}_song_id' for _ in range(1, n)])
    del df_result['song_id']
    for i in range(1, n):
        del df_result[f'next{i}_song_id']
    df_result.rename(columns={f'next{n}_song_id': f'next_song_id'}, inplace=True)
    return df_result

In [42]:
N = 1
freq_table = get_freq(train_source, test_source, N)
data = dataX(test_source.copy(), N)
result1 = predict(data, freq_table, N).groupby(['session_id', 'song_id_0']).head(1)

In [16]:
N = 2
freq_table = get_freq(train_source, test_source, N)
data = dataX(test_source.copy(), N)
result2 = predict(data, freq_table, N)

In [22]:
N = 3
freq_table = get_freq(train_source, test_source, N)
data = dataX(test_source.copy(), N)
result3 = predict(data, freq_table, N)

In [56]:
N = 4
freq_table = get_freq(train_source, test_source, N)
data = dataX(test_source.copy(), N)
result4 = predict(data, freq_table, N)

In [57]:
N = 5
freq_table = get_freq(train_source, test_source, N)
data = dataX(test_source.copy(), N)
result5 = predict(data, freq_table, N)

In [64]:
N = 6
freq_table = get_freq(train_source, test_source, N)
data = dataX(test_source.copy(), N)
result6 = predict(data, freq_table, N)

In [65]:
N = 7
freq_table = get_freq(train_source, test_source, N)
data = dataX(test_source.copy(), N)
result7 = predict(data, freq_table, N)

In [66]:
N = 8
freq_table = get_freq(train_source, test_source, N)
data = dataX(test_source.copy(), N)
result8 = predict(data, freq_table, N)

In [58]:
# result_tmp = result2.copy()
# result_tmp.fillna(result1, inplace=True)
# result_tmp['next_song_id'].fillna(result_tmp['song_id_1'], inplace=True)
# final_result = result_tmp[['session_id', 'next_song_id', 'next_song_id', 'next_song_id', 'next_song_id', 'next_song_id']]
# final_result.columns = ['session_id', 'top1', 'top2', 'top3', 'top4', 'top5']
# final_result

In [59]:
# result3.fillna(result2, inplace=True)
# # result3.fillna(result1, inplace=True)
# result3['next_song_id'].fillna(result3['song_id_2'], inplace=True)

# final_result = result3[['session_id', 'next_song_id', 'next_song_id', 'next_song_id', 'next_song_id', 'next_song_id']]
# final_result.columns = ['session_id', 'top1', 'top2', 'top3', 'top4', 'top5']
# final_result

In [71]:
result8.fillna(result7, inplace=True)
result8.fillna(result6, inplace=True)
result8.fillna(result5, inplace=True)
result8.fillna(result4, inplace=True)
result8.fillna(result3, inplace=True)
result8.fillna(result2, inplace=True)
result8['next_song_id'].fillna(result8['song_id_2'], inplace=True)

final_result = result8[['session_id', 'next_song_id', 'next_song_id', 'next_song_id', 'next_song_id', 'next_song_id']]
final_result.columns = ['session_id', 'top1', 'top2', 'top3', 'top4', 'top5']
final_result

Unnamed: 0,session_id,top1,top2,top3,top4,top5
0,8,76723b980445c7c7b1350ca038348bbb,76723b980445c7c7b1350ca038348bbb,76723b980445c7c7b1350ca038348bbb,76723b980445c7c7b1350ca038348bbb,76723b980445c7c7b1350ca038348bbb
1,9,7653de936f26c9b71c728e88cdd29c1a,7653de936f26c9b71c728e88cdd29c1a,7653de936f26c9b71c728e88cdd29c1a,7653de936f26c9b71c728e88cdd29c1a,7653de936f26c9b71c728e88cdd29c1a
2,18,9cb10772a82000f45de4a950883df945,9cb10772a82000f45de4a950883df945,9cb10772a82000f45de4a950883df945,9cb10772a82000f45de4a950883df945,9cb10772a82000f45de4a950883df945
3,19,d6813a67f316942de54f967fda79abb2,d6813a67f316942de54f967fda79abb2,d6813a67f316942de54f967fda79abb2,d6813a67f316942de54f967fda79abb2,d6813a67f316942de54f967fda79abb2
4,28,20171b381e293bda502851778947ce57,20171b381e293bda502851778947ce57,20171b381e293bda502851778947ce57,20171b381e293bda502851778947ce57,20171b381e293bda502851778947ce57
...,...,...,...,...,...,...
143059,715299,e2ff3605e7ef9864e57496db028bdcff,e2ff3605e7ef9864e57496db028bdcff,e2ff3605e7ef9864e57496db028bdcff,e2ff3605e7ef9864e57496db028bdcff,e2ff3605e7ef9864e57496db028bdcff
143060,715308,1b23dec79509ed41a743ac7336998db5,1b23dec79509ed41a743ac7336998db5,1b23dec79509ed41a743ac7336998db5,1b23dec79509ed41a743ac7336998db5,1b23dec79509ed41a743ac7336998db5
143061,715309,b83d7912b15c97dea1ec15d9ef7acf6e,b83d7912b15c97dea1ec15d9ef7acf6e,b83d7912b15c97dea1ec15d9ef7acf6e,b83d7912b15c97dea1ec15d9ef7acf6e,b83d7912b15c97dea1ec15d9ef7acf6e
143062,715318,3752912c9a7012dbd8b7f1cb75cebeb9,3752912c9a7012dbd8b7f1cb75cebeb9,3752912c9a7012dbd8b7f1cb75cebeb9,3752912c9a7012dbd8b7f1cb75cebeb9,3752912c9a7012dbd8b7f1cb75cebeb9


In [72]:
# 輸出檔案
final_result.to_csv('submission.csv', index=False)

## 訓練驗證

In [187]:
N = 1
freq_table = get_freq(train_source, test_source, N)
data = data_process(train_source.copy(), train_target.copy(), N)
result1 = predict(data, freq_table, N)

In [189]:
N = 2
freq_table = get_freq(train_source, test_source, N)
data = data_process(train_source.copy(), train_target.copy(), N)
result2 = predict(data, freq_table, N)

In [190]:
N = 3
freq_table = get_freq(train_source, test_source, N)
data = data_process(train_source.copy(), train_target.copy(), N)
result3 = predict(data, freq_table, N)

In [200]:
data

Unnamed: 0,session_id,song_id_0,song_id_1,song_id_2,song_id_21,song_id_22,song_id_23,song_id_24,song_id_25
0,1,2ad3043e1a7e459ddb09c5ba27e475f8,7bb8fadfc8f2bf145f4b29a0325fe79a,824c159701c8553b0e38f0d36ddd6197,b186d853cb06ceba2bd56bcdc701b8aa,1071e128ee5f9e2a0f0be0aade025b39,0100af68f3477f6bf724664d0b303f29,22a696cede3609c20ba1bceb8de28032,d6476892c926c88c7f232fd800b75845
1,2,ae75e2846669037aa91cac098e5009c3,3030938d53f52426ac30f213cf9915a1,6309716e08a58c64871c823f44749686,cd6dade76334f76d993b913d10bb2ac7,822486e25be33911a62b3f12e549d24d,ed37d73c29e3696f65817916316bf05d,197a95a09a55390eb435e35f7dab9f9f,59844ef9a5e469cd13b1208c4f1142d1
2,3,e57c39b5735364b56ed8f743fa948697,e904a955c0350ca5b1b6bb84174ee5be,3b050c1502af19554bb9ef8efc67c00d,a17aab4730e49f8e2972bf26227b890c,868bbeb92bd33a728dcf77d0073cd44d,c4d98c14bd48ec0495239c1a65a5178c,4ea79c3619515a39423851b5ceca965e,d76c868bfb392d89b3ef222907de69fc
3,4,a288350a1332f6716cde4eedbbf8443f,6650e87fecc3e969c0d48f7f5e9b2c81,e29a8f57507d7e4a926a6bd3d51841d3,f6c0c5501b42f208e094de2414ca1167,e29a8f57507d7e4a926a6bd3d51841d3,f6c0c5501b42f208e094de2414ca1167,e29a8f57507d7e4a926a6bd3d51841d3,f6c0c5501b42f208e094de2414ca1167
4,5,aaf15c1b9400651add7925c09b609461,c10efc962e6b703053c2b44b7d48e8da,fd96fd2df80abfae763c3693932f8b26,f5606954ce38e9948ed7099251d7e339,995e60ea01e4f371dabd785e76aa5e2f,f5606954ce38e9948ed7099251d7e339,995e60ea01e4f371dabd785e76aa5e2f,b3ac505ffec601de591ce2ce2e010d30
...,...,...,...,...,...,...,...,...,...
572254,715317,78c792b3afc368829431ef50a3d871dd,0f59db4259ec62fcf00399b69786a3c0,cb6e848eff5f656c47a4541e6be72079,3c66c29e5c26efc705c261732b977950,a0c8441cbe902320c846da76329a38e6,5d043bbef0ca3cf5385432478f3d3751,3b48c6883b4d16ddc07cb64b7be38d1c,546da46c864f38bb193940a7b8704f44
572255,715320,9702af9ea15361bf6b178e073514660d,477ecc7fb10644c0d31408030cd464c6,aa7ce52650c7e4f7979e087cedb89020,b4b1c6c0e7395c50f2cf0702dd1f85d6,dd45bfa90e9186681877c15eb0166a35,c366d245501f007629813a28c9267d11,f4aabaca763955330cdf57f763d2a72c,f4aabaca763955330cdf57f763d2a72c
572256,715321,b784c2c425ed31d70f7007e949f6b177,9e65aa27a1c3060b5909dda36859f39f,fba1bc1cf12692b1dd65c98e79985e52,261bceb1e3440067b85f5c54444d91be,65961a80125282df5c95b4c0a3606cc5,a91ec5b2b51524c53ce866f621208087,ab8683d279c095064693c282b8353b98,01c48d6a895a6efed0d5afba83b87902
572257,715322,826cf1a6e383fdbd900c3865d2edb010,940ed872f8ae21d392495623c29df127,ceee2a16d7fb7531cdc633ac478fe46a,c31a95cdd5635c6edfbde9143286b9b7,aa41008b812435306a6fe831f40b0c11,d1492ba4a2fdf555510934e9740ea8c3,f1bc21ff88fe47bac896fbe591fa2f5f,d1492ba4a2fdf555510934e9740ea8c3


In [191]:
result1['next_song_id'].notna().sum() / len(result1)

0.9999515763119805

In [192]:
result2['next_song_id'].notna().sum() / len(result2)

0.4307822157449686

In [193]:
result3['next_song_id'].notna().sum() / len(result3)

0.12841038760421417

In [150]:
result3.fillna(result2, inplace=True)
result3.fillna(result1, inplace=True)
result3['next_song_id'].notna().sum() / len(result3)

0.9999353439613882

In [194]:
result3['next_song_id'].fillna(result3['song_id_2'], inplace=True)
result3['next_song_id'].notna().sum() / len(result3)

1.0

In [201]:
result3

Unnamed: 0,session_id,song_id_0,song_id_1,song_id_2,song_id_21,song_id_22,song_id_23,song_id_24,song_id_25,next_song_id,freq
0,1,2ad3043e1a7e459ddb09c5ba27e475f8,7bb8fadfc8f2bf145f4b29a0325fe79a,824c159701c8553b0e38f0d36ddd6197,b186d853cb06ceba2bd56bcdc701b8aa,1071e128ee5f9e2a0f0be0aade025b39,0100af68f3477f6bf724664d0b303f29,22a696cede3609c20ba1bceb8de28032,d6476892c926c88c7f232fd800b75845,824c159701c8553b0e38f0d36ddd6197,
1,2,ae75e2846669037aa91cac098e5009c3,3030938d53f52426ac30f213cf9915a1,6309716e08a58c64871c823f44749686,cd6dade76334f76d993b913d10bb2ac7,822486e25be33911a62b3f12e549d24d,ed37d73c29e3696f65817916316bf05d,197a95a09a55390eb435e35f7dab9f9f,59844ef9a5e469cd13b1208c4f1142d1,6309716e08a58c64871c823f44749686,
2,3,e57c39b5735364b56ed8f743fa948697,e904a955c0350ca5b1b6bb84174ee5be,3b050c1502af19554bb9ef8efc67c00d,a17aab4730e49f8e2972bf26227b890c,868bbeb92bd33a728dcf77d0073cd44d,c4d98c14bd48ec0495239c1a65a5178c,4ea79c3619515a39423851b5ceca965e,d76c868bfb392d89b3ef222907de69fc,3b050c1502af19554bb9ef8efc67c00d,
3,4,a288350a1332f6716cde4eedbbf8443f,6650e87fecc3e969c0d48f7f5e9b2c81,e29a8f57507d7e4a926a6bd3d51841d3,f6c0c5501b42f208e094de2414ca1167,e29a8f57507d7e4a926a6bd3d51841d3,f6c0c5501b42f208e094de2414ca1167,e29a8f57507d7e4a926a6bd3d51841d3,f6c0c5501b42f208e094de2414ca1167,e29a8f57507d7e4a926a6bd3d51841d3,
4,5,aaf15c1b9400651add7925c09b609461,c10efc962e6b703053c2b44b7d48e8da,fd96fd2df80abfae763c3693932f8b26,f5606954ce38e9948ed7099251d7e339,995e60ea01e4f371dabd785e76aa5e2f,f5606954ce38e9948ed7099251d7e339,995e60ea01e4f371dabd785e76aa5e2f,b3ac505ffec601de591ce2ce2e010d30,fd96fd2df80abfae763c3693932f8b26,
...,...,...,...,...,...,...,...,...,...,...,...
572254,715317,78c792b3afc368829431ef50a3d871dd,0f59db4259ec62fcf00399b69786a3c0,cb6e848eff5f656c47a4541e6be72079,3c66c29e5c26efc705c261732b977950,a0c8441cbe902320c846da76329a38e6,5d043bbef0ca3cf5385432478f3d3751,3b48c6883b4d16ddc07cb64b7be38d1c,546da46c864f38bb193940a7b8704f44,3c66c29e5c26efc705c261732b977950,1.0
572255,715320,9702af9ea15361bf6b178e073514660d,477ecc7fb10644c0d31408030cd464c6,aa7ce52650c7e4f7979e087cedb89020,b4b1c6c0e7395c50f2cf0702dd1f85d6,dd45bfa90e9186681877c15eb0166a35,c366d245501f007629813a28c9267d11,f4aabaca763955330cdf57f763d2a72c,f4aabaca763955330cdf57f763d2a72c,aa7ce52650c7e4f7979e087cedb89020,
572256,715321,b784c2c425ed31d70f7007e949f6b177,9e65aa27a1c3060b5909dda36859f39f,fba1bc1cf12692b1dd65c98e79985e52,261bceb1e3440067b85f5c54444d91be,65961a80125282df5c95b4c0a3606cc5,a91ec5b2b51524c53ce866f621208087,ab8683d279c095064693c282b8353b98,01c48d6a895a6efed0d5afba83b87902,fba1bc1cf12692b1dd65c98e79985e52,
572257,715322,826cf1a6e383fdbd900c3865d2edb010,940ed872f8ae21d392495623c29df127,ceee2a16d7fb7531cdc633ac478fe46a,c31a95cdd5635c6edfbde9143286b9b7,aa41008b812435306a6fe831f40b0c11,d1492ba4a2fdf555510934e9740ea8c3,f1bc21ff88fe47bac896fbe591fa2f5f,d1492ba4a2fdf555510934e9740ea8c3,c31a95cdd5635c6edfbde9143286b9b7,1.0


In [198]:
final_result = result3[['session_id', 'next_song_id', 'next_song_id', 'next_song_id', 'next_song_id', 'next_song_id']]
final_result.columns = ['session_id', 'top1', 'top2', 'top3', 'top4', 'top5']
final_result

Unnamed: 0,session_id,top1,top2,top3,top4,top5
0,1,824c159701c8553b0e38f0d36ddd6197,824c159701c8553b0e38f0d36ddd6197,824c159701c8553b0e38f0d36ddd6197,824c159701c8553b0e38f0d36ddd6197,824c159701c8553b0e38f0d36ddd6197
1,2,6309716e08a58c64871c823f44749686,6309716e08a58c64871c823f44749686,6309716e08a58c64871c823f44749686,6309716e08a58c64871c823f44749686,6309716e08a58c64871c823f44749686
2,3,3b050c1502af19554bb9ef8efc67c00d,3b050c1502af19554bb9ef8efc67c00d,3b050c1502af19554bb9ef8efc67c00d,3b050c1502af19554bb9ef8efc67c00d,3b050c1502af19554bb9ef8efc67c00d
3,4,e29a8f57507d7e4a926a6bd3d51841d3,e29a8f57507d7e4a926a6bd3d51841d3,e29a8f57507d7e4a926a6bd3d51841d3,e29a8f57507d7e4a926a6bd3d51841d3,e29a8f57507d7e4a926a6bd3d51841d3
4,5,fd96fd2df80abfae763c3693932f8b26,fd96fd2df80abfae763c3693932f8b26,fd96fd2df80abfae763c3693932f8b26,fd96fd2df80abfae763c3693932f8b26,fd96fd2df80abfae763c3693932f8b26
...,...,...,...,...,...,...
572254,715317,3c66c29e5c26efc705c261732b977950,3c66c29e5c26efc705c261732b977950,3c66c29e5c26efc705c261732b977950,3c66c29e5c26efc705c261732b977950,3c66c29e5c26efc705c261732b977950
572255,715320,aa7ce52650c7e4f7979e087cedb89020,aa7ce52650c7e4f7979e087cedb89020,aa7ce52650c7e4f7979e087cedb89020,aa7ce52650c7e4f7979e087cedb89020,aa7ce52650c7e4f7979e087cedb89020
572256,715321,fba1bc1cf12692b1dd65c98e79985e52,fba1bc1cf12692b1dd65c98e79985e52,fba1bc1cf12692b1dd65c98e79985e52,fba1bc1cf12692b1dd65c98e79985e52,fba1bc1cf12692b1dd65c98e79985e52
572257,715322,c31a95cdd5635c6edfbde9143286b9b7,c31a95cdd5635c6edfbde9143286b9b7,c31a95cdd5635c6edfbde9143286b9b7,c31a95cdd5635c6edfbde9143286b9b7,c31a95cdd5635c6edfbde9143286b9b7


In [202]:
w = [1.0, 0.63, 0.5, 0.43, 0.38]
s = 0
for i in range(5):
    m = (result3[f'song_id_2{i+1}'] == result3['next_song_id']).mean()
    print(i, m)
    s += m*w[i]
print(s*0.8)

0 0.16386111882906165
1 0.06854763315212167
2 0.04937274905243954
3 0.04990048212435278
4 0.04471052443037156
0.21614376707050476


In [203]:
# correct way to calculate score
import numpy as np
A = result3[[f'song_id_2{i+1}' for i in range(5)]].values == result3[['next_song_id']].values
s = (A * np.array(w).reshape(1, -1)).max(axis=1).mean()
s

0.1872690687258741

In [204]:
a1 = test_source['session_id'].nunique()/meta_song.shape[0] * 0.2
a1

0.027760227881309232

In [205]:
print(s*0.8+a1*4 + result3['next_song_id'].nunique()/meta_song.shape[0] * 0.2)

0.2894166179220447


In [206]:
w = np.array([1.0, 0.63, 0.5, 0.43, 0.38])
A = result3[[f'song_id_2{i+1}' for i in range(5)]].values == result3[['next_song_id']].values
s = (A * w.reshape(1, -1)).max(axis=1).mean()
s

0.1872690687258741

In [207]:
print(s*0.8+a1*4 + result3['next_song_id'].nunique()/meta_song.shape[0] * 0.2)

0.2894166179220447
