In [None]:
import os
import torch
import networkx as nx
import matplotlib.pyplot as plt
import matplotlib as mpl
from torch_geometric.utils import to_networkx

In [None]:
class TripletDataset:
    def __init__(self, dataset_path):
        self.dataset_path = dataset_path
        self.batch_files = os.listdir(dataset_path)
        
    def iter_triplets(self):
        for batch_file in self.batch_files:
            file_path = os.path.join(self.dataset_path, batch_file)
            batch = torch.load(file_path)
            for triplet in batch:
                yield triplet
                
    def __len__(self, batch_size):
        return len(self.batch_files) * batch_size

In [None]:


def visualize_heterodata(data, node_colors=None, node_size=300, font_size=12):
    """
    Visualize a HeteroData object using NetworkX and Matplotlib.

    Args:
        data (HeteroData): The heterogeneous graph data.
        node_colors (dict): Optional. A dictionary mapping node types to colors.
        node_size (int): Size of the nodes in the plot.
        font_size (int): Font size of the labels.
    """
    # Convert HeteroData to NetworkX graph
    G = to_networkx(data, node_attrs=['x'], edge_attrs=['edge_attr'])

    # Create a color map for the nodes based on their type
    if node_colors is None:
        cmap = mpl.colormaps.get_cmap('tab20')
        node_colors = {key: cmap(i % cmap.N) for i, key in enumerate(data.node_types)}

    color_map = []
    for node in G.nodes(data=True):
        node_type = node[1]['type']  # Get node type from data
        color_map.append(node_colors[node_type])
    
    pos = nx.spring_layout(G)  # Define the layout for visualization
    
    plt.figure(figsize=(12, 8))
    nx.draw(G, pos, with_labels=True, node_color=color_map, node_size=node_size, font_size=font_size, cmap=plt.get_cmap('tab20'))

    # Create a legend for node types
    for node_type, color in node_colors.items():
        plt.scatter([], [], c=[color], label=node_type, s=node_size)
    plt.legend(scatterpoints=1, frameon=False, labelspacing=1, loc='upper left')

    plt.show()

# Example usage
# Assume `hetero_data` is an instance of HeteroData you want to visualize
path = './data/triplet_dataset'
dataset = TripletDataset(path)

for triplet in dataset.iter_triplets():
    anchor = triplet['anchor']['data']
    visualize_heterodata(anchor)
    visualize_heterodata(triplet['pos']['data'])
    visualize_heterodata(triplet['neg']['data'])
    break
