# NN Training

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import HeteroConv, MessagePassing
from torch_geometric.data import HeteroData
from torch_geometric.transforms import ToUndirected
import math
import copy
from collections import defaultdict, deque
from sklearn.metrics import r2_score

DEVICE = torch.device('cpu')

GRAPH_PATH = "filtered_results/hetero_graph_small_1_0.pt"

HIDDEN_DIM = 64
NUM_LAYERS = 2
AGGREGATION = 'mean'
DROPOUT = 0.3
L2_WEIGHT_DECAY = 1e-4
LEARNING_RATE = 1e-3
LOSS_FUNCTION = 'RMSE'
NUM_EPOCHS = 50
BATCH_SIZE = 1024
EARLY_STOPPING_PATIENCE = 10
K = 10


data: HeteroData = torch.load(GRAPH_PATH)
num_users = data['user'].num_nodes
num_items = data['item'].num_nodes

ui_edge_index = data[('user', 'reviews', 'item')].edge_index
ui_edge_attr = data[('user', 'reviews', 'item')].edge_attr

if ('item', 'similar', 'item') in data.edge_index_dict:
    ii_edge_index = data[('item', 'similar', 'item')].edge_index
    ii_edge_attr = data[('item', 'similar', 'item')].edge_attr
    item_adj = defaultdict(list)
    for src, dst in zip(ii_edge_index[0], ii_edge_index[1]):
        s, d = src.item(), dst.item()
        item_adj[s].append(d)
        item_adj[d].append(s)

    all_items_set = set(ii_edge_index[0].cpu().tolist() + ii_edge_index[1].cpu().tolist())
    mask_ii = torch.tensor(
        [(src.item() in all_items_set) and (dst.item() in all_items_set)
         for src, dst in zip(ii_edge_index[0], ii_edge_index[1])],
        dtype=torch.bool
    )
    final_ii_edge_index = ii_edge_index[:, mask_ii]
    final_ii_edge_attr = ii_edge_attr[mask_ii]
else:
    final_ii_edge_index = None
    final_ii_edge_attr = None

data[('user', 'reviews', 'item')].edge_index = ui_edge_index
data[('user', 'reviews', 'item')].edge_attr = ui_edge_attr

if final_ii_edge_index is not None:
    data[('item', 'similar', 'item')].edge_index = final_ii_edge_index
    data[('item', 'similar', 'item')].edge_attr = final_ii_edge_attr

user_edges = defaultdict(list)
for i, (u, v) in enumerate(zip(ui_edge_index[0], ui_edge_index[1])):
    user_edges[u.item()].append(i)

train_idx, val_idx, test_idx = [], [], []

for user, edge_indices in user_edges.items():
    num_edges = len(edge_indices)
    torch.manual_seed(42)
    shuffled_indices = torch.randperm(num_edges)

    train_size = int(0.8 * num_edges)
    val_size = max(1, int(0.1 * num_edges))
    test_size = num_edges - train_size - val_size

    train_idx.extend([edge_indices[i] for i in shuffled_indices[:train_size]])
    val_idx.extend([edge_indices[i] for i in shuffled_indices[train_size:train_size + val_size]])
    test_idx.extend([edge_indices[i] for i in shuffled_indices[train_size + val_size:]])

train_idx = torch.tensor(train_idx, dtype=torch.long)
val_idx = torch.tensor(val_idx, dtype=torch.long)
test_idx = torch.tensor(test_idx, dtype=torch.long)

train_edge_index = ui_edge_index[:, train_idx]
train_edge_attr = ui_edge_attr[train_idx]
val_edge_index = ui_edge_index[:, val_idx]
val_edge_attr = ui_edge_attr[val_idx]
test_edge_index = ui_edge_index[:, test_idx]
test_edge_attr = ui_edge_attr[test_idx]

data[('user', 'reviews', 'item')].train_edge_index = train_edge_index
data[('user', 'reviews', 'item')].train_edge_attr = train_edge_attr
data[('user', 'reviews', 'item')].val_edge_index = val_edge_index
data[('user', 'reviews', 'item')].val_edge_attr = val_edge_attr
data[('user', 'reviews', 'item')].test_edge_index = test_edge_index
data[('user', 'reviews', 'item')].test_edge_attr = test_edge_attr

data = ToUndirected()(data)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import torch
import os

split_dir = "splits"
os.makedirs(split_dir, exist_ok=True)

torch.save({
    'train_idx': train_idx,
    'val_idx': val_idx,
    'test_idx': test_idx
}, os.path.join(split_dir, "splits.pt"))


In [3]:
graph_files = [
    "filtered_results/hetero_graph_small_1_0.pt",
    "filtered_results/hetero_graph_small_0_1.pt",
    "filtered_results/hetero_graph_small_02_08.pt",
    "filtered_results/hetero_graph_small_08_02.pt",
    "filtered_results/hetero_graph_small_05_05.pt"
]

split_file = os.path.join("splits", "splits.pt")
splits = torch.load(split_file)
train_idx = splits['train_idx']
val_idx = splits['val_idx']
test_idx = splits['test_idx']

def apply_splits(data: HeteroData, train_idx, val_idx, test_idx) -> HeteroData:
    """Given a data and precomputed splits, apply them to user-item edges."""
    ui_edge_index = data[('user', 'reviews', 'item')].edge_index
    ui_edge_attr = data[('user', 'reviews', 'item')].edge_attr

    train_edge_index = ui_edge_index[:, train_idx]
    train_edge_attr = ui_edge_attr[train_idx]
    val_edge_index = ui_edge_index[:, val_idx]
    val_edge_attr = ui_edge_attr[val_idx]
    test_edge_index = ui_edge_index[:, test_idx]
    test_edge_attr = ui_edge_attr[test_idx]

    data[('user', 'reviews', 'item')].train_edge_index = train_edge_index
    data[('user', 'reviews', 'item')].train_edge_attr = train_edge_attr
    data[('user', 'reviews', 'item')].val_edge_index = val_edge_index
    data[('user', 'reviews', 'item')].val_edge_attr = val_edge_attr
    data[('user', 'reviews', 'item')].test_edge_index = test_edge_index
    data[('user', 'reviews', 'item')].test_edge_attr = test_edge_attr

    data = ToUndirected()(data)
    return data

all_datasets = {}
for gf in graph_files:
    gdata: HeteroData = torch.load(gf)

    gdata = apply_splits(gdata, train_idx, val_idx, test_idx)
    all_datasets[gf] = gdata


In [4]:
class WeightedSAGEConv(MessagePassing):
    def __init__(self, in_channels_src, in_channels_dst, out_channels, aggr='mean'):
        super().__init__(aggr=None)
        self.in_channels_src = in_channels_src
        self.in_channels_dst = in_channels_dst
        self.out_channels = out_channels
        self.aggr = aggr

        self.lin_l = nn.Linear(in_channels_src, out_channels)
        self.lin_r = nn.Linear(in_channels_dst, out_channels)
        self.lin_update = nn.Linear(out_channels, out_channels)
        self.reset_parameters()

    def reset_parameters(self):
        self.lin_l.reset_parameters()
        self.lin_r.reset_parameters()
        self.lin_update.reset_parameters()

    def forward(self, x, edge_index, edge_weight=None):
        if isinstance(x, tuple):
            x_src, x_dst = x
        else:
            x_src = x_dst = x

        src, dst = edge_index

        x_src_trans = self.lin_l(x_src)  
        x_dst_trans = self.lin_r(x_dst)  

        if edge_weight is not None:
            messages = x_src_trans[src] * edge_weight.view(-1, 1)
        else:
            messages = x_src_trans[src]

        out = torch.zeros_like(x_dst_trans)
        out.scatter_add_(0, dst.unsqueeze(-1).expand(-1, messages.size(-1)), messages)

        if self.aggr == 'mean':
            count = torch.zeros(x_dst_trans.size(0), device=x_dst_trans.device, dtype=torch.long)
            count.scatter_add_(0, dst, torch.ones(messages.size(0), device=x_dst_trans.device, dtype=torch.long))
            count = torch.clamp(count, min=1)
            out = out / count.unsqueeze(-1)

        out = out + x_dst_trans
        out = self.lin_update(out)

        return out


def rmse_loss(preds, targets):
    return torch.sqrt(torch.mean((preds - targets)**2))

def mae_loss(preds, targets):
    return torch.mean(torch.abs(preds - targets))

def calculate_r2_score(preds, targets):
    preds = preds.detach().cpu().numpy()
    targets = targets.detach().cpu().numpy()
    return r2_score(targets, preds)

def evaluate(model, data, edge_index, edge_attr):
    model.eval()
    with torch.no_grad():
        emb_dict = model.get_embeddings(data)
        user_ids = edge_index[0]
        item_ids = edge_index[1]
        preds = model.predict(user_ids, item_ids, emb_dict)
        targets = edge_attr
        rmse = rmse_loss(preds, targets)
        mae = mae_loss(preds, targets)
    return rmse.item(), mae.item()

def weighted_rmse_loss(preds, targets, item_ids, data):
    similar_item_edge_index = data[('item', 'similar', 'item')].edge_index
    similar_items_set = set(similar_item_edge_index[0].tolist())

    has_similar_items = torch.tensor(
        [1.0 if item_id in similar_items_set else 0.0 for item_id in item_ids],
        dtype=torch.float32, device=DEVICE
    )

    weights = has_similar_items + 1.0
    loss = torch.sqrt(torch.mean(((preds - targets) ** 2) * weights))
    return loss


class GraphSAGEModel(nn.Module):
    def __init__(self, num_users, num_items,
                 hidden_dim=HIDDEN_DIM, num_layers=NUM_LAYERS, aggr=AGGREGATION, dropout=DROPOUT):
        super().__init__()
        self.user_emb = nn.Embedding(num_users, hidden_dim)
        self.item_emb = nn.Embedding(num_items, hidden_dim)
        nn.init.xavier_uniform_(self.user_emb.weight)
        nn.init.xavier_uniform_(self.item_emb.weight)

        self.convs = nn.ModuleList()
        for _ in range(num_layers):
            conv = HeteroConv(
                {
                    ('user', 'reviews', 'item'): WeightedSAGEConv(hidden_dim, hidden_dim, hidden_dim, aggr=aggr),
                    ('item', 'rev_reviews', 'user'): WeightedSAGEConv(hidden_dim, hidden_dim, hidden_dim, aggr=aggr),
                    ('item', 'similar', 'item'): WeightedSAGEConv(hidden_dim, hidden_dim, hidden_dim, aggr=aggr)
                },
                aggr='mean'
            )
            self.convs.append(conv)
        self.dropout = dropout
        self.pred_layer = nn.Linear(2 * hidden_dim, 1)

    def forward(self, x_dict, edge_index_dict, edge_attr_dict):
        for conv in self.convs:
            x_dict = conv(x_dict, edge_index_dict, edge_weight_dict=edge_attr_dict)
            for node_type in x_dict:
                x_dict[node_type] = F.relu(x_dict[node_type])
                x_dict[node_type] = F.dropout(x_dict[node_type], p=self.dropout, training=self.training)
        return x_dict

    def get_embeddings(self, data: HeteroData):
        x_dict = {
            'user': self.user_emb.weight,
            'item': self.item_emb.weight
        }

        edge_attr_dict = {}
        if ('user', 'reviews', 'item') in data.edge_index_dict:
            ui_edge_attr = data[('user', 'reviews', 'item')].edge_attr
            edge_attr_dict[('user', 'reviews', 'item')] = ui_edge_attr.view(-1)

        if ('item', 'rev_reviews', 'user') in data.edge_index_dict:
            if 'edge_attr' in data[('item', 'rev_reviews', 'user')]:
                iu_edge_attr = data[('item', 'rev_reviews', 'user')].edge_attr
            else:
                iu_edge_attr = torch.ones(data[('item', 'rev_reviews', 'user')].edge_index.size(1), device=DEVICE)
            edge_attr_dict[('item', 'rev_reviews', 'user')] = iu_edge_attr.view(-1)

        if ('item', 'similar', 'item') in data.edge_index_dict:
            ii_edge_attr = data[('item', 'similar', 'item')].edge_attr
            edge_attr_dict[('item', 'similar', 'item')] = ii_edge_attr.view(-1)

        x_dict = {k: v.to(DEVICE) for k,v in x_dict.items()}
        for k in data.edge_index_dict:
            data.edge_index_dict[k] = data.edge_index_dict[k].to(DEVICE)
        for k in edge_attr_dict:
            edge_attr_dict[k] = edge_attr_dict[k].to(DEVICE)

        with torch.no_grad():
            self.eval()
            emb_dict = self.forward(x_dict, data.edge_index_dict, edge_attr_dict)

        return emb_dict

    def predict(self, user_ids, item_ids, emb_dict):
        user_emb = emb_dict['user'][user_ids]
        item_emb = emb_dict['item'][item_ids]
        x = torch.cat([user_emb, item_emb], dim=-1)
        preds = self.pred_layer(x)
        return preds.view(-1)

In [5]:
def train_model(model, data, num_epochs=NUM_EPOCHS):
    optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=L2_WEIGHT_DECAY)
    
    model.to(DEVICE)
    for store in data.edge_index_dict.values():
        store.to(DEVICE)
    
    data[('user', 'reviews', 'item')].train_edge_index = data[('user', 'reviews', 'item')].train_edge_index.to(DEVICE)
    data[('user', 'reviews', 'item')].train_edge_attr = data[('user', 'reviews', 'item')].train_edge_attr.to(DEVICE)
    data[('user', 'reviews', 'item')].val_edge_index = data[('user', 'reviews', 'item')].val_edge_index.to(DEVICE)
    data[('user', 'reviews', 'item')].val_edge_attr = data[('user', 'reviews', 'item')].val_edge_attr.to(DEVICE)
    data[('user', 'reviews', 'item')].test_edge_index = data[('user', 'reviews', 'item')].test_edge_index.to(DEVICE)
    data[('user', 'reviews', 'item')].test_edge_attr = data[('user', 'reviews', 'item')].test_edge_attr.to(DEVICE)
    
    best_val_rmse = float('inf')
    best_model = None
    patience_counter = 0
    
    for epoch in range(1, num_epochs+1):
        model.train()
        optimizer.zero_grad()
        
        emb_dict = {
            'user': model.user_emb.weight,
            'item': model.item_emb.weight
        }
        
        edge_attr_dict = {}
        
        if ('user', 'reviews', 'item') in data.edge_index_dict:
            edge_attr_dict[('user', 'reviews', 'item')] = data[('user', 'reviews', 'item')].edge_attr.view(-1)
        
        if ('item', 'rev_reviews', 'user') in data.edge_index_dict:
            if 'edge_attr' in data[('item', 'rev_reviews', 'user')]:
                edge_attr_dict[('item', 'rev_reviews', 'user')] = data[('item', 'rev_reviews', 'user')].edge_attr.view(-1)
            else:
                edge_attr_dict[('item', 'rev_reviews', 'user')] = torch.ones(data[('item', 'rev_reviews', 'user')].edge_index.size(1), device=DEVICE)
        
        if ('item', 'similar', 'item') in data.edge_index_dict:
            edge_attr_dict[('item', 'similar', 'item')] = data[('item', 'similar', 'item')].edge_attr.view(-1)
        
        x_dict = model.forward(emb_dict, data.edge_index_dict, edge_attr_dict)
        
        train_users = data[('user', 'reviews', 'item')].train_edge_index[0]
        train_items = data[('user', 'reviews', 'item')].train_edge_index[1]
        train_preds = model.predict(train_users, train_items, x_dict)
        train_targets = data[('user', 'reviews', 'item')].train_edge_attr
        
        train_min_score = train_preds.min().item()
        train_max_score = train_preds.max().item()


        train_rmse = rmse_loss(train_preds, train_targets)
        train_r2 = calculate_r2_score(train_preds, train_targets)

        train_rmse.backward()
        optimizer.step()
        

        train_loss = train_rmse.item() if isinstance(train_rmse, torch.Tensor) else float(train_rmse)

        val_rmse, val_mae = evaluate(
            model, data,
            data[('user', 'reviews', 'item')].val_edge_index,
            data[('user', 'reviews', 'item')].val_edge_attr
        )


        val_preds = model.predict(
            data[('user', 'reviews', 'item')].val_edge_index[0],
            data[('user', 'reviews', 'item')].val_edge_index[1],
            x_dict
        )
        val_targets = data[('user', 'reviews', 'item')].val_edge_attr
        val_r2 = calculate_r2_score(val_preds, val_targets)

        val_min_score = val_preds.min().item()
        val_max_score = val_preds.max().item()


        val_rmse = float(val_rmse)
        val_mae = float(val_mae)
        val_r2 = float(val_r2)
        print(
            f"Epoch: {epoch} | "
            f"Train RMSE: {train_rmse.item():.4f}, Train R²: {train_r2:.4f}, Train Min: {train_min_score:.4f}, Train Max: {train_max_score:.4f} | "
            f"Val RMSE: {val_rmse:.4f}, Val R²: {val_r2:.4f}, Val Min: {val_min_score:.4f}, Val Max: {val_max_score:.4f}"
        )

        if val_rmse < best_val_rmse:
            best_val_rmse = val_rmse
            best_model = copy.deepcopy(model.state_dict())
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= EARLY_STOPPING_PATIENCE:
                print("Early stopping triggered.")
                break
    
    if best_model is not None:
        model.load_state_dict(best_model)
    return model



In [6]:
results = {}
all_models = {}

for graph_name, gdata in all_datasets.items():
    print(f"\n=== Training on {graph_name} ===")
    
    num_users = gdata['user'].num_nodes
    num_items = gdata['item'].num_nodes
    
    model = GraphSAGEModel(num_users, num_items)
    model.to(DEVICE)
    
    model = train_model(model, gdata, num_epochs=NUM_EPOCHS)
    
    all_models[graph_name] = model

    test_rmse, test_mae = evaluate(
        model, gdata,
        gdata[('user', 'reviews', 'item')].test_edge_index,
        gdata[('user', 'reviews', 'item')].test_edge_attr
    )
    
    print(f"For {graph_name}, Test RMSE: {test_rmse:.4f}, Test MAE: {test_mae:.4f}")
    
    results[graph_name] = (test_rmse, test_mae)

print("\n=== Summary of Results ===")
for graph_name, (rmse, mae) in results.items():
    print(f"{graph_name} -> Test RMSE: {rmse:.4f}, Test MAE: {mae:.4f}")



=== Training on filtered_results/hetero_graph_small_1_0.pt ===
Epoch: 1 | Train RMSE: 0.7775, Train R²: -20.5886, Train Min: -0.1590, Train Max: 0.0018 | Val RMSE: 0.7475, Val R²: -19.6352, Val Min: -0.1527, Val Max: -0.0036
Epoch: 2 | Train RMSE: 0.7414, Train R²: -18.6109, Train Min: -0.1271, Train Max: 0.0313 | Val RMSE: 0.7142, Val R²: -17.8327, Val Min: -0.1131, Val Max: 0.0391
Epoch: 3 | Train RMSE: 0.7094, Train R²: -16.9320, Train Min: -0.0911, Train Max: 0.0561 | Val RMSE: 0.6806, Val R²: -16.1551, Val Min: -0.1016, Val Max: 0.0672
Epoch: 4 | Train RMSE: 0.6745, Train R²: -15.2094, Train Min: -0.0794, Train Max: 0.0979 | Val RMSE: 0.6457, Val R²: -14.5174, Val Min: -0.0662, Val Max: 0.1113
Epoch: 5 | Train RMSE: 0.6395, Train R²: -13.5585, Train Min: -0.0340, Train Max: 0.1471 | Val RMSE: 0.6094, Val R²: -12.9258, Val Min: -0.0392, Val Max: 0.1662
Epoch: 6 | Train RMSE: 0.6036, Train R²: -11.9452, Train Min: -0.0134, Train Max: 0.1915 | Val RMSE: 0.5724, Val R²: -11.3776, Val

In [7]:
DEVICE = torch.device('cpu')

reference_graph_name = "filtered_results/hetero_graph_small_1_0.pt"
ref_data = all_datasets[reference_graph_name]

ref_val_user_ids = ref_data[('user', 'reviews', 'item')].val_edge_index[0].cpu()
ref_val_item_ids = ref_data[('user', 'reviews', 'item')].val_edge_index[1].cpu()
ref_val_targets = ref_data[('user', 'reviews', 'item')].val_edge_attr.cpu()

user_val_counts = {}
for u in ref_val_user_ids.tolist():
    user_val_counts[u] = user_val_counts.get(u, 0) + 1

selected_users = [u for u, count in user_val_counts.items() if count > 3]

user_item_pairs = {}
for u in selected_users:
    user_mask = (ref_val_user_ids == u)
    user_val_items = ref_val_item_ids[user_mask]
    user_val_targets = ref_val_targets[user_mask]

    user_item_pairs[u] = (user_val_items, user_val_targets)

###########################################################
# Now evaluate on all graphs using the SAME user/item pairs
###########################################################

for graph_name, gdata in all_datasets.items():
    model = all_models[graph_name]

    print(f"\nEvaluating on {graph_name} using the same users and items:")

    model.eval()
    with torch.no_grad():
        emb_dict = model.get_embeddings(gdata)

        for u in selected_users:
            user_val_items, user_val_targets = user_item_pairs[u]

            user_ids_for_prediction = torch.tensor([u]*len(user_val_items), device=DEVICE)
            user_val_items = user_val_items.to(DEVICE)
            user_preds = model.predict(user_ids_for_prediction, user_val_items, emb_dict)

            print(f"\nValidation set predictions for user {u} in {graph_name}:")
            for it, actual_s, pred_s in zip(user_val_items.cpu().tolist(), user_val_targets.tolist(), user_preds.cpu().tolist()):
                if isinstance(actual_s, list) and len(actual_s) == 1:
                    actual_s = float(actual_s[0])
                elif isinstance(actual_s, list):
                    print(f"  - Item {it}, Actual: {actual_s}, Predicted: {pred_s:.4f}")
                    continue

                print(f"  - Item {it}, Actual: {actual_s:.4f}, Predicted: {pred_s:.4f}")


Evaluating on filtered_results/hetero_graph_small_1_0.pt using the same users and items:

Validation set predictions for user 265 in filtered_results/hetero_graph_small_1_0.pt:
  - Item 34647, Actual: 0.3383, Predicted: 0.6275
  - Item 35714, Actual: 0.8333, Predicted: 0.7058
  - Item 24654, Actual: 0.7000, Predicted: 0.7475
  - Item 45943, Actual: 0.7500, Predicted: 0.6550
  - Item 6262, Actual: 0.7000, Predicted: 0.6517

Validation set predictions for user 708 in filtered_results/hetero_graph_small_1_0.pt:
  - Item 20012, Actual: 0.8479, Predicted: 0.6270
  - Item 44175, Actual: 0.8232, Predicted: 0.6423
  - Item 15872, Actual: 0.7198, Predicted: 0.6935
  - Item 9675, Actual: 0.6176, Predicted: 0.6298
  - Item 13294, Actual: 0.6846, Predicted: 0.6684
  - Item 36142, Actual: 0.7540, Predicted: 0.6361
  - Item 27952, Actual: 0.5807, Predicted: 0.6884
  - Item 8874, Actual: 0.5783, Predicted: 0.6668

Validation set predictions for user 1370 in filtered_results/hetero_graph_small_1_0.pt

In [15]:
import json
import pandas as pd
import torch

metadata_csv_path = 'datasets_filtered/filtered_metadata.csv'
metadata_jsonl_path = 'amazon_db/metadata.jsonl'
reviews_csv_path = 'datasets_filtered/filtered_reviews.csv'

metadata = pd.read_csv(metadata_csv_path)
item_asins = metadata['parent_asin'].tolist()

asin_to_info = {}
with open(metadata_jsonl_path, 'r', encoding='utf-8') as file:
    for line in file:
        record = json.loads(line)
        parent_asin = record.get('parent_asin')
        title = record.get('title', 'No title available')
        images = record.get('images', [])
        image_url = images[0].get('large') if images else 'No image available'
        asin_to_info[parent_asin] = {'title': title, 'image_url': image_url}

reviews = pd.read_csv(reviews_csv_path)

def get_user_id(user_index):
    user_ids = reviews['user_id'].unique()
    return user_ids[user_index] if user_index < len(user_ids) else None

def get_previous_purchases(user_id):
    user_reviews = reviews[reviews['user_id'] == user_id]
    purchased_asins = user_reviews['parent_asin'].unique()
    purchased_items = []

    for asin in purchased_asins:
        item_info = asin_to_info.get(asin, {'title': 'Not found', 'image_url': 'Not found'})
        purchased_items.append({'asin': asin, 'title': item_info['title'], 'image_url': item_info['image_url']})
    
    return purchased_items

def display_item_predictions(graph_name, user_index, user_item_pairs, model, gdata, display_purchases_once=True):
    print(f"\nEvaluating on {graph_name} with item details:")

    model.eval()
    with torch.no_grad():
        emb_dict = model.get_embeddings(gdata)

        user_id = get_user_id(user_index)
        if not user_id:
            print(f"User index {user_index} does not have a corresponding user_id.")
            return

        if display_purchases_once:
            previous_purchases = get_previous_purchases(user_id)
            print(f"\nUser {user_id} (Index {user_index}) Previous Purchases:")
            for item in previous_purchases:
                print(f"  - Item {item['asin']}, Title: {item['title']}")
                print(f"    Image URL: {item['image_url']}")
                print("-" * 50)
            display_purchases_once = False

        user_val_items, user_val_targets = user_item_pairs[user_index]

        user_ids_for_prediction = torch.tensor([user_index] * len(user_val_items), device=DEVICE)
        user_val_items = user_val_items.to(DEVICE)
        user_preds = model.predict(user_ids_for_prediction, user_val_items, emb_dict)

        print(f"\nRecommended Items for User {user_id} (Index {user_index}):")
        for item_idx, actual_s, pred_s in zip(user_val_items.cpu().tolist(), user_val_targets, user_preds.cpu().tolist()):
            asin = item_asins[item_idx] if item_idx < len(item_asins) else 'Unknown ASIN'
            item_info = asin_to_info.get(asin, {'title': 'Not found', 'image_url': 'Not found'})

            if isinstance(actual_s, list):
                actual_s_str = ', '.join([f"{float(val):.4f}" for val in actual_s])
            else:
                actual_s_str = f"{float(actual_s):.4f}"

            print(f"  - Item {asin}, Title: {item_info['title']}")
            print(f"    Actual: {actual_s_str}, Predicted: {pred_s:.4f}")
            print(f"    Image URL: {item_info['image_url']}")
            print("-" * 50)


single_user = selected_users[1]

display_purchases_once = True

for graph_name, gdata in all_datasets.items():
    model = all_models[graph_name]
    display_item_predictions(graph_name, single_user, user_item_pairs, model, gdata, display_purchases_once)
    display_purchases_once = False



Evaluating on filtered_results/hetero_graph_small_1_0.pt with item details:

User AGGGJOYI5L5LXQKIEF6WIJYQIE7Q (Index 708) Previous Purchases:
  - Item B07HGT7JC8, Title: Final Fantasy X & X-2 HD Remaster - Xbox One
    Image URL: https://m.media-amazon.com/images/I/51zvayrOyFL.jpg
--------------------------------------------------
  - Item B0BNP57X61, Title: Mato Anomalies - PlayStation 5
    Image URL: https://m.media-amazon.com/images/I/519MyWf79iL.jpg
--------------------------------------------------
  - Item B0BSQNTBSY, Title: Resident Evil 4 - Xbox Series X
    Image URL: https://m.media-amazon.com/images/I/51MmxHL3MsL._AC_.jpg
--------------------------------------------------
  - Item B0BNWF36TD, Title: Record of Agarest War
    Image URL: https://m.media-amazon.com/images/I/51xJyPf5qwL.jpg
--------------------------------------------------
  - Item B0BNQS48Q4, Title: Atomic Heart PS5
    Image URL: https://m.media-amazon.com/images/I/515dBYImfgL.jpg
-------------------------