In [47]:
import random
import torch
from torch_geometric.data import Data

In [48]:
def influence_maximization(graph:Data, k):
    seeds = []
    cur_reward = 0
    for _ in range(k):
        max_increase = 0
        node_to_add = None
        for node in range(graph.x.shape[0]):
            if node not in seeds:
                increase = simulate_cascade(graph, seeds + [node]) - cur_reward
                if increase > max_increase:
                    max_increase = increase
                    node_to_add = node
        
        seeds.append(node_to_add)        
        cur_reward += max_increase
    return seeds

def simulate_cascade(graph:Data, seeds, max_iter=100):
    influenced = set(seeds)
    new_influenced = set(seeds)
    for _ in range(max_iter):
        if not new_influenced:
            break
        new_influenced = set()
        for node in new_influenced:
            neighs_index = torch.where(graph.edge_index[0]==node)
            neighs = graph.edge_index[:, neighs_index][1]
            probs = graph.edge_attr[neighs_index][0]
            for (neigh, prob) in zip(neighs, probs):
                if neigh.item() not in influenced and random.random() < prob.item():
                    new_influenced.add(neigh.item())
        influenced.update(new_influenced)
    return len(influenced)

In [49]:
edges = torch.tensor([
    [0, 1, 0.6], [1, 0, 0.6], [0, 2, 0.3], [0, 4, 0.4], 
    [2, 1, 0.2], [3, 1, 0.2], [3, 2, 0.1], [4, 2, 0.5],
    [4, 3, 0.3], [4, 5, 0.5], [5, 4, 0.5], [3, 5, 0.2],
    [5, 3, 0.2]
])
edges_index = edges[:, :2].T
edge_attr = edges[:, [-1]]
graph = Data(x=torch.arange(6), edge_index=edges_index, edge_attr=edge_attr)

In [50]:
influence_maximization(graph, 2)

[0, 1]