In [1]:
import polars as pl
from datetime import timedelta
from concurrent.futures import ProcessPoolExecutor
from itertools import chain
from functools import partial
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from collections import defaultdict
import numpy as np

df_clickstream = pl.read_parquet(f'clickstream.pq')
df_event = pl.read_parquet(f'events.pq')
df_cat = pl.read_parquet(f'cat_features.pq')

In [2]:
df_clickstream = df_clickstream.join(df_cat.select('item', 'category', 'location'), on='item', how='inner')
df_clickstream = df_clickstream.join(df_event, on='event', how='inner')

In [4]:
df_clickstream = df_clickstream.sort('event_date').group_by(['cookie', 'node']).first()

In [5]:
cut = df_clickstream["event_date"].max() - timedelta(days=28)
df_clickstream = df_clickstream.filter(pl.col("event_date") <= cut)
df_clickstream

cookie,node,item,event,event_date,platform,surface,category,location,is_contact
i64,u32,i64,i64,datetime[ns],i64,i64,i64,i64,i64
113587,313781,20516204,17,2025-01-13 11:30:03,2,2,35,9508,0
145917,287332,5816253,17,2025-01-15 08:56:52,3,11,35,2348,0
139139,155650,6179547,17,2025-01-24 17:43:09,3,11,50,1234,0
10630,264108,18274016,17,2025-01-17 06:35:03,2,11,51,2903,0
82332,51164,28268984,17,2025-01-10 00:11:51,2,2,49,1503,0
…,…,…,…,…,…,…,…,…,…
3507,71511,14917618,17,2025-01-16 22:28:38,2,2,7,4577,0
7156,348051,15880766,17,2025-01-16 09:49:44,2,8,51,2354,0
65535,200418,12471077,17,2025-01-12 22:26:40,3,11,57,2348,0
5564,229339,7479809,17,2025-01-15 12:40:15,2,2,51,4351,0


In [6]:
max_seq_length = 20
embed_dim = 96
n_heads = 3 
n_layers = 2
batch_size = 512
n_epochs = 5
top_k = 300

def prepare_data(df):
    # create sessions based on 30min inactivity
    df = df.with_columns(
        pl.col("event_date").cast(pl.Int64) // 1e9
    ).sort(["cookie", "event_date"])
    
    df = df.with_columns(
        pl.col("event_date").diff().over("cookie").alias("time_diff")
    ).with_columns(
        pl.when(pl.col("time_diff") > 1800).then(1).otherwise(0).alias("session_start")
    ).with_columns(
        pl.col("session_start").cum_sum().over("cookie").alias("session_id")
    )
    
    contact_df = df.filter(pl.col("is_contact")==1)
    
    sequences = (
        df.join(
            contact_df.select(["cookie", pl.col("session_id").alias("target_session")]),
            on="cookie",
            how="inner"
        )
        .filter(pl.col("session_id") < pl.col("target_session"))
    )
    
    return sequences.group_by(["cookie", "target_session"]).agg(
        pl.col("node").tail(max_seq_length).alias("node_seq"),
        pl.col("event").tail(max_seq_length).alias("event_seq"),
        pl.col("category").tail(max_seq_length).alias("category_seq"),
        pl.col("platform").last().alias("platform"),
        pl.col("surface").last().alias("surface"),
        pl.col("location").last().alias("location"),
        pl.col("node").last().alias("target_node")
    )

def create_vocab_mappings(df):
    node_list = df["node"].unique().to_list()
    node_vocab = {v: i+1 for i, v in enumerate(node_list)}
    node_reverse = {i+1: v for i, v in enumerate(node_list)}
    node_reverse[0] = "PAD"
    
    return {
        "node": defaultdict(lambda: 0, node_vocab),
        "node_reverse": node_reverse,
        "event": defaultdict(lambda: 0, {v: i+1 for i, v in enumerate(df["event"].unique().to_list())}),
        "category": defaultdict(lambda: 0, {v: i+1 for i, v in enumerate(df["category"].unique().to_list())}),
        "platform": defaultdict(lambda: 0, {v: i+1 for i, v in enumerate(df["platform"].unique().to_list())}),
        "surface": defaultdict(lambda: 0, {v: i+1 for i, v in enumerate(df["surface"].unique().to_list())}),
        "location": defaultdict(lambda: 0, {v: i+1 for i, v in enumerate(df["location"].unique().to_list())}),
    }

class ClickDataset(Dataset):
    def __init__(self, data, vocabs, is_training):
        self.data = data
        self.vocabs = vocabs
        self.is_training = is_training
        
    def __len__(self):
        return self.data.height
    
    def __getitem__(self, idx):
        row = self.data.row(idx, named=True)
        item = {
            "nodes": torch.LongTensor(self.pad_sequence([self.vocabs["node"][v] for v in row["node_seq"]])),
            "events": torch.LongTensor(self.pad_sequence([self.vocabs["event"][v] for v in row["event_seq"]])),
            "categories": torch.LongTensor(self.pad_sequence([self.vocabs["category"][v] for v in row["category_seq"]])),
            "platform": torch.LongTensor([self.vocabs["platform"][row["platform"]]]),
            "surface": torch.LongTensor([self.vocabs["surface"][row["surface"]]]),
            "location": torch.LongTensor([self.vocabs["location"][row["location"]]]),
            # "target": torch.tensor(self.vocabs["node"][row["target_node"]], dtype=torch.long),
            "cookie": row["cookie"]  # Keep cookie for final predictions
        }
        if self.is_training:
            item["target"] = torch.tensor(self.vocabs["node"][row["target_node"]], dtype=torch.long)
        else:
            item["cookie"] = row["cookie"]
        return item
    
    def pad_sequence(self, seq):
        return seq[-max_seq_length:] + [0]*(max_seq_length - len(seq))

class ContactPredictor(nn.Module):
    def __init__(self, vocabs):
        super().__init__()
        
        self.node_emb = nn.Embedding(len(vocabs["node"])+1, embed_dim, padding_idx=0)
        self.event_emb = nn.Embedding(len(vocabs["event"])+1, embed_dim//4, padding_idx=0)
        self.cat_emb = nn.Embedding(len(vocabs["category"])+1, embed_dim//4, padding_idx=0)
        
        # transformer
        self.pos_embedding = nn.Embedding(max_seq_length, embed_dim + embed_dim//4*2)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim + embed_dim//4*2,
            nhead=n_heads,
            dim_feedforward=embed_dim*2,
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
        
        self.context_emb = nn.ModuleDict({
            'platform': nn.Embedding(len(vocabs["platform"])+1, embed_dim//4),
            'surface': nn.Embedding(len(vocabs["surface"])+1, embed_dim//4),
            'location': nn.Embedding(len(vocabs["location"])+1, embed_dim//4)
        })
        
        seq_dim     = embed_dim + 2*(embed_dim//4)
        context_dim = 3*(embed_dim//4)
        self.context_proj = nn.Sequential(
            nn.Linear(seq_dim + context_dim, embed_dim),
            nn.ReLU()
        )
        
        self.fc = nn.Linear(embed_dim, len(vocabs["node"])+1)
        
    def forward(self, x):
        nodes = x["nodes"].to(self.node_emb.weight.device)
        events = x["events"].to(self.event_emb.weight.device)
        categories = x["categories"].to(self.cat_emb.weight.device)
        
        node_emb = self.node_emb(nodes)
        event_emb = self.event_emb(events)
        cat_emb = self.cat_emb(categories)
        
        seq_emb = torch.cat([node_emb, event_emb, cat_emb], dim=-1)
        
        positions = self.pos_embedding(torch.arange(max_seq_length, device=nodes.device))
        seq_emb += positions.unsqueeze(0)
        
        transformer_out = self.transformer(seq_emb)
        seq_encoded = transformer_out.mean(dim=1) 
        
        platform_emb = self.context_emb['platform'](x["platform"].squeeze().to(nodes.device))
        surface_emb = self.context_emb['surface'](x["surface"].squeeze().to(nodes.device))
        location_emb = self.context_emb['location'](x["location"].squeeze().to(nodes.device))
        
        combined = torch.cat([
            seq_encoded,
            platform_emb,
            surface_emb,
            location_emb
        ], dim=-1)
        
        return self.fc(self.context_proj(combined))

def generate_efficient_predictions(df, model, vocabs):
    latest_seq = df.sort(["cookie", "event_date"]).group_by("cookie").agg(
        pl.col("node").tail(20).alias("node_seq"),
        pl.col("event").tail(20).alias("event_seq"),
        pl.col("category").tail(20).alias("category_seq"),
        pl.col("platform").last(),
        pl.col("surface").last(),
        pl.col("location").last()
    )

    user_histories = df.group_by("cookie").agg(pl.col("node").unique().alias("hist_nodes"))

    pred_ds = ClickDataset(latest_seq, vocabs, is_training=False)

    model.eval()
    all_preds = []
    with torch.no_grad():
        for batch in DataLoader(pred_ds, batch_size=batch_size):
            outputs = model(batch)
            probs = torch.softmax(outputs, dim=-1)
            top_probs, top_idxs = torch.topk(probs, k=300)

            for cookie, idxs, p in zip(batch["cookie"], top_idxs.cpu().numpy(), top_probs.cpu().numpy()):
                recommended = [vocabs["node_reverse"].get(idx, "PAD") for idx in idxs if idx in vocabs["node_reverse"]]
                user_history = set(user_histories.filter(pl.col("cookie") == cookie)["hist_nodes"].to_list()[0])

                new_recs = [(rank+1, node, prob) for rank, (node, prob) in enumerate(zip(recommended, p)) if node not in user_history][:300]

                for rank, node, prob in new_recs:
                    all_preds.append({
                        "cookie": cookie,
                        "node": node,
                        "rank": rank,
                        "proba": prob
                    })

    return pl.DataFrame(all_preds)

def train_model(df):
    seq_df = prepare_data(df)
    vocabs = create_vocab_mappings(df)
    
    train_df = seq_df.filter(pl.col("target_session") % 10 < 18)
    val_df = seq_df.filter(pl.col("target_session") % 10 >= 8)
    
    train_ds = ClickDataset(train_df, vocabs, is_training=True)
    val_ds = ClickDataset(val_df, vocabs, is_training=True)
    
    model = ContactPredictor(vocabs)
    optimizer = optim.AdamW(model.parameters(), lr=1e-3)
    criterion = nn.CrossEntropyLoss()
    
    for epoch in range(n_epochs):
        model.train()
        train_loss = 0
        for batch in tqdm(DataLoader(train_ds, batch_size=batch_size, shuffle=True)):
            optimizer.zero_grad()
            outputs = model(batch)
            loss = criterion(outputs, batch["target"])
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for batch in DataLoader(val_ds, batch_size=batch_size):
                outputs = model(batch)
                val_loss += criterion(outputs, batch["target"]).item()
        
        print(f"Epoch {epoch+1}: Train Loss {train_loss/len(train_ds):.4f}, Val Loss {val_loss/len(val_ds):.4f}")
    
    return model, vocabs

In [8]:
if __name__ == "__main__":
    model, vocabs = train_model(df_clickstream)
    
    predictions = generate_efficient_predictions(df_clickstream, model, vocabs)
    
    predictions.write_parquet("retrieval_data/transformer2_28d.parquet")

100%|██████████| 16/16 [00:07<00:00,  2.25it/s]


Epoch 1: Train Loss 0.0252, Val Loss 0.0233


100%|██████████| 16/16 [00:07<00:00,  2.09it/s]


Epoch 2: Train Loss 0.0202, Val Loss 0.0176


100%|██████████| 16/16 [00:08<00:00,  1.87it/s]


Epoch 3: Train Loss 0.0169, Val Loss 0.0164


100%|██████████| 16/16 [00:07<00:00,  2.13it/s]


Epoch 4: Train Loss 0.0162, Val Loss 0.0160


100%|██████████| 16/16 [00:07<00:00,  2.14it/s]


Epoch 5: Train Loss 0.0158, Val Loss 0.0157


In [9]:
df_clickstream2= pl.read_parquet(f'clickstream.pq')

df_eval = df_clickstream2.filter(df_clickstream2['event_date']> cut)[['cookie', 'node', 'event']]
df_eval = df_eval.join(df_clickstream, on=['cookie', 'node'], how='anti')
df_eval = df_eval.filter(
    pl.col('event').is_in(
        df_event.filter(pl.col('is_contact')==1)['event'].unique()
    )
)

df_eval = df_eval.unique(['cookie', 'node'])

In [10]:
from utils import create_features, recall_at, fit_lgb_ranker
recall_at(df_eval, predictions, k=300) #0.1273 0.2258 0.22945 0.18153

0.2476348211618965