In [20]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, RGCNConv
from torch_geometric.utils import negative_sampling, to_undirected
from sklearn.metrics import roc_auc_score, average_precision_score
import numpy as np
from datetime import datetime, timedelta
import random
from tqdm import tqdm
import json
import pickle

In [21]:
class TemporalGCNRecommender(nn.Module):
    """
    RE-GCN style temporal recommender for product recommendations
    
    This model:
    1. Uses recurrent GCN layers to capture temporal dynamics
    2. Predicts user-product interactions and their weights (quantities)
    3. Handles bipartite user-product graphs with temporal evolution
    """
    
    def __init__(self, num_users, num_products, node_features_dim, hidden_dim=64, num_layers=2, dropout=0.1):
        super().__init__()
        
        self.num_users = num_users
        self.num_products = num_products
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        
        # Node embeddings
        self.user_embedding = nn.Embedding(num_users, hidden_dim)
        self.product_embedding = nn.Embedding(num_products, hidden_dim)
        
        # Feature projection
        self.feature_proj = nn.Linear(node_features_dim, hidden_dim)
        
        # Temporal GCN layers (recurrent style)
        self.gcn_layers = nn.ModuleList([
            GCNConv(hidden_dim, hidden_dim) for _ in range(num_layers)
        ])
        
        # Recurrent cells for temporal modeling
        self.gru_cells = nn.ModuleList([
            nn.GRUCell(hidden_dim, hidden_dim) for _ in range(num_layers)
        ])
        
        # Edge weight predictor (for quantity prediction)
        self.edge_weight_predictor = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, 1),
            #nn.Sigmoid()  # Normalize weights
        )
        
        # Link prediction (binary classification)
        self.link_predictor = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )
        
        self.dropout = nn.Dropout(dropout)
        
        # Initialize hidden states
        self.reset_hidden_states()
    
    def reset_hidden_states(self):
        """Reset hidden states for new sequence"""
        self.hidden_states = [None] * self.num_layers
    
    def forward(self, x, edge_index, edge_attr, user_indices, product_indices):
        """
        Forward pass for temporal link prediction
        
        Args:
            x: Node features [num_nodes, feature_dim]
            edge_index: Edge indices [2, num_edges]
            edge_attr: Edge attributes [num_edges, 2] (quantity, timestamp)
            user_indices: User node indices for prediction
            product_indices: Product node indices for prediction
        """
        batch_size = len(user_indices)
        
        # Project features to hidden dimension
        h = self.feature_proj(x)
        
        # Add positional embeddings
        user_emb = self.user_embedding(torch.arange(self.num_users, device=x.device))
        product_emb = self.product_embedding(torch.arange(self.num_products, device=x.device))
        
        h[:self.num_users] += user_emb
        h[self.num_users:] += product_emb
        
        # Temporal GCN with recurrent connections
        for layer_idx, (gcn, gru) in enumerate(zip(self.gcn_layers, self.gru_cells)):
            # Graph convolution
            h_new = gcn(h, edge_index)
            h_new = F.relu(h_new)
            h_new = self.dropout(h_new)
            
            # Recurrent update (temporal modeling)
            if self.hidden_states[layer_idx] is not None:
                h_new = gru(h_new, self.hidden_states[layer_idx])
            else:
                h_new = gru(h_new, torch.zeros_like(h_new))
            
            # Update hidden state
            self.hidden_states[layer_idx] = h_new.detach()
            h = h_new
        
        # Extract user and product embeddings for prediction
        user_embeddings = h[user_indices]
        product_embeddings = h[product_indices]
        
        # Concatenate user and product embeddings
        edge_embeddings = torch.cat([user_embeddings, product_embeddings], dim=1)
        
        # Predict link existence and edge weight
        link_scores = self.link_predictor(edge_embeddings).squeeze()
        edge_weights = self.edge_weight_predictor(edge_embeddings).squeeze()
        
        return link_scores, edge_weights, h

class TemporalRecommenderTrainer:
    """Training and evaluation pipeline for temporal recommendations"""
    
    def __init__(self, model, tkg_data, device='cpu'):
        self.model = model.to(device)
        self.device = device
        self.tkg_data = tkg_data
        
        # Prepare temporal snapshots
        self.temporal_snapshots = self.create_temporal_snapshots()
        
    def create_temporal_snapshots(self, time_window_hours=24):
        """
        Create temporal snapshots from TKG data
        
        Args:
            time_window_hours: Hours per snapshot
        """
        quadruples = self.tkg_data['quadruples']
        
        # Group quadruples by time windows
        snapshots = {}
        min_time = min(q[3] for q in quadruples)
        
        for user_id, quantity, product_id, timestamp in quadruples:
            # Calculate time window
            hours_since_start = (timestamp - min_time).total_seconds() / 3600
            window_id = int(hours_since_start // time_window_hours)
            
            if window_id not in snapshots:
                snapshots[window_id] = []
            
            snapshots[window_id].append((user_id, quantity, product_id, timestamp))
        
        # Convert to sorted list
        sorted_snapshots = [snapshots[i] for i in sorted(snapshots.keys())]
        
        print(f"Created {len(sorted_snapshots)} temporal snapshots")
        return sorted_snapshots
    
    def prepare_snapshot_data(self, snapshot_edges):
        """Prepare PyG data for a single snapshot"""
        
        # Extract edges and attributes
        edge_list = []
        edge_weights = []
        
        for user_id, quantity, product_id, timestamp in snapshot_edges:
            edge_list.append([user_id, product_id])
            edge_weights.append(quantity)
        
        if len(edge_list) == 0:
            return None, None, None
        
        edge_index = torch.tensor(edge_list, dtype=torch.long).t().contiguous()
        edge_attr = torch.tensor(edge_weights, dtype=torch.float).unsqueeze(1)
        
        # Make undirected for better message passing
        # edge_index = to_undirected(edge_index)
        
        return edge_index, edge_attr, edge_list
    
    def create_negative_samples(self, positive_edges, num_negative=None):
        """Create negative samples for link prediction"""
        
        if num_negative is None:
            num_negative = len(positive_edges)
        
        negative_edges = []
        user_ids = set(edge[0] for edge in positive_edges)
        product_ids = set(edge[1] for edge in positive_edges)
        positive_set = set((edge[0], edge[1]) for edge in positive_edges)
        
        while len(negative_edges) < num_negative:
            user_id = random.choice(list(user_ids))
            product_id = random.choice(list(product_ids))
            
            if (user_id, product_id) not in positive_set:
                negative_edges.append([user_id, product_id])
                positive_set.add((user_id, product_id))  # Avoid duplicates
        
        return negative_edges
    
    def train_epoch(self, optimizer, train_snapshots):
        """Train for one epoch"""
        self.model.train()
        total_loss = 0
        num_batches = 0
        
        # Reset hidden states for new epoch
        self.model.reset_hidden_states()
        
        for snapshot_idx, snapshot_edges in enumerate(tqdm(train_snapshots, desc="Training")):
            
            edge_index, edge_attr, edge_list = self.prepare_snapshot_data(snapshot_edges)
            
            if edge_index is None:
                continue
            
            # Move to device
            x = self.tkg_data['x'].to(self.device)
            edge_index = edge_index.to(self.device)
            edge_attr = edge_attr.to(self.device)
            
            # Create positive and negative samples
            positive_edges = edge_list
            negative_edges = self.create_negative_samples(positive_edges)
            
            # Prepare training data
            all_edges = positive_edges + negative_edges
            labels = torch.cat([
                torch.ones(len(positive_edges)),
                torch.zeros(len(negative_edges))
            ]).to(self.device)
            
            # Extract edge weights (only for positive edges) - NO NORMALIZATION
            true_weights = torch.tensor([edge[1] for edge in snapshot_edges], dtype=torch.float).to(self.device)
            
            # Get user and product indices
            user_indices = torch.tensor([edge[0] for edge in all_edges], dtype=torch.long).to(self.device)
            product_indices = torch.tensor([edge[1] for edge in all_edges], dtype=torch.long).to(self.device)
            
            optimizer.zero_grad()
            
            # Forward pass
            link_scores, pred_weights, _ = self.model(x, edge_index, edge_attr, user_indices, product_indices)
            
            # Link prediction loss
            link_loss = F.binary_cross_entropy(link_scores, labels)
            
            # Edge weight loss (only for positive edges)
            if len(true_weights) > 0:
                weight_loss = F.mse_loss(pred_weights[:len(positive_edges)], true_weights)
            else:
                weight_loss = torch.tensor(0.0, device=self.device)
            
            # Combined loss
            total_loss_batch = link_loss + 0.1 * weight_loss  # Reduced weight loss coefficient
            
            total_loss_batch.backward()
            optimizer.step()
            
            total_loss += total_loss_batch.item()
            num_batches += 1
        
        return total_loss / max(num_batches, 1)

    def evaluate(self, test_snapshots):
        """Evaluate model performance"""
        self.model.eval()
        
        all_link_preds = []
        all_link_labels = []
        all_weight_preds = []
        all_weight_labels = []
        
        with torch.no_grad():
            # Reset hidden states
            self.model.reset_hidden_states()
            
            for snapshot_edges in tqdm(test_snapshots, desc="Evaluating"):
                
                edge_index, edge_attr, edge_list = self.prepare_snapshot_data(snapshot_edges)
                
                if edge_index is None:
                    continue
                
                # Move to device
                x = self.tkg_data['x'].to(self.device)
                edge_index = edge_index.to(self.device)
                edge_attr = edge_attr.to(self.device)
                
                # Create test samples
                positive_edges = edge_list
                negative_edges = self.create_negative_samples(positive_edges)
                
                all_edges = positive_edges + negative_edges
                labels = torch.cat([
                    torch.ones(len(positive_edges)),
                    torch.zeros(len(negative_edges))
                ])
                
                # True weights - NO NORMALIZATION
                true_weights = torch.tensor([q for _, q, _, _ in snapshot_edges], dtype=torch.float)
                
                # Get indices
                user_indices = torch.tensor([edge[0] for edge in all_edges], dtype=torch.long).to(self.device)
                product_indices = torch.tensor([edge[1] for edge in all_edges], dtype=torch.long).to(self.device)
                
                # Forward pass
                link_scores, pred_weights, _ = self.model(x, edge_index, edge_attr, user_indices, product_indices)
                
                # Collect predictions
                all_link_preds.extend(link_scores.cpu().numpy())
                all_link_labels.extend(labels.numpy())
                
                if len(true_weights) > 0:
                    all_weight_preds.extend(pred_weights[:len(positive_edges)].cpu().numpy())
                    all_weight_labels.extend(true_weights.numpy())
        
        # Calculate metrics
        if len(all_link_preds) > 0:
            link_auc = roc_auc_score(all_link_labels, all_link_preds)
            link_ap = average_precision_score(all_link_labels, all_link_preds)
        else:
            link_auc = link_ap = 0.0
        
        if len(all_weight_preds) > 0:
            weight_mse = np.mean((np.array(all_weight_preds) - np.array(all_weight_labels)) ** 2)
            weight_mae = np.mean(np.abs(np.array(all_weight_preds) - np.array(all_weight_labels)))
        else:
            weight_mse = weight_mae = 0.0
        
        return {
            'link_auc': link_auc,
            'link_ap': link_ap,
            'weight_mse': weight_mse,
            'weight_mae': weight_mae
        }
    
    def train(self, num_epochs=50, lr=0.01, train_ratio=0.8):
        """Full training pipeline"""
        
        # Split temporal snapshots
        split_idx = int(len(self.temporal_snapshots) * train_ratio)
        train_snapshots = self.temporal_snapshots[:split_idx]
        test_snapshots = self.temporal_snapshots[split_idx:]
        
        print(f"Training on {len(train_snapshots)} snapshots, testing on {len(test_snapshots)} snapshots")
        
        optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)
        
        best_auc = 0
        for epoch in range(num_epochs):
            
            # Train
            train_loss = self.train_epoch(optimizer, train_snapshots)
            
            # Evaluate
            if epoch % 5 == 0:
                metrics = self.evaluate(test_snapshots)
                
                print(f"Epoch {epoch:3d}: Loss={train_loss:.4f}, "
                      f"AUC={metrics['link_auc']:.4f}, "
                      f"AP={metrics['link_ap']:.4f}, "
                      f"Weight MSE={metrics['weight_mse']:.4f}")
                
                if metrics['link_auc'] > best_auc:
                    best_auc = metrics['link_auc']
                    torch.save(self.model.state_dict(), 'best_temporal_recommender.pt')
            
            scheduler.step()
        
        # Load best model
        self.model.load_state_dict(torch.load('best_temporal_recommender.pt'))
        
        # Final evaluation
        final_metrics = self.evaluate(test_snapshots)
        print(f"\nFinal Results:")
        print(f"Link AUC: {final_metrics['link_auc']:.4f}")
        print(f"Link AP: {final_metrics['link_ap']:.4f}")
        print(f"Weight MSE: {final_metrics['weight_mse']:.4f}")
        
        return final_metrics
    
    def recommend_products(self, user_id, top_k=10, exclude_existing=True):
        """
        Generate product recommendations for a user
        
        Args:
            user_id: Target user ID
            top_k: Number of recommendations
            exclude_existing: Whether to exclude products user already bought
        """
        self.model.eval()
        
        with torch.no_grad():
            x = self.tkg_data['x'].to(self.device)
            
            # Get all products
            all_product_ids = list(range(self.tkg_data['num_users'], 
                                       self.tkg_data['num_users'] + self.tkg_data['num_products']))
            
            # Exclude existing products if requested
            if exclude_existing:
                existing_products = set()
                for quad in self.tkg_data['quadruples']:
                    if quad[0] == user_id:  # user_id matches
                        existing_products.add(quad[2])  # product_id
                
                candidate_products = [pid for pid in all_product_ids if pid not in existing_products]
            else:
                candidate_products = all_product_ids
            
            if len(candidate_products) == 0:
                return []
            
            # Create dummy edge index (we need this for the forward pass)
            dummy_edges = [[user_id, pid] for pid in candidate_products[:min(100, len(candidate_products))]]
            edge_index = torch.tensor(dummy_edges, dtype=torch.long).t().contiguous().to(self.device)
            edge_attr = torch.zeros((len(dummy_edges), 1), dtype=torch.float).to(self.device)
            
            # Prepare indices for prediction
            user_indices = torch.tensor([user_id] * len(dummy_edges), dtype=torch.long).to(self.device)
            product_indices = torch.tensor([edge[1] for edge in dummy_edges], dtype=torch.long).to(self.device)
            
            # Get predictions
            link_scores, pred_weights, _ = self.model(x, edge_index, edge_attr, user_indices, product_indices)
            
            # Combine scores (link probability * predicted quantity)
            combined_scores = link_scores * pred_weights
            
            # Get top-k recommendations
            top_indices = torch.argsort(combined_scores, descending=True)[:top_k]
            
            recommendations = []
            for idx in top_indices:
                product_id = dummy_edges[idx][1]
                score = combined_scores[idx].item()
                link_prob = link_scores[idx].item()
                pred_quantity = pred_weights[idx].item()
                
                recommendations.append({
                    'product_id': product_id,
                    'product_name': self.tkg_data.get('id_to_product', {}).get(product_id, f'Product_{product_id}'),
                    'score': score,
                    'link_probability': link_prob,
                    'predicted_quantity': pred_quantity
                })
            
            return recommendations


In [15]:
# Load your TKG data
tkg_data = torch.load('graph/my_retail_tkg_pyg.pt', weights_only=False)
"""
# Initialize model
model = TemporalGCNRecommender(
    num_users=tkg_data['num_users'],
    num_products=tkg_data['num_products'],
    node_features_dim=tkg_data['x'].shape[1],
    hidden_dim=64
)

# Initialize trainer
trainer = TemporalRecommenderTrainer(model, tkg_data, device='cuda' if torch.cuda.is_available() else 'cpu')

# Train the model
results = trainer.train(num_epochs=30, lr=0.01)

# Get recommendations for user
user_id = 0  # First user
recommendations = trainer.recommend_products(user_id, top_k=10)

print(f"Top recommendations for user {user_id}:")
for i, rec in enumerate(recommendations, 1):
    print(f"{i}. {rec['product_name']}")
    print(f"   Score: {rec['score']:.3f}, Link Prob: {rec['link_probability']:.3f}, Pred Qty: {rec['predicted_quantity']:.2f}")
"""

'\n# Initialize model\nmodel = TemporalGCNRecommender(\n    num_users=tkg_data[\'num_users\'],\n    num_products=tkg_data[\'num_products\'],\n    node_features_dim=tkg_data[\'x\'].shape[1],\n    hidden_dim=64\n)\n\n# Initialize trainer\ntrainer = TemporalRecommenderTrainer(model, tkg_data, device=\'cuda\' if torch.cuda.is_available() else \'cpu\')\n\n# Train the model\nresults = trainer.train(num_epochs=30, lr=0.01)\n\n# Get recommendations for user\nuser_id = 0  # First user\nrecommendations = trainer.recommend_products(user_id, top_k=10)\n\nprint(f"Top recommendations for user {user_id}:")\nfor i, rec in enumerate(recommendations, 1):\n    print(f"{i}. {rec[\'product_name\']}")\n    print(f"   Score: {rec[\'score\']:.3f}, Link Prob: {rec[\'link_probability\']:.3f}, Pred Qty: {rec[\'predicted_quantity\']:.2f}")\n'

In [16]:
# Initialize model
model = TemporalGCNRecommender(
    num_users=tkg_data['num_users'],
    num_products=tkg_data['num_products'],
    node_features_dim=tkg_data['x'].shape[1],
    hidden_dim=64
)

# Initialize trainer
trainer = TemporalRecommenderTrainer(model, tkg_data, device='cuda' if torch.cuda.is_available() else 'cpu')

# Train the model
results = trainer.train(num_epochs=30, lr=0.01)

# Get recommendations for user
user_id = 0  # First user
recommendations = trainer.recommend_products(user_id, top_k=10)

print(f"Top recommendations for user {user_id}:")
for i, rec in enumerate(recommendations, 1):
    print(f"{i}. {rec['product_name']}")
    print(f"   Score: {rec['score']:.3f}, Link Prob: {rec['link_probability']:.3f}, Pred Qty: {rec['predicted_quantity']:.2f}")

Created 305 temporal snapshots
Training on 244 snapshots, testing on 61 snapshots


Training: 100%|██████████| 244/244 [00:23<00:00, 10.43it/s]
Evaluating: 100%|██████████| 61/61 [00:03<00:00, 17.88it/s]


Epoch   0: Loss=0.6997, AUC=0.7929, AP=0.7284, Weight MSE=0.8309


Training: 100%|██████████| 244/244 [00:30<00:00,  7.95it/s]
Training: 100%|██████████| 244/244 [00:30<00:00,  8.00it/s]
Training: 100%|██████████| 244/244 [00:29<00:00,  8.14it/s]
Training: 100%|██████████| 244/244 [00:41<00:00,  5.83it/s]
Training: 100%|██████████| 244/244 [00:41<00:00,  5.89it/s]
Evaluating: 100%|██████████| 61/61 [00:05<00:00, 12.01it/s]


Epoch   5: Loss=0.4038, AUC=0.8680, AP=0.8225, Weight MSE=0.6761


Training: 100%|██████████| 244/244 [00:48<00:00,  5.04it/s]
Training: 100%|██████████| 244/244 [00:39<00:00,  6.15it/s]
Training: 100%|██████████| 244/244 [00:46<00:00,  5.22it/s]
Training: 100%|██████████| 244/244 [00:41<00:00,  5.88it/s]
Training: 100%|██████████| 244/244 [00:45<00:00,  5.38it/s]
Evaluating: 100%|██████████| 61/61 [00:06<00:00,  9.09it/s]


Epoch  10: Loss=0.3791, AUC=0.8734, AP=0.8255, Weight MSE=0.6745


Training: 100%|██████████| 244/244 [00:39<00:00,  6.11it/s]
Training: 100%|██████████| 244/244 [00:30<00:00,  8.13it/s]
Training: 100%|██████████| 244/244 [00:26<00:00,  9.16it/s]
Training: 100%|██████████| 244/244 [00:28<00:00,  8.42it/s]
Training: 100%|██████████| 244/244 [00:29<00:00,  8.26it/s]
Evaluating: 100%|██████████| 61/61 [00:03<00:00, 20.14it/s]


Epoch  15: Loss=0.3675, AUC=0.8732, AP=0.8209, Weight MSE=0.6798


Training: 100%|██████████| 244/244 [00:30<00:00,  8.03it/s]
Training: 100%|██████████| 244/244 [00:30<00:00,  7.92it/s]
Training: 100%|██████████| 244/244 [00:34<00:00,  7.05it/s]
Training: 100%|██████████| 244/244 [00:33<00:00,  7.29it/s]
Training: 100%|██████████| 244/244 [00:32<00:00,  7.40it/s]
Evaluating: 100%|██████████| 61/61 [00:03<00:00, 18.16it/s]


Epoch  20: Loss=0.3425, AUC=0.8784, AP=0.8363, Weight MSE=0.6887


Training: 100%|██████████| 244/244 [00:36<00:00,  6.66it/s]
Training: 100%|██████████| 244/244 [00:34<00:00,  7.05it/s]
Training: 100%|██████████| 244/244 [00:36<00:00,  6.64it/s]
Training: 100%|██████████| 244/244 [00:34<00:00,  7.10it/s]
Training: 100%|██████████| 244/244 [00:35<00:00,  6.89it/s]
Evaluating: 100%|██████████| 61/61 [00:04<00:00, 14.78it/s]


Epoch  25: Loss=0.3260, AUC=0.9133, AP=0.8785, Weight MSE=0.6840


Training: 100%|██████████| 244/244 [00:41<00:00,  5.93it/s]
Training: 100%|██████████| 244/244 [00:39<00:00,  6.24it/s]
Training: 100%|██████████| 244/244 [00:36<00:00,  6.68it/s]
Training: 100%|██████████| 244/244 [00:36<00:00,  6.76it/s]
Evaluating: 100%|██████████| 61/61 [00:03<00:00, 18.24it/s]



Final Results:
Link AUC: 0.9137
Link AP: 0.8795
Weight MSE: 0.6840
Top recommendations for user 0:
1. Product_4025
   Score: 1.825, Link Prob: 0.903, Pred Qty: 2.02
2. Product_3942
   Score: 1.810, Link Prob: 0.905, Pred Qty: 2.00
3. Product_3966
   Score: 1.809, Link Prob: 0.904, Pred Qty: 2.00
4. Product_3971
   Score: 1.808, Link Prob: 0.921, Pred Qty: 1.96
5. Product_3961
   Score: 1.807, Link Prob: 0.904, Pred Qty: 2.00
6. Product_3970
   Score: 1.806, Link Prob: 0.901, Pred Qty: 2.00
7. Product_4037
   Score: 1.805, Link Prob: 0.898, Pred Qty: 2.01
8. Product_4016
   Score: 1.804, Link Prob: 0.904, Pred Qty: 2.00
9. Product_4024
   Score: 1.803, Link Prob: 0.898, Pred Qty: 2.01
10. Product_3931
   Score: 1.800, Link Prob: 0.904, Pred Qty: 1.99


In [18]:
# Load trained model - takes seconds
model = TemporalGCNRecommender(
    num_users=tkg_data['num_users'],
    num_products=tkg_data['num_products'],
    node_features_dim=tkg_data['x'].shape[1],
    hidden_dim=64
)
model.load_state_dict(torch.load('best_temporal_recommender.pt'))

# Get recommendations instantly (milliseconds)
recommendations = trainer.recommend_products(user_id=1002, top_k=10)

In [19]:
recommendations 

[{'product_id': 3957,
  'product_name': 'Product_3957',
  'score': 1.9419025182724,
  'link_probability': 0.8873864412307739,
  'predicted_quantity': 2.1883392333984375},
 {'product_id': 3988,
  'product_name': 'Product_3988',
  'score': 1.9356331825256348,
  'link_probability': 0.8947839140892029,
  'predicted_quantity': 2.163240909576416},
 {'product_id': 3942,
  'product_name': 'Product_3942',
  'score': 1.9340131282806396,
  'link_probability': 0.895521342754364,
  'predicted_quantity': 2.1596505641937256},
 {'product_id': 3931,
  'product_name': 'Product_3931',
  'score': 1.931741714477539,
  'link_probability': 0.8929761648178101,
  'predicted_quantity': 2.163262367248535},
 {'product_id': 3932,
  'product_name': 'Product_3932',
  'score': 1.9310879707336426,
  'link_probability': 0.8947371244430542,
  'predicted_quantity': 2.1582741737365723},
 {'product_id': 3970,
  'product_name': 'Product_3970',
  'score': 1.9257551431655884,
  'link_probability': 0.8819207549095154,
  'predi

In [None]:
def load_tkg_mappings(base_path='graph/my_retail_tkg'):
    """
    Load TKG mappings and data from saved files
    
    Args:
        base_path: Base path to the saved TKG files (without extension)
    
    Returns:
        Dictionary containing all mappings and metadata
    """
    # Load the main pickle file with all data
    with open(f'{base_path}.pkl', 'rb') as f:
        tkg_data = pickle.load(f)
    
    # Load the JSON mappings (for easy inspection)
    with open(f'{base_path}_mappings.json', 'r') as f:
        mappings_json = json.load(f)
    
    # Load PyTorch geometric data
    pyg_data = torch.load(f'{base_path}_pyg.pt', weights_only=False)
    
    return {
        'user_to_id': tkg_data['user_to_id'],
        'product_to_id': tkg_data['product_to_id'],
        'id_to_user': tkg_data['id_to_user'],
        'id_to_product': tkg_data['id_to_product'],
        'user_features': tkg_data['user_features'],
        'product_features': tkg_data['product_features'],
        'quadruples': tkg_data['quadruples'],
        'quantity_scaler': tkg_data.get('quantity_scaler'),  # May not exist in older saves
        'log_transform_used': tkg_data.get('log_transform_used', True),
        'metadata': tkg_data['metadata'],
        'pyg_data': pyg_data
    }

def reverse_quantity_scaling(normalized_quantities, mappings_data):
    """
    Reverse the quantity scaling applied during TKG creation
    
    Args:
        normalized_quantities: Array or single value of normalized quantities
        mappings_data: Data loaded from load_tkg_mappings()
    
    Returns:
        Actual quantities
    """
    # Convert single value to array for processing
    is_single = not isinstance(normalized_quantities, (list, np.ndarray))
    if is_single:
        normalized_quantities = [normalized_quantities]
    
    # If we have the original scaler, use it
    if mappings_data.get('quantity_scaler') is not None:
        quantity_scaler = mappings_data['quantity_scaler']
        log_transform_used = mappings_data.get('log_transform_used', True)
        
        # Inverse min-max scaling
        log_quantities = quantity_scaler.inverse_transform(
            np.array(normalized_quantities).reshape(-1, 1)
        ).flatten()
        
        # Inverse log transform if it was used
        if log_transform_used:
            actual_quantities = np.expm1(log_quantities)  # exp(x) - 1
        else:
            actual_quantities = log_quantities
    else:
        # Fallback: estimate reverse scaling based on typical range
        # This is less accurate but works if scaler wasn't saved
        print("Warning: Original scaler not found, using estimated reverse scaling")
        
        # Assume normalized range was 1.0-10.0, map back to reasonable quantities
        normalized_array = np.array(normalized_quantities)
        # Simple linear mapping back to 1-50 range (adjust based on your data)
        actual_quantities = 1 + (normalized_array - 1) * 49 / 9  # (10-1) = 9 is the norm range
    
    # Ensure minimum quantity of 1 and round to reasonable precision
    actual_quantities = np.maximum(1, actual_quantities)
    
    return actual_quantities[0] if is_single else actual_quantities

def map_recommendations_to_actual(recommendations, mappings_data):
    """
    Map internal IDs and normalized values back to actual customer IDs, product descriptions, and quantities
    
    Args:
        recommendations: List of recommendation dictionaries from the model
        mappings_data: Data loaded from load_tkg_mappings()
    
    Returns:
        List of recommendations with actual values
    """
    mapped_recommendations = []
    
    for rec in recommendations:
        # Map product ID to actual description
        product_id = rec['product_id']
        actual_description = mappings_data['id_to_product'].get(product_id, 'Unknown Product')
        
        # Map predicted quantity back to actual scale
        predicted_qty_normalized = rec['predicted_quantity']
        actual_qty = reverse_quantity_scaling(predicted_qty_normalized, mappings_data)
        
        mapped_rec = {
            'product_id': product_id,
            'product_description': actual_description,
            'score': rec['score'],
            'link_probability': rec['link_probability'],
            'predicted_quantity_normalized': predicted_qty_normalized,
            'predicted_quantity_actual': round(actual_qty, 2)
        }
        
        mapped_recommendations.append(mapped_rec)
    
    return mapped_recommendations

def map_user_id_to_actual(user_id, mappings_data):
    """
    Map internal user ID back to actual customer ID
    
    Args:
        user_id: Internal user ID used in the model
        mappings_data: Data loaded from load_tkg_mappings()
    
    Returns:
        Actual customer ID
    """
    return mappings_data['id_to_user'].get(user_id, 'Unknown Customer')

def get_recommendations_with_actual_values(user_id, trainer, mappings_data, top_k=10):
    """
    Get recommendations and map all values back to actual identifiers
    
    Args:
        user_id: Internal user ID
        trainer: Your trained model/trainer instance
        mappings_data: Data loaded from load_tkg_mappings()
        top_k: Number of recommendations to get
    
    Returns:
        Dictionary with actual customer ID and mapped recommendations
    """
    # Get raw recommendations
    raw_recommendations = trainer.recommend_products(user_id=user_id, top_k=top_k)
    
    # Map to actual values
    mapped_recommendations = map_recommendations_to_actual(raw_recommendations, mappings_data)
    
    # Get actual customer ID
    actual_customer_id = map_user_id_to_actual(user_id, mappings_data)
    
    return {
        'customer_id_internal': user_id,
        'customer_id_actual': actual_customer_id,
        'recommendations': mapped_recommendations
    }

def create_recommendation_report(user_id, trainer, mappings_data, top_k=10):
    """
    Create a formatted recommendation report with actual values
    """
    result = get_recommendations_with_actual_values(user_id, trainer, mappings_data, top_k)
    
    print("=" * 80)
    print(f"PRODUCT RECOMMENDATIONS")
    print("=" * 80)
    print(f"Customer ID: {result['customer_id_actual']}")
    print(f"Internal ID: {result['customer_id_internal']}")
    print(f"Top {top_k} Recommendations:")
    print("-" * 80)
    
    for i, rec in enumerate(result['recommendations'], 1):
        print(f"{i:2d}. Product: {rec['product_description']}")
        print(f"     Confidence Score: {rec['score']:.3f}")
        print(f"     Purchase Probability: {rec['link_probability']:.1%}")
        print(f"     Expected Quantity: {rec['predicted_quantity_actual']}")
        print(f"     Product ID: {rec['product_id']}")
        print()
    
    return result

def get_batch_recommendations(user_ids, trainer, mappings_data, top_k=10):
    """
    Get recommendations for multiple users at once
    
    Args:
        user_ids: List of internal user IDs
        trainer: Your trained model/trainer instance  
        mappings_data: Data loaded from load_tkg_mappings()
        top_k: Number of recommendations per user
    
    Returns:
        Dictionary mapping actual customer IDs to their recommendations
    """
    batch_results = {}
    
    for user_id in user_ids:
        result = get_recommendations_with_actual_values(user_id, trainer, mappings_data, top_k)
        batch_results[result['customer_id_actual']] = result
    
    return batch_results

def find_user_by_customer_id(customer_id, mappings_data):
    """
    Find internal user ID by actual customer ID
    
    Args:
        customer_id: Actual customer ID from your original data
        mappings_data: Data loaded from load_tkg_mappings()
        
    Returns:
        Internal user ID or None if not found
    """
    return mappings_data['user_to_id'].get(customer_id)

def find_product_by_description(description, mappings_data):
    """
    Find internal product ID by product description
    
    Args:
        description: Product description from your original data
        mappings_data: Data loaded from load_tkg_mappings()
        
    Returns:
        Internal product ID or None if not found
    """
    return mappings_data['product_to_id'].get(description)



Customer: 12346.0
- FELT EGG COSY CHICKEN: 7.33 units
- JAM MAKING SET WITH JARS: 7.31 units
- POPPY'S PLAYHOUSE BEDROOM : 7.33 units
- POPPY'S PLAYHOUSE KITCHEN: 7.31 units
- PACK OF 72 RETROSPOT CAKE CASES: 7.45 units
- HOT WATER BOTTLE TEA AND SYMPATHY: 7.27 units
- COOK WITH WINE METAL SIGN : 7.35 units
- HAND WARMER UNION JACK: 7.21 units
- JUMBO BAG PINK POLKADOT: 7.29 units
- VICTORIAN SEWING BOX LARGE: 7.4 units
PRODUCT RECOMMENDATIONS
Customer ID: 12346.0
Internal ID: 1002
Top 10 Recommendations:
--------------------------------------------------------------------------------
 1. Product: VICTORIAN SEWING BOX LARGE
     Confidence Score: 1.942
     Purchase Probability: 88.7%
     Expected Quantity: 7.47
     Product ID: 3957

 2. Product: FELT EGG COSY CHICKEN
     Confidence Score: 1.935
     Purchase Probability: 89.5%
     Expected Quantity: 7.33
     Product ID: 3988

 3. Product: JAM MAKING SET WITH JARS
     Confidence Score: 1.934
     Purchase Probability: 89.5%
     

"\n# Find internal ID by customer ID\ncustomer_id = '12345'  # Your actual customer ID\ninternal_id = find_user_by_customer_id(customer_id, mappings_data)\nif internal_id is not None:\n    result = get_recommendations_with_actual_values(internal_id, trainer, mappings_data)\n"

In [23]:
# Usage example:

# Load all mappings and data
mappings_data = load_tkg_mappings('graph/my_retail_tkg')

# Get recommendations with actual values
result = get_recommendations_with_actual_values(
    user_id=43, 
    trainer=trainer, 
    mappings_data=mappings_data, 
    top_k=10
)

print(f"Customer: {result['customer_id_actual']}")
for rec in result['recommendations']:
    print(f"- {rec['product_description']}: {rec['predicted_quantity_actual']} units")

# Or create a formatted report
create_recommendation_report(user_id=1002, trainer=trainer, mappings_data=mappings_data)
"""
# Find internal ID by customer ID
customer_id = '12345'  # Your actual customer ID
internal_id = find_user_by_customer_id(customer_id, mappings_data)
if internal_id is not None:
    result = get_recommendations_with_actual_values(internal_id, trainer, mappings_data)
"""

Customer: 12748.0
- 5 STRAND GLASS NECKLACE CRYSTAL: 5.87 units
- BLUE NEW BAROQUE CANDLESTICK CANDLE: 5.87 units
- SET OF 3 COLOURED  FLYING DUCKS: 5.87 units
- ROUND SNACK BOXES SET OF4 WOODLAND : 5.97 units
- DISCO BALL CHRISTMAS DECORATION: 5.98 units
- RETROSPOT LAMP: 5.97 units
- SMALL HEART FLOWERS HOOK : 5.95 units
- BLUE COAT RACK PARIS FASHION: 5.87 units
- VICTORIAN SEWING BOX LARGE: 6.01 units
- DOORMAT FAIRY CAKE: 5.88 units
PRODUCT RECOMMENDATIONS
Customer ID: 12346.0
Internal ID: 1002
Top 10 Recommendations:
--------------------------------------------------------------------------------
 1. Product: POPPY'S PLAYHOUSE KITCHEN
     Confidence Score: 1.938
     Purchase Probability: 89.5%
     Expected Quantity: 7.35
     Product ID: 3932

 2. Product: JAM MAKING SET WITH JARS
     Confidence Score: 1.927
     Purchase Probability: 89.2%
     Expected Quantity: 7.32
     Product ID: 3942

 3. Product: POPPY'S PLAYHOUSE BEDROOM 
     Confidence Score: 1.927
     Purchase Pr

"\n# Find internal ID by customer ID\ncustomer_id = '12345'  # Your actual customer ID\ninternal_id = find_user_by_customer_id(customer_id, mappings_data)\nif internal_id is not None:\n    result = get_recommendations_with_actual_values(internal_id, trainer, mappings_data)\n"