In [6]:
import math
import itertools
import collections
import numpy as np
import matplotlib.pyplot as plt

from functools import reduce

from utils import *

# %matplotlib inline

### Function to pool based on $\Sigma$ matrix

In [127]:
def __enumerate_policies_classic__(M, R):
    # All arms have the same intensities
    intensities = np.arange(R-1) + 1
    policies = []
    for pol in itertools.product(intensities, repeat=M):
        policies.append(pol)
    return policies

def __enumerate_policies_complex__(R):
    # Each arm may have different intensities
    intensities = []
    for Ri in R:
        intensities.append(np.arange(Ri-1) + 1)
    policies = []
    for pol in itertools.product(*intensities, repeat=1):
        policies.append(pol)
    return policies
        

def enumerate_policies(M, R):
    if isinstance(R, int) or len(R) == 1:
        if not isinstance(R, int):
            R = R[0]
        policies = __enumerate_policies_classic__(M, R)
    else:
        policies = __enumerate_policies_complex__(R)
    return policies

In [166]:
def lattice_adjacencies(sigma, policies):
    num_policies = len(policies)
    edges = []
    for i in range(num_policies):
        pol1 = policies[i]
        for j in range(i+1, num_policies):
            pol2 = policies[j]
            
            diff = np.array(pol1) - np.array(pol2)
            diff = np.sum(np.abs(diff))
            if diff == 1:
                comp = np.not_equal(pol1, pol2)
                arm = np.where(comp)[0][0]
                min_dosage = min(pol1[arm], pol2[arm])
                if sigma[arm, min_dosage-1] == 1:
                    edges.append((i, j))
    return edges

In [155]:
# Helper to merge components
def __merge_components__(parent, x):
    if parent[x] == x:
        return x
    return __merge_components__(parent, parent[x])
 

# Disjoint Set Union Algorithm from https://cp-algorithms.com/data_structures/disjoint_set_union.html
# Implementation modified from https://www.geeksforgeeks.org/connected-components-in-an-undirected-graph/
def connected_components(n, edges):
 
    # find all parents
    parent = [i for i in range(n)]
    for x in edges:
        comp1 = __merge_components__(parent, x[0])
        comp2 = __merge_components__(parent, x[1])
        parent[comp1] = comp2
     
    # merge all components
    for i in range(n):
        parent[i] = __merge_components__(parent, parent[i])
     
    cc_map = collections.defaultdict(list)
    for i in range(n):
        cc_map[parent[i]].append(i)
    conn_comps = [val for key, val in cc_map.items()]
        
    return conn_comps

In [185]:

sigma = np.array([[1, 1, 0],
                  [0, 1, 1]])

M, n = sigma.shape
R = n + 2

num_policies = (R-1)**M
policies = enumerate_policies(M, R)
relations = lattice_adjacencies(sigma, policies)
pools = connected_components(num_policies, relations)

for pool in pools:
    print(pool)

[0, 4, 8]
[1, 2, 3, 5, 6, 7, 9, 10, 11]
[12]
[13, 14, 15]


In [186]:
# Function calculate B

# Function to calculate Q