In [1]:
import pandas as pd
import numpy as np

train = pd.read_csv("../data/train.csv")

In [2]:
interactions = (
    train.groupby(["msno", "song_id"])
         .size()
         .reset_index(name="play_count")
)

In [3]:
user_codes = interactions["msno"].astype("category").cat.codes
item_codes = interactions["song_id"].astype("category").cat.codes

user_id_map = dict(enumerate(interactions["msno"].astype("category").cat.categories))
item_id_map = dict(enumerate(interactions["song_id"].astype("category").cat.categories))


In [4]:
user_mapping = dict(enumerate(interactions["msno"].astype("category").cat.categories))
item_mapping = dict(enumerate(interactions["song_id"].astype("category").cat.categories))

In [5]:
import scipy
from scipy.sparse import csr_matrix

In [6]:
from scipy.sparse import csr_matrix

user_item_matrix = csr_matrix(
    (interactions["play_count"],
     (user_codes, item_codes))
)

In [7]:
user_item_matrix.shape

(30755, 359966)

In [8]:
from implicit.als import AlternatingLeastSquares

als_model = AlternatingLeastSquares(
    factors=64,
    regularization=0.1,
    iterations=20,
    random_state=42
)

als_model.fit(user_item_matrix)


  from .autonotebook import tqdm as notebook_tqdm
100%|██████████| 20/20 [00:45<00:00,  2.30s/it]


In [9]:
user_id = 0
user_items = user_item_matrix[user_id]

ids, scores = als_model.recommend(
    userid=user_id,
    user_items=user_items,
    N=10
)

ids, scores


(array([ 31486, 136455, 169515, 342271,  75292, 230572, 149281, 215054,
        225460, 243500], dtype=int32),
 array([0.72222495, 0.67358327, 0.61404586, 0.6048321 , 0.59560156,
        0.57939214, 0.56619847, 0.5528924 , 0.54900277, 0.52264076],
       dtype=float32))

In [10]:
print(user_item_matrix.shape)
print(als_model.item_factors.shape)


(30755, 359966)
(359966, 64)


In [11]:
import os
import joblib
from scipy.sparse import save_npz

In [12]:
MODEL_DIR = "../model"
os.makedirs(MODEL_DIR, exist_ok=True)

In [13]:
joblib.dump(als_model, f"{MODEL_DIR}/als_model.pkl")

['../model/als_model.pkl']

In [14]:
save_npz(f"{MODEL_DIR}/user_item_matrix.npz", user_item_matrix)

In [16]:
joblib.dump(user_id_mapping, f"{MODEL_DIR}/user_id_mapping.pkl")
joblib.dump(item_id_mapping, f"{MODEL_DIR}/item_id_mapping.pkl")
joblib.dump(id_to_song, f"{MODEL_DIR}/id_to_song.pkl")

['../model/id_to_song.pkl']

In [15]:
# user mapping: original msno -> internal id
user_id_mapping = dict(zip(interactions["msno"].unique(),
                           range(interactions["msno"].nunique())))

# item mapping: original song_id -> internal id
item_id_mapping = dict(zip(interactions["song_id"].unique(),
                           range(interactions["song_id"].nunique())))

# reverse mapping: internal item id -> song_id
id_to_song = {v: k for k, v in item_id_mapping.items()}

In [17]:
def recommend_songs(
    original_user_id,
    model,
    user_item_matrix,
    user_id_mapping,
    id_to_song,
    N=10
):
    if original_user_id not in user_id_mapping:
        raise ValueError("User not found in training data")

    internal_user_id = user_id_mapping[original_user_id]

    user_items = user_item_matrix[internal_user_id]

    item_ids, scores = model.recommend(
        userid=internal_user_id,
        user_items=user_items,
        N=N
    )

    recommended_songs = [id_to_song[i] for i in item_ids]

    return recommended_songs, scores


In [18]:
sample_user = interactions["msno"].iloc[0]

songs, scores = recommend_songs(
    original_user_id=sample_user,
    model=als_model,
    user_item_matrix=user_item_matrix,
    user_id_mapping=user_id_mapping,
    id_to_song=id_to_song,
    N=10
)

songs


['2PT4EDoGGKA4m5ad51BBr9OzWfI5FzKVtewYZCG/tyE=',
 'TyFQMxncfLEa4UklBnb5UGgx6Wx+5aAUSDVqO0dLSiI=',
 'xyPPK9fuIlwJr3fGlMbOIxWHyqHEJWB2f0H7fSclahM=',
 'ZmekVY4jkdaVVnZmIUDtKMh06JdOVqkea+pENjO8mp8=',
 'nSh2h2PHzjZXCVoxpwhPAEHsnHc3Tg4jxilhK8JreNw=',
 'QlYwuyGI4+foDYa6f/zEu0BOhfqNrEEhR2rlKEYmYlc=',
 '+mpSSL8w2FDjnXmY9g2tLxpvz0RnSSQU9qQfRQGnZRk=',
 'ok2Z+hc1x+Ee5nIdDCWq7mkZwEGe6l0StNWXriDV5JM=',
 'ATRH5zyug/64lQjRdV1BPmhq2BMg+QYNjt5aCAqAZIs=',
 'Nv/pvh6PInraIi9lE33DBFdmqrJrvl69NhuxExxKpXs=']