In [None]:
using Pkg;
# Pkg.add("Distributions")
# Pkg.add("Random")
# Pkg.add("POMDPs")
# Pkg.add("QuickPOMDPs")
# Pkg.add("POMDPModelTools")
# Pkg.add("POMDPModels")
# Pkg.add("POMDPSimulators")
# Pkg.add("POMDPGifs")
# Pkg.add("QMDP")
# Pkg.add("CSV")
# Pkg.add("DataFrames")
# Pkg.add("LocalApproximationValueIteration")
# Pkg.add("GridInterpolations")
# Pkg.add("LocalFunctionApproximation")
# Pkg.add("StaticArrays")
# Pkg.add("PyPlot")
# Pkg.add("PyCall")
# Pkg.add("Reel")
# Pkg.add("TextWrap")

In [None]:
using Random, Distributions, LinearAlgebra, StaticArrays, Statistics;
using CSV, DataFrames;
using POMDPs, QuickPOMDPs, POMDPModelTools, POMDPSimulators, POMDPModels, MCTS, POMDPGifs;
using GridInterpolations;
using LocalFunctionApproximation;
using LocalApproximationValueIteration;
using PyPlot, PyCall, Reel, TextWrap;
@pyimport matplotlib.patches as patches

In [None]:
# Start by loading the data.
populations = DataFrame(CSV.File("data/populations.csv"));
self_reporting_distributions = DataFrame(CSV.File("data/distributions.csv"));
demographics = DataFrame(CSV.File("data/indicators.csv"));

In [None]:
# Join the dataframes by state.
tract_data = innerjoin(populations, self_reporting_distributions, on=:STATE => :state);
tract_data = innerjoin(tract_data, demographics, on=:STATE => :id);

In [None]:
# Drop the households column (we have a separate population column).
select!(tract_data, Not(12));

In [None]:
# Dividing each timestep in input data.
timestep_divisor = 2;

In [None]:
# Get vectors from data.
tracts_populations = tract_data["POPESTIMATE2019"];
self_reporting_means_start = tract_data["mu1"] / (100.0 * timestep_divisor);
self_reporting_stdevs_start = tract_data["std1"] / (100.0 * timestep_divisor);
self_reporting_means_end = tract_data["mu2"] / (100.0 * timestep_divisor);
self_reporting_stdevs_end = tract_data["std2"] / (100.0 * timestep_divisor);
demographic_start_col = 10;
tracts_demographics = convert(Array{Float64, 2}, tract_data[:, demographic_start_col:end]);
n_tracts = length(tracts_populations);
n_demographics = size(tracts_demographics)[2];

In [None]:
# Total timesteps to allow.
total_timesteps = 28 * timestep_divisor;

# We fit two distributions to our data based on whether or not
# the timestamp is within the first 7 weeks. This is to
# accommodate an inflection point on the 7th week.
distribution_switch_timestep = 7 * timestep_divisor;

# We run up to a few more iterations to see that the terminal step math works.
iterate_timesteps = total_timesteps + 5;

In [None]:
# probability of a visit increasing percentage in that tract -> fixed mean + stddev
visit_distribution_mean = 0.1;
visit_distribution_stdev = 0.05;

In [None]:
# Take a step from state s with action a, applying self-reporting and visit increases.
function random_step(rng, s, a)
    time = s[1]
    tract_percentages = copy(s[2:end])
           
    self_reporting_means = (time <= distribution_switch_timestep) ? self_reporting_means_start : self_reporting_means_end;
    self_reporting_stdevs = (time <= distribution_switch_timestep) ? self_reporting_stdevs_start : self_reporting_stdevs_end;
    transition_means = self_reporting_means + tract_percentages;
    transition_stdevs = copy(self_reporting_stdevs);
    
    if a != 0
        # Apply the self-reporting increases to the means/stdevs.
        # Note that adding distributions required adding variances (stdev ^ 2).
        # and not stdevs directly.
        transition_means[a] += visit_distribution_mean;
        transition_stdevs[a] = sqrt((transition_stdevs[a] ^ 2) + (visit_distribution_stdev ^ 2));
    end
    
    # Step the tracts.
    for i in 1:n_tracts
        if tract_percentages[i] < 1.0
            tract_percentages[i] = Base.rand(rng, TruncatedNormal(transition_means[i], transition_stdevs[i], tract_percentages[i], 1.0));
        end
    end
    
    return [time + 1; tract_percentages];
end;

In [None]:
# Define a fixed cost for visits.
fixed_visit_cost = -0.01;

In [None]:
function get_demographic_means(tract_percentages)
    people_counted_per_tract = tract_percentages .* tracts_populations;
    return sum(people_counted_per_tract .* tracts_demographics, dims=1) ./ sum(people_counted_per_tract);
end;

In [None]:
# Define the final reward using both the population and demographic terms.
function reward_final_fn(tract_percentages)  
    # Constants
    weight_demographics = ones(n_demographics); # TODO: Assign reasonable weights here.
    
    # Get the total number of people covered
    people_counted_per_tract = tract_percentages .* tracts_populations;
    people_counted_fraction = sum(people_counted_per_tract) / sum(tracts_populations);
    
    # Get the total value of each demographic if the whole population was counted
    demographic_counts_total = get_demographic_means(ones(n_tracts));
    
    # Get the total value of each demographic that was counted
    demographic_counts_counted = get_demographic_means(tract_percentages) ./ demographic_counts_total;
    
    # Get the difference between our counts and the total
    diff_demographic_counts = abs.(1 .- demographic_counts_counted);
    diff_demographic_counts_weighted_sum = dot(weight_demographics, diff_demographic_counts);
    
    return people_counted_fraction + diff_demographic_counts_weighted_sum;
end;

# Define our reward function using fixed visit cost & final state reward.
function reward_fn(s, a)
    time = s[1]
    tract_percentages = s[2:end]
    
    # If we've reached the end time.
    if time == total_timesteps - 1
        return reward_final_fn(tract_percentages);
    end
    
    # If the action is not making a visit.
    if a == 0
        return 0;
    end
    
    return fixed_visit_cost;
end;

In [None]:
# Our initial state should have time t = 1, and every tract at 0 percent.
initial_state = Deterministic([0; [0.0 for _ in 1:n_tracts]]);

In [None]:
# Our actions are 0 for no visits and i for visiting the ith tract.
actions = collect(0:0);

In [None]:
# Define our terminal state.
function is_terminal_fn(s)
   return s[1] >= total_timesteps;
end;

In [None]:
# Define functions to render MDP state as a bar chart.
function render_sa(s, a)
    percentages = s[2:end];
    time_step = s[1];
    f = figure(figsize=(10,5), dpi=300);
    plt.style.use("grayscale")
    # TODO: update axis as percentages approach 1
    ylim(0,100);
    tracts = collect(1:n_tracts);
    colors = [i == a ? "C0" : "C1" for i in tracts];
    bar(tracts, percentages * 100, color=colors, align="center", width=1); # tighten space with using width when graph gets bigger
    draw()

    xticks(tracts, tract_data["SHORTNAME"], rotation=90);
    xlabel("Tracts");
    ylabel("% Response");

    human_time_step = convert(Int, time_step + 1);
    visited = a != 0 ? tract_data["SHORTNAME"][a] : nothing;
    title("""Time step $human_time_step / $total_timesteps: $(a == 0 ? "no visit" : "visit $visited")""");
    return f;
end;

function render_fn(step)
  render_sa(step.s, step.a);
end;

In [None]:
# Now we can define our MDP.
mdp = QuickMDP(
    gen = (s, a, rng) -> (sp=random_step(rng, s, a), r=reward_fn(s, a)),
    actions = actions,
    initialstate = initial_state,
    isterminal = is_terminal_fn,
    render = render_fn,
);

In [None]:
# Create an MCTS solver.
solver = MCTSSolver(n_iterations=10000, depth=iterate_timesteps, exploration_constant=5.0);

In [None]:
# Learn our policy.
policy = solve(solver, mdp);

In [None]:
# Simulate our policy and save process as GIF
frames = Frames(MIME("image/png"), fps=2)
rsum = 0.0
final_state = nothing;
for (s,a,r) in stepthrough(mdp, policy, "s,a,r", max_steps=iterate_timesteps)
    println("s: $s, a: $a")
    push!(frames, render_sa(s, a));
    global rsum += r;
    global final_state = s;
end
println("Undiscounted reward was $rsum.")
println()
write("./outputs/stepthrough.gif", frames)

In [None]:
# Plot demographic comparison of two different states.
function plot_demographics(s_no_demo, s_w_demo)
    plt.style.use("grayscale")
    
    p_no_demo = s_no_demo[2:end];
    p_w_demo = s_w_demo[2:end];
    p_full = ones(n_tracts);
    
    m_full = vec(get_demographic_means(p_full));
    m_no_demo = vec(get_demographic_means(p_no_demo)) ./ m_full;
    m_w_demo = vec(get_demographic_means(p_w_demo)) ./ m_full;
    
    f = figure(figsize=(10,5), dpi=300);
    # ylim(0.0, 1.1);
    demographics = collect(1:n_demographics);
    # bar(tracts, percentages * 100, color=colors, align="center", width=1); # tighten space with using width when graph gets bigger
    plot(demographics, m_no_demo, linestyle=":", label="Sample mean with population-based reward");
    plot(demographics, m_w_demo, linestyle="--", label="Sample mean with demographics-based reward");
    plot(demographics, ones(n_demographics), linestyle="-", label="Population mean");
    x_labels = [wrap(x, width=20) for x in names(tract_data)[demographic_start_col:end]]
    xticks(demographics, x_labels);
    ylabel("Ratio of sample mean to population mean");
    title("Comparison of demographics' sample mean to population mean");
    legend()
    return f;
end;

In [None]:
plot_demographics(pop_state, dem_state);