In [None]:
import random
import pickle
import gzip
import numpy as np
import time

import torch
from torch.utils.data import random_split
from torch_geometric.data import Dataset, Data, DataLoader

import import_ipynb
from simulation import simul_multi, simul
from constants import *
from utils import *

In [None]:
# saving tag must start with '-'
def generate_dataset(graph_name, data_num, saving_tag='', print_epoch=100):
    n,m,adj = txt2adj(graph_name)

    data = []    
    seed_sizes = np.random.randint(min(10,n), max(10,int(n*SEED_SIZE)), size=data_num)
    for i in range(data_num):
        if i%print_epoch==0: print(f'{i}/{data_num}th simulation start')
        seed_idx = np.random.choice(n,seed_sizes[i],replace=False)
        is_seed = np.zeros(n, dtype=int)
        is_seed[seed_idx] = 1
        
        prob = simul_multi(adj,seed_idx)

        while prob.sum().item()==seed_sizes[i]:
            seed_idx = np.random.choice(n,seed_sizes[i],replace=False)
            is_seed = np.zeros(n, dtype=int)
            is_seed[seed_idx] = 1
            
            prob = simul_multi(adj,seed_idx)
        
        data.append((is_seed,prob))
    
    with gzip.open(DATASET_DIR+graph_name+saving_tag+'.pkl.gz','wb') as f: pickle.dump(data, f, protocol=4)

In [None]:
# example
# generate_dataset('Celebrity_test_LP', 50, saving_tag='-50_new', print_epoch=1)

In [None]:
def get_data(dataset_name, data_num=None):
    graph_name = dataset_name.split('-')[0]
    edge_index, edge_attr = txt2coo(graph_name)
    
    data = []
    with gzip.open(DATASET_DIR+dataset_name, 'rb') as f: rawdata = pickle.load(f)
    if data_num==None: data_num = len(rawdata)
    for is_seed, prob in rawdata[:data_num]:
        is_seed = torch.from_numpy(np.expand_dims(is_seed,axis=-1)).float()
        prob = torch.from_numpy(np.expand_dims(prob,axis=-1)).float()
        G = Data(x=is_seed, edge_index=edge_index, edge_attr=edge_attr, y=prob)
        data.append(G)
    return data


def get_data2(dataset_name1, dataset_name2, data_num=None):
    data1 = get_data(dataset_name1, data_num)
    data2 = get_data(dataset_name2, data_num)
    data = data1+data2
    random.shuffle(data)
    del data1
    del data2
    return data

    
def get_data_split(dataset_name, data_num=None):
    data = get_data(dataset_name, data_num)
    
    train_num = int(len(data)*0.8)
    val_num = int(len(data)*0.1)
    test_num = len(data)-train_num-val_num

    train_data = data[:train_num]
    val_data = data[train_num:train_num+val_num]
    test_data = data[train_num+val_num:]

    return train_data, val_data, test_data


def get_data_split2(dataset_name1, dataset_name2, data_num=None):
    data = get_data2(dataset_name1, dataset_name2, data_num)
    
    train_num = int(len(data)*0.8)
    val_num = int(len(data)*0.1)
    test_num = len(data)-train_num-val_num

    train_data = data[:train_num]
    val_data = data[train_num:train_num+val_num]
    test_data = data[train_num+val_num:]

    return train_data, val_data, test_data


def get_dataloader(dataset_name, data_num=None, batch_size=20):
    train_data, val_data, test_data = get_data_split(dataset_name, data_num)
    
    train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
    val_dataloader = DataLoader(val_data, batch_size=batch_size, shuffle=False)
    test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=False)
    
    return train_dataloader, val_dataloader, test_dataloader


def get_dataloader2(dataset_name1, dataset_name2, data_num=None, batch_size=32):
    train_data, val_data, test_data = get_data_split2(dataset_name1, dataset_name2, data_num)
    
    train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
    val_dataloader = DataLoader(val_data, batch_size=batch_size, shuffle=False)
    test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=False)
    
    return train_dataloader, val_dataloader, test_dataloader