In [None]:
import pandas as pd
import torch
import numpy as np
from tqdm import tqdm
from torch.nn import functional as F
from torch_geometric.data import Data
from torch.utils.tensorboard import SummaryWriter
from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score, roc_auc_score
import os
from sklearn.metrics import roc_curve, auc
import matplotlib.pyplot as plt

from torch_geometric.utils import to_undirected, negative_sampling, add_self_loops

import torch
from torch_geometric.utils import to_undirected, negative_sampling, structured_negative_sampling
from torch_geometric.data import Data
import random
from collections import defaultdict, deque
from torch_geometric.nn import GraphSAGE, GAT, GIN


In [None]:
# processed_co_purchase_df = pd.read_csv('kcore70new.csv')
processed_co_purchase_df = pd.read_csv('kcore5new.csv')
processed_co_purchase_df.head()

In [None]:
asin_list = list(processed_co_purchase_df['asin'].unique())

In [None]:
image_embeddings = np.load("image_embeddings_k_5.npy")
text_embeddings = np.load("text_embeddings_k_5.npy")
edge_index = np.load("k5_edge_index.npy")

# image_embeddings = np.load("image_embeddings.npy")
# text_embeddings = np.load("text_embeddings.npy")
# edge_index = np.load("kcore70newedges.npy")
random_embeddings = torch.randn_like(torch.tensor(text_embeddings.astype(np.float32))) / 10

combined_embedings = np.concatenate((image_embeddings, text_embeddings), axis = 1)
product_feature_dim = 1024
num_products = len(asin_list)

In [None]:
import torch_geometric.transforms as T
from torch_geometric.utils import to_undirected


data = Data(x=torch.tensor(combined_embedings.astype(np.float32)), edge_index=torch.tensor(edge_index))
# data = Data(x=combined_embedings, edge_index=torch.tensor(edge_index))

data.edge_index = to_undirected(data.edge_index)



In [None]:
def convert_edge_index_to_adj_list(edge_index):
    adj_list = defaultdict(list)
    for src, dest in zip(edge_index[0], edge_index[1]):
        adj_list[src.item()].append(dest.item())
    return adj_list


In [None]:
# In the first hop, we sample at most 20 neighbors.
# In the second hop, we sample at most 10 neighbors.
# In addition, during training, we want to sample negative edges on-the-fly with
# a ratio of 2:1.
# We can make use of the `loader.LinkNeighborLoader` from PyG:
from torch_geometric.loader import LinkNeighborLoader

# For this, we first split the set of edges into
# training (80%), validation (10%), and testing edges (10%).
# Across the training edges, we use 70% of edges for message passing,
# and 30% of edges for supervision.
# We further want to generate fixed negative edges for evaluation with a ratio of 2:1.
# Negative edges during training will be generated on-the-fly.
# We can leverage the `RandomLinkSplit()` transform for this from PyG:
transform = T.RandomLinkSplit(
    num_val=0.05,
    num_test=0.05,
    is_undirected=True,
    add_negative_train_samples = False,
    neg_sampling_ratio=0.0,
    disjoint_train_ratio=0.2)
    
train_data, val_data, test_data = transform(data)

# Define seed edges:
train_loader = LinkNeighborLoader(
    data=train_data,
    num_neighbors=[10, 10, 10],
    edge_label_index=train_data.edge_label_index,
    edge_label=train_data.edge_label,
    batch_size=64,
    shuffle=True,
)

val_loader = LinkNeighborLoader(
    data = val_data, 
    num_neighbors=[10, 10, 10],
    edge_label_index=val_data.edge_label_index,
    edge_label=val_data.edge_label,
    batch_size = 1,
    shuffle = True
)

test_loader = LinkNeighborLoader(
    data = test_data, 
    num_neighbors=[10, 10, 10],
    edge_label_index=test_data.edge_label_index,
    edge_label=test_data.edge_label,
    batch_size = 1,
    shuffle = True
)


In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [None]:
def convert_edge_index_to_adj_list(edge_index):
    adj_list = defaultdict(list)
    for src, dest in zip(edge_index[0], edge_index[1]):
        adj_list[src.item()].append(dest.item())
    return adj_list

def bfs_sample(adj_list, start_node, max_depth):
    visited = set()
    queue = deque([(start_node, 0)])
    visited.add(start_node)
    
    while queue:
        current_node, depth = queue.popleft()
        if depth == max_depth:
            return current_node
        
        for neighbor in adj_list[current_node]:
            if neighbor not in visited:
                visited.add(neighbor)
                queue.append((neighbor, depth + 1))
    
    return None

def add_bfs_negatives(data, m=3, k=3):
    adj_list = convert_edge_index_to_adj_list(data.edge_index)
    new_edges = []
    new_labels = []
    new_edge_label_index = []

    for edge in tqdm(range(data.edge_label_index.size(1))):
        start_nodes = [data.edge_label_index[0, edge].item(), data.edge_label_index[1, edge].item()]
        for start_node in start_nodes:
            for _ in range(k):
                depth = random.randint(2, m)
                neg_node = bfs_sample(adj_list, start_node, depth)
                if neg_node is not None and neg_node not in adj_list[start_node]:
                    new_edges.append([start_node, neg_node])
                    new_labels.append(0)  # Label for negative samples
                    new_edge_label_index.append([start_node, neg_node])

    new_edges = torch.tensor(new_edges, dtype=torch.long).t()
    new_labels = torch.tensor(new_labels, dtype=torch.float)
    new_edge_label_index = torch.tensor(new_edge_label_index, dtype=torch.long).t()

    data.edge_index = torch.cat([data.edge_index, new_edges], dim=1)
    data.edge_label = torch.cat([data.edge_label, new_labels])
    data.edge_label_index = torch.cat([data.edge_label_index, new_edge_label_index], dim=1)

    return data


In [None]:
def append_negative_samples(data, negative_samples):
    """
    Appends negative edge samples to the data object.

    Args:
    data (Data): The data object containing graph information.
    negative_samples (tuple): A tuple of (i, k) where (i, k) are negative edges.

    Returns:
    Updated data object with negative samples appended.
    """

    # Unpack the negative samples (only i and k are present)
    i, k = negative_samples

    # Create a tensor of zeros for the negative edge labels
    negative_edge_labels = torch.zeros(i.size(0), dtype=data.edge_label.dtype)

    # Append the negative edge labels to the edge_label attribute
    data.edge_label = torch.cat([data.edge_label, negative_edge_labels.to(data.edge_label.device)], dim=0)

    # Create the negative edge label index (i, k)
    negative_edge_label_index = torch.stack([i, k], dim=0)

    # Append the negative edge label index to the edge_label_index attribute
    data.edge_label_index = torch.cat([data.edge_label_index, negative_edge_label_index], dim=1)

    return data


In [None]:
def generate_negative_samples(data, num_neg_samples):
    """
    Generates and appends negative edge samples to the data object.

    Args:
    data (Data): The data object containing graph information.
    num_neg_samples (int): The number of negative samples to generate.

    Returns:
    Updated data object with negative samples appended.
    """

    # Ensure num_neg_samples is valid
    if num_neg_samples < 1:
        raise ValueError("Number of negative samples must be at least 1.")

    all_i = []
    all_k = []

    for _ in range(num_neg_samples):
        # Generate negative samples
        i, _, k = structured_negative_sampling(data.edge_label_index, data.num_nodes)
        
        # Aggregate the i and k components of the negative samples
        all_i.append(i)
        all_k.append(k)

    # Combine all i and k components
    combined_i = torch.cat(all_i, dim=0)
    combined_k = torch.cat(all_k, dim=0)
    combined_negative_samples = (combined_i, combined_k)

    # Append the combined negative samples to the data object
    return append_negative_samples(data, combined_negative_samples)


In [None]:
def old_generate_negative_samples(data, num_neg_samples):
    """
    Generates and appends negative edge samples to the data object.

    Args:
    data (Data): The data object containing graph information.
    num_neg_samples (int): The number of negative samples to generate.

    Returns:
    Updated data object with negative samples appended.
    """

    new_edge_label_index = []
    new_edge_labels = []

    for pos_edge in range(data.edge_label_index.size(1)):
        # Keep the positive edge
        new_edge_label_index.append(data.edge_label_index[:, pos_edge].view(2, 1))  # Reshape for consistency
        new_edge_labels.append(data.edge_label[pos_edge].view(1))

        # Generate negative samples for this positive edge
        for _ in range(num_neg_samples):
            _, _, neg_edge = structured_negative_sampling(data.edge_label_index[:, [pos_edge]], data.num_nodes)
            new_edge_label_index.append(neg_edge.view(2, 1))  # Reshape for consistency
            new_edge_labels.append(torch.tensor([0], dtype=data.edge_label.dtype))  # Negative label

    # Concatenate all new edges and labels
    data.edge_label_index = torch.cat(new_edge_label_index, dim=1)
    data.edge_label = torch.cat(new_edge_labels, dim=0)

    return data


In [None]:

model = GraphSAGE(
    in_channels=product_feature_dim,
    hidden_channels=256,
    out_channels=256,
    dropout = 0.1,
    project = True,
    num_layers=3).to(device)
    

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)


In [None]:

model = GAT(
    in_channels=product_feature_dim,
    hidden_channels=256,
    out_channels=256,
    num_layers = 3,
    dropout = 0.02,
    heads=4).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

In [None]:
model = GIN(
    in_channels = product_feature_dim, 
    hidden_channels=256,
    out_channels = 256,
    dropout = 0.1,
    num_layers = 3
).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)



In [None]:
experiment_name = "iter_SAGE_final_2"  # replace with your experiment name
num_iters = 500
logging_freq = 50
num_val_batch_sample = 30
num_epochs = 1
num_neg_sample = 4
top_k = 3


os.makedirs(experiment_name, exist_ok=True)
writer = SummaryWriter(log_dir=experiment_name)


def recover_logits(model, batch):
    z = model(batch.x, batch.edge_index)

    head_embeddings = z[batch.edge_label_index[0]]
    head_embeddings = head_embeddings / head_embeddings.norm(dim = 1, keepdim = True) #think about adding L2 Norm after every layer if applicable. is important. 

    tail_embeddings = z[batch.edge_label_index[1]]
    tail_embeddings = tail_embeddings / tail_embeddings.norm(dim = 1, keepdim = True)


    link_logits = (head_embeddings * tail_embeddings).sum(dim=1)


    return link_logits


def get_embeddings(model, batch):
    z = model(batch.x, batch.edge_index)
    z = z / z.norm(dim=1, keepdim=True)  # Normalizing embeddings
    return z

def compute_dot_products(embeddings, head_node_index):
    head_embedding = embeddings[head_node_index]
    dot_products = torch.matmul(embeddings, head_embedding.unsqueeze(-1)).squeeze(-1)
    return dot_products


def recall_at_k(dot_products, actual_tail_index, k=10):
    print(dot_products.shape)
    top_k_scores, top_k_indices = torch.topk(dot_products, k)
    return 1 if actual_tail_index in top_k_indices else 0



def bpr_loss(pos_score, neg_scores):
    return -torch.log(torch.sigmoid(pos_score - neg_scores)).mean()

total_loss = 0.0
for epoch in range(num_epochs):
    for i, batch in tqdm(enumerate(train_loader)):
        if i % logging_freq == 0:
            recall_at_k_scores = []

            for _ in range(num_val_batch_sample):
                val_sample = next(iter(val_loader))
                val_sample.to(device)

                embeddings = get_embeddings(model, val_sample)

                head_node_index = val_sample.edge_label_index[0, 0]  # Assuming single label per batch
                actual_tail_index = val_sample.edge_label_index[1, 0]

                dot_products = compute_dot_products(embeddings, head_node_index)
                recall_at_k_score = recall_at_k(dot_products, actual_tail_index, k=top_k)

                recall_at_k_scores.append(recall_at_k_score)


            # Average the scores

            recallk = np.mean(recall_at_k_scores)
            
            writer.add_scalar('Recall at K', recallk, i)


            torch.save(model.state_dict(), os.path.join(experiment_name, f'checkpoint_{i}.pth'))

      


            batch = generate_negative_samples(batch, num_neg_sample)  # Generate 2 negative samples per positive sample
            batch.to(device)
            optimizer.zero_grad()

            logits = recover_logits(model, batch)

            # BPR loss computation
            loss = 0.0
            num_pos_samples = len(batch.edge_label) // (num_neg_sample + 1)  # Total number of positive samples
            for j in range(num_pos_samples):
                pos_logit = logits[j]
                neg_logits = []
                for n in range(num_neg_sample):
                    neg_index = num_pos_samples + j + n * num_pos_samples
                    neg_logits.append(logits[neg_index])
                neg_logits = torch.stack(neg_logits)
                loss += bpr_loss(pos_logit, neg_logits)

            loss = loss / num_pos_samples
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            writer.add_scalar('Loss', loss.item(), i)

        if i == num_iters:
            break



writer.close()

In [None]:
# Initialize variables for calculating average recall at k
total_recall_at_k = 0
num_test_batches = 0

# Iterate through the test_loader
for test_batch in test_loader:
    num_test_batches += 1
    test_batch.to(device)

    # Get embeddings for the test batch
    embeddings = get_embeddings(model, test_batch)

    # Assuming you are interested in the first edge in each batch for recall calculation
    head_node_index = test_batch.edge_label_index[0, 0] 
    actual_tail_index = test_batch.edge_label_index[1, 0]

    # Compute dot products and recall at k for the test batch
    dot_products = compute_dot_products(embeddings, head_node_index)
    recall_at_k_score = recall_at_k(dot_products, actual_tail_index, k=top_k)  # You can adjust k as needed

    # Update total recall at k score
    total_recall_at_k += recall_at_k_score

    if num_test_batches == 200:
        break

# Calculate and print the average recall at k
average_recall_at_k = total_recall_at_k / num_test_batches
print(f'Average Recall at K: {average_recall_at_k}')
