In [1]:
from typing import List
import pickle
import torch
from torch_geometric.data import Data

In [2]:


class PairData(Data):

    def __init__(self, edge_index_s=None, x_s=None, edge_index_t=None, x_t=None, y=None):
        super().__init__()
        self.edge_index_s = edge_index_s
        self.x_s = x_s
        self.edge_index_t = edge_index_t
        self.x_t = x_t
        self.y = y
        
    def __inc__(self, key, value, *args, **kwargs):
        if key == 'edge_index_s':
            return self.x_s.size(0)
        if key == 'edge_index_t':
            return self.x_t.size(0)
        else:
            return super().__inc__(key, value, *args, **kwargs)

In [3]:
from torch_geometric.loader import DataLoader

edge_index_s = torch.tensor([
     [0, 0, 0, 0],
     [1, 2, 3, 4],
 ])
    
x_s = torch.randn(5, 16)  # 5 nodes.
edge_index_t = torch.tensor([
     [0, 0, 0],
     [1, 2, 3],
])
x_t = torch.randn(4, 16)  # 4 nodes.

data = PairData(edge_index_s, x_s, edge_index_t, x_t, y=1)
data_list = [data, data]
loader = DataLoader(data_list, batch_size=2)
batch = next(iter(loader))

print(batch)
    
print(batch.edge_index_s)

print(batch.edge_index_t)


PairDataBatch(edge_index_s=[2, 8], x_s=[10, 16], edge_index_t=[2, 6], x_t=[8, 16], y=[2])
tensor([[0, 0, 0, 0, 5, 5, 5, 5],
        [1, 2, 3, 4, 6, 7, 8, 9]])
tensor([[0, 0, 0, 4, 4, 4],
        [1, 2, 3, 5, 6, 7]])


In [4]:
batch.y

tensor([1, 1])

In [5]:
loader = DataLoader(data_list, batch_size=2, follow_batch=['x_s', 'x_t'])
batch = next(iter(loader))

print(batch)

PairDataBatch(edge_index_s=[2, 8], x_s=[10, 16], x_s_batch=[10], x_s_ptr=[3], edge_index_t=[2, 6], x_t=[8, 16], x_t_batch=[8], x_t_ptr=[3], y=[2])


In [6]:
print(batch.x_s_batch)

tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])


# Load datset

In [7]:
import sys, os
from pathlib import Path

sys.path.append(str(Path(os.path.abspath("")).parent))


In [8]:
from dataset import TorchMemoryDataset, BenchmarkType

In [9]:
!ls ../id_files

axiom_caption_test.txt	example.txt	      jjt_fof.txt	test.txt
axiom_test.txt		jjt_fof_half.txt      jjt_sine_1_0.txt	train.txt
deepmath.txt		jjt_fof_other.txt     mizar40.txt	validation.txt
dev_100.txt		jjt_fof_sine_1_0.txt  single.txt


In [10]:
dataset_path = '../unsupervised_data/train/'
id_file = '../id_files/train.txt'

In [44]:

def get_pair_dataset(dataset, dataset_path) -> List[PairData]:
    
    # Get files
    with open(os.path.join(dataset_path, 'target.pkl'), 'rb') as f:
        targets = pickle.load(f)

    with open(os.path.join(dataset_path, 'idx.pkl'), 'rb') as f:
        ids = pickle.load(f)
        
    assert len(ids) == len(targets)
    
    
    
    # Load problem pairs
    pair_list = []
    for target, pair in zip(targets, ids):
        # Get left and right graph in the pair
        data_s = dataset.get(pair[0])
        data_t = dataset.get(pair[1])

        # Construct pair data point
        data = PairData(edge_index_s=data_s.edge_index, x_s=data_s.x, 
                        edge_index_t=data_t.edge_index, x_t=data_t.x, 
                        y=target)

        # Add to list
        pair_list += [data]
    
    return pair_list

In [45]:
    
dataset = TorchMemoryDataset(id_file) # TODO this should be loaded as normal

pair_list = get_pair_dataset(dataset, dataset_path)
data

TorchPairDataset(100)

In [46]:
loader = DataLoader(pair_list, batch_size=2, follow_batch=['x_s', 'x_t'])

In [47]:
batch = next(iter(loader))

print(batch)

PairDataBatch(edge_index_s=[2, 1767], x_s=[1055], x_s_batch=[1055], x_s_ptr=[3], edge_index_t=[2, 1237], x_t=[793], x_t_batch=[793], x_t_ptr=[3], y=[2])


In [49]:
batch.edge_index_s

tensor([[   0,    3,    4,  ..., 1053,  997, 1054],
        [   1,    1,    2,  ..., 1051, 1053,  997]])

In [15]:
from enum import Enum

In [16]:
class LearningTask(Enum):

    PREMISE = "premise"
    SIMILARITY = "similarity"

    def __str__(self):
        return self.value


In [17]:
d = DataLoader(pair_list, batch_size=2, follow_batch=None)

In [18]:
a = next(iter(d))
a

PairDataBatch(edge_index_s=[2, 1767], x_s=[2], edge_index_t=[2, 1237], x_t=[793], y=[2])

In [19]:
a.x_s

[Data(x=[552], edge_index=[2, 903], premise_index=[4], conjecture_index=[1], name='l87_oppcat_1', y=[4]),
 Data(x=[503], edge_index=[2, 864], premise_index=[16], conjecture_index=[1], name='t24_substlat', y=[16])]

In [20]:
from torch_geometric.data import Dataset, InMemoryDataset

In [21]:
class MyOwnDataset(InMemoryDataset):
    def __init__(self, transform=None, pre_transform=None):
        super().__init__(None, transform, pre_transform)

In [22]:
test = MyOwnDataset()

In [23]:
test

MyOwnDataset()

In [24]:
class TorchPairDataset(InMemoryDataset):
    
    def __init__(self, data_pairs: List[PairData], id_file: str, benchmark_type: BenchmarkType = BenchmarkType.DEEPMATH):
        

        # Set the pair ids
        self.pair_ids = list(map(str, range(0, len(data_pairs))))
        
        self.id_partition = Path(id_file).stem
        self.benchmark_type = benchmark_type
        
        # TODO hack - might consume too much memory
        self.data_pairs = data_pairs
        
        # Initialise the super
        self.root = Path(".")
        #super().__init__(self.root.name, transform, pre_transform)
        super().__init__(self.root.name)
        
        # Start process of getting the data
        self.data, self.slices = torch.load(self.processed_paths[0])
        

        
    @property
    def raw_file_names(self) -> List[str]:
        return self.pair_ids

    @property
    def processed_file_names(self) -> List[str]:
        # return [Path(prob).stem + ".pt" for prob in self.problems]
        return [f"pair_{self.benchmark_type}_{self.id_partition}.pt"]

    
    def len(self) -> int:
        return len(self.raw_file_names)



    def process(self):
        # Read data into huge `Data` list.
        data, slices = self.collate(self.data_pairs)
        out = Path(self.processed_dir) / self.processed_file_names[0]
        torch.save((data, slices), out)
        del self.data_pairs


In [25]:
data = TorchPairDataset(pair_list, id_file)

In [39]:
next(iter(data)).x_s

Data(x=[552], edge_index=[2, 903], premise_index=[4], conjecture_index=[1], name='l87_oppcat_1', y=[4])

In [26]:
pair_list

[PairData(edge_index_s=[2, 903], x_s=Data(x=[552], edge_index=[2, 903], premise_index=[4], conjecture_index=[1], name='l87_oppcat_1', y=[4]), edge_index_t=[2, 801], x_t=[510], y=39.480156),
 PairData(edge_index_s=[2, 864], x_s=Data(x=[503], edge_index=[2, 864], premise_index=[16], conjecture_index=[1], name='t24_substlat', y=[16]), edge_index_t=[2, 436], x_t=[283], y=28.310375),
 PairData(edge_index_s=[2, 2758], x_s=Data(x=[1650], edge_index=[2, 2758], premise_index=[22], conjecture_index=[1], name='t22_nfcont_1', y=[22]), edge_index_t=[2, 214], x_t=[141], y=37.670086),
 PairData(edge_index_s=[2, 1027], x_s=Data(x=[616], edge_index=[2, 1027], premise_index=[10], conjecture_index=[1], name='t72_filter_2', y=[10]), edge_index_t=[2, 1981], x_t=[1192], y=62.929962),
 PairData(edge_index_s=[2, 245], x_s=Data(x=[170], edge_index=[2, 245], premise_index=[6], conjecture_index=[1], name='t45_nat_1', y=[6]), edge_index_t=[2, 1362], x_t=[839], y=37.85287),
 PairData(edge_index_s=[2, 1393], x_s=Da

In [27]:
d = next(iter(data))

In [28]:
d.edge_index_s

tensor([[  0,   3,   4,  ..., 550, 463, 551],
        [  1,   1,   2,  ..., 517, 550, 463]])

In [29]:
d.x_s

Data(x=[552], edge_index=[2, 903], premise_index=[4], conjecture_index=[1], name='l87_oppcat_1', y=[4])

In [30]:
len(d)

5

In [31]:
d

PairData(edge_index_s=[2, 903], x_s=Data(x=[552], edge_index=[2, 903], premise_index=[4], conjecture_index=[1], name='l87_oppcat_1', y=[4]), edge_index_t=[2, 801], x_t=[510], y=39.480156)

In [32]:
#DataLoader(pair_list, id_file)

In [33]:
def get_default_args(func):
    signature = inspect.signature(func)
    return {
        k: v.default
        for k, v in signature.parameters.items()
        if v.default is not inspect.Parameter.empty
    }

In [34]:
get_default_args(DataLoader())

TypeError: DataLoader.__init__() missing 1 required positional argument: 'dataset'

In [None]:
d = DataLoader(pair_list, batch_size=2)

In [None]:
d.__dict__

In [None]:
dataset_path = '../unsupervised_data/validation/'
# Get files
with open(os.path.join(dataset_path, 'target.pkl'), 'rb') as f:
    targets = pickle.load(f)

In [None]:
targets

In [None]:
len(pair_list)

In [None]:
len()