# NARM with dwell time

In [None]:
import pandas as pd

df = pd.read_csv(
    "../nov_reduced.csv",
    usecols=["event_time", "user_id", "product_id", "user_session"]
)

df["event_time"] = pd.to_datetime(df["event_time"])

print(df.head())
print(df.dtypes)

                        event_time  product_id    user_id  \
48061404 2019-11-19 08:35:46+00:00    30200005  512412397   
59981191 2019-11-26 14:16:08+00:00     1005115  568675496   
17020977 2019-11-10 17:50:50+00:00    15700275  513262731   
5666944  2019-11-04 14:23:52+00:00     1004589  562973725   
65419992 2019-11-29 17:11:17+00:00     5300157  560750791   

                                  user_session  
48061404  f62be3c5-18af-4ab1-bdce-f1a1119a3df4  
59981191  c857db53-cd0a-480d-a93f-dd738be33126  
17020977  c637d18a-6fc5-4c1c-9044-b537d1f9d8bb  
5666944   e41d3c3f-830e-48df-97a5-ff1de86c3c5d  
65419992  0538a90a-6395-4134-b032-232e81b17397  
event_time      datetime64[ns, UTC]
product_id                    int64
user_id                       int64
user_session                 object
dtype: object


In [None]:
import numpy as np
import math
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import random
from collections import Counter

SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)

MAX_LEN = 50
MIN_SESSION_LEN = 5
TOP_N_ITEMS = 2000 
PAD_IDX = 0


print("\n\n================ GRU4REC CON DWELL ================")

DWELL_THRESHOLD = 75   

def build_sessions_with_dwell_repetitions(df):
    sessions = []

    df = df.sort_values(["user_session", "event_time"])

    for session_id, group in df.groupby("user_session"):
        group = group.sort_values("event_time")

        items = group["product_id"].tolist()
        times = group["event_time"].values

        dwells = np.diff(times).astype("timedelta64[s]").astype(float)
        dwells = np.append(dwells, 1.0)

        seq = []
        for item, d in zip(items, dwells):

            reps = int(max(d, 1) // DWELL_THRESHOLD) + 1

            seq.extend([item] * reps)

        if len(seq) >= 2:
            sessions.append(seq)

    return sessions


sessions = build_sessions_with_dwell_repetitions(df)
print("Sesiones creadas:", len(sessions))
print("Ejemplo sesión aumentada:", sessions[0][:25])


split = int(0.8 * len(sessions))
train_sessions = sessions[:split]
test_sessions = sessions[split:]

print("Train:", len(train_sessions), " Test:", len(test_sessions))


counter = Counter([item for sess in train_sessions for item in sess])
top_items = [item for item, _ in counter.most_common(TOP_N_ITEMS)]

item2idx = {item: i + 1 for i, item in enumerate(top_items)}
idx2item = {v: k for k, v in item2idx.items()}
N_ITEMS = len(item2idx) + 1

def encode_session(sess):
    return [item2idx[item] for item in sess if item in item2idx]

train_encoded = [encode_session(s) for s in train_sessions]
test_encoded = [encode_session(s) for s in test_sessions]

train_encoded = [s for s in train_encoded if len(s) >= MIN_SESSION_LEN]
test_encoded = [s for s in test_encoded if len(s) >= MIN_SESSION_LEN]

print("Train encoded:", len(train_encoded))
print("Test encoded:", len(test_encoded))
print("N_ITEMS:", N_ITEMS)


Device: cuda


Sesiones creadas: 5859994
Ejemplo sesión aumentada: [5100816, 5100816, 5100816, 1005107, 1005107, 11200402]
Train: 4687995  Test: 1171999
Train encoded: 1309888
Test encoded: 327080
N_ITEMS: 2001


In [None]:
class GRUDataset(Dataset):
    def __init__(self, sequences):
        self.sequences = sequences

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        seq = self.sequences[idx]

        # limitar a MAX_LEN
        if len(seq) > MAX_LEN:
            seq = seq[-MAX_LEN:]

        items = torch.tensor(seq[:-1], dtype=torch.long)
        targets = torch.tensor(seq[1:], dtype=torch.long)

        return items, targets


def collate_fn(batch):
    items_batch, targets_batch = zip(*batch)
    max_len = max(len(x) for x in items_batch)

    def pad(x, pad_value=0):
        return torch.cat([x, torch.full((max_len - len(x),), pad_value, dtype=x.dtype)])

    items = torch.stack([pad(x, PAD_IDX) for x in items_batch])
    targets = torch.stack([pad(x, PAD_IDX) for x in targets_batch])

    return items, targets


train_loader = DataLoader(
    GRUDataset(train_encoded),
    batch_size=32,
    shuffle=True,
    collate_fn=collate_fn
)

test_loader = DataLoader(
    GRUDataset(test_encoded),
    batch_size=1,
    shuffle=False,
    collate_fn=collate_fn
)

class NARM(nn.Module):
    def __init__(
        self,
        n_items,
        emb_size=128,
        hidden_size=128,
        dropout=0.2
    ):
        super().__init__()

        self.embedding = nn.Embedding(n_items, emb_size, padding_idx=PAD_IDX)

        self.gru = nn.GRU(
            input_size=emb_size,
            hidden_size=hidden_size,
            batch_first=True
        )

        self.linear_one = nn.Linear(hidden_size, hidden_size)
        self.linear_two = nn.Linear(hidden_size, hidden_size)

        self.fc = nn.Linear(hidden_size * 2, n_items)

        self.dropout = nn.Dropout(dropout)
        self.hidden_size = hidden_size

    def forward(self, items):
        """
        items: (B, T)
        Retorna: logits (B, T, n_items)
        Igual que SASRec, para que tu training loop siga funcionando tal cual.
        """
        emb = self.dropout(self.embedding(items))      

        gru_out, h_last = self.gru(emb)                 
        h_last = h_last.squeeze(0)                   

        q1 = self.linear_one(gru_out)              
        q2 = self.linear_two(h_last).unsqueeze(1)     

        attn_scores = torch.sum(torch.tanh(q1 + q2), dim=-1)  
        attn_weights = torch.softmax(attn_scores, dim=-1)     

        context = torch.bmm(attn_weights.unsqueeze(1), gru_out)  
        context = context.squeeze(1)                            

        final_rep = torch.cat([context, h_last], dim=-1)          

        B, T = items.size()
        logits = self.fc(final_rep).unsqueeze(1).repeat(1, T, 1)

        return logits


def train_epoch_sasrec(model, loader, optimizer, criterion):
    model.train()
    total_loss = 0

    for items, targets in loader:
        items, targets = items.to(device), targets.to(device)

        optimizer.zero_grad()
        logits = model(items)

        B, T, C = logits.shape
        loss = criterion(logits.reshape(B*T, C), targets.reshape(B*T))
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(loader)


def evaluate_sasrec(model, loader):
    model.eval()

    recall_k = 10
    recall_sum = 0
    mrr_sum = 0
    ndcg_sum = 0
    total = 0

    for items, targets in loader:
        items, targets = items.to(device), targets.to(device)

        logits = model(items)[0, -1]
        topk = torch.topk(logits, recall_k).indices.tolist()

        target = targets[0, -1].item()
        if target == PAD_IDX:
            continue

        total += 1

        if target in topk:
            rank = topk.index(target) + 1
            recall_sum += 1
            mrr_sum += 1 / rank
            ndcg_sum += 1 / math.log2(rank + 1)

    return (
        recall_sum / total,
        mrr_sum / total,
        ndcg_sum / total
    )


print("\n\n================ NARM CON DWELL ================\n")

narm_dwell = NARM(N_ITEMS).to(device)
criterion_narm = nn.CrossEntropyLoss(ignore_index=PAD_IDX)
optimizer_narm = torch.optim.Adam(narm_dwell.parameters(), lr=1e-3)

import time

start = time.time()

EPOCHS = 10

for ep in range(1, EPOCHS + 1):
    loss = train_epoch_sasrec(narm_dwell, train_loader, optimizer_narm, criterion_narm)
    

    print(f"\nEpoch {ep}/{EPOCHS} (NARM CON DWELL)")
    print(f"Loss: {loss:.4f}")
    if (ep == 0) or (ep == 5) or (ep == EPOCHS):
        recall, mrr, ndcg = evaluate_sasrec(narm_dwell, test_loader)
        print(f"Recall@10: {recall:.4f}  MRR@10: {mrr:.4f}  NDCG@10: {ndcg:.4f}")

end = time.time()

print(f"\n⏱ Tiempo total: {end - start:.2f} segundos")





Epoch 1/10 (NARM CON DWELL)
Loss: 1.4834

Epoch 2/10 (NARM CON DWELL)
Loss: 1.1807

Epoch 3/10 (NARM CON DWELL)
Loss: 1.1381

Epoch 4/10 (NARM CON DWELL)
Loss: 1.1173

Epoch 5/10 (NARM CON DWELL)
Loss: 1.1056
Recall@10: 0.7055  MRR@10: 0.5111  NDCG@10: 0.5588

Epoch 6/10 (NARM CON DWELL)
Loss: 1.0972

Epoch 7/10 (NARM CON DWELL)
Loss: 1.0911

Epoch 8/10 (NARM CON DWELL)
Loss: 1.0872

Epoch 9/10 (NARM CON DWELL)
Loss: 1.0833

Epoch 10/10 (NARM CON DWELL)
Loss: 1.0807
Recall@10: 0.7095  MRR@10: 0.5095  NDCG@10: 0.5585

⏱ Tiempo total: 2194.72 segundos
