In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, RGCNConv
from torch_geometric.utils import to_undirected
import numpy as np
from sklearn.metrics import roc_auc_score, average_precision_score
from tqdm import tqdm
import random

In [2]:
class EvolutionUnit(nn.Module):
    """
    RE-GCN Evolution Unit that combines GCN and GRU with pooling
    """
    
    def __init__(self, hidden_dim, num_relations=1, dropout=0.1):
        super().__init__()
        
        self.hidden_dim = hidden_dim
        
        # Relation-aware GCN (or regular GCN if num_relations=1)
        if num_relations > 1:
            self.gcn = RGCNConv(hidden_dim, hidden_dim, num_relations)
        else:
            self.gcn = GCNConv(hidden_dim, hidden_dim)
            
        # GRU for temporal evolution
        self.gru = nn.GRUCell(hidden_dim, hidden_dim)
        
        # Pooling mechanism for historical states
        self.pooling_weights = nn.Linear(hidden_dim, 1)
        
        # Static graph constraint (attention mechanism)
        self.static_attention = nn.MultiheadAttention(hidden_dim, num_heads=4, dropout=dropout, batch_first=True)
        
        # Time gate for controlling temporal influence
        self.time_gate = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.Sigmoid()
        )
        
        self.dropout = nn.Dropout(dropout)
        
        # Store historical states for pooling
        self.historical_states = []
        self.max_history = 10  # Limit memory usage
        
    def apply_pooling(self, current_state):
        """
        Apply pooling operation to aggregate historical information
        """
        if len(self.historical_states) == 0:
            return current_state
            
        # Stack historical states: [history_len, num_nodes, hidden_dim]
        historical_tensor = torch.stack(self.historical_states, dim=0)
        
        # Compute attention weights for pooling
        pooling_scores = self.pooling_weights(historical_tensor)  # [history_len, num_nodes, 1]
        pooling_weights = F.softmax(pooling_scores, dim=0)  # Normalize over history
        
        # Weighted sum of historical states
        pooled_history = torch.sum(pooling_weights * historical_tensor, dim=0)  # [num_nodes, hidden_dim]
        
        return pooled_history
        
    def forward(self, x, edge_index, edge_type=None, hidden_state=None):
        """
        Forward pass through evolution unit
        
        Args:
            x: Node features [num_nodes, hidden_dim]
            edge_index: Edge indices [2, num_edges]
            edge_type: Edge types [num_edges] (optional)
            hidden_state: Previous hidden state [num_nodes, hidden_dim]
        """
        # 1. Graph Convolution
        if edge_type is not None and hasattr(self.gcn, 'num_relations'):
            h_gcn = self.gcn(x, edge_index, edge_type)
        else:
            h_gcn = self.gcn(x, edge_index)
            
        h_gcn = F.relu(h_gcn)
        h_gcn = self.dropout(h_gcn)
        
        # 2. Apply pooling with historical information
        h_pooled = self.apply_pooling(h_gcn)
        
        # 3. Static graph constraint via self-attention
        h_static, _ = self.static_attention(h_pooled.unsqueeze(0), h_pooled.unsqueeze(0), h_pooled.unsqueeze(0))
        h_static = h_static.squeeze(0)
        
        # 4. Temporal evolution with GRU
        if hidden_state is not None:
            # Time gate to control influence of previous state
            gate_input = torch.cat([h_static, hidden_state], dim=1)
            time_weight = self.time_gate(gate_input)
            gated_hidden = time_weight * hidden_state
            
            h_new = self.gru(h_static, gated_hidden)
        else:
            h_new = self.gru(h_static, torch.zeros_like(h_static))
        
        # 5. Update historical states for pooling
        self.historical_states.append(h_new.detach().clone())
        if len(self.historical_states) > self.max_history:
            self.historical_states.pop(0)  # Remove oldest state
            
        return h_new

class TemporalGCNRecommender(nn.Module):
    """
    RE-GCN style temporal recommender for product recommendations
    
    This model:
    1. Uses Evolution Units that combine GCN + GRU + Pooling
    2. Predicts user-product interactions and their weights (quantities)
    3. Handles bipartite user-product graphs with temporal evolution
    """
    
    def __init__(self, num_users, num_products, node_features_dim, hidden_dim=64, num_layers=2, dropout=0.1, num_relations=1):
        super().__init__()
        
        self.num_users = num_users
        self.num_products = num_products
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        
        # Node embeddings
        self.user_embedding = nn.Embedding(num_users, hidden_dim)
        self.product_embedding = nn.Embedding(num_products, hidden_dim)
        
        # Feature projection
        self.feature_proj = nn.Linear(node_features_dim, hidden_dim)
        
        # Evolution Units (RE-GCN style)
        self.evolution_units = nn.ModuleList([
            EvolutionUnit(hidden_dim, num_relations, dropout) for _ in range(num_layers)
        ])
        
        # Score functions for entity and relation prediction
        self.entity_predictor = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )
        
        self.relation_predictor = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, 1)  # Predicts quantity/weight
        )
        
        self.dropout = nn.Dropout(dropout)
        
        # Initialize hidden states for each evolution unit
        self.reset_hidden_states()
    
    def reset_hidden_states(self):
        """Reset hidden states and historical information for new sequence"""
        self.hidden_states = [None] * self.num_layers
        for unit in self.evolution_units:
            unit.historical_states = []
    
    def forward(self, x, edge_index, edge_attr=None, user_indices=None, product_indices=None, edge_type=None):
        """
        Forward pass for temporal link prediction
        
        Args:
            x: Node features [num_nodes, feature_dim]
            edge_index: Edge indices [2, num_edges]
            edge_attr: Edge attributes [num_edges, 2] (quantity, timestamp)
            user_indices: User node indices for prediction
            product_indices: Product node indices for prediction
            edge_type: Edge types [num_edges] (optional)
        """
        # Project features to hidden dimension
        h = self.feature_proj(x)
        
        # Add positional embeddings
        user_emb = self.user_embedding(torch.arange(self.num_users, device=x.device))
        product_emb = self.product_embedding(torch.arange(self.num_products, device=x.device))
        
        h[:self.num_users] += user_emb
        h[self.num_users:] += product_emb
        
        # Pass through Evolution Units
        for layer_idx, evolution_unit in enumerate(self.evolution_units):
            h = evolution_unit(h, edge_index, edge_type, self.hidden_states[layer_idx])
            self.hidden_states[layer_idx] = h.detach()
        
        # If no specific indices provided, return node embeddings
        if user_indices is None or product_indices is None:
            return h
        
        # Extract user and product embeddings for prediction
        user_embeddings = h[user_indices]
        product_embeddings = h[product_indices]
        
        # Concatenate user and product embeddings
        edge_embeddings = torch.cat([user_embeddings, product_embeddings], dim=1)
        
        # Predict using score functions (following RE-GCN paper)
        entity_scores = self.entity_predictor(edge_embeddings).squeeze()  # Link existence
        relation_scores = self.relation_predictor(edge_embeddings).squeeze()  # Edge weights
        
        return entity_scores, relation_scores, h

class TemporalRecommenderTrainer:
    """Training and evaluation pipeline for temporal recommendations"""
    
    def __init__(self, model, tkg_data, device='cpu'):
        self.model = model.to(device)
        self.device = device
        self.tkg_data = tkg_data
        
        # Prepare temporal snapshots
        self.temporal_snapshots = self.create_temporal_snapshots()
        
    def create_temporal_snapshots(self, time_window_hours=24):
        """
        Create temporal snapshots from TKG data
        """
        quadruples = self.tkg_data['quadruples']
        
        # Group quadruples by time windows
        snapshots = {}
        min_time = min(q[3] for q in quadruples)
        
        for user_id, quantity, product_id, timestamp in quadruples:
            # Calculate time window
            hours_since_start = (timestamp - min_time).total_seconds() / 3600
            window_id = int(hours_since_start // time_window_hours)
            
            if window_id not in snapshots:
                snapshots[window_id] = []
            
            snapshots[window_id].append((user_id, quantity, product_id, timestamp))
        
        # Convert to sorted list
        sorted_snapshots = [snapshots[i] for i in sorted(snapshots.keys())]
        
        print(f"Created {len(sorted_snapshots)} temporal snapshots")
        return sorted_snapshots
    
    def prepare_snapshot_data(self, snapshot_edges):
        """Prepare PyG data for a single snapshot"""
        
        # Extract edges and attributes
        edge_list = []
        edge_weights = []
        
        for user_id, quantity, product_id, timestamp in snapshot_edges:
            edge_list.append([user_id, product_id])
            edge_weights.append(quantity)
        
        if len(edge_list) == 0:
            return None, None, None
        
        edge_index = torch.tensor(edge_list, dtype=torch.long).t().contiguous()
        edge_attr = torch.tensor(edge_weights, dtype=torch.float).unsqueeze(1)
        
        return edge_index, edge_attr, edge_list
    
    def create_negative_samples(self, positive_edges, num_negative=None):
        """Create negative samples for link prediction"""
        
        if num_negative is None:
            num_negative = len(positive_edges)
        
        negative_edges = []
        user_ids = set(edge[0] for edge in positive_edges)
        product_ids = set(edge[1] for edge in positive_edges)
        positive_set = set((edge[0], edge[1]) for edge in positive_edges)
        
        while len(negative_edges) < num_negative:
            user_id = random.choice(list(user_ids))
            product_id = random.choice(list(product_ids))
            
            if (user_id, product_id) not in positive_set:
                negative_edges.append([user_id, product_id])
                positive_set.add((user_id, product_id))  # Avoid duplicates
        
        return negative_edges
    
    def train_epoch(self, optimizer, train_snapshots):
        """Train for one epoch"""
        self.model.train()
        total_loss = 0
        num_batches = 0
        
        # Reset hidden states for new epoch
        self.model.reset_hidden_states()
        
        for snapshot_idx, snapshot_edges in enumerate(tqdm(train_snapshots, desc="Training")):
            
            edge_index, edge_attr, edge_list = self.prepare_snapshot_data(snapshot_edges)
            
            if edge_index is None:
                continue
            
            # Move to device
            x = self.tkg_data['x'].to(self.device)
            edge_index = edge_index.to(self.device)
            edge_attr = edge_attr.to(self.device)
            
            # Create positive and negative samples
            positive_edges = edge_list
            negative_edges = self.create_negative_samples(positive_edges)
            
            # Prepare training data
            all_edges = positive_edges + negative_edges
            labels = torch.cat([
                torch.ones(len(positive_edges)),
                torch.zeros(len(negative_edges))
            ]).to(self.device)
            
            # Extract edge weights (only for positive edges)
            true_weights = torch.tensor([edge[1] for edge in snapshot_edges], dtype=torch.float).to(self.device)
            # Normalize weights to [0, 1] range
            if len(true_weights) > 0:
                true_weights = (true_weights - true_weights.min()) / (true_weights.max() - true_weights.min() + 1e-8)
            
            # Get user and product indices
            user_indices = torch.tensor([edge[0] for edge in all_edges], dtype=torch.long).to(self.device)
            product_indices = torch.tensor([edge[1] for edge in all_edges], dtype=torch.long).to(self.device)
            
            optimizer.zero_grad()
            
            # Forward pass
            entity_scores, relation_scores, _ = self.model(x, edge_index, edge_attr, user_indices, product_indices)
            
            # Entity prediction loss (link existence)
            entity_loss = F.binary_cross_entropy(entity_scores, labels)
            
            # Relation prediction loss (edge weights, only for positive edges)
            if len(true_weights) > 0:
                relation_loss = F.mse_loss(relation_scores[:len(positive_edges)], true_weights)
            else:
                relation_loss = torch.tensor(0.0, device=self.device)
            
            # Combined loss (following RE-GCN formulation)
            total_loss_batch = entity_loss + 0.5 * relation_loss
            
            total_loss_batch.backward()
            optimizer.step()
            
            total_loss += total_loss_batch.item()
            num_batches += 1
        
        return total_loss / max(num_batches, 1)
    
    def evaluate(self, test_snapshots):
        """Evaluate model performance"""
        self.model.eval()
        
        all_entity_preds = []
        all_entity_labels = []
        all_relation_preds = []
        all_relation_labels = []
        
        with torch.no_grad():
            # Reset hidden states
            self.model.reset_hidden_states()
            
            for snapshot_edges in tqdm(test_snapshots, desc="Evaluating"):
                
                edge_index, edge_attr, edge_list = self.prepare_snapshot_data(snapshot_edges)
                
                if edge_index is None:
                    continue
                
                # Move to device
                x = self.tkg_data['x'].to(self.device)
                edge_index = edge_index.to(self.device)
                edge_attr = edge_attr.to(self.device)
                
                # Create test samples
                positive_edges = edge_list
                negative_edges = self.create_negative_samples(positive_edges)
                
                all_edges = positive_edges + negative_edges
                labels = torch.cat([
                    torch.ones(len(positive_edges)),
                    torch.zeros(len(negative_edges))
                ])
                
                # True weights
                true_weights = torch.tensor([q for _, q, _, _ in snapshot_edges], dtype=torch.float)
                if len(true_weights) > 0:
                    true_weights = (true_weights - true_weights.min()) / (true_weights.max() - true_weights.min() + 1e-8)
                
                # Get indices
                user_indices = torch.tensor([edge[0] for edge in all_edges], dtype=torch.long).to(self.device)
                product_indices = torch.tensor([edge[1] for edge in all_edges], dtype=torch.long).to(self.device)
                
                # Forward pass
                entity_scores, relation_scores, _ = self.model(x, edge_index, edge_attr, user_indices, product_indices)
                
                # Collect predictions
                all_entity_preds.extend(entity_scores.cpu().numpy())
                all_entity_labels.extend(labels.numpy())
                
                if len(true_weights) > 0:
                    all_relation_preds.extend(relation_scores[:len(positive_edges)].cpu().numpy())
                    all_relation_labels.extend(true_weights.numpy())
        
        # Calculate metrics
        if len(all_entity_preds) > 0:
            entity_auc = roc_auc_score(all_entity_labels, all_entity_preds)
            entity_ap = average_precision_score(all_entity_labels, all_entity_preds)
        else:
            entity_auc = entity_ap = 0.0
        
        if len(all_relation_preds) > 0:
            relation_mse = np.mean((np.array(all_relation_preds) - np.array(all_relation_labels)) ** 2)
        else:
            relation_mse = 0.0
        
        return {
            'entity_auc': entity_auc,
            'entity_ap': entity_ap,
            'relation_mse': relation_mse
        }
    
    def train(self, num_epochs=50, lr=0.01, train_ratio=0.8):
        """Full training pipeline"""
        
        # Split temporal snapshots
        split_idx = int(len(self.temporal_snapshots) * train_ratio)
        train_snapshots = self.temporal_snapshots[:split_idx]
        test_snapshots = self.temporal_snapshots[split_idx:]
        
        print(f"Training on {len(train_snapshots)} snapshots, testing on {len(test_snapshots)} snapshots")
        
        optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)
        
        best_auc = 0
        for epoch in range(num_epochs):
            
            # Train
            train_loss = self.train_epoch(optimizer, train_snapshots)
            
            # Evaluate
            if epoch % 5 == 0:
                metrics = self.evaluate(test_snapshots)
                
                print(f"Epoch {epoch:3d}: Loss={train_loss:.4f}, "
                      f"Entity AUC={metrics['entity_auc']:.4f}, "
                      f"Entity AP={metrics['entity_ap']:.4f}, "
                      f"Relation MSE={metrics['relation_mse']:.4f}")
                
                if metrics['entity_auc'] > best_auc:
                    best_auc = metrics['entity_auc']
                    torch.save(self.model.state_dict(), 'best_temporal_recommender.pt')
            
            scheduler.step()
        
        # Load best model
        self.model.load_state_dict(torch.load('best_temporal_recommender.pt'))
        
        # Final evaluation
        final_metrics = self.evaluate(test_snapshots)
        print(f"\nFinal Results:")
        print(f"Entity AUC: {final_metrics['entity_auc']:.4f}")
        print(f"Entity AP: {final_metrics['entity_ap']:.4f}")
        print(f"Relation MSE: {final_metrics['relation_mse']:.4f}")
        
        return final_metrics
    
    def recommend_products(self, user_id, top_k=10, exclude_existing=True):
        """
        Generate product recommendations for a user
        
        Args:
            user_id: Target user ID
            top_k: Number of recommendations
            exclude_existing: Whether to exclude products user already bought
        """
        self.model.eval()
        
        with torch.no_grad():
            x = self.tkg_data['x'].to(self.device)
            
            # Get all products
            all_product_ids = list(range(self.tkg_data['num_users'], 
                                       self.tkg_data['num_users'] + self.tkg_data['num_products']))
            
            # Exclude existing products if requested
            if exclude_existing:
                existing_products = set()
                for quad in self.tkg_data['quadruples']:
                    if quad[0] == user_id:  # user_id matches
                        existing_products.add(quad[2])  # product_id
                
                candidate_products = [pid for pid in all_product_ids if pid not in existing_products]
            else:
                candidate_products = all_product_ids
            
            if len(candidate_products) == 0:
                return []
            
            # Create dummy edge index (we need this for the forward pass)
            dummy_edges = [[user_id, pid] for pid in candidate_products[:min(100, len(candidate_products))]]
            edge_index = torch.tensor(dummy_edges, dtype=torch.long).t().contiguous().to(self.device)
            edge_attr = torch.zeros((len(dummy_edges), 1), dtype=torch.float).to(self.device)
            
            # Prepare indices for prediction
            user_indices = torch.tensor([user_id] * len(dummy_edges), dtype=torch.long).to(self.device)
            product_indices = torch.tensor([edge[1] for edge in dummy_edges], dtype=torch.long).to(self.device)
            
            # Get predictions
            entity_scores, relation_scores, _ = self.model(x, edge_index, edge_attr, user_indices, product_indices)
            
            # Combine scores (entity probability * predicted relation strength)
            combined_scores = entity_scores * torch.abs(relation_scores)  # Use absolute value for relation scores
            
            # Get top-k recommendations
            top_indices = torch.argsort(combined_scores, descending=True)[:top_k]
            
            recommendations = []
            for idx in top_indices:
                product_id = dummy_edges[idx][1]
                score = combined_scores[idx].item()
                entity_prob = entity_scores[idx].item()
                relation_strength = relation_scores[idx].item()
                
                recommendations.append({
                    'product_id': product_id,
                    'product_name': self.tkg_data.get('id_to_product', {}).get(product_id, f'Product_{product_id}'),
                    'score': score,
                    'entity_probability': entity_prob,
                    'relation_strength': relation_strength
                })
            
            return recommendations

In [3]:
# Load your TKG data
tkg_data = torch.load('graph/my_retail_tkg_pyg.pt', weights_only=False)
"""
# Initialize model
model = TemporalGCNRecommender(
    num_users=tkg_data['num_users'],
    num_products=tkg_data['num_products'],
    node_features_dim=tkg_data['x'].shape[1],
    hidden_dim=64
)

# Initialize trainer
trainer = TemporalRecommenderTrainer(model, tkg_data, device='cuda' if torch.cuda.is_available() else 'cpu')

# Train the model
results = trainer.train(num_epochs=30, lr=0.01)

# Get recommendations for user
user_id = 0  # First user
recommendations = trainer.recommend_products(user_id, top_k=10)

print(f"Top recommendations for user {user_id}:")
for i, rec in enumerate(recommendations, 1):
    print(f"{i}. {rec['product_name']}")
    print(f"   Score: {rec['score']:.3f}, Link Prob: {rec['link_probability']:.3f}, Pred Qty: {rec['predicted_quantity']:.2f}")
"""

'\n# Initialize model\nmodel = TemporalGCNRecommender(\n    num_users=tkg_data[\'num_users\'],\n    num_products=tkg_data[\'num_products\'],\n    node_features_dim=tkg_data[\'x\'].shape[1],\n    hidden_dim=64\n)\n\n# Initialize trainer\ntrainer = TemporalRecommenderTrainer(model, tkg_data, device=\'cuda\' if torch.cuda.is_available() else \'cpu\')\n\n# Train the model\nresults = trainer.train(num_epochs=30, lr=0.01)\n\n# Get recommendations for user\nuser_id = 0  # First user\nrecommendations = trainer.recommend_products(user_id, top_k=10)\n\nprint(f"Top recommendations for user {user_id}:")\nfor i, rec in enumerate(recommendations, 1):\n    print(f"{i}. {rec[\'product_name\']}")\n    print(f"   Score: {rec[\'score\']:.3f}, Link Prob: {rec[\'link_probability\']:.3f}, Pred Qty: {rec[\'predicted_quantity\']:.2f}")\n'

In [4]:
# Initialize model
model = TemporalGCNRecommender(
    num_users=tkg_data['num_users'],
    num_products=tkg_data['num_products'],
    node_features_dim=tkg_data['x'].shape[1],
    hidden_dim=64
)

# Initialize trainer
trainer = TemporalRecommenderTrainer(model, tkg_data, device='cuda' if torch.cuda.is_available() else 'cpu')

# Train the model
results = trainer.train(num_epochs=30, lr=0.01)

# Get recommendations for user
user_id = 0  # First user
recommendations = trainer.recommend_products(user_id, top_k=10)

print(f"Top recommendations for user {user_id}:")
for i, rec in enumerate(recommendations, 1):
    print(f"{i}. {rec['product_name']}")
    print(f"   Score: {rec['score']:.3f}, Link Prob: {rec['link_probability']:.3f}, Pred Qty: {rec['predicted_quantity']:.2f}")

Created 305 temporal snapshots
Training on 244 snapshots, testing on 61 snapshots


Training: 100%|██████████| 244/244 [35:09<00:00,  8.64s/it]
Evaluating: 100%|██████████| 61/61 [01:39<00:00,  1.63s/it]


Epoch   0: Loss=0.7111, Entity AUC=0.5000, Entity AP=0.4997, Relation MSE=0.0351


Training:  50%|████▉     | 121/244 [15:44<16:00,  7.81s/it]


KeyboardInterrupt: 

In [None]:
# Load trained model - takes seconds
model = TemporalGCNRecommender(
    num_users=tkg_data['num_users'],
    num_products=tkg_data['num_products'],
    node_features_dim=tkg_data['x'].shape[1],
    hidden_dim=64
)
model.load_state_dict(torch.load('best_temporal_recommender.pt'))

# Get recommendations instantly (milliseconds)
recommendations = trainer.recommend_products(user_id=1002, top_k=10)