In [1]:
import torch
import numpy as np
import pickle as pkl
from glob import glob
import os
import argparse
import regex as re
import pandas as pd
import statistics
import copy

from MisInfoSpread import MisInfoSpread
from MisInfoSpread import MisInfoSpreadState

def flatten(state):
    return [val * i for val, adj in zip(state.node_states, state.adjacency_matrix) for i in adj]

In [2]:
def run_inference( dataset_path, model_path, nodes, max_steps, st, count_inf, count_actions):
    
    # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    print(f"Using device: {device}")

    misinfo = MisInfoSpread(num_nodes=nodes, max_time_steps=max_steps, 
                        trust_on_source=st, count_infected_nodes=count_inf, 
                        count_actions=count_actions)

    model = misinfo.get_nnet_model().to(device)
    model.load_state_dict(torch.load(model_path , map_location=torch.device(device)))

    states = pkl.load(open(dataset_path, 'rb'))
    for state in states:
        state.node_states = [int(x) for x in state.node_states]
        for i in range(len(state.adjacency_matrix)):
            for j in range(len(state.adjacency_matrix[i])):
                if state.adjacency_matrix[i][j] != 0:
                    state.adjacency_matrix[i][j] = 1
                else:
                    state.adjacency_matrix[i][j] = 0

    candidate_nodes = misinfo.find_neighbor_batch(states)

    actions_dict = {i: [] for i in range(len(states))}

    while any(candidate_node for candidate_node in candidate_nodes):
        blockernode_np = []
        count = 0
        for state, cand_nodes in zip(states, candidate_nodes):
            print("Processing states ", count, end='\r')
            if cand_nodes:
                expectation_values = []
                for cand_node in cand_nodes:
                    temp_ns, _, _ = misinfo.step(copy.deepcopy(state), [cand_node])
                    output_tensor = torch.FloatTensor(flatten(temp_ns)).view(1, -1).to(device)
                    expected_infection = model(output_tensor).detach().cpu().numpy()
                    expectation_values.append( (expected_infection, cand_node) )

                # sort the expectation values based on the expected infection
                expectation_values.sort(key=lambda x: x[0], reverse=True)

                if len(expectation_values) < count_actions:
                    blockernode_np.append([node for _, node in expectation_values])
                    actions_dict[count].append([node for _, node in expectation_values])
                else:
                    blockernode_np.append([node for _, node in expectation_values[:count_actions]])
                    actions_dict[count].append([node for _, node in expectation_values[:count_actions]])
            else:
                blockernode_np.append([])

            count += 1
        next_states, rewards, done = misinfo.step_batch(states, blockernode_np)
        states = next_states
        candidate_nodes = misinfo.find_neighbor_batch(states)
        # print count of dones
        print("Done: ", done.count(True), " "*20)
        if all(done):
            break
    
    inf_rate = []
    for state in states:
        inf_rate.append(state.node_states.count(-1.0)/len(state.node_states))

    mean = round(statistics.mean(inf_rate), 4)
    std_dev = round(statistics.stdev(inf_rate), 4)
    print(f"Mean: {mean}, Std Dev: {std_dev}")
    return mean, std_dev, actions_dict

In [3]:
model = "saved_models/target_model_1_1_mn10_ms50_st1.0.pt"
dataset_path = "/Users/bittu/Desktop/InfoSpread-server/dataset/generate_dataset/deg_dataset/10/10_deg_3.pkl"

mean, std_dev, actions_dict = run_inference(dataset_path, model, 10, 50, 1.0, 1, 1)


Using device: mps
Done:  88                     
Done:  697                     
Done:  996                     
Done:  1000                     
Mean: 0.4837, Std Dev: 0.1158


In [5]:
for key, item in actions_dict.items():
    actions_dict[key] = [action_item[0] for action_item in item]

In [None]:
import json

# save actions dict as json file with intend

with open("actions_r0.json", "w") as outfile: 
    json.dump(actions_dict, outfile, indent=4)