In [2]:
import torch
import pickle
import numpy as np
import pandas as pd
from tqdm import tqdm
from torch.nn.functional import normalize

In [4]:
# loading custom trained bert model embeddings
product_embeddings = pickle.load(
    open("../upsell/checkpoint/custom_trained_item_embedding.pkl", "rb")
)

embs = normalize(torch.tensor(product_embeddings[1]))
item_id = product_embeddings[2]

item_emb = {item_id[idx]: embs[idx].clone() for idx in range(len(item_id))}

  embs = normalize(torch.tensor(product_embeddings[1]))


In [5]:
data = pd.read_csv("../../data/transactions.csv").dropna(subset=["article_id"]).copy()
data = data[data["article_id"].isin(product_embeddings[2])]
data["t_dat"] = pd.to_datetime(data.t_dat).dt.date

data.head(2)

Unnamed: 0,t_dat,customer_id,article_id,price
0,2018-09-20,001ea4e9c54f7e9c88811260d954edc059d596147e1cf8...,652075001,0.011847
1,2018-09-20,001ea4e9c54f7e9c88811260d954edc059d596147e1cf8...,670295001,0.010153


In [6]:
def user_emb_using_entity(data, entity_emb: dict):
    # number of times an item was purchased by a customer
    count = (
        data.drop_duplicates(subset=["t_dat", "customer_id", "article_id"])
        .groupby(["customer_id", "article_id"])
        .size()
        .reset_index(name="count")
    )

    # number of days since last order
    days_since = (
        data.groupby(["customer_id", "article_id"])["t_dat"]
        .max()
        .reset_index(name="days_since")
    )
    days_since["days_since"] = (data.t_dat.max() - days_since["days_since"]).dt.days

    # weights or rating of each item
    weight = count.merge(days_since)
    weight["weight"] = weight["count"] / (weight["days_since"] + 1)
    weight["weight"] = weight["weight"] / weight.groupby("customer_id")[
        "weight"
    ].transform(sum)

    weight["entity_weight"] = list(zip(weight["article_id"], weight.weight))
    conjoint = (
        weight.groupby("customer_id")["entity_weight"]
        .apply(list)
        .reset_index(name="entity_weight")
    )

    user_emb_using_entity = {
        cid: normalize(
            torch.stack(
                [entity_emb[entity] * weight_ for entity, weight_ in entity_weight]
            )
            .transpose(1, 0)
            .sum(-1),
            dim=0,
        )
        for cid, entity_weight in tqdm(conjoint.values)
    }

    return user_emb_using_entity, weight[["customer_id", "article_id", "weight"]]

In [7]:
user_emb_using_item, item_rating = user_emb_using_entity(data=data, entity_emb=item_emb)

100%|██████████| 10000/10000 [00:03<00:00, 2715.62it/s]


In [8]:
# user metadata
device = "cuda" if torch.cuda.is_available() else "cpu"

user_emb = {user: emb.to(device) for user, emb in user_emb_using_item.items()}
all_users = torch.stack(list(user_emb.values()))

idx2user = {idx: user for idx, user in enumerate(user_emb.keys())}

In [9]:
def similar_users(count: int = 20):
    output = []

    for user, emb in tqdm(user_emb.items()):
        scores = torch.matmul(all_users, emb)
        scores, idx = torch.sort(scores, descending=True)

        scores = scores[:count].tolist()
        idx = [idx2user[i] for i in idx[:count].tolist()]
        lusers = [user] * count

        output.extend(list(zip(*[lusers, idx, scores])))

    output = pd.DataFrame(output, columns=["user", "similar_user", "u2u_score"])
    return output

In [10]:
superset = similar_users(count=20)
superset.sort_values(by=["user", "u2u_score"], ascending=[True, False], inplace=True)
superset = superset[superset["user"] != superset["similar_user"]]

100%|██████████| 10000/10000 [00:22<00:00, 449.00it/s]


In [11]:
# read item metadata for reference
items = pd.read_csv("../../data/articles.csv")
items = items[
    [
        "article_id",
        "prod_name",
        "product_type_name",
        "product_group_name",
        "graphical_appearance_name",
        "index_group_name",
        "section_name",
        "colour_group_name",
        "perceived_colour_value_name",
    ]
].copy()
items.head(2)

Unnamed: 0,article_id,prod_name,product_type_name,product_group_name,graphical_appearance_name,index_group_name,section_name,colour_group_name,perceived_colour_value_name
0,108775015,Strap top,Vest top,Garment Upper body,Solid,Ladieswear,Womens Everyday Basics,Black,Dark
1,108775044,Strap top,Vest top,Garment Upper body,Solid,Ladieswear,Womens Everyday Basics,White,Light


In [None]:
user_rec = superset.rename(columns={"similar_user": "customer_id"}).merge(
    item_rating, on=["customer_id"]
)
user_rec.head(2)

Unnamed: 0,user,customer_id,u2u_score,article_id,weight
0,000fa62c9e64d11bc25c530736949fd8dfc9a39d50c453...,957a2cd74207d9dd01b801c7740354b40f7b28e292adf6...,0.970002,200182001,0.006213
1,000fa62c9e64d11bc25c530736949fd8dfc9a39d50c453...,957a2cd74207d9dd01b801c7740354b40f7b28e292adf6...,0.970002,509134001,0.00475


In [None]:
user_rec = user_rec.merge(
    user_rec.groupby("user")["weight"].mean().reset_index(name="mean_rating")
)
user_rec["weight"] *= user_rec["u2u_score"]
user_rec["weight"] += user_rec["mean_rating"]
user_rec.head(2)

Unnamed: 0,user,customer_id,u2u_score,article_id,weight,mean_rating
0,000fa62c9e64d11bc25c530736949fd8dfc9a39d50c453...,957a2cd74207d9dd01b801c7740354b40f7b28e292adf6...,0.970002,200182001,0.030169,0.024142
1,000fa62c9e64d11bc25c530736949fd8dfc9a39d50c453...,957a2cd74207d9dd01b801c7740354b40f7b28e292adf6...,0.970002,509134001,0.02875,0.024142


In [None]:
user_rec = (
    user_rec.groupby(["user", "article_id"])["weight"]
    .sum()
    .reset_index(name="u2i_score")
)
user_rec.head(2)

Unnamed: 0,user,article_id,u2i_score
0,000fa62c9e64d11bc25c530736949fd8dfc9a39d50c453...,158340001,0.049992
1,000fa62c9e64d11bc25c530736949fd8dfc9a39d50c453...,188183015,0.037688


In [None]:
user_rec.sort_values(by=["user", "u2i_score"], ascending=[True, False], inplace=True)
user_rec.rename(columns={"user": "customer_id"}, inplace=True)
user_rec = user_rec.groupby("customer_id", sort=False).head(300)

In [None]:
user_rec.head()

Unnamed: 0,customer_id,article_id,u2i_score
635,000fa62c9e64d11bc25c530736949fd8dfc9a39d50c453...,806388003,0.423594
420,000fa62c9e64d11bc25c530736949fd8dfc9a39d50c453...,717490064,0.230421
381,000fa62c9e64d11bc25c530736949fd8dfc9a39d50c453...,706016001,0.208791
328,000fa62c9e64d11bc25c530736949fd8dfc9a39d50c453...,685816002,0.203095
181,000fa62c9e64d11bc25c530736949fd8dfc9a39d50c453...,610776001,0.186463


In [None]:
user_rec.to_csv("./user_user_recommendation.csv", index=False)

### Qualitative Analysis
##### comparing past purchases of similar users

In [42]:
user = "8d6abe71d67e27769c182b3a2ca7677664e5a507d4223e5528847e41acec4225"
similar_user = "3813cfbf2d3ea02d5b8838f7725a09f5617f003abaaba44b773a6994cedd9387"
cols = ["customer_id", "prod_name", "product_type_name"]

In [46]:
data[data["customer_id"] == user].merge(
    items[["article_id", "prod_name", "product_type_name"]]
)[cols].tail(10)

Unnamed: 0,customer_id,prod_name,product_type_name
3,8d6abe71d67e27769c182b3a2ca7677664e5a507d4223e...,Box 4p Tights,Underwear Tights
4,8d6abe71d67e27769c182b3a2ca7677664e5a507d4223e...,Box 4p Tights,Underwear Tights
5,8d6abe71d67e27769c182b3a2ca7677664e5a507d4223e...,Shake it in Balconette,Bikini top
6,8d6abe71d67e27769c182b3a2ca7677664e5a507d4223e...,EDC Flossa skirt.,Skirt
7,8d6abe71d67e27769c182b3a2ca7677664e5a507d4223e...,missy sneaker(1),Sneakers
8,8d6abe71d67e27769c182b3a2ca7677664e5a507d4223e...,Tonia shorts,Shorts
9,8d6abe71d67e27769c182b3a2ca7677664e5a507d4223e...,Asa top,Top
10,8d6abe71d67e27769c182b3a2ca7677664e5a507d4223e...,Kendall Denim TRS,Trousers
11,8d6abe71d67e27769c182b3a2ca7677664e5a507d4223e...,C Antibes Tie Tanga,Swimwear bottom
12,8d6abe71d67e27769c182b3a2ca7677664e5a507d4223e...,Swish Super Push,Bikini top


In [45]:
data[data["customer_id"] == similar_user].merge(
    items[["article_id", "prod_name", "product_type_name"]]
)[cols].tail(10)

Unnamed: 0,customer_id,prod_name,product_type_name
60,3813cfbf2d3ea02d5b8838f7725a09f5617f003abaaba4...,Darcy PQ sandal,Sandals
61,3813cfbf2d3ea02d5b8838f7725a09f5617f003abaaba4...,NORA RW shorts innerbriefs,Shorts
62,3813cfbf2d3ea02d5b8838f7725a09f5617f003abaaba4...,Panorama mid support bra,Bra
63,3813cfbf2d3ea02d5b8838f7725a09f5617f003abaaba4...,3p Sneaker Socks,Socks
64,3813cfbf2d3ea02d5b8838f7725a09f5617f003abaaba4...,Push it Push Bra.,Bikini top
65,3813cfbf2d3ea02d5b8838f7725a09f5617f003abaaba4...,Calvin Clean Banana brief,Swimwear bottom
66,3813cfbf2d3ea02d5b8838f7725a09f5617f003abaaba4...,Frizz skirt,Skirt
67,3813cfbf2d3ea02d5b8838f7725a09f5617f003abaaba4...,Tropicana Tie Tanga,Swimwear bottom
68,3813cfbf2d3ea02d5b8838f7725a09f5617f003abaaba4...,Calvin Clean wire bra structur,Bikini top
69,3813cfbf2d3ea02d5b8838f7725a09f5617f003abaaba4...,Tropicana Triangle,Bikini top
