In [None]:
import polygraphs as pg

from polygraphs import timer
from polygraphs import init
from polygraphs import graphs
from polygraphs import hyperparameters as hparams

from polygraphs.ops import math

import dgl
import torch
import dgl.function as fn


def getsampler(size, epsilon, trials):
    """
    Creates binomial sampler.
    """
    # Action yields Bernoulli payoff of 1 (success) with probability p (= 0.5 + e)
    probs = init.halfs(size) + epsilon

    # Number of Bernoulli trials
    count = init.zeros(size) + trials

    return torch.distributions.binomial.Binomial(total_count=count, probs=probs)


def converged(graph):
    """
    Returns True if graph has converged.
    """
    result = torch.all(torch.gt(graph.ndata["beliefs"], 0.99))
    return result.item()


def filterfn(edges):
    return torch.gt(edges.src["payoffs"][:, 1], 0.0)


def messagefn(edges):
    return {"payoffs": edges.src["payoffs"]}


def reducefn(nodes):
    return {"payoffs": torch.sum(nodes.mailbox["payoffs"], dim=1)}


def applyfn(nodes):

    # Log probability
    logits = nodes.data["logits"]

    # Number of successful trials received
    values = nodes.data["payoffs"][:, 0]

    # Number of total trials
    trials = nodes.data["payoffs"][:, 1]

    # Prior, P(H) (aka. belief that B is better)
    prior = nodes.data["beliefs"]

    # Posterior, P(H|E)
    posterior = math.bayes(prior, math.Evidence(logits, values, trials))

    # Update node attribute
    return {"beliefs": posterior}


def cooop(graph, sampler, filteredges, usebuiltins):
    """
    Performs computation on a coo-formatted graph.
    """
    # Consider only nodes whose belief is greater than 0.5
    mask = graph.ndata["beliefs"] > 0.5
    sample = sampler.sample() * mask
    result = torch.stack((sample, sampler.total_count * mask))
    graph.ndata["payoffs"] = result.T
    if filteredges:
        edges = graph.filter_edges(filterfn)
    else:
        edges = graph.edges()
    if not usebuiltins:
        graph.send_and_recv(edges, messagefn, reducefn, applyfn)
    else:
        graph.send_and_recv(edges, fn.copy_src('payoffs', 'payoffs'), fn.sum('payoffs', 'payoffs'), applyfn)
    return


def cscop(graph, sampler, filteredges, usebuiltins):
    """
    Performs computation on a csc-formatted graph.
    """
    # Unused parameters
    del filteredges, usebuiltins
    # Consider only nodes who believe action B is better
    mask = graph.ndata["beliefs"] > 0.5
    sample = sampler.sample() * mask
    result = torch.stack((sample, sampler.total_count * mask))
    graph.ndata["payoffs"] = result.T
    graph.update_all(fn.copy_src('payoffs', 'payoffs'), fn.sum('payoffs', 'payoffs'), applyfn)
    return

# Create a PolyGraph configuration
params = hparams.PolyGraphHyperParameters()

# Initial beliefs are random uniform between 0 and 1
params.init.kind = 'uniform'
# Chance that action B is better than action A
params.epsilon = 0.001
params.trials = 32

params.network.kind = 'complete'
params.network.size = 4

# Graph format: 'coo' or 'csc'
FORMAT = 'csc'

# Only applicable in the case of 'coo':
#
# Filter edges
filteredges = True
# Use built-in functions
usebuiltins = True

# Set random seed
pg.random(123456789)

# Create graph (from hyper-parameters)
G = graphs.create(params.network)

# Get the shape of all node attributes
size = (G.num_nodes(),)

# Set initial beliefs
G.ndata['beliefs'] = init.init(size, params.init)
print("Initial beliefs is", G.ndata['beliefs'])

# Configure binomial distribution with specific parameters for now
sampler = getsampler(size, params.epsilon, params.trials)

# Store action B's probability of success as a graph node attribute
G.ndata["logits"] = sampler.logits

# Set graph format
G = G.formats(FORMAT)
print(G.formats())

# Set computation function
computefn = {
    'coo': cooop,
    'csc': cscop,
}[FORMAT]

steps = 100

t = timer.Timer()

t.start()
for step in range(steps):
    computefn(G, sampler, filteredges, usebuiltins)
    if converged(G):
        # Early stopping
        break
dt = t.dt()
print(G.ndata["beliefs"])
print(f"Finished {step + 1} steps in {dt:5.2f}s")
print("Bye.")
