In [1]:
import torch
import numpy as np

from gene_mat import gen_dataset_wt
from dcj_comp import dcj_dist

from multiprocessing import Pool

In [2]:
def check_dcj(x):
    a,b,c = dcj_dist(x[0], x[1])[-1], dcj_dist(x[-1], x[1])[-1], dcj_dist(x[0], x[-1])[-1]
    if a != b or (a+b) != c:
        return False
    return True

In [3]:
def gen_g2g_data(gene_len, graph_num, step, op_type):
    l = 0
    res = np.zeros((graph_num, 3, gene_len), dtype = np.int32)
    while True:
        s,o,t = gen_dataset_wt(gene_len, graph_num * 2, 2*step + 1, op_type)
        s = s[:, (0, step, -1)]

        with Pool(22) as p:
            tags = p.map(check_dcj, list(s))
        s =  s[tags]
        size = min(s.shape[0], graph_num - l)

        res[l: (l + size)] = s[:size]
        l += size
        if l>=graph_num:
            return res

In [4]:
def save_g2g_dataset(gene_len, step, graph_num = None, fname = None):
    if graph_num == None:
        graph_num = 1000
        
    if fname == None:
        fname = 'g2g_' + str(gene_len) + '_' + str(step) + '.pt'
    
    source = np.zeros((graph_num * step, 2, gene_len), dtype = np.int32) #[]
    target = np.zeros((graph_num * step, gene_len), dtype = np.int32) #[]
    
    for dist in range(0, step):
        s = gen_g2g_data(gene_len, graph_num, dist, op_type = 2)
        
        source[dist * graph_num : (dist + 1) * graph_num] = s[:, (0, -1)]
        target[dist * graph_num : (dist + 1) * graph_num] = s[:, 1]
    torch.save((source, target), fname)

In [5]:
import torch
with torch.cuda.device(1):
    save_g2g_dataset(gene_len = 100, step = 5, graph_num = None, fname = None)

In [6]:
from torch_geometric.data import InMemoryDataset

In [7]:
class G2GraphDataset(InMemoryDataset):
    def __init__(self, root, gene_len, step_range, graph_num = 10000):
#                  transform=None, pre_transform=None, pre_filter = None):
        self.gene_len = gene_len
        self.step_range = step_range
        self.graph_num = graph_num
        super().__init__(root + '_' + str(self.gene_len) + '_' 
                         + str(self.step_range) + '_' + str(self.graph_num), 
                         transform = None, 
                         pre_transform = None, 
                         pre_filter = None)
        self.data, self.slices = torch.load(self.processed_paths[0])
        
    @property
    def raw_file_names(self):
        return ['g2raw_' + str(self.gene_len) +
                '_' + str(self.step_range) + '.pt']

    @property
    def processed_file_names(self):
        return ['g2dat_' + str(self.gene_len) +
                '_' + str(self.step_range) + '.pt']

    def download(self):
        # Download to `self.raw_dir`.
        print('Generating...', file=sys.stderr)
        save_dataset(self.gene_len, self.step_range, 
                     graph_num = self.graph_num, 
                     fname = self.raw_dir + '/' + self.raw_file_names[0])
        pass

    def process(self):
        # Read data into huge `Data` list.
        filename = self.raw_dir + '/' + self.raw_file_names[0]
        source, target = torch.load(filename, map_location=torch.device('cuda'))        
        
        # todo: change gen_graph function
        data_list = [gen_graph(x, label = inv_num) for x, inv_num in zip(gene_list, label)]

        if self.pre_filter is not None:
            data_list = [data for data in data_list if self.pre_filter(data)]

        if self.pre_transform is not None:
            data_list = [self.pre_transform(data) for data in data_list]

        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])