In [None]:
import random
import pickle
import gzip
from collections import deque
import numpy as np
import torch
from torch_geometric.data import Data
from multiprocessing import Pool

import import_ipynb
from constants import *

In [None]:
def simul_fromData(data, hard_mask):
    x = data.x.numpy().squeeze().astype(int)
    edge_index = data.edge_index.numpy()
    edge_attr = data.edge_attr.numpy().squeeze()
    n = len(x)
    m = len(edge_attr)
    
    seed_idx = np.where(x==1)[0]
    adj_list = [[] for _ in range(n)]
    for i in range(m):
        if not hard_mask[i]: continue
        u,v,p = edge_index[0,i], edge_index[1,i],edge_attr[i]
        adj_list[u].append((v,p))
    return simul_multi(adj_list, seed_idx)

In [None]:
def simul(adj_list, seed_idx, simul_num=SIMUL_NUM, step_max=100):
    n = len(adj_list)
    seed = np.zeros(n, dtype=int)
    seed[seed_idx] = 1
    
    active_num = np.zeros(n, dtype=int)
    
    for _ in range(simul_num):
        active = seed.copy()
        
        Q = deque(seed_idx)
        for _ in range(step_max):
            if not Q: break
            for _ in range(len(Q)):
                u = Q.popleft()
                for v,p in adj_list[u]:
                    if active[v]: continue
                    if random.random()<p:
                        active[v]=1
                        Q.append(v)
        active_num+=active
    
    prob = active_num/simul_num
    return prob



def simul_multi_helper(adj_list, seed_idx):
    step_max=100
    n = len(adj_list)
    active = np.zeros(n, dtype=int)
    active[seed_idx] = 1
    Q = deque(seed_idx)
    for _ in range(step_max):
        if not Q: break
        for _ in range(len(Q)):
            u = Q.popleft()
            for v,p in adj_list[u]:
                if active[v]: continue
                if random.random()<p:
                    active[v]=1
                    Q.append(v)
    return active
def simul_multi(adj_list, seed_idx, simul_num=SIMUL_NUM):
    
    with Pool(10) as p: result = p.starmap(simul_multi_helper, [(adj_list, seed_idx) for i in range(simul_num)])
    
    result = np.sum(result,axis=0)
    prob = result/simul_num
    
    return prob