In [1]:
import pickle
import networkx as nx
import random
from torch.utils.data import Dataset, DataLoader
from transformers import DataCollatorWithPadding, DefaultDataCollator
import torch


In [10]:
class GraphDataset(Dataset):
    def __init__(self, graph, num_triplets):
        self.graph = graph
        self.num_triplets = num_triplets
        self.triplets = self.generate_triplets()
        
    def generate_triplets(self):
        """Randomly samples nodes and their ancestors from a given graph.

        Returns:
            list: list of sample triplets
        """

        triplets = []
        while len(triplets) < self.num_triplets:
            # randomly choose a node from graph
            child = random.choice(list(self.graph.nodes()))

            # check whether node has predecessors and randomly choose one of them if more than 1
            if list(self.graph.predecessors(child)):
                parent = random.choice(list(self.graph.predecessors(child)))
            else:
                continue
            # sample grandparents
            if list(self.graph.predecessors(parent)):
                grandparent_candidates = list(self.graph.predecessors(parent))
            else:
                continue

            # in case graph has loops
            if child in grandparent_candidates:
                grandparent_candidates.remove(child)

            grandparent = random.choice(grandparent_candidates)
            # combine triplet
            triplet = (child, parent, grandparent)

            # check distance
            if triplets:
                if self.ensure_distance(triplets, triplet):
                    triplets.append(triplet)
            else:
                triplets.append(triplet)

        return triplets

    def ensure_distance(self, triplets, potential_triplet):
        """Checks whether triplet is valid

        Args:
            triplets (list): already sampled triplets
            potential_triplet (tuple): triplet to include

        Returns:
            bool: boolean value depending on whether distance between new triplet and all others is more than one.
        """
        for triplet in triplets:
            try: 
                first_length = len(nx.bidirectional_shortest_path(self.graph, triplet[2], potential_triplet[0]))
            except Exception as e:
                # print(e)
                first_length = float('inf')
            try:
                second_length = len(nx.bidirectional_shortest_path(self.graph, potential_triplet[2], triplet[0]))
            except Exception as e:
                # print(e)
                second_length = float('inf')

            if first_length <=  1:
                return False
            elif second_length <= 1:
                return False
            else:
                return True

    def __len__(self):
        return self.num_triplets

    def __getitem__(self, idx):
        child, parent, grandparent = self.triplets[idx]
        return (child, parent, grandparent)


class GraphDataCollator(DefaultDataCollator):
    def __init__(self):
        super().__init__()

    def __call__(self, features):

        final_pairs = []
        # Разделяем каждый триплет на два дуплета

        for feature in features:
            (child, parent, grandparent) = feature
            final_pairs.append(((child, parent), (parent, grandparent)))
            # final_pairs.append((child, parent))
            # final_pairs.append((parent, grandparent))

        return final_pairs

In [None]:
if __name__ == '__main__':
    iterations = 1000
    batch_size = 32

    G=nx.read_edgelist('../data/noun/all.edgelist', create_using=nx.DiGraph)

    dataset = GraphDataset(G, iterations)
    data_loader = DataLoader(dataset, batch_size=batch_size, collate_fn=GraphDataCollator())


    for batch in data_loader:
        print(*batch, sep='\n')
        print(len(batch))