Key Enhancements:
Batch Normalization: Added batch normalization after each GAT layer to stabilize the training and improve performance.

Deeper Architecture: Added an additional GAT layer to enhance the learning capability.

Regularization: Added weight decay to the optimizer to prevent overfitting.

Learning Rate: Reduced the learning rate for more stable training.

In [2]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import torch
from torch_geometric.data import Data
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv
from sklearn.metrics import accuracy_score

# Load the dataset
data = pd.read_csv('wholesale_banking_synthetic_data.csv')

# Display the first few rows of the dataset to understand its structure
print(data.head())

# Using appropriate columns for user_id, item_id, and interactions
data['user_id'] = data['Customer ID']
data['item_id'] = data['Product ID']
data['interaction'] = data['Transaction Amount']  # or 'Transaction Frequency' if more appropriate

# Encode user_id and item_id
user_ids = data['user_id'].unique()
item_ids = data['item_id'].unique()
user_map = {id: idx for idx, id in enumerate(user_ids)}
item_map = {id: idx for idx, id in enumerate(item_ids)}

data['user_id'] = data['user_id'].map(user_map)
data['item_id'] = data['item_id'].map(item_map)

# Normalize interaction values
scaler = StandardScaler()
data['interaction'] = scaler.fit_transform(data[['interaction']])

# Convert data to PyTorch tensors
users = torch.tensor(data['user_id'].values, dtype=torch.long)
items = torch.tensor(data['item_id'].values, dtype=torch.long)
interactions = torch.tensor(data['interaction'].values, dtype=torch.float)

# Split the data into training and testing sets
train_data, test_data = train_test_split(data, test_size=0.2, random_state=42)

# Define the enhanced GAT model
class EnhancedDotProductGAT(nn.Module):
    def __init__(self, num_users, num_items, embedding_dim, num_heads, dropout=0.3):
        super(EnhancedDotProductGAT, self).__init__()
        self.user_embedding = nn.Embedding(num_users, embedding_dim)
        self.item_embedding = nn.Embedding(num_items, embedding_dim)
        self.gat_conv1 = GATConv(embedding_dim, embedding_dim, heads=num_heads, concat=True)
        self.gat_conv2 = GATConv(embedding_dim * num_heads, embedding_dim, heads=num_heads, concat=False)
        self.dropout = nn.Dropout(dropout)
        self.batch_norm1 = nn.BatchNorm1d(embedding_dim * num_heads)
        self.batch_norm2 = nn.BatchNorm1d(embedding_dim)
        
    def forward(self, edge_index, edge_weight):
        x = torch.cat([self.user_embedding.weight, self.item_embedding.weight], dim=0)
        x = self.gat_conv1(x, edge_index, edge_weight)
        x = self.batch_norm1(x)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.gat_conv2(x, edge_index, edge_weight)
        x = self.batch_norm2(x)
        x = self.dropout(x)
        return x

# Define edge_index and edge_weight for the GAT model
edge_index = torch.stack([users, items + len(user_map)], dim=0)
edge_weight = interactions

# Create a Data object for PyTorch Geometric
train_data = Data(x=None, edge_index=edge_index, edge_attr=edge_weight)

# Define the training parameters
embedding_dim = 64
num_heads = 4
num_epochs = 100
learning_rate = 0.001
weight_decay = 1e-4
dropout = 0.3

# Initialize the model, loss function, and optimizer
model = EnhancedDotProductGAT(len(user_map), len(item_map), embedding_dim, num_heads, dropout)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
loss_fn = nn.MSELoss()

# Training loop
model.train()
for epoch in range(num_epochs):
    optimizer.zero_grad()
    out = model(train_data.edge_index, train_data.edge_attr)
    loss = loss_fn(out[train_data.edge_index[0]], out[train_data.edge_index[1]])
    loss.backward()
    optimizer.step()
    print(f'Epoch {epoch+1}, Loss: {loss.item()}')

# Evaluation and recommendation generation
model.eval()
with torch.no_grad():
    # Get the embeddings for users and items
    embeddings = model(train_data.edge_index, train_data.edge_attr)
    user_embeddings = embeddings[:len(user_map)]
    item_embeddings = embeddings[len(user_map):]

    # Compute the dot product for all user-item pairs
    scores = torch.matmul(user_embeddings, item_embeddings.t())

    # Generate predictions for the test set
    test_users = torch.tensor(test_data['user_id'].values, dtype=torch.long)
    test_items = torch.tensor(test_data['item_id'].values, dtype=torch.long)
    test_interactions = torch.tensor(test_data['interaction'].values, dtype=torch.float)
    
    test_scores = scores[test_users, test_items]

    # Convert scores to binary predictions (for simplicity, using a threshold of 0.5)
    threshold = 0.5 * (test_scores.max() - test_scores.min()) + test_scores.min()
    test_predictions = (test_scores >= threshold).float()

    # Calculate accuracy
    test_interactions_binary = (test_interactions >= threshold).float()
    accuracy = accuracy_score(test_interactions_binary.cpu(), test_predictions.cpu())
    
    print(f'Accuracy: {accuracy}')

    # Get the top N recommendations for a specific user
    user_id = 0  # Example user ID
    top_n = 10
    recommendations = torch.topk(scores[user_id], top_n).indices

    print(f'Top {top_n} recommendations for user {user_id}: {recommendations}')


  Customer ID  Company Size Product ID Transaction ID  \
0     CUS-400          9802   PROD-924       TRAN-818   
1     CUS-406          6641   PROD-101       TRAN-842   
2     CUS-025          1017   PROD-696       TRAN-126   
3     CUS-100          9661   PROD-223       TRAN-600   
4     CUS-319          7524   PROD-928       TRAN-760   

             Transaction Date  Transaction Amount  Transaction Frequency  \
0  2024-04-04 14:57:40.199522        38674.321084                      7   
1  2024-04-24 14:57:40.199522         5185.624659                      9   
2  2023-09-06 14:57:40.199522        18715.523229                      9   
3  2023-09-20 14:57:40.199522        16288.497862                      9   
4  2023-10-28 14:57:40.199522        10391.581761                     10   

   Transaction Value  Product Adoption  Product Usage  ...  \
0       55246.378535                 4              7  ...   
1       80132.003483                 4              8  ...   
2       94653.