In [1]:
import copy
import random
import catboost
import numpy as np
from tqdm.notebook import tqdm

In [2]:
np.argsort([3, 2, 1, 5])[::-1]

array([3, 0, 1, 2])

### Dataset preprocessing

In [9]:
def preprocess_dataset(path="data"):
    print("Preprocessing songs.csv ...")
    with open(path + "/songs.csv") as f:
        first_readed = False
        song2id = {}
        artist_id2name = {}
        artist2id = {}
        composer2id = {}
        lyrics2id = {}
        song_id2data = {}
        genre2id = {}
        genre2ctn =[0 for _ in range(191)]
        for i, l in enumerate(f):
            if not first_readed:
                first_readed = True
                continue
            parts = l[:-1].split(',')
            song_id, song_length, genre_ids, artist_name, _, _, language = parts[:7]
            if song_id not in song2id:
                song2id[song_id] = len(song2id)
                
            if genre_ids == "":
                genre_ids = []
            else:
                genre_ids = list(map(int, genre_ids.split("|")))
            for g in genre_ids:
                if g not in genre2id:
                    genre2id[g] = len(genre2id)
                genre2ctn[genre2id[g]] += 1
            genre_ids = [genre2id[g] for g in genre_ids]
            if not artist_name in artist2id:
                artist_id2name[len(artist2id)] = artist_name
                artist2id[artist_name] = len(artist2id) 
                
            song_id2data[song2id[song_id]] = [int(song_length), artist2id[artist_name], int(float(language)), genre_ids]
        retain_genres = (np.argsort(genre2ctn)[::-1])[:10]
        retain_genres = dict([(g, i) for i, g in enumerate(retain_genres)])
        
        for k in song_id2data:
            genre_ids = song_id2data[k][-1]
            genre_ids = [retain_genres[g] for g in genre_ids if g in retain_genres]
            song_id2data[k][-1] = genre_ids
        
    print("Total genres: ", len(genre2id))
    print("Preprocessing song_extra_info.csv ...")
    with open(path + "/song_extra_info.csv") as f:
        first_readed = False
        song_id2name = {}
        for l in f:
            if not first_readed:
                first_readed = True
                continue
            song_id, name, isrc = l[:-1].split(',')
            
            if song_id not in song2id:
                continue
            song_id2name[song2id[song_id]] = name
            if song2id[song_id] in song_id2data:
                song_id2data[song2id[song_id]].append(isrc)
    print("Preprocessing members.csv ...")
    with open(path + "/members.csv", "r") as f:
        first_readed = False
        user2id = {}
        member_id2data = {}
        gender2id = {"male": -1, "": 0, "female": 1}
        for l in f:
            if not first_readed:
                first_readed = True
                continue
            msno, city, bd, gender, registered_via, registration_init_time, expiration_date = l[:-1].split(',')
            if msno not in user2id:
                user2id[msno] = len(user2id)
            member_id2data[user2id[msno]] = (
                int(city), int(bd), gender2id[gender], int(registered_via), 
                int(registration_init_time), int(expiration_date)
            )
    print("Preprocessing train.csv ...")
    with open(path + "/train.csv", "r") as f:
        first_readed = False
        user2songs2ctn = {}
        user2artist2ctn = {}
        sst2id = {}
        ssn2id = {}
        st2id = {}
        train = []
        for l in f:
            if not first_readed:
                first_readed = True
                continue
            msno, song_id, source_system_tab, source_screen_name, source_type, target = l[:-1].split(',')
            if msno not in user2id:
                continue
            if song_id not in song2id:
                continue
            
            if user2id[msno] not in user2songs2ctn:
                user2songs2ctn[user2id[msno]] = {}
                user2artist2ctn[user2id[msno]] = {}
            if song2id[song_id] not in user2songs2ctn[user2id[msno]]:
                user2songs2ctn[user2id[msno]][song2id[song_id]] = 0
            artist_id = song_id2data[song2id[song_id]][1]
            if artist_id not in user2artist2ctn[user2id[msno]]:
                user2artist2ctn[user2id[msno]][artist_id] = 0
                
            if source_system_tab not in sst2id:
                sst2id[source_system_tab] = len(sst2id)
            if source_type not in st2id:
                st2id[source_type] = len(st2id)
            train.append((
                user2id[msno], song2id[song_id], user2songs2ctn[user2id[msno]][song2id[song_id]], user2artist2ctn[user2id[msno]][artist_id], 
                sst2id[source_system_tab], st2id[source_type], int(target)
            ))

            if int(target) > 0:
                user2songs2ctn[user2id[msno]][song2id[song_id]] += 1
                user2artist2ctn[user2id[msno]][artist_id] += 1
    return train, member_id2data, song_id2data, song_id2name, artist_id2name

In [10]:
train, members, songs, song_id2name, artist_id2name = preprocess_dataset()

Preprocessing songs.csv ...
Total genres:  191
Preprocessing song_extra_info.csv ...
Preprocessing members.csv ...
Preprocessing train.csv ...


In [11]:
len(train)

7377416

### Feature extraction

In [20]:
def isrc_to_year(isrc): # https://www.kaggle.com/kamilkk/i-have-to-say-this
    if isrc != "":
        if int(isrc[5:7]) > 17:
            return 1900 + int(isrc[5:7])
        else:
            return 2000 + int(isrc[5:7])
    else:
        return np.nan

def one_hot_genres(genres):
    result = [0 for _ in range(10)]
    for g in genres:
        result[g] = 1
    return result
    
def get_song_features(song_id):
    song_length, artist_id, language, genre_ids, isrc = songs[song_id]
    year = isrc_to_year(isrc)
    genre = one_hot_genres(genre_ids)
    return [year, language, song_length] + genre    

def get_user_features(user_id):
    return list(members[user_id])

def extract_features(dataset):
    X = []
    y = []
    for user_id, song_id, u2s2c, u2a2c, sst_id, st_id, tgt in dataset:
        user_features = get_user_features(user_id)
        song_features = get_song_features(song_id)
        song_year = song_features[0]
        user_register, user_expired = user_features[-2]//10000, user_features[-1]//10000
        register_song_diff = song_year - user_register
        expired_song_diff = song_year - user_expired
        X.append(user_features + song_features + [u2s2c, u2a2c, register_song_diff, expired_song_diff])
        y.append(tgt)
    return X, y

### Task 1

In [24]:
def cross_validation(train_function, parts=5):
    train_shuffled = copy.deepcopy(train)
    random.shuffle(train_shuffled)
    X, y = extract_features(train_shuffled)
    avg_stat = {}
    for i in tqdm(range(parts)):
        X_train = X[:i * len(train_shuffled) // parts] + X[(i+1) * len(train_shuffled) // parts:]
        y_train = y[:i * len(train_shuffled) // parts] + y[(i+1) * len(train_shuffled) // parts:]
        X_test = X[i * len(train_shuffled) // parts:(i+1) * len(train_shuffled) // parts]
        y_test = y[i * len(train_shuffled) // parts:(i+1) * len(train_shuffled) // parts]
        stat = train_function((X_train, y_train), (X_test, y_test))
        for key in stat:
            avg_stat[key] = avg_stat.get(key, 0) + stat[key] / parts
    for key in avg_stat:
        print(f"{key}: {avg_stat[key]}")
    print()
    return avg_stat

In [25]:
def build_train_func(depth=16, min_data_in_leaf=1, l2_leaf_reg=3.0):
    def train_catboost(train, test, verbose=False):
        X_train, y_train = train
        X_test, y_test = test
        model = catboost.CatBoostClassifier(iterations=100, verbose=verbose, 
                                            depth=depth, min_data_in_leaf=min_data_in_leaf)
        pool = catboost.Pool(X_train, y_train)
        model.fit(pool)
        pool = catboost.Pool(X_test, y_test)
        metrics = model.eval_metrics(pool, ['Logloss', 'AUC'])
        for k in metrics:
            metrics[k] = metrics[k][-1]
        return metrics
    return train_catboost

In [26]:
cross_validation(build_train_func())

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=5.0), HTML(value='')))


Logloss: 0.6154619547768531
AUC: 0.7192393070791001



{'Logloss': 0.6154619547768531, 'AUC': 0.7192393070791001}

#### ROC AUC ~ 0.719 which is (maybe) good enough

In [27]:
cross_validation(build_train_func(16, 100, 3.0))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=5.0), HTML(value='')))


Logloss: 0.6162607792109148
AUC: 0.7181217113168242



{'Logloss': 0.6162607792109148, 'AUC': 0.7181217113168242}

In [28]:
cross_validation(build_train_func(8, 1, 3.0))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=5.0), HTML(value='')))


Logloss: 0.6494050988788231
AUC: 0.665325081489575



{'Logloss': 0.6494050988788231, 'AUC': 0.665325081489575}

In [29]:
cross_validation(build_train_func(8, 100, 3.0))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=5.0), HTML(value='')))


Logloss: 0.6496599736865354
AUC: 0.6647784457873176



{'Logloss': 0.6496599736865354, 'AUC': 0.6647784457873176}

In [30]:
cross_validation(build_train_func(4, 100, 3.0))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=5.0), HTML(value='')))


Logloss: 0.6608828142896415
AUC: 0.64160666475336



{'Logloss': 0.6608828142896415, 'AUC': 0.64160666475336}

In [31]:
cross_validation(build_train_func(16, 100, 10.0))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=5.0), HTML(value='')))


Logloss: 0.6156102557384175
AUC: 0.719118310403183



{'Logloss': 0.6156102557384175, 'AUC': 0.719118310403183}

In [32]:
cross_validation(build_train_func(16, 100, 1.0))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=5.0), HTML(value='')))


Logloss: 0.615380495326499
AUC: 0.7193730667178964



{'Logloss': 0.615380495326499, 'AUC': 0.7193730667178964}

In [33]:
cross_validation(build_train_func(16, 100, 0.1))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=5.0), HTML(value='')))


Logloss: 0.615474907009779
AUC: 0.7191859170531917



{'Logloss': 0.615474907009779, 'AUC': 0.7191859170531917}