In [1]:
from gene_mat import gen_m3g_data, mid_dcj

In [2]:
m_seq, t_seq = gen_m3g_data(10, 1000, 7, op_type = 2)

In [3]:
m_seq.shape

(1000, 3, 10)

In [4]:
t_seq.shape

(1000, 1, 10)

In [5]:
for s, t in zip(t_seq, m_seq):
    assert mid_dcj(s[0], t)

In [6]:
import sys

import numpy as np
import torch

from gene_mat import gen_dataset_wt, gen_dataset_wb, gen_g2g_data
from genome_graph import gen_graph, gen_g2g_graph, gen_g2b_graph

from torch_geometric.data import InMemoryDataset

In [15]:
def save_g3m_dataset(gene_len, step, graph_num = None, fname = None):
    if graph_num == None:
        graph_num = 100
        
    if fname == None:
        fname = 'g3m_' + str(gene_len) + '_' + str(step) + '.pt'
    
    source = np.zeros((graph_num * step, 3, gene_len), dtype = np.int32) 
    target = np.zeros((graph_num * step, gene_len), dtype = np.int32)
    
    for dist in range(0, step):
        m_seq, t_seq = gen_m3g_data(gene_len, graph_num, step, op_type = 2)
        source[dist * graph_num : (dist + 1) * graph_num] = m_seq
        target[dist * graph_num : (dist + 1) * graph_num] = t_seq.squeeze()
        
    torch.save((source, target), fname)    

In [1]:
from genome_graph import gen_graph

In [9]:
data = gen_graph([[1,-2,3], [1,2,3], [-1,2,3]])

In [10]:
data.edge_index

tensor([[0, 1, 3, 2, 4, 0, 1, 2, 3, 4, 1, 0, 2, 3, 4],
        [1, 3, 2, 4, 5, 1, 2, 3, 4, 5, 0, 2, 3, 4, 5]])

In [None]:
class G3MedianDataset(InMemoryDataset):
    def __init__(self, root, gene_len, step_range, graph_num = 100):
#                  transform=None, pre_transform=None, pre_filter = None):
        self.gene_len = gene_len
        self.step_range = step_range
        self.graph_num = graph_num
        super(G3MedianDataset, self).__init__(root + '_' + str(self.gene_len) + '_' 
                         + str(self.step_range) + '_' + str(self.graph_num), 
                         transform = None, 
                         pre_transform = None, 
                         pre_filter = None)
        self.data, self.slices = torch.load(self.processed_paths[0])
        
    @property
    def raw_file_names(self):
        return ['g3raw_' + str(self.gene_len) +
                '_' + str(self.step_range) + '.pt']

    @property
    def processed_file_names(self):
        return ['g3dat_' + str(self.gene_len) +
                '_' + str(self.step_range) + '.pt']

    def download(self):
        # Download to `self.raw_dir`.
        print('Generating...', file=sys.stderr)
        save_g3m_dataset(self.gene_len, self.step_range, 
                     graph_num = self.graph_num, 
                     fname = self.raw_dir + '/' + self.raw_file_names[0])
        pass

    def process(self):
        # Read data into huge `Data` list.
        filename = self.raw_dir + '/' + self.raw_file_names[0]
        source, target = torch.load(filename) #, map_location=torch.device('cuda'))        
                
        data_list = [gen_g2g_graph(s, t) for s,t in zip(source, target)]

        if self.pre_filter is not None:
            data_list = [data for data in data_list if self.pre_filter(data)]

        if self.pre_transform is not None:
            data_list = [self.pre_transform(data) for data in data_list]

        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])