In [1]:
import pickle
import networkx as nx
import random
from torch.utils.data import Dataset, DataLoader
from transformers import DataCollatorWithPadding, DefaultDataCollator
import itertools
import torch


In [2]:
class TripletDataset(Dataset):
    def __init__(self, graph):
        self.graph = graph

    def __len__(self):
        return sum(1 for _ in self.generate_triplets())

    def __getitem__(self, idx):
        return next(t for i, t in enumerate(self.generate_triplets()) if i == idx)

    def generate_triplets(self):

        for node, degree in self.graph.out_degree():
            if (
                degree >= 1
                and len(node) > 1
            ):
                for child in self.graph.successors(node):
                        grandchildren = list(self.graph.successors(child))
                        if grandchildren:
                            for gch in grandchildren:
                                yield (node, child, gch)
                        else:
                            continue


In [36]:
class GraphDataCollator(DefaultDataCollator):
    def __init__(self, lengths):
        super().__init__()
        self.lengths = lengths

    def __call__(self, features):

        suitable_triplets = []
        for triplet in features:
            flag = True
            # suitable_triplets.append(triplet)
            for second_tr in features:
                if triplet != second_tr:
                    result = self.check_neighbour(self.lengths, triplet, second_tr)
                    if not result:
                        flag = False
                        break
            if flag:

                suitable_triplets.append(triplet)

        return suitable_triplets
    
    def check_neighbour(self, lengths, triplet1, triplet2):
        pairs = [(triplet1[0], triplet2[0]), (triplet1[0], triplet2[2]), (triplet1[2], triplet2[0]), (triplet1[2], triplet2[2])]
        for pair in pairs:
            try:
                if lengths[pair[0]][pair[1]] < 2:
                    return False
            except:
                continue
        return True


In [3]:
G=nx.read_edgelist('../data/noun/all.edgelist', create_using=nx.DiGraph)

In [7]:
lengths = {}
leng = list(nx.shortest_path_length(G))
for i in leng:
    lengths[i[0]] = i[1]

In [4]:
dataset = TripletDataset(G)

In [37]:
dataloader = DataLoader(dataset=dataset, batch_size=1024, collate_fn=GraphDataCollator(lengths), shuffle=True)

In [38]:
for batch in dataloader:
    print(batch)
    break

[('power_saw.n.01', 'circular_saw.n.01', 'portable_circular_saw.n.01'), ('communicator.n.01', 'signaler.n.01', 'toller.n.01'), ('opening.n.10', 'hole.n.02', 'loophole.n.02'), ('bone.n.01', 'carpal_bone.n.01', 'triquetral.n.01'), ('generosity.n.01', 'unselfishness.n.01', 'altruism.n.01'), ('command.n.01', 'order.n.01', 'summons.n.02'), ('irrigation.n.02', 'enema.n.01', 'barium_enema.n.01'), ('trade.n.02', 'handicraft.n.02', 'needlework.n.02'), ('nerve_cell.n.01', 'brain_cell.n.01', "golgi's_cell.n.01"), ('truck.n.01', 'van.n.05', 'delivery_truck.n.01'), ('agent.n.03', 'bleaching_agent.n.01', 'sodium_hypochlorite.n.01'), ('linear_unit.n.01', 'metric_linear_unit.n.01', 'kilometer.n.01'), ('barbiturate.n.01', 'amobarbital.n.01', 'amobarbital_sodium.n.01'), ('eye_operation.n.01', 'keratotomy.n.01', 'radial_keratotomy.n.01'), ('line.n.04', 'curve.n.01', 'spiral.n.01'), ('transportation.n.02', 'delivery.n.01', 'post.n.11'), ('improvement.n.01', 'adjustment.n.01', 'domestication.n.03'), ('curr

In [39]:
len(batch)

320