In [1]:
import sys

sys.path.append('../GraphStructureLearning')

In [2]:
from glob import glob
import yaml
from easydict import EasyDict as edict

In [3]:
import pickle
import os
from os import path
import torch
import numpy

from torch_geometric.data import Data

In [144]:
class MakeDataset:
    def __init__(self, config):
        super(MakeDataset, self).__init__()

        self.data_dir = config.dataset.root
        self.window_size = config.dataset.window_size
        self.slide = config.dataset.slide
        self.pred_step = config.dataset.pred_step
        self.idx_ratio = config.dataset.idx_ratio
        self.train_valid_test = config.dataset.train_valid_test
        self.encoder_step = config.encoder_step
        self.decoder_step = config.decoder_step
        
        self.load = None
        
        self.total_input_size = (self.encoder_step + self.decoder_step -1) * self.slide + self.window_size
        self.batch_idx = int(self.total_input_size * self.idx_ratio)

        self.save_dir = path.join(config.dataset.save, f'{config.dataset.name}_{self.window_size}_{self.slide}')

        if path.exists(self.save_dir):
            self.dataset = pickle.load(open(self.save_dir, 'rb'))
            self.load = True
        else:
            self.spk_bin = pickle.load(open('./data/spk_bin_n100.pickle', 'rb'))
            self.lam_bin = pickle.load(open('./data/lam_bin_n100.pickle', 'rb'))
            self.load = False

    def _save_dataset(self):
        pass

    def _valid_sampling(self, i):
        if i == 0:
            start = 0
            total_length = self.train_valid_test[i]
        else:
            start = self.train_valid_test[i - 1]
            total_length = self.train_valid_test[i] - self.train_valid_test[i - 1]

        valid_sampling_locations = []
        valid_sampling_locations += [
            i
            for i in range(start, start + total_length + 1 - self.pred_step - self.total_input_size)
            if (i % self.batch_idx) == 0
        ]
        
        return valid_sampling_locations

    def _split(self, i):
        if i == 0:
            data = self.spk_bin[:self.train_valid_test[i]]
            lam = self.lam_bin[:self.train_valid_test[i]]
        else:
            data = self.spk_bin[self.train_valid_test[i - 1]:self.train_valid_test[i]]
            lam = self.lam_bin[self.train_valid_test[i - 1]:self.train_valid_test[i]]

        return data, lam

    def make(self):
        data_dict = {'train': None,
                     'valid': None,
                     'test': None}

        if not self.load:
            for i, types in enumerate(list(data_dict.keys())):
                data, lam = self._split(i)
                valid_sampling_locations = self._valid_sampling(i)

                data_list = []
                for start_idx in valid_sampling_locations:
                    spike_input = data[:, start_idx:start_idx + self.total_input_size]
                    lam_output = lam[:, start_idx+self.encoder_step*self.slide+self.window_size:start_idx+self.total_input_size+self.pred_step]
                    
                    print(start_idx+self.encoder_step*self.slide+self.window_size)
                    print(start_idx+self.total_input_size+self.pred_step)
                    data_item = Data(x=torch.FloatTensor(spike_input), edge_index=None, y=torch.FloatTensor(lam_output))
                    data_list.append(data_item)

                if types == 'train':
                    data_dict['train'] = data_list
                elif types == 'valid':
                    data_dict['valid'] = data_list
                elif types == 'test':
                    data_dict['test'] = data_list

#                 pickle.dump(data_dict, open(self.save_dir, 'wb'))

            return data_dict

        else:
            return self.dataset


In [145]:
config_file = glob('./config/GTS/*.yaml')[0]
config = edict(yaml.load(open(config_file, 'r'), Loader=yaml.FullLoader))

In [149]:
config.decoder_step =10

In [150]:
data = MakeDataset(config)

In [151]:
data.total_input_size

1150

In [152]:
a = data.make()

700
1200
1275
1775
1850
2350
2425
2925
3000
3500
3575
4075
4150
4650
4725
5225
5300
5800
5875
6375
6450
6950
7025
7525
7600
8100
8175
8675
8750
9250
9325
9825
9900
10400
10475
10975
11050
11550
11625
12125
12200
12700
12775
13275
13350
13850
13925
14425
14500
15000
15075
15575
15650
16150
16225
16725
16800
17300
17375
17875
17950
18450
18525
19025
19100
19600
19675
20175
20250
20750
20825
21325
21400
21900
21975
22475
22550
23050
23125
23625
23700
24200
24275
24775
24850
25350
25425
25925
26000
26500
26575
27075
27150
27650
27725
28225
28300
28800
28875
29375
29450
29950
30025
30525
30600
31100
31175
31675
31750
32250
32325
32825
32900
33400
33475
33975
34050
34550
34625
35125
35200
35700
35775
36275
36350
36850
36925
37425
37500
38000
38075
38575
38650
39150
39225
39725
40950
41450
41525
42025
42100
42600
42675
43175
43250
43750
44975
45475
45550
46050
46125
46625
46700
47200
47275
47775


In [153]:
from torch_geometric.loader import DataLoader

In [154]:
d = DataLoader(a['train'], batch_size=1)

In [155]:
next(iter(d))

DataBatch(x=[100, 1150], y=[100, 500], batch=[100], ptr=[2])

In [123]:
for batch in d:
    pass

In [92]:
batch

DataBatch(x=[100, 800], y=[100, 50], batch=[100], ptr=[2])

In [99]:
batch.x[:, 100:200].shape

torch.Size([100, 100])

In [105]:
batch.y.shape

torch.Size([100, 50])

In [101]:
total_input_size = 800

In [102]:
        valid_sampling_locations = []
        valid_sampling_locations += [
            i
            for i in range(0, total_input_size)
            if (i % 50) == 0
        ]

In [103]:
valid_sampling_locations

[0, 50, 100, 150, 200, 250, 300, 350, 400, 450, 500, 550, 600, 650, 700, 750]

In [104]:
len(valid_sampling_locations)

16