In [2]:
import torch
import pickle
import numpy as np
import pandas as pd
from tqdm import tqdm
from torch.nn.functional import normalize

In [None]:
# loading custom trained bert model embeddings
product_embeddings = pickle.load(open('../upsell/checkpoint/custom_trained_item_embedding.pkl', 'rb'))

embs = normalize(torch.tensor(product_embeddings[1]))
item_id = product_embeddings[2]

item_emb = {item_id[idx]: embs[idx].clone() for idx in range(len(item_id))}

In [5]:
data = pd.read_csv('../../data/transaction_data_sample.csv').dropna(subset=['article_id']).copy()
data = data[data['article_id'].isin(product_embeddings[2])]
data['t_dat'] = pd.to_datetime(data.t_dat).dt.date
    
data.head(2)

Unnamed: 0,t_dat,customer_id,article_id,price
0,2018-09-20,001ea4e9c54f7e9c88811260d954edc059d596147e1cf8...,652075001,0.011847
1,2018-09-20,001ea4e9c54f7e9c88811260d954edc059d596147e1cf8...,670295001,0.010153


In [6]:
def user_emb_using_entity(data, entity_emb: dict):
    
    # number of times an item was purchased by a customer
    count = data.drop_duplicates(subset=["t_dat", "customer_id", 'article_id']).groupby(["customer_id", 'article_id']).size().reset_index(name="count")

    # number of days since last order
    days_since = data.groupby(["customer_id", 'article_id'])["t_dat"].max().reset_index(name="days_since")
    days_since["days_since"] = (data.t_dat.max() - days_since["days_since"]).dt.days

    # weights or rating of each item
    weight = count.merge(days_since)
    weight["weight"] = weight["count"] / (weight["days_since"] + 1)
    weight["weight"] = weight["weight"] / weight.groupby("customer_id")["weight"].transform(sum)

    weight["entity_weight"] = list(zip(weight['article_id'], weight.weight))
    conjoint = weight.groupby("customer_id")["entity_weight"].apply(list).reset_index(name="entity_weight")

    user_emb_using_entity = {cid: normalize(torch.stack([entity_emb[entity] * weight_ 
                                                         for entity, weight_ in entity_weight]).transpose(1, 0).sum(-1), dim=0) 
                                                         for cid, entity_weight in tqdm(conjoint.values)}
    
    return user_emb_using_entity, weight[["customer_id", "article_id", "weight"]]

In [7]:
user_emb_using_item, item_rating = user_emb_using_entity(data=data, entity_emb=item_emb)

100%|██████████| 10000/10000 [00:03<00:00, 2992.33it/s]


In [8]:
# user metadata
device = "cuda" if torch.cuda.is_available() else "cpu"

user_emb = {user: emb.to(device) for user, emb in user_emb_using_item.items()}
all_users = torch.stack(list(user_emb.values()))

idx2user = {idx: user for idx, user in enumerate(user_emb.keys())}

In [9]:
def similar_users(count: int=20):
    output = []

    for user, emb in tqdm(user_emb.items()):
        scores = torch.matmul(all_users, emb)
        scores, idx = torch.sort(scores, descending=True)

        scores = scores[:count].tolist()
        idx = [idx2user[i] for i in idx[:count].tolist()]
        lusers = [user] * count

        output.extend(list(zip(*[lusers, idx, scores])))

    output = pd.DataFrame(output, columns=["user", "similar_user", "u2u_score"])
    return output

In [10]:
superset = similar_users(count=20)
superset.sort_values(by=["user", "u2u_score"], ascending=[True, False], inplace=True)
superset = superset[superset['user'] != superset['similar_user']]

100%|██████████| 10000/10000 [00:19<00:00, 514.09it/s]


In [17]:
# read item metadata for reference
items = pd.read_csv('../../data/articles.csv')
items = items[['article_id', 'prod_name', 'product_type_name', 'product_group_name', 
                     'graphical_appearance_name', 'index_group_name', 'section_name',
                     'colour_group_name', 'perceived_colour_value_name']].copy()
items.head(2)

Unnamed: 0,article_id,prod_name,product_type_name,product_group_name,graphical_appearance_name,index_group_name,section_name,colour_group_name,perceived_colour_value_name
0,108775015,Strap top,Vest top,Garment Upper body,Solid,Ladieswear,Womens Everyday Basics,Black,Dark
1,108775044,Strap top,Vest top,Garment Upper body,Solid,Ladieswear,Womens Everyday Basics,White,Light


In [74]:
user_rec = superset.rename(columns={'similar_user': 'customer_id'}).merge(item_rating, on=["customer_id"])
user_rec.head(2)

Unnamed: 0,user,customer_id,u2u_score,article_id,weight
0,000fa62c9e64d11bc25c530736949fd8dfc9a39d50c453...,957a2cd74207d9dd01b801c7740354b40f7b28e292adf6...,0.970002,200182001,0.006213
1,000fa62c9e64d11bc25c530736949fd8dfc9a39d50c453...,957a2cd74207d9dd01b801c7740354b40f7b28e292adf6...,0.970002,509134001,0.00475


In [75]:
user_rec = user_rec.merge(user_rec.groupby("user")["weight"].mean().reset_index(name="mean_rating"))
user_rec["weight"] *=  user_rec["u2u_score"]
user_rec["weight"] +=  user_rec["mean_rating"]
user_rec.head(2)

Unnamed: 0,user,customer_id,u2u_score,article_id,weight,mean_rating
0,000fa62c9e64d11bc25c530736949fd8dfc9a39d50c453...,957a2cd74207d9dd01b801c7740354b40f7b28e292adf6...,0.970002,200182001,0.030169,0.024142
1,000fa62c9e64d11bc25c530736949fd8dfc9a39d50c453...,957a2cd74207d9dd01b801c7740354b40f7b28e292adf6...,0.970002,509134001,0.02875,0.024142


In [76]:
user_rec = user_rec.groupby(["user", "article_id"])["weight"].sum().reset_index(name="u2i_score")
user_rec.head(2)

Unnamed: 0,user,article_id,u2i_score
0,000fa62c9e64d11bc25c530736949fd8dfc9a39d50c453...,158340001,0.049992
1,000fa62c9e64d11bc25c530736949fd8dfc9a39d50c453...,188183015,0.037688


In [77]:
user_rec.sort_values(by=["user", "u2i_score"], ascending=[True, False], inplace=True)
user_rec.rename(columns={"user": "customer_id"}, inplace=True)
user_rec = user_rec.groupby('customer_id', sort=False).head(300)

In [78]:
user_rec.head()

Unnamed: 0,customer_id,article_id,u2i_score
635,000fa62c9e64d11bc25c530736949fd8dfc9a39d50c453...,806388003,0.423594
420,000fa62c9e64d11bc25c530736949fd8dfc9a39d50c453...,717490064,0.230421
381,000fa62c9e64d11bc25c530736949fd8dfc9a39d50c453...,706016001,0.208791
328,000fa62c9e64d11bc25c530736949fd8dfc9a39d50c453...,685816002,0.203095
181,000fa62c9e64d11bc25c530736949fd8dfc9a39d50c453...,610776001,0.186463


In [73]:
user_rec.to_csv('./user_user_recommendation.csv', index=False)

### Qualitative Analysis
##### comparing past purchases of similar users

In [47]:
superset[superset['user'] == '00c615fdd1a42adaccd4b0f314f5369bf38d6ea1428e395764563743f8835e1c'].head()

Unnamed: 0,user,similar_user,u2u_score
601,00c615fdd1a42adaccd4b0f314f5369bf38d6ea1428e39...,05869a2f91df2737c2a75b05d82dd13d38ad3e79b47ccd...,0.985445
602,00c615fdd1a42adaccd4b0f314f5369bf38d6ea1428e39...,156fd825c26b01391efc155941b93020516a3191a28dff...,0.982485
603,00c615fdd1a42adaccd4b0f314f5369bf38d6ea1428e39...,d2fe3394303a305eb1506df8ad2f24833c283979bce6b7...,0.981813
604,00c615fdd1a42adaccd4b0f314f5369bf38d6ea1428e39...,8e5264f13953dfa3e2291af05c50bfeeafb4b40f6ac0ad...,0.981196
605,00c615fdd1a42adaccd4b0f314f5369bf38d6ea1428e39...,607f025e64e96162e6b07844db6654641becdcce33f93d...,0.98111


In [61]:
user = '00c615fdd1a42adaccd4b0f314f5369bf38d6ea1428e395764563743f8835e1c'
similar_user = '05869a2f91df2737c2a75b05d82dd13d38ad3e79b47ccd4dc7ff383c5cdee3c1'
cols = ['customer_id', 'prod_name', 'product_type_name']

In [62]:
data[data['customer_id']==user].merge(items[['article_id', 'prod_name', 'product_type_name']])[cols].head(10)

Unnamed: 0,customer_id,prod_name,product_type_name
0,00c615fdd1a42adaccd4b0f314f5369bf38d6ea1428e39...,OP Brazilian 2p Low (Acacia),Underwear bottom
1,00c615fdd1a42adaccd4b0f314f5369bf38d6ea1428e39...,Yen,Blouse
2,00c615fdd1a42adaccd4b0f314f5369bf38d6ea1428e39...,Love Espadrille,Sandals
3,00c615fdd1a42adaccd4b0f314f5369bf38d6ea1428e39...,SIGNE BOAT NECK,Sweater
4,00c615fdd1a42adaccd4b0f314f5369bf38d6ea1428e39...,SIGNE BOAT NECK,Sweater
5,00c615fdd1a42adaccd4b0f314f5369bf38d6ea1428e39...,Claudine (1),T-shirt
6,00c615fdd1a42adaccd4b0f314f5369bf38d6ea1428e39...,Asta (1),Dress
7,00c615fdd1a42adaccd4b0f314f5369bf38d6ea1428e39...,Lumiere Bralette Soft,Bra
8,00c615fdd1a42adaccd4b0f314f5369bf38d6ea1428e39...,Casper highwaist thong,Underwear bottom
9,00c615fdd1a42adaccd4b0f314f5369bf38d6ea1428e39...,Bebe Sl-set (J),Pyjama set


In [63]:
data[data['customer_id']==similar_user].merge(items[['article_id', 'prod_name', 'product_type_name']])[cols].head(10)

Unnamed: 0,customer_id,prod_name,product_type_name
0,05869a2f91df2737c2a75b05d82dd13d38ad3e79b47ccd...,Sascha pullover hoodie,Hoodie
1,05869a2f91df2737c2a75b05d82dd13d38ad3e79b47ccd...,Tina leggings,Leggings/Tights
2,05869a2f91df2737c2a75b05d82dd13d38ad3e79b47ccd...,Olivia top 2p,Bra
3,05869a2f91df2737c2a75b05d82dd13d38ad3e79b47ccd...,Eden Push Lace Valencia,Bra
4,05869a2f91df2737c2a75b05d82dd13d38ad3e79b47ccd...,Cardamom Push Valencia,Bra
5,05869a2f91df2737c2a75b05d82dd13d38ad3e79b47ccd...,Lumiere Bralette Soft,Bra
6,05869a2f91df2737c2a75b05d82dd13d38ad3e79b47ccd...,Dahlia strappy bikini,Underwear bottom
7,05869a2f91df2737c2a75b05d82dd13d38ad3e79b47ccd...,Lash brazilian,Underwear bottom
8,05869a2f91df2737c2a75b05d82dd13d38ad3e79b47ccd...,Hazelnut Brazilian Azalea Low,Underwear bottom
9,05869a2f91df2737c2a75b05d82dd13d38ad3e79b47ccd...,Cactus mynta thong,Underwear bottom
