In [1]:
import random
import pickle
import gzip
from collections import deque
import numpy as np
import torch
from torch_geometric.data import Data
from multiprocessing import Pool

In [2]:
SIMUL_NUM = 10000
STEP_MAX = 100

In [3]:
def simul_fromData(data, hard_mask):
    x = data.x.numpy().squeeze().astype(int)
    edge_index = data.edge_index.numpy()
    edge_attr = data.edge_attr.numpy().squeeze()
    n = len(x)
    
    seed_idx = np.where(x==1)[0]
    adj_list = [[] for _ in range(n)]
    for i in range(n):
        if not hard_mask[i]: continue
        u,v,p = edge_index[0,i], edge_index[1,i],edge_attr[i]
        adj_list[u].append((v,p))
    return simul(adj_list, seed_idx)

In [4]:
def simul(adj_list, seed_idx, simul_num=SIMUL_NUM, step_max=STEP_MAX):
    n = len(adj_list)
    seed = np.zeros(n, dtype=int)
    seed[seed_idx] = 1
    
    active_num = np.zeros(n, dtype=int)
    
    for _ in range(simul_num):
        active = seed.copy()
        
        Q = deque(seed_idx)
        for _ in range(step_max):
            if not Q: break
            for _ in range(len(Q)):
                u = Q.popleft()
                for v,p in adj_list[u]:
                    if active[v]: continue
                    if random.random()<p:
                        active[v]=1
                        Q.append(v)
        active_num+=active
    
    prob = active_num/simul_num
    return prob

In [5]:
# seed set 크기는 일정 이상으로 커져봤자 큰 차이 없으니 우선 20%까지만 실험
def simul_random_seeds(adj, data_num):    
    n = len(adj)
    seeds = []
    probs = []
    
    seed_sizes = np.random.randint(1, max(1,int(n*0.2)), size=data_num)
    for i in range(data_num):
        if i%100==0: print(f'{i}/{data_num}th simulation start')
        seed_idx = np.random.choice(n,seed_sizes[i],replace=False)
        seed = np.zeros(n, dtype=int)
        seed[seed_idx] = 1
        
        prob = simul(adj,seed_idx)
        
        seeds.append(seed)
        probs.append(prob)
    
    seeds = np.array(seeds)
    probs = np.array(probs)
    
    return seeds, probs



def simul_random_seeds_multi_helper(adj, seed_size):
    n = len(adj)
    seed_idx = np.random.choice(n,seed_size,replace=False)
    seed = np.zeros(n, dtype=int)
    seed[seed_idx] = 1

    prob = simul(adj,seed_idx)
    
    return np.array([seed,prob])

def simul_random_seeds_multi(adj, data_num):    
    n = len(adj)
    seeds = []
    probs = []
    
    seed_sizes = np.random.randint(1, max(1,int(n*0.2)), size=data_num)
    
    result = Pool(5).starmap(simul_random_seeds_multi_helper, [(adj, seed_sizes[i]) for i in range(data_num)])
    result = np.stack(result)
    
    return result[:,0],result[:,1]

In [9]:
def generate_data(graph_dir, save_dir, graph_name, data_num):
    with open(graph_dir+'/'+graph_name+'.txt','r') as f:
        n,m = map(int,f.readline().split())
        adj = [[] for i in range(n)]
        for _ in range(m):
            u,v,p = f.readline().split()
            u,v,p = int(u),int(v),float(p)
            adj[u].append((v,p))
    
    seeds,probs = simul_random_seeds_multi(adj, data_num)
    
    with gzip.open(save_dir+'/'+graph_name+'.pkl.gz','wb') as f:
        pickle.dump((seeds,probs), f, protocol=4)

In [10]:
#generate_data('/data/URP/graphs', '/data/URP/raw', 'Celebrity_train_JI', 350)
# 1proc으로 100개에 15분정도

In [54]:
# test code
"""
adj = [[(1,0.5)],[(0,0.3),(2,0.2)],[]]
seed,prob = simul_random_seeds(adj,10)
for i in range(10):
    print(seed[i])
    print(prob[i])
    print()
"""

0/10th simulation start
[1 0 0]
[1.     0.4949 0.0932]

[0 1 1]
[0.3006 1.     1.    ]

[0 1 1]
[0.2965 1.     1.    ]

[1 0 1]
[1.    0.492 1.   ]

[1 1 0]
[1.     1.     0.2058]

[0 1 1]
[0.3003 1.     1.    ]

[1 0 0]
[1.     0.4968 0.1007]

[1 1 0]
[1.    1.    0.199]

[1 0 1]
[1.     0.5057 1.    ]

[0 1 0]
[0.2931 1.     0.1954]



In [43]:
#test
"""
x = torch.tensor([[0], [1], [1]], dtype=torch.float)
edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=torch.long)
edge_attr = torch.tensor([[0], [1], [1], [1]], dtype=torch.float)

data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)

print(simul_fromData(data, [0,0,0,0]))
print(simul_fromData(data, [0,0,1,1]))
print(simul_fromData(data, [1,1,1,1]))
"""

[0. 1. 1.]
[0. 1. 1.]
[1. 1. 1.]
