# LightFM

In [1]:
from lightfm import LightFM

In [2]:
import pandas as pd
import numpy as np
from tqdm import tqdm
from interaction_table import InteractionTable
from h3_index import H3Index

In [3]:
from process_data import preprocess_orders_and_clicks, additional_filtration_orders_and_clicks
from user_features import generate_user_features

In [4]:
orders = pd.read_parquet("../data/orders_filtered.parquet")
user_features = pd.read_parquet("../data/user_features.parquet")

In [5]:
interactions = InteractionTable(orders, None, alpha=0, test_slice=None)

Orders weighter: use user avg orders per chain as weight
            user_id      chain_id        weight
count  3.106486e+06  3.106486e+06  3.106486e+06
mean   3.666636e+07  3.212015e+04  1.755490e+00
std    2.148159e+07  1.517362e+04  8.203714e+01
min    0.000000e+00  9.000000e+00  1.000000e+00
25%    1.143635e+07  2.714700e+04  1.000000e+00
50%    3.991074e+07  3.007500e+04  1.000000e+00
75%    5.175972e+07  4.451900e+04  2.000000e+00
max    7.213893e+07  7.332400e+04  1.444470e+05
Orders df weighted: size=3106486, uniq_users=1394062, uniq_chains=7792


In [6]:
import scipy

In [7]:
# возьмем все фичи и сравним со средним
user_features_sparse = scipy.sparse.csr_matrix(
    (
        user_features.loc[interactions.user_to_index.keys()] 
        - user_features.loc[interactions.user_to_index.keys()].mean()
        > 0
    ).astype(int)
)

In [8]:
user_features_sparse

<1394062x24 sparse matrix of type '<class 'numpy.int64'>'
	with 9180625 stored elements in Compressed Sparse Row format>

In [9]:
#!pip install fastparquet
h3index = H3Index('../data/raw/h3_to_chains.pkl')

In [10]:
val_df = pd.read_pickle('../data/raw/test_VALID.pkl')
val_df = val_df[['customer_id', 'h3', 'chain_id']]
val_df = val_df.rename(columns={"customer_id": "user_id"})
val_df.user_id = val_df.user_id.astype(int)
print("Initial validation dataset size:", len(val_df))
val_df = val_df[val_df["h3"].isin(h3index.valid)]
print("Filter h3 indices that not in h3_to_chain dict", len(val_df))
val_df = val_df[val_df["user_id"].isin(interactions.user_to_index)]
print("Filter users", len(val_df))
val_df = val_df[val_df["chain_id"].isin(interactions.chain_to_index)]
print("Filter chains", len(val_df))
val_df = pd.pivot_table(val_df,
                        values=['chain_id'],
                        index=['user_id', 'h3'],
                        aggfunc={'chain_id': set})
val_df = val_df.reset_index()
val_df.head()

Initial validation dataset size: 2300001
Filter h3 indices that not in h3_to_chain dict 2293762
Filter users 483567
Filter chains 341552


Unnamed: 0,user_id,h3,chain_id
0,0,89118108b43ffff,{28720}
1,0,89118134503ffff,{28720}
2,0,89118134513ffff,{28720}
3,0,89118134517ffff,{28720}
4,0,8911813456bffff,{28720}


### Если h3 пользователя неизвестен, то можно брать следующий в иерархии h3 (более крупный)

In [13]:
def predict(model, user_id, h3, top_k=10):
    user_index = interactions.user_to_index[user_id]
    valid_chains = h3index.h3_to_chains[h3]
    valid_chain_index = [v for k, v in interactions.chain_to_index.items() if k in valid_chains]
    pred = model.predict(user_index, valid_chain_index, user_features=user_features_sparse, num_threads=4)
    top_chain_index = [x for _, x in sorted(zip(pred, valid_chain_index), reverse=True)][:top_k]
    top = [interactions.index_to_chain[k] for k in top_chain_index]
    return top

def old_items(user_id):
    return set(interactions.interaction_df[interactions.interaction_df['user_id'] == user_id]['chain_id'].unique())

In [14]:
def metric(y_true, y_pred, y_old, at1=10, at2=30, average=True):
    """
    new_prec@10 + new_prec@30 + 1/2 *(prec_@10 + prec@30)
    """
    scores_new = []
    scores_all = []
    scores_total = []
    for t, p, o in zip(y_true, y_pred, y_old):
        t = list(t)
        p = list(p)
        o = o if isinstance(o, (set, list)) else []
        
        prec1 = len(set(t[:at1]) & set(p[:at1])) / at1
        prec2 = len(set(t[:at2]) & set(p[:at2])) / at2
        new_prec1 = len((set(p[:at1]) - set(o)) & set(t[:at1])) / at1
        new_prec2 = len((set(p[:at2]) - set(o)) & set(t[:at2])) / at2

        scores_total.append(new_prec1 + new_prec2 + 0.5 * (prec1 + prec2))
        scores_new.append(new_prec1 + new_prec2)
        scores_all.append(prec1 + prec2)

    return (np.mean(scores_total) if average else scores_total,
            np.mean(scores_new) if average else scores_new,
            np.mean(scores_all) if average else scores_all)

In [15]:
# !pip install implicit
import implicit

def hyper_params(val_df, epochs=60, top_k=30):
    #print('factors: ', factors, ', thr: ', thr, ', top_k: ', top_k, ', filter_liked: ', filter_liked)
    model = LightFM(loss='warp', user_alpha=0.1)
    model.fit(
        interactions.sparse_interaction_matrix.T, 
        user_features=user_features_sparse, 
        epochs=epochs, num_threads=4
    )
    val = val_df
    val['pred_chains'] = val.apply(lambda x: predict(model, x.user_id, x.h3, top_k), axis=1)
    val['old_chains'] = val.apply(lambda x: old_items(x.user_id), axis=1)
    scores = metric(val['chain_id'], val['pred_chains'], val['old_chains'])
    print('total, new, all = ', scores)
    print()

In [16]:
hyper_params(val_df)

total, new, all =  (0.08375840480289043, 0.03254397569949924, 0.10242885820678242)



In [22]:
scores = metric(val_df['chain_id'], val_df['chain_id'], val_df['old_chains'])
print('total, new, all = ', scores)

total, new, all =  (0.16267924153184724, 0.07506150749273093, 0.17523546807823256)


LightFM без фичей
epochs=60, top_k=30

total, new, all =  (0.08583693873624455, 0.03078052077852865, 0.11011283591543182)

LightFM c несколькими фичами

total, new, all =  (0.0822435945227267, 0.032637490991326824, 0.09921220706279976)

LightFM с сравнениями со средними  
user_alpha=0.1

total, new, all =  (0.08338800666020525, 0.03301026367454459, 0.10075548597132136)

In [None]:
for factors in [30, 40, 50, 60, 70]:
    for thr in [0.7, 0.75, 0.8, 0.85, 0.9]:
        for top_k in [5, 10, 20, 30]:
            for filter_liked in [True, False]:
                hyper_params(val_df, factors, thr, top_k, filter_liked) 