# Full pipeline применение

In [1]:
import pickle
from top_recommender import TopRecommender
from h3_index import H3Index
import pandas as pd
import numpy as np

In [2]:
# and later you can load it

with open('lightfm_moscow.pkl', 'rb') as f:
    lightfm_moscow = pickle.load(f)
with open('interactions_moscow.pkl', 'rb') as f:
    interactions_moscow = pickle.load(f)
with open('user_features_sparse_moscow.pkl', 'rb') as f:
    user_features_sparse_moscow = pickle.load(f)
    
# and later you can load it
with open('lightfm_piter.pkl', 'rb') as f:
    lightfm_piter = pickle.load(f)
with open('interactions_piter.pkl', 'rb') as f:
    interactions_piter = pickle.load(f)
with open('user_features_sparse_piter.pkl', 'rb') as f:
    user_features_sparse_piter = pickle.load(f)
    
# and later you can load it
with open('lightfm_other.pkl', 'rb') as f:
    lightfm_other = pickle.load(f)
with open('interactions_other.pkl', 'rb') as f:
    interactions_other = pickle.load(f)
with open('user_features_sparse_other.pkl', 'rb') as f:
    user_features_sparse_other = pickle.load(f)
    
    
# with open('top_rec.pkl', 'rb') as f:
#     top_rec = pickle.load(f)
with open('top_rec_moscow.pkl', 'rb') as f:
    top_rec_moscow = pickle.load(f)
with open('top_rec_piter.pkl', 'rb') as f:
    top_rec_piter = pickle.load(f)
with open('top_rec_other.pkl', 'rb') as f:
    top_rec_other = pickle.load(f)

In [3]:
#!pip install fastparquet
h3index = H3Index('../data/raw/h3_to_chains.pkl')

In [4]:
val_df = pd.read_pickle('../data/raw/test_VALID.pkl')
val_df = val_df.rename(columns={"customer_id": "user_id"})
val_df.user_id = val_df.user_id.astype(int)

print("Initial validation dataset size:", len(val_df))
# val_df = val_df[val_df["h3"].isin(h3index.valid)]
# print("Filter h3 indices that not in h3_to_chain dict", len(val_df))

val_df = pd.pivot_table(val_df,
                        values=['chain_id'],
                        index=['user_id', 'h3', 'city_id'],
                        aggfunc={'chain_id': set})
val_df = val_df.reset_index()
val_df.head()

Initial validation dataset size: 2300001


Unnamed: 0,user_id,h3,city_id,chain_id
0,0,890b0638003ffff,49,{34646}
1,0,890b0638007ffff,49,{34646}
2,0,890b063800fffff,49,{34646}
3,0,890b0638023ffff,49,{34646}
4,0,890b0638027ffff,49,{34646}


In [5]:
def predict_in_city(lightfm, top_rec, user_id, h3, interactions, user_features_sparse, top_k=10):
    if h3 in h3index.valid:
        valid_chains = h3index.h3_to_chains[h3]
        valid_chain_index = [v for k, v in interactions.chain_to_index.items() if k in valid_chains]
        if user_id in interactions.user_to_index and len(valid_chain_index) > 9:
            user_index = interactions.user_to_index[user_id]
            pred = lightfm.predict(user_index, valid_chain_index, user_features=user_features_sparse)
            top_chain_index = [x for _, x in sorted(zip(pred, valid_chain_index), reverse=True)][:top_k]
            top = [interactions.index_to_chain[k] for k in top_chain_index]
        else:
            pred = top_rec.predict(valid_chains)
            top = [x for _, x in sorted(zip(pred, valid_chains), reverse=True)][:top_k]
    else:
        top = [
            k for k, v in sorted(
                top_rec.chains_to_cnt.items(), key=lambda item: item[1], reverse=True
            )[:30]
        ]
    return top

def old_items_in_city(user_id, interactions):
    return set(interactions.interaction_df[interactions.interaction_df['user_id'] == user_id]['chain_id'].unique())

In [6]:
def predict(user_id, h3, city_id, top_k=10):
    if city_id == 1:
        top = predict_in_city(
            lightfm_moscow,
            top_rec_moscow,
            user_id,
            h3,
            interactions_moscow,
            user_features_sparse_moscow,
            top_k=top_k
        )
    elif city_id == 2:
        top = predict_in_city(
            lightfm_piter,
            top_rec_piter,
            user_id,
            h3,
            interactions_piter,
            user_features_sparse_piter,
            top_k=top_k
        )
    else:
        top = predict_in_city(
            lightfm_other,
            top_rec_other,
            user_id,
            h3,
            interactions_other,
            user_features_sparse_other,
            top_k=top_k
        ) 
    return top
        
def old_items(user_id, city_id):
    if city_id ==1:
        old = old_items_in_city(user_id, interactions_moscow)
    elif city_id == 2:
        old = old_items_in_city(user_id, interactions_piter)
    else:
        old = old_items_in_city(user_id, interactions_other)
    
    return old

In [7]:
def metric(y_true, y_pred, y_old, at1=10, at2=30, average=True):
    """
    new_prec@10 + new_prec@30 + 1/2 *(prec_@10 + prec@30)
    """
    scores_new = []
    scores_all = []
    scores_total = []
    for t, p, o in zip(y_true, y_pred, y_old):
        t = list(t)
        p = list(p)
        o = o if isinstance(o, (set, list)) else []
        
        prec1 = len(set(t[:at1]) & set(p[:at1])) / at1
        prec2 = len(set(t[:at2]) & set(p[:at2])) / at2
        new_prec1 = len((set(p[:at1]) - set(o)) & set(t[:at1])) / at1
        new_prec2 = len((set(p[:at2]) - set(o)) & set(t[:at2])) / at2

        scores_total.append(new_prec1 + new_prec2 + 0.5 * (prec1 + prec2))
        scores_new.append(new_prec1 + new_prec2)
        scores_all.append(prec1 + prec2)

    return (np.mean(scores_total) if average else scores_total,
            np.mean(scores_new) if average else scores_new,
            np.mean(scores_all) if average else scores_all)

In [8]:
def compute_score(val_df, frac=0.001, top_k=30):
    if frac:
        val = val_df.sample(frac=frac, random_state=42)
    else:
        val = val_df
    val['pred_chains'] = val.apply(lambda x: predict(x.user_id, x.h3, x.city_id, top_k), axis=1)
    val['old_chains'] = val.apply(lambda x: old_items(x.user_id, x.city_id), axis=1)
    scores = metric(val['chain_id'], val['pred_chains'], val['old_chains'])
    print('total, new, all = ', scores)
    print()

In [9]:
%%time
compute_score(val_df, frac=0.01)

total, new, all =  (0.13213649562013519, 0.0875704560427494, 0.08913207915477148)

CPU times: user 16.4 s, sys: 2.95 ms, total: 16.4 s
Wall time: 16.4 s


Лучший результат на lightfm срезе 100000  
total, new, all = (0.13231475904351955, 0.08761624234232909, 0.08939703340238092)  

total, new, all =  (0.13308236619166544, 0.08807600019587679, 0.09001273199157729)

Результат на данных без отбрасывания h3 один top_rec  
total, new, all =  (0.13208891491594077, 0.0875387355732865, 0.08910035868530852)

Результат на данных без отбрасывания h3, для каждого региона top_rec  
total, new, all =  (0.13212551545762877, 0.08756313593441181, 0.08912475904643388)

In [10]:
lb_df = pd.read_pickle('../data/raw/test_LB_del_chains.pkl')

In [11]:
lb_df = pd.read_pickle('../data/raw/test_LB_del_chains.pkl')

(1347676, 5)

In [12]:
lb_df.drop_duplicates(subset=["customer_id", "h3"]).shape

(1347581, 5)

In [13]:
lb_df.shape

(2252350, 5)

In [14]:
lb_df[lb_df["h3"].isin(h3index.valid)].shape

(2246224, 5)

In [15]:
print("Initial size", len(lb_df))
lb_df = lb_df.drop_duplicates(subset=["customer_id", "h3"])
print("Size after deduplication", len(lb_df))
lb_df = lb_df.rename(columns={"customer_id": "user_id"})
lb_df.user_id = lb_df.user_id.astype(int)

# lb_df = lb_df[lb_df["h3"].isin(h3index.valid)]
# print("Filter h3 indices that not in h3_to_chain dict", len(lb_df))

lb_df = lb_df[["user_id", "city_id", "h3"]]
lb_df.head()

Initial size 2252350
Size after deduplication 1347581


Unnamed: 0,user_id,city_id,h3
2300000,1013971,1,8911aa78dabffff
2300001,5490865,45,89119631a73ffff
2300002,61145453,1,8911aa6aea7ffff
2300003,27620522,1,8911aa79507ffff
2300004,583734,1,8911aa09b23ffff


In [16]:
lb_df['pred_chains'] = lb_df.apply(lambda x: predict(x.user_id, x.h3, x.city_id, 30), axis=1)

In [13]:
lb_df.pred_chains.apply(len).value_counts(normalize=True)

30    0.985017
25    0.001146
27    0.001088
26    0.000867
23    0.000842
18    0.000726
29    0.000715
20    0.000706
16    0.000675
28    0.000659
17    0.000657
22    0.000605
12    0.000600
19    0.000589
24    0.000564
14    0.000562
21    0.000518
15    0.000490
13    0.000375
10    0.000344
11    0.000335
9     0.000316
8     0.000290
6     0.000237
3     0.000229
1     0.000214
2     0.000171
7     0.000168
5     0.000153
4     0.000140
Name: pred_chains, dtype: float64

In [18]:
with open('LB_pred_dict.pkl', 'wb') as f:
    pickle.dump(lb_df.set_index(["user_id", "h3"])["pred_chains"].to_dict(), f)
    
lb_df.to_pickle("LB_pred_pandas.pkl")

In [5]:
lb_df = pd.read_pickle("LB_pred_pandas.pkl")

In [6]:
lb_df.head()

Unnamed: 0,user_id,city_id,h3,pred_chains
2300000,1013971,1,8911aa78dabffff,"[28720, 48274, 15275, 29454, 32049, 30112, 324..."
2300001,5490865,45,89119631a73ffff,"[31185, 6836, 32809, 31342, 26606, 34383, 3622..."
2300002,61145453,1,8911aa6aea7ffff,"[28720, 48274, 15275, 32049, 30112, 32449, 136..."
2300003,27620522,1,8911aa79507ffff,"[28720, 48274, 15275, 32049, 32449, 45822, 294..."
2300004,583734,1,8911aa09b23ffff,"[28720, 48274, 15275, 29454, 32049, 30112, 458..."


In [12]:
with open('LB_pred_dict_v2.pkl', 'wb') as f:
    pickle.dump(
        lb_df
        .groupby("user_id", sort=False)
        .apply(lambda x: dict(zip(x["h3"], x["pred_chains"])))
        .to_dict(),
        f
    )

In [13]:
dictt = (
    lb_df
    .groupby("user_id", sort=False)
    .apply(lambda x: dict(zip(x["h3"], x["pred_chains"])))
    .to_dict()
)

In [17]:
dictt

{1013971: {'8911aa78dabffff': [28720,
   48274,
   15275,
   29454,
   32049,
   30112,
   32449,
   13698,
   31776,
   1929,
   23376,
   31698,
   777,
   27490,
   16277,
   26878,
   1994,
   30513,
   28795,
   30499,
   649,
   1551,
   49344,
   26595,
   806,
   36198,
   2305,
   16302,
   36316,
   13],
  '8911aa6900bffff': [28720,
   30112,
   13698,
   31776,
   1929,
   23376,
   777,
   25352,
   26595,
   806,
   31584,
   26395,
   77,
   31106,
   376,
   155,
   29259,
   828,
   4174,
   23443,
   35180,
   67877,
   4364,
   30245,
   3490,
   11519,
   4183,
   60282,
   13657,
   26725]},
 5490865: {'89119631a73ffff': [31185,
   6836,
   32809,
   31342,
   26606,
   34383,
   36220,
   53399,
   56264,
   28536,
   30913,
   26575,
   69863,
   33302,
   26829,
   29371,
   38732,
   30604,
   29030,
   67298,
   36632,
   42487,
   34453,
   27789,
   65463,
   35787,
   50905,
   65203,
   41594,
   27960]},
 61145453: {'8911aa6aea7ffff': [28720,
   48274,
   