In [None]:
import pandas as pd
import torch
import numpy as np
from tqdm import tqdm
from torch.nn import functional as F
from torch_geometric.data import Data
from torch.utils.tensorboard import SummaryWriter
from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score, roc_auc_score
import os
from sklearn.metrics import roc_curve, auc
import matplotlib.pyplot as plt

from torch_geometric.utils import to_undirected, negative_sampling, add_self_loops
from torch_geometric.nn import GraphSAGE, GAT, GIN

import torch
from torch_geometric.utils import to_undirected
from torch_geometric.data import Data
import random
from collections import defaultdict, deque


In [None]:
# processed_co_purchase_df = pd.read_csv('kcore70new.csv')
processed_co_purchase_df = pd.read_csv('kcore5new.csv')
processed_co_purchase_df.head()

In [None]:
asin_list = list(processed_co_purchase_df['asin'].unique())

In [None]:
image_embeddings = np.load("image_embeddings_k_5.npy")
text_embeddings = np.load("text_embeddings_k_5.npy")
edge_index = np.load("k5_edge_index.npy")

# image_embeddings = np.load("image_embeddings.npy")
# text_embeddings = np.load("text_embeddings.npy")
# edge_index = np.load("kcore70newedges.npy")
random_embeddings = torch.randn_like(torch.tensor(text_embeddings.astype(np.float32))) / 10

combined_embedings = np.concatenate((image_embeddings, text_embeddings), axis = 1)
product_feature_dim = 1024
num_products = len(asin_list)

In [None]:
import torch_geometric.transforms as T
from torch_geometric.utils import to_undirected


data = Data(x=torch.tensor(combined_embedings.astype(np.float32)), edge_index=torch.tensor(edge_index))
# data = Data(x=combined_embedings, edge_index=torch.tensor(edge_index))

data.edge_index = to_undirected(data.edge_index)

# For this, we first split the set of edges into
# training (80%), validation (10%), and testing edges (10%).
# Across the training edges, we use 70% of edges for message passing,
# and 30% of edges for supervision.
# We further want to generate fixed negative edges for evaluation with a ratio of 2:1.
# Negative edges during training will be generated on-the-fly.
# We can leverage the `RandomLinkSplit()` transform for this from PyG:
transform = T.RandomLinkSplit(
    num_val=0.05,
    num_test=0.05,
    is_undirected=True,
    neg_sampling_ratio = 0.0,
    add_negative_train_samples=False,
    disjoint_train_ratio=0.2)
train_data, val_data, test_data = transform(data)


In [None]:
def convert_edge_index_to_adj_list(edge_index):
    adj_list = defaultdict(list)
    for src, dest in zip(edge_index[0], edge_index[1]):
        adj_list[src.item()].append(dest.item())
    return adj_list


In [None]:
# In the first hop, we sample at most 20 neighbors.
# In the second hop, we sample at most 10 neighbors.
# In addition, during training, we want to sample negative edges on-the-fly with
# a ratio of 2:1.
# We can make use of the `loader.LinkNeighborLoader` from PyG:
from torch_geometric.loader import LinkNeighborLoader

# Define seed edges:
train_loader = LinkNeighborLoader(
    data=train_data,
    num_neighbors=[10, 10, 10],
    edge_label_index=train_data.edge_label_index,
    edge_label=train_data.edge_label,
    neg_sampling_ratio=4,
    batch_size=256,
    shuffle=True,
)

val_loader = LinkNeighborLoader(
    data = val_data, 
    num_neighbors=[10, 10, 10],
    edge_label_index=val_data.edge_label_index,
    edge_label=val_data.edge_label,
    batch_size = 1,
    shuffle = True
)

test_loader = LinkNeighborLoader(
    data = test_data, 
    num_neighbors=[10, 10, 10],
    edge_label_index=test_data.edge_label_index,
    edge_label=test_data.edge_label,
    batch_size = 1,
    shuffle = True
)


In [None]:
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')

In [None]:
def convert_edge_index_to_adj_list(edge_index):
    adj_list = defaultdict(list)
    for src, dest in zip(edge_index[0], edge_index[1]):
        adj_list[src.item()].append(dest.item())
    return adj_list

def bfs_sample(adj_list, start_node, max_depth):
    visited = set()
    queue = deque([(start_node, 0)])
    visited.add(start_node)
    
    while queue:
        current_node, depth = queue.popleft()
        if depth == max_depth:
            return current_node
        
        for neighbor in adj_list[current_node]:
            if neighbor not in visited:
                visited.add(neighbor)
                queue.append((neighbor, depth + 1))
    
    return None

def add_bfs_negatives(data, m=3, k=3):
    adj_list = convert_edge_index_to_adj_list(data.edge_index)
    new_edges = []
    new_labels = []
    new_edge_label_index = []

    for edge in tqdm(range(data.edge_label_index.size(1))):
        start_nodes = [data.edge_label_index[0, edge].item(), data.edge_label_index[1, edge].item()]
        for start_node in start_nodes:
            for _ in range(k):
                depth = random.randint(2, m)
                neg_node = bfs_sample(adj_list, start_node, depth)
                if neg_node is not None and neg_node not in adj_list[start_node]:
                    new_edges.append([start_node, neg_node])
                    new_labels.append(0)  # Label for negative samples
                    new_edge_label_index.append([start_node, neg_node])

    new_edges = torch.tensor(new_edges, dtype=torch.long).t()
    new_labels = torch.tensor(new_labels, dtype=torch.float)
    new_edge_label_index = torch.tensor(new_edge_label_index, dtype=torch.long).t()

    data.edge_index = torch.cat([data.edge_index, new_edges], dim=1)
    data.edge_label = torch.cat([data.edge_label, new_labels])
    data.edge_label_index = torch.cat([data.edge_label_index, new_edge_label_index], dim=1)

    return data


In [None]:

model = GraphSAGE(
    in_channels=product_feature_dim,
    hidden_channels=256,
    out_channels=256,
    dropout = 0.1,
    project = True,
    num_layers=3).to(device)
    

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)


In [None]:

model = GAT(
    in_channels=product_feature_dim,
    hidden_channels=256,
    out_channels=256,
    num_layers = 3,
    dropout = 0.1,
    heads=2).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

In [None]:
model = GIN(
    in_channels = product_feature_dim, 
    hidden_channels=256,
    out_channels = 256,
    dropout = 0.1,
    num_layers = 3
).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)



In [None]:
experiment_name = "gin_baseline"  # replace with your experiment name
num_iters = 500
logging_freq = 50
num_val_batch_sample = 20
num_epochs = 1
top_k = 4
os.makedirs(experiment_name, exist_ok=True)
writer = SummaryWriter(log_dir=experiment_name)

In [None]:
def recover_logits(model, batch):
    z = model(batch.x, batch.edge_index)

    head_embeddings = z[batch.edge_label_index[0]]
    head_embeddings = head_embeddings / head_embeddings.norm(dim = 1, keepdim = True)

    tail_embeddings = z[batch.edge_label_index[1]]
    tail_embeddings = tail_embeddings / tail_embeddings.norm(dim = 1, keepdim = True)


    link_logits = (head_embeddings * tail_embeddings).sum(dim=1)


    return link_logits

def get_embeddings(model, batch):
    z = model(batch.x, batch.edge_index)
    z = z / z.norm(dim=1, keepdim=True)  # Normalizing embeddings
    return z

def compute_dot_products(embeddings, head_node_index):
    head_embedding = embeddings[head_node_index]
    dot_products = torch.matmul(embeddings, head_embedding.unsqueeze(-1)).squeeze(-1)
    return dot_products


def recall_at_k(dot_products, actual_tail_index, k=10):
    print(dot_products.shape)
    top_k_scores, top_k_indices = torch.topk(dot_products, k)
    return 1 if actual_tail_index in top_k_indices else 0



def bpr_loss(pos_score, neg_scores):
    return -torch.log(torch.sigmoid(pos_score - neg_scores)).mean()


total_loss = 0.0
for epoch in range(num_epochs):
    for i, batch in tqdm(enumerate(train_loader)):

        print(batch)


        # batch = add_bfs_negatives(batch)
        if i % logging_freq == 0:
            recall_at_k_scores = []

            for _ in range(num_val_batch_sample):
                val_sample = next(iter(val_loader))
                val_sample.to(device)


                embeddings = get_embeddings(model, val_sample)

                head_node_index = val_sample.edge_label_index[0, 0]  # Assuming single label per batch
                actual_tail_index = val_sample.edge_label_index[1, 0]

                dot_products = compute_dot_products(embeddings, head_node_index)
                recall_at_k_score = recall_at_k(dot_products, actual_tail_index, k=top_k)

                recall_at_k_scores.append(recall_at_k_score)

            recallk = np.mean(recall_at_k_scores)

            
            writer.add_scalar('Recall at K', recallk, i)
            torch.save(model.state_dict(), os.path.join(experiment_name, f'checkpoint_{i}.pth'))

        batch.to(device)
        optimizer.zero_grad()

        train_logits = recover_logits(model, batch)
        train_labels = batch.edge_label.float()

        loss = F.binary_cross_entropy_with_logits(train_logits, train_labels)

        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        writer.add_scalar('Loss', loss.item(), i)


        if i == num_iters:
            break



writer.close()

In [17]:
# Initialize variables for calculating average recall at k
total_recall_at_k = 0
num_test_batches = 0

# Iterate through the test_loader
for test_batch in test_loader:
    num_test_batches += 1
    test_batch.to(device)

    # Get embeddings for the test batch
    embeddings = get_embeddings(model, test_batch)

    # Assuming you are interested in the first edge in each batch for recall calculation
    head_node_index = test_batch.edge_label_index[0, 0] 
    actual_tail_index = test_batch.edge_label_index[1, 0]

    # Compute dot products and recall at k for the test batch
    dot_products = compute_dot_products(embeddings, head_node_index)
    recall_at_k_score = recall_at_k(dot_products, actual_tail_index, k=top_k)  # You can adjust k as needed

    # Update total recall at k score
    total_recall_at_k += recall_at_k_score

    if num_test_batches == 200:
        break

# Calculate and print the average recall at k
average_recall_at_k = total_recall_at_k / num_test_batches
print(f'Average Recall at K: {average_recall_at_k}')


torch.Size([1157])
torch.Size([1020])
torch.Size([375])
torch.Size([992])
torch.Size([980])
torch.Size([801])
torch.Size([1430])
torch.Size([727])
torch.Size([1015])
torch.Size([1290])
torch.Size([473])
torch.Size([645])
torch.Size([859])
torch.Size([1156])
torch.Size([1001])
torch.Size([1068])
torch.Size([224])
torch.Size([1396])
torch.Size([508])
torch.Size([792])
torch.Size([147])
torch.Size([1470])
torch.Size([395])
torch.Size([592])
torch.Size([886])
torch.Size([1026])
torch.Size([966])
torch.Size([1402])
torch.Size([1100])
torch.Size([1178])
torch.Size([960])
torch.Size([868])
torch.Size([741])
torch.Size([903])
torch.Size([765])
torch.Size([57])
torch.Size([1170])
torch.Size([1347])
torch.Size([1146])
torch.Size([654])
torch.Size([815])
torch.Size([1056])
torch.Size([1106])
torch.Size([826])
torch.Size([1486])
torch.Size([900])
torch.Size([1036])
torch.Size([1557])
torch.Size([1198])
torch.Size([1205])
torch.Size([1452])
torch.Size([1144])
torch.Size([1187])
torch.Size([941])
to