In [None]:
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

In [None]:
import numpy as np
import pandas as pd
import math
import matplotlib.pyplot as plt
from collections import defaultdict

import random
import glob
import time
from datetime import datetime

import torch
import torch.nn as nn
from torch.distributions import MultivariateNormal
from torch.distributions import Categorical

In [None]:
age_begin = 13
age_end = 100
risk_groups = ['HM', 'HF', 'MSM']
num_risk = 3
age_groups = 88
number_of_risk_groups = 3
number_of_compartments = 22
dt = 1/12
num_age = 88
num_comp = number_of_compartments-2
prep_efficiency = 0.99

In [None]:
unaware_index = (1,5,9,13,17)
aware_no_care_index = (2,6,10,14,18)
ART_VLS_index = (3,4,7,8,11,12,15,16,19,20)
VLS_index = (4,8,12,16,20)

pop_growth_rate = 0

gamma = np.array([[0.5,0.5,0.5,0.5,1],
                  [0.5,0.5,0.5,0.5,1],
                  [0.5,0.5,0.5,0.5,1]])

scaling_factor_dropout = np.array([[1,1,1,1,1,1,1,1,0,0],
                                   [1,1,1,1,1,1,1,1,0,0],
                                   [1,1,1,1,1,1,1,1,0,0]])

In [None]:
all_jurisdictions = [6001,6037,6059,6065,6067,6071,6073,6]
num_jur = len(all_jurisdictions)

cluster1 = [6001]
cluster2 = [6037]
cluster3 = [6059]
cluster4 = [6065]
cluster5 = [6067]
cluster6 = [6071]
cluster7 = [6073]
cluster8 = [6]

In [None]:
data_array_cluster = np.zeros((num_jur, number_of_risk_groups, age_groups, number_of_compartments))

In [None]:
for jur in range(len(all_jurisdictions)):

    df1 = pd.read_excel('jurisdiction_pop_dist.xlsx', sheet_name=str(all_jurisdictions[jur]), index_col=0)
    data_array_cluster[jur] = df1.iloc[:,3:].to_numpy().reshape(3,88,22)

In [None]:
total_pop = np.apply_over_axes(np.sum, data_array_cluster, [1,2,3]).reshape(num_jur,)
total_pop = np.array([int(x) for x in total_pop])

total_msm = list(np.apply_over_axes(np.sum, data_array_cluster, [2,3]).reshape(num_jur,3)[:,2])
total_msm = np.array([int(x) for x in total_msm])

In [None]:
cluster1_index = np.array([0])
cluster2_index = np.array([1])
cluster3_index = np.array([2])
cluster4_index = np.array([3])
cluster5_index = np.array([4])
cluster6_index = np.array([5])
cluster7_index = np.array([6])
cluster8_index = np.array([7])

In [None]:
df_prep = pd.read_excel('Prep values.xlsx')

df_prep_clus = df_prep.loc[df_prep['FIPS'].isin(all_jurisdictions)]

jur_name = df_prep_clus['JUR'].to_list()
df_prep_clus['Total MSM'] = total_msm

In [None]:
prep_values = np.zeros((num_jur,num_risk))
prep_rates_clus = df_prep_clus['Prep'].values
prep_eligible = df_prep_clus['PrEP Eligible'].values
prep_values[:,2] = prep_rates_clus/100

In [None]:
mixing_excel = 'JURI_mixing_weightedBydistance-6-3-2021.xlsx'

mixing_df_hm = pd.read_excel(mixing_excel, sheet_name='HETM_mixing')
mixing_df_hf = pd.read_excel(mixing_excel, sheet_name='HETF_mixing')
mixing_df_msm = pd.read_excel(mixing_excel, sheet_name='MSM_mixing')

In [None]:
mixing_df_hm = mixing_df_hm[['FIPS', 6001,6037,6059,6065,6067,6071,6073,6]]
mixing_df_hf = mixing_df_hf[['FIPS', 6001,6037,6059,6065,6067,6071,6073,6]]
mixing_df_msm = mixing_df_msm[['FIPS', 6001,6037,6059,6065,6067,6071,6073,6]]

mixing_hm = mixing_df_hm.loc[mixing_df_hm['FIPS'].isin(all_jurisdictions)]
mixing_hf = mixing_df_hf.loc[mixing_df_hf['FIPS'].isin(all_jurisdictions)]
mixing_msm = mixing_df_msm.loc[mixing_df_msm['FIPS'].isin(all_jurisdictions)]

In [None]:
hm_array = mixing_hm.values[:,1:]
hf_array = mixing_hf.values[:,1:]
msm_array = mixing_msm.values[:,1:]

hm_sum = np.sum(hm_array, axis=1)
hf_sum = np.sum(hf_array, axis=1)
msm_sum = np.sum(msm_array, axis=1)

hm_array_scaled = hm_array / hm_sum[:,np.newaxis]
hf_array_scaled = hf_array / hf_sum[:,np.newaxis]
msm_array_scaled = msm_array / msm_sum[:,np.newaxis]

mixing_matrix = np.zeros((num_risk,len(all_jurisdictions), len(all_jurisdictions)))

mixing_matrix[0,:,:] = hm_array_scaled[:,:]
mixing_matrix[1,:,:] = hf_array_scaled[:,:]
mixing_matrix[2,:,:] = msm_array_scaled[:,:]

In [None]:
def generate_age_mixing_mat(age_mat, age_begin, age_end):
    
    A1 = age_mat.values.reshape(8,8)    # age_mat- pandas dataframe, reshape(8,8) gives array of values excluding 1st row i.e. column name(age)
    #print(age_mat.values.shape)
    
    A = age_mat.columns.values         # gives (8,) array of column names (age)
    B = np.append(A, age_end+1)        # gives array of (9,1) with last element age_end+1
    
    
    B1 = B-age_begin                   # new array of B-age_begin (difference in age)
    #print(B, B1)
    
    #print(B1)
    
    C1 = np.empty([0,A.shape[0]])      #empty array of (0,8)
    D1 = np.zeros((A.shape[0]))        # array with zeros (8,) 1d array
    
    F1 =np.array([])
    for i in range(len(B1)-1):
        D1[i] = int(B1[i+1]-B1[i])
        F1 = np.append(F1, np.repeat(D1[i], D1[i]))
    #print(F1)
    D1 = D1.astype(int)
    
    for i in range(len(A1)):
        res1 = np.tile(A1[i],(D1[i],1))
        C1 = np.vstack((C1, res1))
    H1 = np.empty([0,88])
    
    for i in range(C1.shape[0]):
        res3 = np.array([])
        for j in range(C1.shape[1]):
            res3 = np.append( res3, (np.repeat(C1[i,j],D1[j])))
        H1 = np.vstack((H1, res3))
    
           
    G1 = np.zeros((88,88))
    for i in range(H1.shape[0]):
        G1[i] = H1[i]/F1
    
    age_mat = G1/100
    
    return(age_mat)

In [None]:
def read_new_inf_input(age_begin, age_end, risk_groups):
    
    path = os.getcwd()
    
    num_risk_groups = len(risk_groups)
    num_age_groups = age_end-age_begin+1
    
    excel = 'input_new_infections.xlsx'
    
    
    """###"""
    q = pd.read_excel(excel, sheet_name='qmat').values.reshape(22,22)
    #print("\nNUMPY READ",q, q.shape, np.sum(q))
    
    """###"""
    Age_hm = pd.read_excel(excel, sheet_name='A_hm')
    
    Age_hf = pd.read_excel(excel, sheet_name='A_hf')
    
    Age_msm = pd.read_excel(excel, sheet_name='A_msm')
    
    
    A_hm = generate_age_mixing_mat(Age_hm, age_begin, age_end)
    A_hf = generate_age_mixing_mat(Age_hf, age_begin, age_end)
    A_msm = generate_age_mixing_mat(Age_msm, age_begin, age_end)
    
    age_mixing_final_mat = np.vstack((A_hm, A_hf, A_msm)).reshape(num_risk_groups,num_age_groups,num_age_groups)
    
    """###"""
    pi = pd.read_excel(excel, sheet_name='pi').values.reshape(20)
    
    """###"""
    num_sex_acts = pd.read_excel(excel, sheet_name='num_sex_acts') #female upper, female lower, male upper, male lower
    number_of_sex_acts_risk_age = np.zeros((num_risk_groups,num_age_groups))
    sex_act_calibration_param = [0.44, 90000000000000000, 90000000000000000]
    
    for risk in range(num_risk_groups):
        current_risk_group = risk_groups[risk]
        for age_index in range(num_age_groups):
            current_age = age_index+age_begin
    
            if((current_risk_group == 'HM') | (current_risk_group == "MSM") | (current_risk_group == "MSMIDU") | (current_risk_group == "IDUM")):
                upper = num_sex_acts[(num_sex_acts['Age_group'] == current_age)].Male_Upper.values
                lower = num_sex_acts[(num_sex_acts['Age_group'] == current_age)].Male_Lower.values
                
            elif((current_risk_group == 'HF') | (current_risk_group == "IDUF")):
                upper = num_sex_acts[(num_sex_acts['Age_group'] == current_age)].Female_Upper.values
                lower = num_sex_acts[(num_sex_acts['Age_group'] == current_age)].Female_Lower.values
            
            number_of_sex_acts_risk_age[risk,age_index] = lower+((upper-lower)/sex_act_calibration_param[risk])
    
    
    """###"""        
    condom_efficiency = pd.read_excel(excel, sheet_name='condom_efficiency').values
    
    """###"""
    prop_anal_acts = np.zeros((num_risk_groups, num_age_groups))
    prop_acts_pd = pd.read_excel(excel, sheet_name='prop_anal_acts')
    for risk in range(num_risk_groups):
        current_risk = risk_groups[risk]
        prop_anal_acts[risk] = prop_acts_pd[current_risk].values
    
    #print(prop_anal_acts, prop_anal_acts.shape)
    """###"""
    
    prop_casual_partner_v = pd.read_excel(excel, sheet_name='prop_casual_partner') 
    prop_casual_partner_risk = np.zeros((num_risk_groups, 2)) #columns [0]prob_casual	[1]prob_casual_only
    
    for risk in range(num_risk_groups):
        current_risk_group = risk_groups[risk]
        prop_casual_partner_risk[risk] = prop_casual_partner_v[(prop_casual_partner_v['Group'] == current_risk_group)].values[:,1:3]
    
    prop_casual_partner_risk_casual = prop_casual_partner_risk[:,0]
    prop_casual_partner_risk_casual_only = prop_casual_partner_risk[:,1]
    
    """###"""
    num_partner = pd.read_excel(excel, sheet_name='num_partner')
    num_partner_risk = np.zeros((num_risk_groups,2))
    
    for risk in range(num_risk_groups):
        current_risk_group = risk_groups[risk]
        num_partner_risk[risk] = num_partner[(num_partner['Group'] == current_risk_group)].values[:,1:3]
    
    num_partner_risk_casual_only = num_partner_risk[:,0]
    num_partner_risk_casual = num_partner_risk[:,1]
    
    #print(num_partner_risk_casual, num_partner_risk_casual_only)
    
    """###"""
    num_cas_part_main_cas = pd.read_excel(excel, sheet_name ='num_cas_part_main-cas').values
    
    """###"""
    prop_condom_use = pd.read_excel(excel, sheet_name='prop_condom_use')
    prop_condom_use_risk_age_casual = np.zeros((num_risk_groups, num_age_groups))
    prop_condom_use_risk_age_main = np.zeros((num_risk_groups, num_age_groups))
    
    
    prop_condom_use_risk_age_main[0] = prop_condom_use.values[:,3]
    prop_condom_use_risk_age_main[1] = prop_condom_use.values[:,1]
    prop_condom_use_risk_age_main[2] = prop_condom_use.values[:,5]
    
    prop_condom_use_risk_age_casual[0] = prop_condom_use.values[:,4]
    prop_condom_use_risk_age_casual[1] = prop_condom_use.values[:,2]
    prop_condom_use_risk_age_casual[2] = prop_condom_use.values[:,6]
    
    """###"""    
    trans_prob = pd.read_excel(excel, sheet_name='trans_prob')
    trans_prob_v_acts = np.zeros((num_risk_groups))
    trans_prob_a_acts = np.zeros((num_risk_groups))
    
    for risk in range(num_risk_groups):
        current_risk_group = risk_groups[risk]
        trans_prob_v_acts[risk] = trans_prob[(trans_prob['Group'] == current_risk_group)].vaginal.values
        trans_prob_a_acts[risk] = trans_prob[(trans_prob['Group'] == current_risk_group)].anal.values
        
    #print(trans_prob_v_acts, trans_prob_a_acts)
    
    """###""" 
    sex_mixing = pd.read_excel(excel, sheet_name='sex_mixing').values
    
    """###""" 
    #excel = 'input_estimating_unknown_rates_PATH.xlsx'
    excel = 'input_estimating_unknown_rates_HOPE.xlsx'
    testing_mult_fac = pd.read_excel(excel, sheet_name = 'testing_mult_fac')
    testing_mult_fac_risk = np.zeros((num_risk_groups,5))
    
    for risk in range(num_risk_groups):
        current_risk_group = risk_groups[risk]
        testing_mult_fac_risk[risk] = testing_mult_fac[current_risk_group]
        
        
    """###"""
    excel = 'input_new_infections.xlsx'
    age_mixing_diag = pd.read_excel(excel, sheet_name = 'age_mix_diagonals')
    #print(age_mixing_diag.values[0])
    #print(num_cas_part_main_cas)
    new_infections_data = {
        "q": q,
        "age_mixing_final_mat": age_mixing_final_mat,
        "pi": pi,
        "number_of_sex_acts_risk_age": number_of_sex_acts_risk_age,
        "condom_efficiency": condom_efficiency,
        "prop_anal_acts": prop_anal_acts,
        "prop_casual_partner_risk_casual": prop_casual_partner_risk_casual,
        "prop_casual_partner_risk_casual_only": prop_casual_partner_risk_casual_only,
        "num_partner_risk_casual": num_partner_risk_casual,
        "num_partner_risk_casual_only": num_partner_risk_casual_only,
        "num_cas_part_main_cas": num_cas_part_main_cas,
        "prop_condom_use_risk_age_main": prop_condom_use_risk_age_main,
        "prop_condom_use_risk_age_casual": prop_condom_use_risk_age_casual,
        "trans_prob_v_acts": trans_prob_v_acts,
        "trans_prob_a_acts": trans_prob_a_acts,
        "sex_mixing": sex_mixing,
        "testing_mult_fac_risk": testing_mult_fac_risk,
        "age_mixing_diagonals": age_mixing_diag}
    
    return(new_infections_data)

In [None]:
def read_death_rates(age_begin, age_end, risk_groups):
    
    num_risk_groups = len(risk_groups)
    num_age_groups = age_end-age_begin+1
    
    #excel = 'input_estimating_unknown_rates_PATH.xlsx'
    #excel = 'input_estimating_unknown_rates_HOPE.xlsx'
    excel = 'input_estimating_unknown_rates_MOD_v2.xlsx'
    
    """death_rate_uninf = pd.read_excel(excel, sheet_name = 'death_prob_uninf')
    death_rate_inf = pd.read_excel(excel, sheet_name = 'death_prob_inf')
    death_rate_a200 = pd.read_excel(excel, sheet_name = 'death_prob_a200')
    death_rate_b200 = pd.read_excel(excel, sheet_name = 'death_prob_b200')
    
    death_rate_uninf_risk_age = np.zeros((num_risk_groups, num_age_groups))
    for risk in range(num_risk_groups):
        current_risk_group = risk_groups[risk]
        
        death_rate_uninf_risk_age[risk] = death_rate_uninf[current_risk_group].values
    
    death_rate_inf_v = death_rate_inf.values  
    death_rate_a200_age_2010 = death_rate_a200[2010].values
    death_rate_a200_age_2016 = death_rate_a200[2016].values
    death_rate_b200_age_2010 = death_rate_b200[2010].values
    death_rate_b200_age_2016 = death_rate_b200[2016].values
    
    
    death_rate_inf_no_ART_acute = 0
    death_rate_inf_no_ART_above_500 = 0
    death_rate_inf_no_ART_above_350_500 = 0
    death_rate_inf_no_ART_below_200 = 0
    death_rate_inf_ART_below_200_age = 0
    death_rate_inf_ART_200_350_age = 0
    death_rate_inf_ART_above_350 = 0
    """
    death_rate_a200_age_2010 = 0    
    death_rate_a200_age_2016 = 0
    death_rate_b200_age_2010 = 0
    death_rate_b200_age_2016 = 0
    death_rate_inf_v = 0
    
    #The following lines of code are used if I used MOD dataset
    death_rate_uninf = pd.read_excel(excel, sheet_name = 'death_prob_uninf')
    death_rate_inf_no_art = pd.read_excel(excel, sheet_name = 'death_prob_inf_no_art')
    death_prob_inf_art = pd.read_excel(excel, sheet_name = 'death_prob_inf_art')
    
    death_rate_uninf_risk_age = np.zeros((num_risk_groups, num_age_groups))
    for risk in range(num_risk_groups):
        current_risk_group = risk_groups[risk]
        
        death_rate_uninf_risk_age[risk] = death_rate_uninf[current_risk_group].values
    
    
    
    death_rate_inf_no_ART_acute = death_rate_inf_no_art[(death_rate_inf_no_art['CD4_category'] == "Acute")]["death_rate"].values
    death_rate_inf_no_ART_above_500 = death_rate_inf_no_art[(death_rate_inf_no_art['CD4_category'] == "CD4 >500")]["death_rate"].values
    death_rate_inf_no_ART_above_350_500 = death_rate_inf_no_art[(death_rate_inf_no_art['CD4_category'] == "CD4 350-500")]["death_rate"].values
    death_rate_inf_no_ART_above_200_350 = death_rate_inf_no_art[(death_rate_inf_no_art['CD4_category'] == "CD4 200-350")]["death_rate"].values
    death_rate_inf_no_ART_below_200 = death_rate_inf_no_art[(death_rate_inf_no_art['CD4_category'] == "CD4 <200")]["death_rate"].values
    
    death_rate_inf_ART_below_200_age = death_prob_inf_art["CD4_b_200"].values
    death_rate_inf_ART_200_350_age = death_prob_inf_art["CD4_200_350"].values
    death_rate_inf_ART_above_350_age = death_prob_inf_art["CD4_a_350"].values
    
    death_prob_data = {
        "death_rate_uninf_risk_age": death_rate_uninf_risk_age,
        "death_rate_inf": death_rate_inf_v,
        "death_rate_a200_age_2010": death_rate_a200_age_2010,
        "death_rate_a200_age_2016": death_rate_a200_age_2016,
        "death_rate_b200_age_2010": death_rate_b200_age_2010,
        "death_rate_b200_age_2016": death_rate_b200_age_2016,
        "death_rate_inf_no_ART_acute": death_rate_inf_no_ART_acute,
        "death_rate_inf_no_ART_above_500": death_rate_inf_no_ART_above_500,
        "death_rate_inf_no_ART_above_350_500": death_rate_inf_no_ART_above_350_500,
        "death_rate_inf_no_ART_above_200_350": death_rate_inf_no_ART_above_200_350,
        "death_rate_inf_no_ART_below_200": death_rate_inf_no_ART_below_200,
        "death_rate_inf_ART_below_200_age": death_rate_inf_ART_below_200_age,
        "death_rate_inf_ART_200_350_age": death_rate_inf_ART_200_350_age,
        "death_rate_inf_ART_above_350_age": death_rate_inf_ART_above_350_age}
    
    
    #print(death_prob_data)
    
    #print(death_prob_data["death_rate_uninf_risk_age"].shape, death_prob_data["death_rate_uninf_risk_age"])
    
    return(death_prob_data)

In [None]:
def calculate_deaths_vector(number_of_compartments, 
                            risk_groups, 
                            group, 
                            age, 
                            death_prob_data):
    
    current_age = age+13
    current_risk_group = risk_groups[group]
    #print(death_prob_data["death_rate_uninf_risk_age"].shape)
    """print("Group = ", current_risk_group)
    print("Age = ",current_age)
    print("year_to_simulate =", year_to_simulate)"""
    """ Need to multiply this with dt to get monthly rates if simulation time interval is monthly"""
    """
    
    death_col_annual = np.zeros((number_of_compartments,1))
    death_col_annual[0,:] = death_prob_data["death_rate_uninf_risk_age"][group,age].copy()
    
    #ART_more_200 = np.arange(1,18)
    #no_ART_less_200 = (17,18)
    #ART_less_200 = (19,20)
    
    ART_more_200 = np.array([3,4,7,8,11,12,15,16])
    no_ART = np.array([1,2,5,6,9,10,13,14,17,18])
    ART_less_200 = np.array([19,20])
    
    
    if(year_to_simulate < 2016):
        
        death_rate_after_200 = death_prob_data["death_rate_a200_age_2010"][age].copy()
        death_rate_before_200 = death_prob_data["death_rate_b200_age_2010"][age].copy()
    
    elif(year_to_simulate >= 2016):
        
        death_rate_after_200 = death_prob_data["death_rate_a200_age_2016"][age].copy()
        death_rate_before_200 = death_prob_data["death_rate_b200_age_2016"][age].copy()
    #print("death_prob_data",death_prob_data["death_rate_inf"][0,0].copy())
    
    death_col_annual[ART_more_200,:] = np.repeat(death_rate_after_200, len(ART_more_200)).reshape(len(ART_more_200),1)
    death_col_annual[no_ART,:] = np.repeat(death_prob_data["death_rate_inf"][0,0].copy(), len(no_ART)).reshape(len(no_ART),1)
    death_col_annual[ART_less_200,:] = np.repeat(death_rate_before_200, len(ART_less_200)).reshape(len(ART_less_200),1)        
    death_col_annual[21,:] = 0
    """
    
    # MODIFIED rates (include CD4 specific values)
    
    death_col_annual = np.zeros((number_of_compartments,1))
    death_col_annual[0,:] = death_prob_data["death_rate_uninf_risk_age"][group,age].copy()
    
    #print(death_prob_data["death_rate_inf_no_ART_acute"])
    death_col_annual[1,:] = death_prob_data["death_rate_inf_no_ART_acute"].copy()
    death_col_annual[2,:] = death_prob_data["death_rate_inf_no_ART_acute"].copy()
    death_col_annual[3,:] = death_prob_data["death_rate_inf_ART_above_350_age"][age].copy()
    death_col_annual[4,:] = death_prob_data["death_rate_inf_ART_above_350_age"][age].copy()
    
    death_col_annual[5,:] = death_prob_data["death_rate_inf_no_ART_above_500"].copy()
    death_col_annual[6,:] = death_prob_data["death_rate_inf_no_ART_above_500"].copy()
    death_col_annual[7,:] = death_prob_data["death_rate_inf_ART_above_350_age"][age].copy()
    death_col_annual[8,:] = death_prob_data["death_rate_inf_ART_above_350_age"][age].copy()  
    
    death_col_annual[9,:] = death_prob_data["death_rate_inf_no_ART_above_350_500"].copy()
    death_col_annual[10,:] = death_prob_data["death_rate_inf_no_ART_above_350_500"].copy()
    death_col_annual[11,:] = death_prob_data["death_rate_inf_ART_above_350_age"][age].copy()
    death_col_annual[12,:] = death_prob_data["death_rate_inf_ART_above_350_age"][age].copy()
    
    death_col_annual[13,:] = death_prob_data["death_rate_inf_no_ART_above_200_350"].copy()
    death_col_annual[14,:] = death_prob_data["death_rate_inf_no_ART_above_200_350"].copy()
    death_col_annual[15,:] = death_prob_data["death_rate_inf_ART_200_350_age"][age].copy()
    death_col_annual[16,:] = death_prob_data["death_rate_inf_ART_200_350_age"][age].copy()
    
    death_col_annual[17,:] = death_prob_data["death_rate_inf_no_ART_below_200"].copy()
    death_col_annual[18,:] = death_prob_data["death_rate_inf_no_ART_below_200"].copy()
    death_col_annual[19,:] = death_prob_data["death_rate_inf_ART_below_200_age"][age].copy()
    death_col_annual[20,:] = death_prob_data["death_rate_inf_ART_below_200_age"][age].copy()
    
    death_col_annual[21,:] = 0
        
    #print(1-death_col_annual)
    death_col_log = (-np.log(1-death_col_annual)) #*(1-0.56)#*(1-0.4543948)
    #print("group=",group,"age=",age,"\n")
    #print(np.sum(death_col_log))
    #print(death_col_log)
    """if(group == 0):
        if(age == 0):
            print(death_col_log)"""
    return(death_col_log)

In [None]:
def ltc_prep_values(ltc_excel, jur_list):
    
    df_hm_ltc = pd.read_excel(ltc_excel, sheet_name='jur_specific_care_cont_hm')
    df_hf_ltc = pd.read_excel(ltc_excel, sheet_name='jur_specific_care_cont_hf')
    df_msm_ltc = pd.read_excel(ltc_excel, sheet_name='jur_specific_care_cont_msm')
    
    ltc_risk = np.zeros((len(jur_list), number_of_risk_groups))
    
    for loc in range(len(jur_list)):
        d_hm = df_hm_ltc[df_hm_ltc['FIPS']==jur_list[loc]]
        d_hf = df_hf_ltc[df_hf_ltc['FIPS']==jur_list[loc]]
        d_msm = df_msm_ltc[df_msm_ltc['FIPS']==jur_list[loc]]

        ltc_vals = np.array([d_hm.LTC.values[0],
                             d_hf.LTC.values[0],
                             d_msm.LTC.values[0]])
        
        ltc_risk[loc] = ltc_vals
    
    return ltc_risk

In [None]:
def M_x1_y1_value(new_infections_data):
    
    condom_awareness = np.array([0,0.53,0.53,0.53])
    prob_condom_efficency = new_infections_data["condom_efficiency"][0,0].copy()

    c = np.tile(condom_awareness,5).reshape(number_of_compartments-2)
    
    pi_v = new_infections_data["pi"].reshape((number_of_compartments-2)).copy()

    pi = pi_v[np.newaxis,:]
    p_v_x1 = new_infections_data["trans_prob_v_acts"].copy()[:,np.newaxis]
    p_a_x1 = new_infections_data["trans_prob_a_acts"].copy()[:,np.newaxis]

    p_bar_v_x1 = (1-prob_condom_efficency)*p_v_x1
    p_bar_a_x1 = (1-prob_condom_efficency)*p_a_x1
    
    num_sex_acts = new_infections_data["number_of_sex_acts_risk_age"].copy()
    prob_anal_acts = new_infections_data["prop_anal_acts"].copy()
    n_v_x1_y1 = num_sex_acts*(1-prob_anal_acts)
    n_a_x1_y1 = num_sex_acts*prob_anal_acts
    
    num_cas_part_main_cas = new_infections_data["num_cas_part_main_cas"].copy()[0:num_risk]
    nc = (num_cas_part_main_cas*2)/(num_sex_acts)
    nm = 1-nc
    
    prob_casual_only = new_infections_data["prop_casual_partner_risk_casual_only"].copy()[:,np.newaxis]
    prob_casual = new_infections_data["prop_casual_partner_risk_casual"].copy()[:,np.newaxis]

    prob_condom_casual = new_infections_data["prop_condom_use_risk_age_casual"].copy()
    prob_condom_main = new_infections_data["prop_condom_use_risk_age_main"].copy()

    prop_condom_use_calc = ((prob_casual_only*prob_condom_casual)+                                     #only casual
                                        (abs(prob_casual-prob_casual_only)*prob_condom_casual*nc)+ #casual among casual-main
                                        (abs(prob_casual-prob_casual_only)*prob_condom_main*nm)+     #main among casual-main
                                        ((1-prob_casual)*prob_condom_main)) 
    
    num_partners_tot = new_infections_data["num_partner_risk_casual"].copy() + \
                        new_infections_data["num_partner_risk_casual_only"].copy()
    
    lower1 = (1-(p_bar_v_x1*pi))[:,np.newaxis,:]

    upper1 = ((n_v_x1_y1*dt)[:,:,np.newaxis]*(((1-prop_condom_use_calc)[:,:,np.newaxis]*c[np.newaxis,np.newaxis,:])+
                                             prop_condom_use_calc[:,:,np.newaxis]))/(num_partners_tot*dt)[:,np.newaxis,np.newaxis]

    AA = lower1**upper1
    
    lower2 = (1-(p_v_x1*pi))[:,np.newaxis,:]

    upper2 = ((n_v_x1_y1*dt)[:,:,np.newaxis]*(1-prop_condom_use_calc)[:,:,np.newaxis]*(1-c[np.newaxis,np.newaxis,:]))/  \
                                            (num_partners_tot*dt)[:,np.newaxis,np.newaxis]

    BB = lower2**upper2
    
    lower3 = (1-(p_bar_a_x1*pi))[:,np.newaxis,:]

    upper3 = ((n_a_x1_y1*dt)[:,:,np.newaxis]*(((1-prop_condom_use_calc)[:,:,np.newaxis]*c[np.newaxis,np.newaxis,:])+
                                              prop_condom_use_calc[:,:,np.newaxis]))/(num_partners_tot*dt)[:,np.newaxis,np.newaxis]

    CC = lower3**upper3
    
    lower4 = (1-(p_a_x1*pi))[:,np.newaxis,:]

    upper4 = ((n_a_x1_y1*dt)[:,:,np.newaxis]*(1-prop_condom_use_calc)[:,:,np.newaxis]*(1-c[np.newaxis,np.newaxis,:]))/  \
                                            (num_partners_tot*dt)[:,np.newaxis,np.newaxis]

    DD = lower4**upper4
    
    M_x1_y1_i = 1-(1-(AA*BB*CC*DD))
    
    return M_x1_y1_i    

In [None]:
def new_infections_per_month(num_jur, data_array, new_infections_data, M_x1_y1_i, prep_risk):
    
    risk_mat = new_infections_data["sex_mixing"].copy()[0:num_risk,0:num_risk]
    age_mat = new_infections_data["age_mixing_final_mat"].copy()
    
    I = data_array[:,:,:,1:21]
    N = np.sum(data_array[:,:,:,0:21], axis=3)
    d_x1_y1 = new_infections_data["num_partner_risk_casual"].copy()+new_infections_data["num_partner_risk_casual_only"].copy()
    sus_x1_y1 = data_array[:,:,:,0]
    mat_vector = np.repeat(age_mat[:,np.newaxis,:,:],num_risk, axis = 1) * risk_mat[:,:,np.newaxis, np.newaxis]
    I_N_vector = I / N[:,:,:,np.newaxis]
    I_N_mult_vector = mat_vector[np.newaxis,:,:,:,:,np.newaxis]*I_N_vector[:,np.newaxis,:,np.newaxis,:,:]
    Q_inner_vector = np.apply_over_axes(np.sum, I_N_mult_vector, [2,4]).reshape((num_jur, num_risk,num_age,num_comp))
    q_x_y_i_vector = d_x1_y1[np.newaxis,:,np.newaxis,np.newaxis]*dt*Q_inner_vector
    q_mix_vector = np.zeros((num_jur,num_jur,num_risk,num_age,num_comp))
    for risk in range(num_risk):
        q_mix_vector[:,:,risk,:,:] = mixing_matrix[risk][:,:,np.newaxis,np.newaxis]*q_x_y_i_vector[:,risk,:,:][np.newaxis,:,:,:]
    q_mix_sum_vector = np.sum(q_mix_vector, axis = 1)
    M_power_vector = M_x1_y1_i[np.newaxis,:,:,:]**q_mix_sum_vector
    M_prod_vector = 1-np.prod(M_power_vector, axis = 3) 
    
    new_inf_per_month = sus_x1_y1*(1 - prep_risk[:,:,np.newaxis])*M_prod_vector
    
    return new_inf_per_month

In [None]:
def calculate_proportions(data_array, num_jur, number_of_risk_groups, unaware_index, aware_no_care_index, ART_VLS_index,VLS_index):
    
    plwh_risk = np.zeros((num_jur, number_of_risk_groups))
    unaware_risk = np.zeros((num_jur, number_of_risk_groups))
    aware_no_art_risk = np.zeros((num_jur, number_of_risk_groups))
    aware_art_vls_risk = np.zeros((num_jur, number_of_risk_groups))
    vls_risk = np.zeros((num_jur, number_of_risk_groups))
    
    for risk in range(number_of_risk_groups):
        plwh_risk[:,risk] = np.apply_over_axes(np.sum, data_array[:,risk,:,1:21], [1,2]).reshape(num_jur,)
        unaware_risk[:,risk] = np.apply_over_axes(np.sum, data_array[:,risk,:,unaware_index], [0,2]).reshape(num_jur,)
        aware_no_art_risk[:,risk] = np.apply_over_axes(np.sum, data_array[:,risk,:,aware_no_care_index], [0,2]).reshape(num_jur,)
        aware_art_vls_risk[:,risk] = np.apply_over_axes(np.sum, data_array[:,risk,:,ART_VLS_index], [0,2]).reshape(num_jur,)
        vls_risk[:,risk] = np.apply_over_axes(np.sum, data_array[:,risk,:,VLS_index], [0,2]).reshape(num_jur,)
        
    
    total_pop = np.apply_over_axes(np.sum, data_array[:,:,:,0:21], [1,2,3]).reshape(num_jur,1)
    
    prevalence_prop = plwh_risk/total_pop
    
    unaware_prop = unaware_risk/plwh_risk
    aware_no_art_prop = aware_no_art_risk/plwh_risk
    aware_art_vls_prop = aware_art_vls_risk/plwh_risk
    vls_prop = vls_risk/plwh_risk
    
    return total_pop, np.round(prevalence_prop, 6), np.round(unaware_prop, 6), \
                np.round(aware_no_art_prop, 6), np.round(aware_art_vls_prop, 6), vls_prop

In [None]:
def diagnosis_rate(data_array, num_jur, a_unaware, unaware_index, number_of_risk_groups, new_inf_per_month, unaware_prop, death_per_month_risk_age_compartments):
    
    diagnosis_rate_risk = np.zeros((num_jur, number_of_risk_groups))
    
    a_unaware_t = np.round(a_unaware*dt, 6)

    for risk in range(len(risk_groups)):
        # new infectiion per month
        A = np.sum(new_inf_per_month, axis=2)[:,risk]

        # number of unaware population
        B = np.apply_over_axes(np.sum, data_array[:,risk,:,unaware_index], [0,2]).reshape(num_jur,)

        #number of unaware next time period
        # (current inf + new inf - total death)
        C = (np.apply_over_axes(np.sum, data_array[:,risk,:,1:21], [1,2]).reshape(num_jur,) + A - np.apply_over_axes(np.sum, death_per_month_risk_age_compartments[:,risk,:,1:21],[1,2]).reshape(num_jur,)) * (unaware_prop[:,risk] + a_unaware_t[:,risk])

        # total deaths in each compartment
        D = np.apply_over_axes(np.sum, death_per_month_risk_age_compartments[:,risk,:,unaware_index],[0,2]).reshape(num_jur,)

        # number of people in unaware compartment
        E = np.sum(np.sum(data_array[:,risk,:, unaware_index],axis=2)*new_infections_data["testing_mult_fac_risk"][risk].reshape(5,1), axis=0)
        
        diagnosis_rate_risk[:,risk] = (A+B-C-D)/E
        
        diagnosis_rate_risk[:,risk][diagnosis_rate_risk[:,risk] < 0] = 0
        
    return diagnosis_rate_risk

In [None]:
def dropout_rate(num_jur, a_art, ART_VLS_index, diagnosis_rate_risk, ltc_risk, gamma, number_of_risk_groups, data_array, new_inf_per_month, unaware_prop, aware_no_art_prop, aware_art_vls_prop, death_per_month_risk_age_compartments):
    
    dropout_rate_risk = np.zeros((num_jur, number_of_risk_groups))
    
    a_art_t = np.round(a_art *dt, 6)
    gamma_t = np.round(gamma *dt, 6)
    
    for risk in range(len(risk_groups)):
       # total art vls pop 
        F = np.apply_over_axes(np.sum, data_array[:,risk,:,ART_VLS_index], [0,2]).reshape(num_jur,)
        #multiply F with phi for denominator
        
        K = np.sum(np.sum(data_array[:,risk,:,ART_VLS_index], axis=2)*scaling_factor_dropout[risk].reshape(10,1), axis=0)        
        # diagnosed and linked to care
        
        G = diagnosis_rate_risk[:,risk]*ltc_risk[:,risk]*np.sum((np.sum(data_array[:,risk,:,unaware_index],axis=2))*new_infections_data["testing_mult_fac_risk"][risk].reshape(5,1), axis=0)

        #entering care from unaware
        H = np.sum(gamma_t[risk].reshape(5,1)*np.sum(data_array[:,risk,:,aware_no_care_index], axis=2), axis=0)

        #total death art vls
        I = np.apply_over_axes(np.sum,death_per_month_risk_age_compartments[:,risk,:,ART_VLS_index],[0,2]).reshape(num_jur,)

        #number of art vls next time period
        J = (np.apply_over_axes(np.sum, data_array[:,risk,:,1:21], [1,2]).reshape(num_jur,) + np.sum(new_inf_per_month, axis=2)[:,risk] - np.apply_over_axes(np.sum, death_per_month_risk_age_compartments[:,risk,:,1:21],[1,2]).reshape(num_jur,)) * (aware_art_vls_prop[:,risk] + a_art_t[:,risk])
        
        dropout_rate_risk[:,risk] = (F+G+H-I-J)/K
        
#         if dropout_rate_risk[:,risk].any() < 0:
        dropout_rate_risk[:,risk][dropout_rate_risk[:,risk] < 0] = 0
        
    return dropout_rate_risk

In [None]:
def q_matrix(num_jur, new_infections_data, diagnosis_rate_risk, dropout_rate_risk, ltc_risk):
    
    Q_MAT = new_infections_data['q']
    Q_matrix = np.zeros((num_jur, number_of_risk_groups, number_of_compartments, number_of_compartments))
    
    for jur in range(num_jur):
        for risk in range(number_of_risk_groups):
            Q_mat = Q_MAT.copy()
            Q_mat[np.where(Q_mat == 12345)] = (1 - ltc_risk[jur,risk]) * diagnosis_rate_risk[jur,risk]*new_infections_data["testing_mult_fac_risk"][risk]
            Q_mat[np.where(Q_mat == 123456)] = dropout_rate_risk[jur,risk]
            Q_mat[np.where(Q_mat == 1234567)] = ltc_risk[jur,risk] * diagnosis_rate_risk[jur,risk]*new_infections_data["testing_mult_fac_risk"][risk]

            Q_matrix[jur,risk] = Q_mat
            
        
    return Q_matrix  

In [None]:
def q_mat_diag(Q_matrix, num_jur):
    
    Q_matrix_diagonal = np.zeros((num_jur, number_of_risk_groups, number_of_compartments, number_of_compartments))
    
    for jur in range(num_jur):
        for risk in range(number_of_risk_groups):
            Q_i = Q_matrix[jur,risk].copy()
            Q_i_sum = np.sum(Q_i, 1)
            Q_matrix_diagonal[jur,risk] = np.diag(Q_i_sum)
        
    return Q_matrix_diagonal

In [None]:
def aging(data_array, pop_susceptible_12_years):
    new_pop = np.zeros((num_jur, number_of_risk_groups, num_age, number_of_compartments))
    
    new_pop[:,:,1:,:] = data_array[:,:,0:num_age-1,:]
    new_pop[:,:,0,0] = pop_susceptible_12_years
    
    return new_pop

In [None]:
ltc_excel ='CareContinuum-by_jur 7-26-2021 HT V2.xlsx'

ltc_risk = ltc_prep_values(ltc_excel, all_jurisdictions)  # linked to care rate

new_infections_data = read_new_inf_input(age_begin, age_end, risk_groups) # new infection data from excel

M_x1_y1_i = M_x1_y1_value(new_infections_data)

death_prob_data = read_death_rates(age_begin, age_end, risk_groups)

death_rate_risk_age_compartments = np.zeros((num_jur, number_of_risk_groups, age_groups, number_of_compartments))

for risk in range(len(risk_groups)):
    for age in range(age_groups):
        death_rate_risk_age_compartments[:,risk,age,] = calculate_deaths_vector(number_of_compartments, risk_groups, risk, age, death_prob_data).reshape(22)

In [None]:
def extract_state(data_array, prep_values):
    
    data_array_cluster1 = data_array[cluster1_index,:,:,:]
    data_array_cluster2 = data_array[cluster2_index,:,:,:]
    data_array_cluster3 = data_array[cluster3_index,:,:,:]
    data_array_cluster4 = data_array[cluster4_index,:,:,:]
    data_array_cluster5 = data_array[cluster5_index,:,:,:]
    data_array_cluster6 = data_array[cluster6_index,:,:,:]
    data_array_cluster7 = data_array[cluster7_index,:,:,:]
    data_array_cluster8 = data_array[cluster8_index,:,:,:]
    
    #prep_rate
    total_data_cluster1 = np.sum(data_array_cluster1, axis=0)
    total_data_cluster1 = total_data_cluster1[np.newaxis,:,:,:]
    
    total_data_cluster2 = np.sum(data_array_cluster2, axis=0)
    total_data_cluster2 = total_data_cluster2[np.newaxis,:,:,:]
    
    total_data_cluster3 = np.sum(data_array_cluster3, axis=0)
    total_data_cluster3 = total_data_cluster3[np.newaxis,:,:,:]
    
    total_data_cluster4 = np.sum(data_array_cluster4, axis=0)
    total_data_cluster4 = total_data_cluster4[np.newaxis,:,:,:]
    
    total_data_cluster5 = np.sum(data_array_cluster5, axis=0)
    total_data_cluster5 = total_data_cluster5[np.newaxis,:,:,:]
    
    total_data_cluster6 = np.sum(data_array_cluster6, axis=0)
    total_data_cluster6 = total_data_cluster6[np.newaxis,:,:,:]
    
    total_data_cluster7 = np.sum(data_array_cluster7, axis=0)
    total_data_cluster7 = total_data_cluster7[np.newaxis,:,:,:]
    
    total_data_cluster8 = np.sum(data_array_cluster8, axis=0)
    total_data_cluster8 = total_data_cluster8[np.newaxis,:,:,:]

    
    total_pop1, prevalence_prop1, unaware_prop1, aware_no_art_prop1, aware_art_vls_prop1,_ = calculate_proportions(total_data_cluster1, 1, number_of_risk_groups, unaware_index, aware_no_care_index, ART_VLS_index, VLS_index)
    total_pop2, prevalence_prop2, unaware_prop2, aware_no_art_prop2, aware_art_vls_prop2,_ = calculate_proportions(total_data_cluster2, 1, number_of_risk_groups, unaware_index, aware_no_care_index, ART_VLS_index, VLS_index)
    total_pop3, prevalence_prop3, unaware_prop3, aware_no_art_prop3, aware_art_vls_prop3,_ = calculate_proportions(total_data_cluster3, 1, number_of_risk_groups, unaware_index, aware_no_care_index, ART_VLS_index, VLS_index)
    total_pop4, prevalence_prop4, unaware_prop4, aware_no_art_prop4, aware_art_vls_prop4,_ = calculate_proportions(total_data_cluster4, 1, number_of_risk_groups, unaware_index, aware_no_care_index, ART_VLS_index, VLS_index)
    total_pop5, prevalence_prop5, unaware_prop5, aware_no_art_prop5, aware_art_vls_prop5,_ = calculate_proportions(total_data_cluster5, 1, number_of_risk_groups, unaware_index, aware_no_care_index, ART_VLS_index, VLS_index)
    total_pop6, prevalence_prop6, unaware_prop6, aware_no_art_prop6, aware_art_vls_prop6,_ = calculate_proportions(total_data_cluster6, 1, number_of_risk_groups, unaware_index, aware_no_care_index, ART_VLS_index, VLS_index)
    total_pop7, prevalence_prop7, unaware_prop7, aware_no_art_prop7, aware_art_vls_prop7,_ = calculate_proportions(total_data_cluster7, 1, number_of_risk_groups, unaware_index, aware_no_care_index, ART_VLS_index, VLS_index)
    total_pop8, prevalence_prop8, unaware_prop8, aware_no_art_prop8, aware_art_vls_prop8,_ = calculate_proportions(total_data_cluster8, 1, number_of_risk_groups, unaware_index, aware_no_care_index, ART_VLS_index, VLS_index)

    prep_coverage1 = data_array_cluster1[:,2,:,0]*prep_values[cluster1_index,2][:,np.newaxis]
    prep1 = np.round(np.apply_over_axes(np.sum, prep_coverage1, [0,1]).item()/np.apply_over_axes(np.sum, data_array_cluster1[:,2,:,0], [0,1]).item(), 4)
    prep_prop1 = np.array([0,0,prep1])
    
    prep_coverage2 = data_array_cluster2[:,2,:,0]*prep_values[cluster2_index,2][:,np.newaxis]
    prep2 = np.round(np.apply_over_axes(np.sum, prep_coverage2,[0,1]).item()/np.apply_over_axes(np.sum, data_array_cluster2[:,2,:,0], [0,1]).item(), 4)
    prep_prop2 = np.array([0,0,prep2])
    
    prep_coverage3= data_array_cluster3[:,2,:,0]*prep_values[cluster3_index,2][:,np.newaxis]
    prep3= np.round(np.apply_over_axes(np.sum, prep_coverage3, [0,1]).item()/np.apply_over_axes(np.sum, data_array_cluster3[:,2,:,0], [0,1]).item(), 4)
    prep_prop3= np.array([0,0,prep3])
    
    prep_coverage4 = data_array_cluster4[:,2,:,0]*prep_values[cluster4_index,2][:,np.newaxis]
    prep4 = np.round(np.apply_over_axes(np.sum, prep_coverage4, [0,1]).item()/np.apply_over_axes(np.sum, data_array_cluster4[:,2,:,0], [0,1]).item(), 4)
    prep_prop4 = np.array([0,0,prep4])
    
    prep_coverage5 = data_array_cluster5[:,2,:,0]*prep_values[cluster5_index,2][:,np.newaxis]
    prep5 = np.round(np.apply_over_axes(np.sum, prep_coverage5,[0,1]).item()/np.apply_over_axes(np.sum, data_array_cluster5[:,2,:,0], [0,1]).item(), 4)
    prep_prop5 = np.array([0,0,prep5])
    
    prep_coverage6 = data_array_cluster6[:,2,:,0]*prep_values[cluster6_index,2][:,np.newaxis]
    prep6= np.round(np.apply_over_axes(np.sum, prep_coverage6, [0,1]).item()/np.apply_over_axes(np.sum, data_array_cluster6[:,2,:,0], [0,1]).item(), 4)
    prep_prop6= np.array([0,0,prep6])
    
    prep_coverage7 = data_array_cluster7[:,2,:,0]*prep_values[cluster7_index,2][:,np.newaxis]
    prep7 = np.round(np.apply_over_axes(np.sum, prep_coverage7, [0,1]).item()/np.apply_over_axes(np.sum, data_array_cluster7[:,2,:,0], [0,1]).item(), 4)
    prep_prop7 = np.array([0,0,prep7])
    
    prep_coverage8 = data_array_cluster8[:,2,:,0]*prep_values[cluster8_index,2][:,np.newaxis]
    prep8 = np.round(np.apply_over_axes(np.sum, prep_coverage8,[0,1]).item()/np.apply_over_axes(np.sum, data_array_cluster8[:,2,:,0], [0,1]).item(), 4)
    prep_prop8 = np.array([0,0,prep8])
    
    
    current_state1 = np.transpose(np.vstack((prevalence_prop1, unaware_prop1, aware_no_art_prop1, aware_art_vls_prop1, prep_prop1)))
    current_state2 = np.transpose(np.vstack((prevalence_prop2, unaware_prop2, aware_no_art_prop2, aware_art_vls_prop2, prep_prop2)))
    current_state3 = np.transpose(np.vstack((prevalence_prop3, unaware_prop3, aware_no_art_prop3, aware_art_vls_prop3, prep_prop3)))
    current_state4 = np.transpose(np.vstack((prevalence_prop4, unaware_prop4, aware_no_art_prop4, aware_art_vls_prop4, prep_prop4)))
    current_state5 = np.transpose(np.vstack((prevalence_prop5, unaware_prop5, aware_no_art_prop5, aware_art_vls_prop5, prep_prop5)))
    current_state6 = np.transpose(np.vstack((prevalence_prop6, unaware_prop6, aware_no_art_prop6, aware_art_vls_prop6, prep_prop6)))
    current_state7 = np.transpose(np.vstack((prevalence_prop7, unaware_prop7, aware_no_art_prop7, aware_art_vls_prop7, prep_prop7)))
    current_state8 = np.transpose(np.vstack((prevalence_prop8, unaware_prop8, aware_no_art_prop8, aware_art_vls_prop8, prep_prop8)))
    
    return current_state1, current_state2, current_state3, current_state4, current_state5, current_state6, current_state7, current_state8


In [None]:
# initial_data = data_array_cluster.copy()

def initial_state(initial_data, prep_values):
    state1,state2,state3,state4,state5,state6,state7,state8 = extract_state(initial_data, prep_values)
    time = 0
    return initial_data,state1,state2,state3,state4,state5,state6,state7,state8, prep_values, time

In [None]:
state = initial_state(data_array_cluster, prep_values)

In [None]:
def shi_value(I, N):
    return (0.0153/0.0057)*(I/N)

In [None]:
theta = 0.3
tau = 0.3
mu = 0.002
alpha = 0.8

c_r_n = 22.13
c_c_n = 10.36
n_n = 0.45
c_add = 52.66

c_r_p = 78.8
c_c_p = 58.91
n_r_p = 10.86
n_c_p = 5.88
c_cnf = 160.07

O_v0 = 16.59
W = 0.2 #[0.1,0.2,0.3]
# del_x = #outreached people

m_cl = 1000
m_nc = 1000
f_c = 56379
f_nc = 64851

m_o = 1000
f_o = 50000

R_art_0 = 235 #[117,235,300]
Y = 0.2#[0.1,0.2,0.3]
m_r = 500
f_r = 22708 #[17977,22708,29330]

In [None]:
# I_t_0 = np.apply_over_axes(np.sum, data_array_cluster[:,:,:,1:21],[2,3]).reshape(num_jur, num_risk)
# N_t_0 = np.apply_over_axes(np.sum, data_array_cluster[:,:,:,0:21],[2,3]).reshape(num_jur, num_risk)

# x_t_a_0 = (delta_1*I_t_0*p_unaware_0 - mu*I_t_0*p_unaware_0) / (theta*shi_value(I_t_0,N_t_0))

In [None]:
x_t_a_0 = np.array([[2.33180375e+03, 3.63425551e+03, 3.04112649e+02],
       [3.17936973e+04, 2.75670559e+04, 1.47418143e+03],
       [1.36887484e+04, 1.30437095e+04, 5.32839478e+02],
       [4.90237946e+03, 9.09907366e+03, 5.72640202e+02],
       [1.47264430e+03, 3.48346071e+03, 3.16193026e+02],
       [3.62561698e+03, 7.52180368e+03, 1.63632454e+02],
       [5.32412773e+03, 1.15369335e+04, 4.73913288e+02],
       [8.82399246e+02, 5.86348541e+02, 2.57480493e+02]])

delta_1 = np.array([[0.0180713 , 0.03533054, 0.0314886 ],
       [0.05123306, 0.059546  , 0.03496893],
       [0.04669086, 0.05902862, 0.0339207 ],
       [0.02449952, 0.0553693 , 0.03036947],
       [0.01170278, 0.03310768, 0.03407533],
       [0.01437995, 0.03664002, 0.01971395],
       [0.01977946, 0.05301006, 0.02548554],
       [0.038004  , 0.031486  , 0.06326625]])

p_unaware_0 = np.array([[0.17989 , 0.12448 , 0.167303],
       [0.132271, 0.090542, 0.123574],
       [0.196625, 0.137524, 0.184114],
       [0.196059, 0.138523, 0.185914],
       [0.208705, 0.142397, 0.188764],
       [0.282237, 0.20057 , 0.260866],
       [0.186459, 0.130231, 0.174968],
       [0.078656, 0.05356 , 0.074505]])

p_art_0 = np.array([[0.565534, 0.647454, 0.627884],
       [0.570325, 0.645897, 0.636072],
       [0.536505, 0.620825, 0.599717],
       [0.496017, 0.579814, 0.561497],
       [0.646201, 0.72944 , 0.697714],
       [0.543873, 0.639447, 0.6001  ],
       [0.53376 , 0.616783, 0.597922],
       [0.547729, 0.617069, 0.619356]])

In [None]:
def cost(data_t_1, data_t, unaware_prop, aware_art_vls_prop, diagnosis_rate_risk, dropout_rate_risk, prep_rate):

    cost_risk = np.zeros((num_jur, number_of_risk_groups))

    for risk in range(len(risk_groups)):
        
        total_pop_t_risk = np.apply_over_axes(np.sum, data_t[:,risk,:,0:21],[1,2]).reshape(num_jur,)
        I_t_1 = np.apply_over_axes(np.sum, data_t_1[:,risk,:,1:21],[1,2]).reshape(num_jur,)
        p_t_1_unaware = unaware_prop[:,risk]
        p_t_1_art_vls = aware_art_vls_prop[:,risk]
        delta_t = diagnosis_rate_risk[:,risk]
        dropout_t = dropout_rate_risk[:,risk]
        r_t_a_risk = delta_t*I_t_1*p_t_1_unaware 
        shi = shi_value(I_t_1, total_pop_t_risk)
        x_t_a_risk = (r_t_a_risk - mu*I_t_1*p_t_1_unaware) / (theta*shi)
        n_t_a_risk = mu*(total_pop_t_risk - I_t_1*p_t_1_unaware - x_t_a_risk) + x_t_a_risk*theta*(1-shi) 
        X_v_risk = tau*c_r_n + (1-tau)*c_c_n + n_n + (1-alpha)*c_add
        Y_v_risk = tau*(c_r_p + n_r_p) + (1-tau)*(c_c_p + n_c_p) + c_cnf + (1-alpha)*c_add
        
        del_x = (x_t_a_risk - x_t_a_0[:,risk])/total_pop_t_risk
        
        O_v = O_v0*np.exp(del_x*W)
        
        X_f_cl_a = ((r_t_a_risk + n_t_a_risk)*alpha / m_cl)*f_c
        X_f_ncl_a = ((r_t_a_risk + n_t_a_risk)*(1 - alpha) / m_nc)*f_nc
        X_f_o_a = (x_t_a_risk / m_o)*f_o

        cost_of_testing = r_t_a_risk*Y_v_risk + x_t_a_risk* O_v + n_t_a_risk*X_v_risk + X_f_cl_a + X_f_ncl_a + X_f_o_a

        d_t_a_risk = (1 - dropout_t)*I_t_1*p_t_1_art_vls
        del_p_art = p_t_1_art_vls - p_art_0[:,risk]
        R_v_risk = R_art_0*np.exp(del_p_art*Y)
        E_f_a = (d_t_a_risk / m_r)*f_r

        cost_retention_in_care = d_t_a_risk*R_v_risk + E_f_a

        if risk == 2:

            prep_adherence_per_person_per_year = 1431
            prep_medication_per_person_per_year = 12599

            prep_cost = np.sum(data_t_1[:,risk,:,0], axis=1)*prep_rate[:,risk]* \
                            (prep_adherence_per_person_per_year + prep_medication_per_person_per_year)

        else:
            prep_cost = 0

        cost_risk[:,risk] = cost_of_testing + cost_retention_in_care + prep_cost
        
    return dt*cost_risk

In [None]:
c_l = 54000

cd4_gt_350_index = (1,2,3,4,5,6,7,8,9,10,11,12)
cd4_200_350_index = (13,14,15,16)
cd_lt_200_index = (17,18,19,20)

QALY_val_gt_350 = 0.935
QALY_val_250_350 = 0.818
QALY_val_lt_200 = 0.702

In [None]:
def benefit(data_array):
    L_t_risk = np.zeros((num_jur, number_of_risk_groups))

    for risk in range(len(risk_groups)):

        num_uninfected = np.sum(data_array[:,risk,:,0], axis=1)
        num_over_350 = np.apply_over_axes(np.sum, data_array[:,risk,:,cd4_gt_350_index], [0,2]).reshape(num_jur,)
        num_250_350 = np.apply_over_axes(np.sum, data_array[:,risk,:,cd4_200_350_index], [0,2]).reshape(num_jur,)
        num_below_250 = np.apply_over_axes(np.sum, data_array[:,risk,:,cd_lt_200_index], [0,2]).reshape(num_jur,)

        benefit_risk = 1*num_uninfected + QALY_val_gt_350*num_over_350 + QALY_val_250_350*num_250_350 + QALY_val_lt_200*num_below_250

        L_t_risk[:,risk] = benefit_risk
        
    return c_l*dt*L_t_risk

In [None]:
def get_action(min_, max_, val):
    new = min_ + (max_ - min_) * (val + 1)/ 2 
    return new

In [None]:
def change_action_range(action):
    for i in range(3):
        if action[i] < -1:
            action[i] = -1
            
        elif action[i] > 1:
            action[i] = 1
            
        action[i] = get_action(-0.005, 0, action[i])
        
    for i in range(3,6):
        if action[i] < -1:
            action[i] = -1
            
        elif action[i] > 1:
            action[i] = 1
            
        action[i] = get_action(0, 0.04, action[i])
        
    for i in range(6,8):
        if action[i] < -1:
            action[i] = -1
            
        elif action[i] > 1:
            action[i] = 1
            
        action[i] = get_action(0,0, action[i])
        
    if action[8] < -1:
        action[8] = -1
        
    elif action[8] > 1:
        action[8] = 1
        
    action[8] = get_action(0,0.04, action[8])
    
    action = action.reshape(3,3)
    
    return action

In [None]:
def action_tile(action1, action2, action3, action4, action5, action6, action7, action8):
    
    action = np.zeros((num_jur, 3))
    
    action[cluster1_index,:] = action1
    action[cluster2_index,:] = action2
    action[cluster3_index,:] = action3
    action[cluster4_index,:] = action4
    action[cluster5_index,:] = action5
    action[cluster6_index,:] = action6
    action[cluster7_index,:] = action7
    action[cluster8_index,:] = action8
    
    return action

In [None]:
def step(current_state, action1, action2, action3, action4, action5, action6, action7, action8):
    
    data_array = current_state[0]
    prep_values = current_state[9]
    current_time = current_state[10]

    action1 = change_action_range(action1)
    action2 = change_action_range(action2)
    action3 = change_action_range(action3)
    action4 = change_action_range(action4)
    action5 = change_action_range(action5)
    action6 = change_action_range(action6)
    action7 = change_action_range(action7)
    action8 = change_action_range(action8)
    
    a_unaware1 = action1[0]
    a_art1 = action1[1]
    a_prep1 = action1[2]
    
    a_unaware2 = action2[0]
    a_art2 = action2[1]
    a_prep2 = action2[2]
    
    a_unaware3 = action3[0]
    a_art3 = action3[1]
    a_prep3 = action3[2]
    
    a_unaware4 = action4[0]
    a_art4 = action4[1]
    a_prep4 = action4[2]
    
    a_unaware5 = action5[0]
    a_art5 = action5[1]
    a_prep5 = action5[2]
    
    a_unaware6 = action6[0]
    a_art6 = action6[1]
    a_prep6 = action6[2]
    
    a_unaware7 = action7[0]
    a_art7 = action7[1]
    a_prep7 = action7[2]
    
    a_unaware8 = action8[0]
    a_art8 = action8[1]
    a_prep8 = action8[2]
    
    
    a_unaware = action_tile(a_unaware1,a_unaware2,a_unaware3,a_unaware4,a_unaware5,a_unaware6,a_unaware7,a_unaware8)
    a_art = action_tile(a_art1, a_art2, a_art3, a_art4, a_art5, a_art6, a_art7, a_art8)
    a_prep = action_tile(a_prep1, a_prep2, a_prep3, a_prep4, a_prep5, a_prep6, a_prep7, a_prep8)
    
    #prep
    prep_rate = prep_values + a_prep
    
    pop_susceptible_12_years = data_array[:,:,0,0]
    
    total_reward = 0
    total_inf = 0
    total_cost = 0
    done = False
    
    total_pop, prevalence_prop, unaware_prop, aware_no_art_prop, aware_art_vls_prop,_ = \
        calculate_proportions(data_array, num_jur, number_of_risk_groups, unaware_index, aware_no_care_index, ART_VLS_index, VLS_index)

    new_inf_per_month = new_infections_per_month(num_jur, data_array, new_infections_data, M_x1_y1_i, prep_rate)
    death_per_month_risk_age_compartments = data_array*death_rate_risk_age_compartments*dt

    diagnosis_rate_risk = diagnosis_rate(data_array, num_jur, a_unaware, unaware_index, number_of_risk_groups, new_inf_per_month, unaware_prop, death_per_month_risk_age_compartments)

    dropout_rate_risk = dropout_rate(num_jur, a_art, ART_VLS_index, diagnosis_rate_risk, ltc_risk, gamma, number_of_risk_groups, data_array, new_inf_per_month, unaware_prop, aware_no_art_prop, aware_art_vls_prop, death_per_month_risk_age_compartments)

    Q_matrix = q_matrix(num_jur, new_infections_data, diagnosis_rate_risk, dropout_rate_risk, ltc_risk)

    Q_matrix_diagonal = q_mat_diag(Q_matrix, num_jur)
    
    
    for i in range(12):

        new_data = np.zeros((num_jur, number_of_risk_groups, age_groups, number_of_compartments))

        data_t_1 = data_array.copy()

        for risk in range(number_of_risk_groups):

            #calculate flow of infected to diff compartments and subtract from that compartment
            new_data[:,risk,:,:] = data_array[:,risk,:,:] + \
                                    np.matmul(data_array[:,risk,:,:], Q_matrix[:,risk,:,:]) - \
                                    np.matmul(data_array[:,risk,:,:], Q_matrix_diagonal[:,risk,:,:]) - \
                                    death_per_month_risk_age_compartments[:,risk,:,:]

            #subtract from susceptible and add to acute unaware
            new_data[:,risk,:,0] = new_data[:,risk,:,0] - new_inf_per_month[:,risk,:]
            
            new_data[:,risk,:,1] = new_data[:,risk,:,1] + new_inf_per_month[:,risk,:]

            #add the total deaths to last column
            new_data[:,risk,:,21] = np.sum(death_per_month_risk_age_compartments[:,risk,:,:], axis=2)


        cost_per_month = cost(data_t_1, new_data, unaware_prop, aware_art_vls_prop, diagnosis_rate_risk, dropout_rate_risk, prep_rate)

        benefit_per_month = benefit(new_data)

        reward_per_month = benefit_per_month - cost_per_month

        total_reward += reward_per_month
        
        total_cost += cost_per_month
        
        total_inf += new_inf_per_month

        data_array = new_data.copy()
    
    new_pop_dist = aging(data_array, pop_susceptible_12_years*(1+pop_growth_rate)) # adding new pop
    
    new_state1,new_state2,new_state3,new_state4,new_state5,new_state6,new_state7,new_state8 = extract_state(new_pop_dist, prep_rate)
                         
    next_state = (new_pop_dist,new_state1,new_state2,new_state3,new_state4,new_state5,new_state6,new_state7,new_state8, prep_rate, current_time+1)
    
    reward_cluster1 = total_reward[cluster1_index,:]
    reward_cluster2 = total_reward[cluster2_index,:]
    reward_cluster3 = total_reward[cluster3_index,:]
    reward_cluster4 = total_reward[cluster4_index,:]
    reward_cluster5 = total_reward[cluster5_index,:]
    reward_cluster6 = total_reward[cluster6_index,:]
    reward_cluster7 = total_reward[cluster7_index,:]
    reward_cluster8 = total_reward[cluster8_index,:]
    
    inf_cluster1 = total_inf[cluster1_index,:]
    inf_cluster2 = total_inf[cluster2_index,:]
    inf_cluster3 = total_inf[cluster3_index,:]
    inf_cluster4 = total_inf[cluster4_index,:]
    inf_cluster5 = total_inf[cluster5_index,:]
    inf_cluster6 = total_inf[cluster6_index,:]
    inf_cluster7 = total_inf[cluster7_index,:]
    inf_cluster8 = total_inf[cluster8_index,:]
    
    total_cost1 = total_cost[cluster1_index,:]
    total_cost2 = total_cost[cluster2_index,:]
    total_cost3 = total_cost[cluster3_index,:]
    total_cost4 = total_cost[cluster4_index,:]
    total_cost5 = total_cost[cluster5_index,:]
    total_cost6 = total_cost[cluster6_index,:]
    total_cost7 = total_cost[cluster7_index,:]
    total_cost8 = total_cost[cluster8_index,:]

    reward1 = -np.sum(inf_cluster1) 
    reward2 = -np.sum(inf_cluster2)  
    reward3 = -np.sum(inf_cluster3) 
    reward4 = -np.sum(inf_cluster4) 
    reward5 = -np.sum(inf_cluster5) 
    reward6 = -np.sum(inf_cluster6) 
    reward7 = -np.sum(inf_cluster7) 
    reward8 = -np.sum(inf_cluster8) 
    
    if np.sum(total_cost1) > 2.00e6:
        reward1 -= (np.sum(total_cost1) - 2.00e6)
        
    if np.sum(total_cost2) > 7.89e6:
        reward2 -= (np.sum(total_cost2) - 7.89e6)
        
    if np.sum(total_cost3) > 2.00e6:
        reward3 -= (np.sum(total_cost3) - 2.00e6)
        
    if np.sum(total_cost4) > 1.28e6:
        reward4 -= (np.sum(total_cost4) - 1.28e6)
        
    if np.sum(total_cost5) > 2.00e6:
        reward5 -= (np.sum(total_cost5) - 2.00e6)
        
    if np.sum(total_cost6) > 1.28e6 :
        reward6 -= (np.sum(total_cost6) - 1.28e6)
        
    if np.sum(total_cost7) > 2.56e6:
        reward7 -= (np.sum(total_cost7) - 2.56e6)
        
    if np.sum(total_cost8) > 3.53e7:
        reward8 -= (np.sum(total_cost8) - 3.53e7) 


    if current_time+1 == 12:
        done = True
                         
        
    return next_state, reward1, reward2, reward3, reward4, reward5, reward6,reward7, reward8, done

In [None]:
print("============================================================================================")


# set device to cpu or cuda
device = torch.device('cpu')

if(torch.cuda.is_available()): 
    device = torch.device('cuda:0') 
    torch.cuda.empty_cache()
    print("Device set to : " + str(torch.cuda.get_device_name(device)))
else:
    print("Device set to : cpu")
    
print("============================================================================================")

In [None]:
class RolloutBuffer:
    def __init__(self):
        self.actions = []
        self.states = []
        self.logprobs = []
        self.rewards = []
        self.is_terminals = []
    

    def clear(self):
        del self.actions[:]
        del self.states[:]
        del self.logprobs[:]
        del self.rewards[:]
        del self.is_terminals[:]

In [None]:
class ActorCritic(nn.Module):
    def __init__(self, state_dim, action_dim, has_continuous_action_space, action_std_init):
        super(ActorCritic, self).__init__()

        self.has_continuous_action_space = has_continuous_action_space

        if has_continuous_action_space:
            self.action_dim = action_dim
            self.action_var = torch.full((action_dim,), action_std_init * action_std_init).to(device)

        # actor
        if has_continuous_action_space :
            self.actor = nn.Sequential(
                            nn.Linear(state_dim, 64),
                            nn.ReLU(),
                            nn.Linear(64, 64),
                            nn.ReLU(),
                            nn.Linear(64, action_dim),
                            nn.Tanh()
                        )
        else:
            self.actor = nn.Sequential(
                            nn.Linear(state_dim, 64),
                            nn.Tanh(),
                            nn.Linear(64, 64),
                            nn.Tanh(),
                            nn.Linear(64, action_dim),
                            nn.Softmax(dim=-1)
                        )

        
        # critic
        self.critic = nn.Sequential(
                        nn.Linear(state_dim, 64),
                        nn.ReLU(),
                        nn.Linear(64, 64),
                        nn.ReLU(),
                        nn.Linear(64, 1)
                    )
        
    def set_action_std(self, new_action_std):

        if self.has_continuous_action_space:
            self.action_var = torch.full((self.action_dim,), new_action_std * new_action_std).to(device)
        else:
            print("--------------------------------------------------------------------------------------------")
            print("WARNING : Calling ActorCritic::set_action_std() on discrete action space policy")
            print("--------------------------------------------------------------------------------------------")


    def forward(self):
        raise NotImplementedError
    

    def act(self, state):

        if self.has_continuous_action_space:
            action_mean = self.actor(state)

            cov_mat = torch.diag(self.action_var).unsqueeze(dim=0)
            dist = MultivariateNormal(action_mean, cov_mat)
        else:
            action_probs = self.actor(state)
            dist = Categorical(action_probs)

        action = dist.sample()
        action_logprob = dist.log_prob(action)
        
        return action.detach(), action_logprob.detach()
    

    def evaluate(self, state, action):

        if self.has_continuous_action_space:
            action_mean = self.actor(state)
            action_var = self.action_var.expand_as(action_mean)
            cov_mat = torch.diag_embed(action_var).to(device)
            dist = MultivariateNormal(action_mean, cov_mat)
            
            # for single action continuous environments
            if self.action_dim == 1:
                action = action.reshape(-1, self.action_dim)

        else:
            action_probs = self.actor(state)
            dist = Categorical(action_probs)

        action_logprobs = dist.log_prob(action)
        dist_entropy = dist.entropy()
        state_values = self.critic(state)
        
        return action_logprobs, state_values, dist_entropy

In [None]:
class PPO:
    def __init__(self, state_dim, action_dim, lr_actor, lr_critic, gamma_, K_epochs, eps_clip, has_continuous_action_space, action_std_init=0.6):

        self.has_continuous_action_space = has_continuous_action_space

        if has_continuous_action_space:
            self.action_std = action_std_init

        self.gamma_ = gamma_
        self.eps_clip = eps_clip
        self.K_epochs = K_epochs
        
        self.buffer = RolloutBuffer()

        self.policy = ActorCritic(state_dim, action_dim, has_continuous_action_space, action_std_init).to(device)
        self.optimizer = torch.optim.Adam([
                        {'params': self.policy.actor.parameters(), 'lr': lr_actor},
                        {'params': self.policy.critic.parameters(), 'lr': lr_critic}
                    ])

        self.policy_old = ActorCritic(state_dim, action_dim, has_continuous_action_space, action_std_init).to(device)
        self.policy_old.load_state_dict(self.policy.state_dict())
        
        self.MseLoss = nn.MSELoss()


    def set_action_std(self, new_action_std):
        
        if self.has_continuous_action_space:
            self.action_std = new_action_std
            self.policy.set_action_std(new_action_std)
            self.policy_old.set_action_std(new_action_std)
        
        else:
            print("--------------------------------------------------------------------------------------------")
            print("WARNING : Calling PPO::set_action_std() on discrete action space policy")
            print("--------------------------------------------------------------------------------------------")


    def decay_action_std(self, action_std_decay_rate, min_action_std):
        # print("--------------------------------------------------------------------------------------------")

        if self.has_continuous_action_space:
            self.action_std = self.action_std - action_std_decay_rate
            self.action_std = round(self.action_std, 4)
            if (self.action_std <= min_action_std):
                self.action_std = min_action_std
                # print("setting actor output action_std to min_action_std : ", self.action_std)
            # else:
                # print("setting actor output action_std to : ", self.action_std)
            self.set_action_std(self.action_std)

        else:
            print("WARNING : Calling PPO::decay_action_std() on discrete action space policy")

        # print("--------------------------------------------------------------------------------------------")


    def select_action(self, state):

        if self.has_continuous_action_space:
            with torch.no_grad():
                state = torch.FloatTensor(state).to(device)
                action, action_logprob = self.policy_old.act(state)

            self.buffer.states.append(state)
            self.buffer.actions.append(action)
            self.buffer.logprobs.append(action_logprob)

            return action.detach().cpu().numpy().flatten()

        else:
            with torch.no_grad():
                state = torch.FloatTensor(state).to(device)
                action, action_logprob = self.policy_old.act(state)
            
            self.buffer.states.append(state)
            self.buffer.actions.append(action)
            self.buffer.logprobs.append(action_logprob)

            return action.item()


    def update(self):

        # Monte Carlo estimate of returns
        rewards = []
        discounted_reward = 0
        for reward, is_terminal in zip(reversed(self.buffer.rewards), reversed(self.buffer.is_terminals)):
            if is_terminal:
                discounted_reward = 0
            discounted_reward = reward + (self.gamma_ * discounted_reward)
            rewards.insert(0, discounted_reward)
            
        # Normalizing the rewards
        rewards = torch.tensor(rewards, dtype=torch.float32).to(device)
        rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-7)

        # convert list to tensor
        old_states = torch.squeeze(torch.stack(self.buffer.states, dim=0)).detach().to(device)
        old_actions = torch.squeeze(torch.stack(self.buffer.actions, dim=0)).detach().to(device)
        old_logprobs = torch.squeeze(torch.stack(self.buffer.logprobs, dim=0)).detach().to(device)

        
        # Optimize policy for K epochs
        for _ in range(self.K_epochs):

            # Evaluating old actions and values
            logprobs, state_values, dist_entropy = self.policy.evaluate(old_states, old_actions)

            # match state_values tensor dimensions with rewards tensor
            state_values = torch.squeeze(state_values)

            # Finding the ratio (pi_theta / pi_theta__old)
            ratios = torch.exp(logprobs - old_logprobs.detach())

            # Finding Surrogate Loss
            advantages = rewards - state_values.detach()   
            surr1 = ratios * advantages
            surr2 = torch.clamp(ratios, 1-self.eps_clip, 1+self.eps_clip) * advantages

            loss = -torch.min(surr1, surr2) + 0.5*self.MseLoss(state_values, rewards) - 0.01*dist_entropy

            # take gradient step
            self.optimizer.zero_grad()
            loss.mean().backward()
            self.optimizer.step()
            
        # Copy new weights into old policy
        self.policy_old.load_state_dict(self.policy.state_dict())

        # clear buffer
        self.buffer.clear()
    
    
    def save(self, checkpoint_path):
        torch.save(self.policy_old.state_dict(), checkpoint_path)
   

    def load(self, checkpoint_path):
        self.policy_old.load_state_dict(torch.load(checkpoint_path, map_location=lambda storage, loc: storage))
        self.policy.load_state_dict(torch.load(checkpoint_path, map_location=lambda storage, loc: storage))


In [None]:
################################### Training ###################################

####### initialize environment hyperparameters ######

action_std_decay_rate = 0.0046
min_action_std = 0.05
action_std_decay_freq = 1000




has_continuous_action_space = True

max_ep_len = 12                   # max timesteps in one episode
max_training_timesteps = 100000   # break training loop if timeteps > max_training_timesteps

print_freq = max_ep_len * 10     # print avg reward in the interval (in num timesteps)
log_freq = max_ep_len * 2       # log avg reward in the interval (in num timesteps)
save_model_freq = 2000      # save model frequency (in num timesteps)
plot_freq = 1200

action_std = 0.4

#####################################################


################ PPO hyperparameters ################


update_timestep = 120     # update policy every n timesteps
K_epochs = 20               # update policy for K epochs
eps_clip = 0.2              # clip parameter for PPO
gamma_ = 0.99                # discount factor

lr_actor = 0.0003       # learning rate for actor network
lr_critic = 0.0003       # learning rate for critic network

random_seed = 10   # set random seed if required (0 = no random seed)

#####################################################

env_name = 'HIV Jurisdiction'

print("training environment name : " + env_name)


# state space dimension
state_dim = 15

# action space dimension
if has_continuous_action_space:
    action_dim = 9
else:
    action_dim = 1
    
###################### logging ######################

#### log files for multiple runs are NOT overwritten

log_dir = "PPO_logs"
if not os.path.exists(log_dir):
      os.makedirs(log_dir)

log_dir = log_dir + '/' + env_name + '/'
if not os.path.exists(log_dir):
      os.makedirs(log_dir)


#### get number of log files in log directory
run_num = 0
current_num_files = next(os.walk(log_dir))[2]
run_num = len(current_num_files)


#### create new log file for each run 
log_f_name = log_dir + '/PPO_' + env_name + "_log_" + str(run_num) + ".csv"

print("current logging run number for " + env_name + " : ", run_num)
print("logging at : " + log_f_name)


run_num_pretrained = 0     

directory = "PPO_preTrained"
if not os.path.exists(directory):
      os.makedirs(directory)

directory = directory + '/' + env_name + '/' 
if not os.path.exists(directory):
      os.makedirs(directory)


checkpoint_path1 = directory + 'Cluster1' + "PPO_{}_{}_{}.pth".format(env_name, random_seed, run_num_pretrained)
checkpoint_path2 = directory + 'Cluster2' + "PPO_{}_{}_{}.pth".format(env_name, random_seed, run_num_pretrained)
checkpoint_path3 = directory + 'Cluster3' + "PPO_{}_{}_{}.pth".format(env_name, random_seed, run_num_pretrained)
checkpoint_path4 = directory + 'Cluster4' + "PPO_{}_{}_{}.pth".format(env_name, random_seed, run_num_pretrained)
checkpoint_path5 = directory + 'Cluster5' + "PPO_{}_{}_{}.pth".format(env_name, random_seed, run_num_pretrained)
checkpoint_path6 = directory + 'Cluster6' + "PPO_{}_{}_{}.pth".format(env_name, random_seed, run_num_pretrained)
checkpoint_path7 = directory + 'Cluster7' + "PPO_{}_{}_{}.pth".format(env_name, random_seed, run_num_pretrained)
checkpoint_path8 = directory + 'Cluster8' + "PPO_{}_{}_{}.pth".format(env_name, random_seed, run_num_pretrained)

print("save checkpoint path : " + checkpoint_path1)

############# print all hyperparameters #############

print("--------------------------------------------------------------------------------------------")

print("max training timesteps : ", max_training_timesteps)
print("max timesteps per episode : ", max_ep_len)

print("model saving frequency : " + str(save_model_freq) + " timesteps")
print("log frequency : " + str(log_freq) + " timesteps")
print("printing average reward over episodes in last : " + str(print_freq) + " timesteps")

print("--------------------------------------------------------------------------------------------")

print("state space dimension : ", state_dim)
print("action space dimension : ", action_dim)

print("--------------------------------------------------------------------------------------------")

if has_continuous_action_space:
    print("Initializing a continuous action space policy")
    print("--------------------------------------------------------------------------------------------")
    print("starting std of action distribution : ", action_std)
    print("decay rate of std of action distribution : ", action_std_decay_rate)
    print("minimum std of action distribution : ", min_action_std)
    print("decay frequency of std of action distribution : " + str(action_std_decay_freq) + " timesteps")

else:
    print("Initializing a discrete action space policy")

print("--------------------------------------------------------------------------------------------")

print("PPO update frequency : " + str(update_timestep) + " timesteps") 
print("PPO K epochs : ", K_epochs)
print("PPO epsilon clip : ", eps_clip)
print("discount factor (gamma_) : ", gamma_)

print("--------------------------------------------------------------------------------------------")

print("optimizer learning rate actor : ", lr_actor)
print("optimizer learning rate critic : ", lr_critic)

if random_seed:
    print("--------------------------------------------------------------------------------------------")
    print("setting random seed to ", random_seed)
    torch.manual_seed(random_seed)
    np.random.seed(random_seed)

print("============================================================================================")



In [None]:
################# training procedure ################

# initialize a PPO agent
ppo_agent1 = PPO(state_dim, action_dim, lr_actor, lr_critic, gamma_, K_epochs, eps_clip, has_continuous_action_space, action_std)
ppo_agent2 = PPO(state_dim, action_dim, lr_actor, lr_critic, gamma_, K_epochs, eps_clip, has_continuous_action_space, action_std)
ppo_agent3 = PPO(state_dim, action_dim, lr_actor, lr_critic, gamma_, K_epochs, eps_clip, has_continuous_action_space, action_std)
ppo_agent4 = PPO(state_dim, action_dim, lr_actor, lr_critic, gamma_, K_epochs, eps_clip, has_continuous_action_space, action_std)
ppo_agent5 = PPO(state_dim, action_dim, lr_actor, lr_critic, gamma_, K_epochs, eps_clip, has_continuous_action_space, action_std)
ppo_agent6 = PPO(state_dim, action_dim, lr_actor, lr_critic, gamma_, K_epochs, eps_clip, has_continuous_action_space, action_std)
ppo_agent7 = PPO(state_dim, action_dim, lr_actor, lr_critic, gamma_, K_epochs, eps_clip, has_continuous_action_space, action_std)
ppo_agent8 = PPO(state_dim, action_dim, lr_actor, lr_critic, gamma_, K_epochs, eps_clip, has_continuous_action_space, action_std)

# track total training time
start_time = datetime.now().replace(microsecond=0)
print("Started training at (GMT) : ", start_time)

print("============================================================================================")


# logging file
log_f = open(log_f_name,"w+")
log_f.write('episode,timestep,reward\n')

# rew_list = []
rew_list1 = []
rew_list2 = []
rew_list3 = []
rew_list4 = []
rew_list5 = []
rew_list6 = []
rew_list7 = []
rew_list8 = []

ep_rew_list1 = []
ep_rew_list2 = []
ep_rew_list3 = []
ep_rew_list4 = []
ep_rew_list5 = []
ep_rew_list6 = []
ep_rew_list7 = []
ep_rew_list8 = []

# printing and logging variables
print_running_reward1 = 0
print_running_reward2 = 0
print_running_reward3 = 0
print_running_reward4 = 0
print_running_reward5 = 0
print_running_reward6 = 0
print_running_reward7 = 0
print_running_reward8 = 0

print_running_episodes = 0

log_running_reward = 0
log_running_episodes = 0

time_step = 0
i_episode = 0

while time_step <= max_training_timesteps:
    
    state = initial_state(data_array_cluster,prep_values)

    current_ep_reward1 = 0
    current_ep_reward2 = 0
    current_ep_reward3 = 0
    current_ep_reward4 = 0
    current_ep_reward5 = 0
    current_ep_reward6 = 0
    current_ep_reward7 = 0
    current_ep_reward8 = 0
       
    for t in range(0, max_ep_len+1):
        
        # select action with policy

        action1 = ppo_agent1.select_action(state[1].flatten())
        action2 = ppo_agent2.select_action(state[2].flatten())
        action3 = ppo_agent3.select_action(state[3].flatten())
        action4 = ppo_agent4.select_action(state[4].flatten())
        action5 = ppo_agent5.select_action(state[5].flatten())
        action6 = ppo_agent6.select_action(state[6].flatten())
        action7 = ppo_agent7.select_action(state[7].flatten())
        action8 = ppo_agent8.select_action(state[8].flatten())


        state,reward1,reward2,reward3,reward4,reward5,reward6,reward7,reward8,done  = step(state, action1, action2, action3, action4, action5, action6, action7, action8)

        # saving reward and is_terminals
        ppo_agent1.buffer.rewards.append(reward1)
        ppo_agent1.buffer.is_terminals.append(done)
        
        ppo_agent2.buffer.rewards.append(reward2)
        ppo_agent2.buffer.is_terminals.append(done)
        
        ppo_agent3.buffer.rewards.append(reward3)
        ppo_agent3.buffer.is_terminals.append(done)
        
        ppo_agent4.buffer.rewards.append(reward4)
        ppo_agent4.buffer.is_terminals.append(done)
        
        ppo_agent5.buffer.rewards.append(reward5)
        ppo_agent5.buffer.is_terminals.append(done)
        
        ppo_agent6.buffer.rewards.append(reward6)
        ppo_agent6.buffer.is_terminals.append(done)
        
        ppo_agent7.buffer.rewards.append(reward7)
        ppo_agent7.buffer.is_terminals.append(done)
        
        ppo_agent8.buffer.rewards.append(reward8)
        ppo_agent8.buffer.is_terminals.append(done)

        
        time_step +=1
        current_ep_reward1 += reward1
        current_ep_reward2 += reward2
        current_ep_reward3 += reward3
        current_ep_reward4 += reward4
        current_ep_reward5 += reward5
        current_ep_reward6 += reward6
        current_ep_reward7 += reward7
        current_ep_reward8 += reward8

        
        # update PPO agent
        if time_step % update_timestep == 0:
            ppo_agent1.update()
            ppo_agent2.update()
            ppo_agent3.update()
            ppo_agent4.update()
            ppo_agent5.update()
            ppo_agent6.update()
            ppo_agent7.update()
            ppo_agent8.update()


        if has_continuous_action_space and time_step % action_std_decay_freq == 0:
            ppo_agent1.decay_action_std(action_std_decay_rate, min_action_std)
            ppo_agent2.decay_action_std(action_std_decay_rate, min_action_std)
            ppo_agent3.decay_action_std(action_std_decay_rate, min_action_std)
            ppo_agent4.decay_action_std(action_std_decay_rate, min_action_std)
            ppo_agent5.decay_action_std(action_std_decay_rate, min_action_std)
            ppo_agent6.decay_action_std(action_std_decay_rate, min_action_std)
            ppo_agent7.decay_action_std(action_std_decay_rate, min_action_std)
            ppo_agent8.decay_action_std(action_std_decay_rate, min_action_std)


        # log in logging file
        if time_step % log_freq == 0:

            # log average reward till last episode
            log_avg_reward = log_running_reward / log_running_episodes
            log_avg_reward = round(log_avg_reward, 4)

            log_f.write('{},{},{}\n'.format(i_episode, time_step, log_avg_reward))
            log_f.flush()

            log_running_reward = 0
            log_running_episodes = 0

        # printing average reward
        if time_step % print_freq == 0:

            # print average reward till last episode
            print_avg_reward1 = print_running_reward1 / print_running_episodes
            print_avg_reward1 = round(print_avg_reward1, 2)
            
            print_avg_reward2 = print_running_reward2 / print_running_episodes
            print_avg_reward2 = round(print_avg_reward2, 2)
            
            print_avg_reward3 = print_running_reward3 / print_running_episodes
            print_avg_reward3 = round(print_avg_reward3, 2)

            print_avg_reward4 = print_running_reward4 / print_running_episodes
            print_avg_reward4 = round(print_avg_reward4, 2)
            
            print_avg_reward5 = print_running_reward5 / print_running_episodes
            print_avg_reward5 = round(print_avg_reward5, 2)
            
            print_avg_reward6 = print_running_reward6 / print_running_episodes
            print_avg_reward6 = round(print_avg_reward6, 2)

            print_avg_reward7 = print_running_reward7 / print_running_episodes
            print_avg_reward7 = round(print_avg_reward7, 2)
            
            print_avg_reward8 = print_running_reward8 / print_running_episodes
            print_avg_reward8 = round(print_avg_reward8, 2)
            

            print("Agent1 => Episode : {} \t\t Timestep : {} \t\t Average Reward : {}".format(i_episode, time_step, print_avg_reward1))
            print("Agent2 => Episode : {} \t\t Timestep : {} \t\t Average Reward : {}".format(i_episode, time_step, print_avg_reward2))
            print("Agent3 => Episode : {} \t\t Timestep : {} \t\t Average Reward : {}".format(i_episode, time_step, print_avg_reward3))
            print("Agent4 => Episode : {} \t\t Timestep : {} \t\t Average Reward : {}".format(i_episode, time_step, print_avg_reward4))
            print("Agent5 => Episode : {} \t\t Timestep : {} \t\t Average Reward : {}".format(i_episode, time_step, print_avg_reward5))
            print("Agent6 => Episode : {} \t\t Timestep : {} \t\t Average Reward : {}".format(i_episode, time_step, print_avg_reward6))
            print("Agent7 => Episode : {} \t\t Timestep : {} \t\t Average Reward : {}".format(i_episode, time_step, print_avg_reward7))
            print("Agent8 => Episode : {} \t\t Timestep : {} \t\t Average Reward : {}".format(i_episode, time_step, print_avg_reward8))
           
            rew_list1.append(print_avg_reward1)
            rew_list2.append(print_avg_reward2)
            rew_list3.append(print_avg_reward3)
            rew_list4.append(print_avg_reward4)
            rew_list5.append(print_avg_reward5)
            rew_list6.append(print_avg_reward6)
            rew_list7.append(print_avg_reward7)
            rew_list8.append(print_avg_reward8)
           
            print_running_reward1 = 0
            print_running_reward2 = 0
            print_running_reward3 = 0
            print_running_reward4 = 0
            print_running_reward5 = 0
            print_running_reward6 = 0
            print_running_reward7 = 0
            print_running_reward8 = 0
  

            print_running_episodes = 0
            
        if time_step % plot_freq == 0:
            fig,ax = plt.subplots(2,4,sharex=False, sharey=False, figsize=(20,7))
            fig.tight_layout(h_pad=3, w_pad=1)
            ax[0][0].plot(range(len(rew_list1)), rew_list1)
            ax[0][0].set_title('Rewards for Cluster 1',pad=12)
            ax[0][1].plot(range(len(rew_list2)), rew_list2)
            ax[0][1].set_title('Rewards for Cluster 2',pad=12)
            ax[0][2].plot(range(len(rew_list3)), rew_list3)
            ax[0][2].set_title('Rewards for Cluster 3',pad=12)
            ax[0][3].plot(range(len(rew_list4)), rew_list4)
            ax[0][3].set_title('Rewards for Cluster 4',pad=12)
            ax[1][0].plot(range(len(rew_list5)), rew_list5)
            ax[1][0].set_title('Rewards for Cluster 5',pad=12)
            ax[1][1].plot(range(len(rew_list6)), rew_list6)
            ax[1][1].set_title('Rewards for Cluster 6',pad=12)
            ax[1][2].plot(range(len(rew_list7)), rew_list7)
            ax[1][2].set_title('Rewards for Cluster 7',pad=12)
            ax[1][3].plot(range(len(rew_list8)), rew_list8)
            ax[1][3].set_title('Rewards for Cluster 8',pad=12)
            plt.show()
            
            
        # save model weights
        if time_step % save_model_freq == 0:
            print("--------------------------------------------------------------------------------------------")
            print("saving model at : " + checkpoint_path1)

            ppo_agent1.save(checkpoint_path1)
            ppo_agent2.save(checkpoint_path2)
            ppo_agent3.save(checkpoint_path3)
            ppo_agent4.save(checkpoint_path4)
            ppo_agent5.save(checkpoint_path5)
            ppo_agent6.save(checkpoint_path6)
            ppo_agent7.save(checkpoint_path7)
            ppo_agent8.save(checkpoint_path8)

            print("model saved")
            print("Elapsed Time  : ", datetime.now().replace(microsecond=0) - start_time)
            print("--------------------------------------------------------------------------------------------")
            
        
        if done:
            break

    print_running_reward1 += current_ep_reward1
    print_running_reward2 += current_ep_reward2
    print_running_reward3 += current_ep_reward3
    print_running_reward4 += current_ep_reward4
    print_running_reward5 += current_ep_reward5
    print_running_reward6 += current_ep_reward6
    print_running_reward7 += current_ep_reward7
    print_running_reward8 += current_ep_reward8
    
    ep_rew_list1.append(current_ep_reward1)
    ep_rew_list2.append(current_ep_reward2)
    ep_rew_list3.append(current_ep_reward3)
    ep_rew_list4.append(current_ep_reward4)
    ep_rew_list5.append(current_ep_reward5)
    ep_rew_list6.append(current_ep_reward6)
    ep_rew_list7.append(current_ep_reward7)
    ep_rew_list8.append(current_ep_reward8)

    print_running_episodes += 1

    log_running_reward += current_ep_reward1
    log_running_episodes += 1

    i_episode += 1


log_f.close()

In [None]:
def test_step(current_state, action1, action2, action3, action4, action5, action6, action7, action8):
    
    data_array = current_state[0]
    prep_values = current_state[9]
    current_time = current_state[10]

    action1 = change_action_range(action1)
    action2 = change_action_range(action2)
    action3 = change_action_range(action3)
    action4 = change_action_range(action4)
    action5 = change_action_range(action5)
    action6 = change_action_range(action6)
    action7 = change_action_range(action7)
    action8 = change_action_range(action8)
    
    a_unaware1 = action1[0]
    a_art1 = action1[1]
    a_prep1 = action1[2]
    
    a_unaware2 = action2[0]
    a_art2 = action2[1]
    a_prep2 = action2[2]
    
    a_unaware3 = action3[0]
    a_art3 = action3[1]
    a_prep3 = action3[2]
    
    a_unaware4 = action4[0]
    a_art4 = action4[1]
    a_prep4 = action4[2]
    
    a_unaware5 = action5[0]
    a_art5 = action5[1]
    a_prep5 = action5[2]
    
    a_unaware6 = action6[0]
    a_art6 = action6[1]
    a_prep6 = action6[2]
    
    a_unaware7 = action7[0]
    a_art7 = action7[1]
    a_prep7 = action7[2]
    
    a_unaware8 = action8[0]
    a_art8 = action8[1]
    a_prep8 = action8[2]
    
    a_unaware = action_tile(a_unaware1,a_unaware2,a_unaware3,a_unaware4,a_unaware5,a_unaware6,a_unaware7,a_unaware8)
    a_art = action_tile(a_art1, a_art2, a_art3, a_art4, a_art5, a_art6, a_art7, a_art8)
    a_prep = action_tile(a_prep1, a_prep2, a_prep3, a_prep4, a_prep5, a_prep6, a_prep7, a_prep8)
    
    #prep
    prep_rate = prep_values + a_prep
    
    pop_susceptible_12_years = data_array[:,:,0,0]
    
    total_reward = 0
    total_cost = 0
    total_inf = 0
    done = False
    
    
    total_pop, prevalence_prop, unaware_prop, aware_no_art_prop, aware_art_vls_prop,_ = \
        calculate_proportions(data_array, num_jur, number_of_risk_groups, unaware_index, aware_no_care_index, ART_VLS_index, VLS_index)

    new_inf_per_month = new_infections_per_month(num_jur, data_array, new_infections_data, M_x1_y1_i, prep_rate)
    death_per_month_risk_age_compartments = data_array*death_rate_risk_age_compartments*dt

    diagnosis_rate_risk = diagnosis_rate(data_array, num_jur, a_unaware, unaware_index, number_of_risk_groups, new_inf_per_month, unaware_prop, death_per_month_risk_age_compartments)

    dropout_rate_risk = dropout_rate(num_jur, a_art, ART_VLS_index, diagnosis_rate_risk, ltc_risk, gamma, number_of_risk_groups, data_array, new_inf_per_month, unaware_prop, aware_no_art_prop, aware_art_vls_prop, death_per_month_risk_age_compartments)

    Q_matrix = q_matrix(num_jur, new_infections_data, diagnosis_rate_risk, dropout_rate_risk, ltc_risk)

    Q_matrix_diagonal = q_mat_diag(Q_matrix, num_jur)
    
    
    for i in range(12):

        new_data = np.zeros((num_jur, number_of_risk_groups, age_groups, number_of_compartments))

        data_t_1 = data_array.copy()

        for risk in range(number_of_risk_groups):

            #calculate flow of infected to diff compartments and subtract from that compartment
            new_data[:,risk,:,:] = data_array[:,risk,:,:] + \
                                    np.matmul(data_array[:,risk,:,:], Q_matrix[:,risk,:,:]) - \
                                    np.matmul(data_array[:,risk,:,:], Q_matrix_diagonal[:,risk,:,:]) - \
                                    death_per_month_risk_age_compartments[:,risk,:,:]

            #subtract from susceptible and add to acute unaware
            new_data[:,risk,:,0] = new_data[:,risk,:,0] - new_inf_per_month[:,risk,:]
            
            new_data[:,risk,:,1] = new_data[:,risk,:,1] + new_inf_per_month[:,risk,:]

            #add the total deaths to last column
            new_data[:,risk,:,21] = np.sum(death_per_month_risk_age_compartments[:,risk,:,:], axis=2)


        cost_per_month = cost(data_t_1, new_data, unaware_prop, aware_art_vls_prop, diagnosis_rate_risk, dropout_rate_risk, prep_rate)

        benefit_per_month = benefit(new_data)

        reward_per_month = benefit_per_month - cost_per_month

        total_reward += reward_per_month
        
        total_cost += cost_per_month
        
        total_inf += new_inf_per_month

        data_array = new_data.copy()
    
    new_pop_dist = aging(data_array, pop_susceptible_12_years*(1+pop_growth_rate)) # adding new pop
    
    new_state1,new_state2,new_state3,new_state4,new_state5,new_state6,new_state7,new_state8 = extract_state(new_pop_dist, prep_rate)
                         
    next_state = (new_pop_dist,new_state1,new_state2,new_state3,new_state4,new_state5,new_state6,new_state7,new_state8, prep_rate, current_time+1)
    
#     total_inf = 12*new_inf_per_month
    total_inf = np.apply_over_axes(np.sum, total_inf, [2])
    
    reward_cluster1 = total_reward[cluster1_index,:]
    reward_cluster2 = total_reward[cluster2_index,:]
    reward_cluster3 = total_reward[cluster3_index,:]
    reward_cluster4 = total_reward[cluster4_index,:]
    reward_cluster5 = total_reward[cluster5_index,:]
    reward_cluster6 = total_reward[cluster6_index,:]
    reward_cluster7 = total_reward[cluster7_index,:]
    reward_cluster8 = total_reward[cluster8_index,:]
    
    inf_cluster1 = total_inf[cluster1_index,:]
    inf_cluster2 = total_inf[cluster2_index,:]
    inf_cluster3 = total_inf[cluster3_index,:]
    inf_cluster4 = total_inf[cluster4_index,:]
    inf_cluster5 = total_inf[cluster5_index,:]
    inf_cluster6 = total_inf[cluster6_index,:]
    inf_cluster7 = total_inf[cluster7_index,:]
    inf_cluster8 = total_inf[cluster8_index,:]
    
    total_cost1 = total_cost[cluster1_index,:]
    total_cost2 = total_cost[cluster2_index,:]
    total_cost3 = total_cost[cluster3_index,:]
    total_cost4 = total_cost[cluster4_index,:]
    total_cost5 = total_cost[cluster5_index,:]
    total_cost6 = total_cost[cluster6_index,:]
    total_cost7 = total_cost[cluster7_index,:]
    total_cost8 = total_cost[cluster8_index,:]

    reward1 = -np.sum(inf_cluster1) 
    reward2 = -np.sum(inf_cluster2)  
    reward3 = -np.sum(inf_cluster3) 
    reward4 = -np.sum(inf_cluster4) 
    reward5 = -np.sum(inf_cluster5) 
    reward6 = -np.sum(inf_cluster6) 
    reward7 = -np.sum(inf_cluster7) 
    reward8 = -np.sum(inf_cluster8) 
    
    if np.sum(total_cost1) > 2.00e6:
        reward1 -= (np.sum(total_cost1) - 2.00e6)
        
    if np.sum(total_cost2) > 7.89e6:
        reward2 -= (np.sum(total_cost2) - 7.89e6)
        
    if np.sum(total_cost3) > 2.00e6:
        reward3 -= (np.sum(total_cost3) - 2.00e6)
        
    if np.sum(total_cost4) > 1.28e6:
        reward4 -= (np.sum(total_cost4) - 1.28e6)
        
    if np.sum(total_cost5) > 2.00e6:
        reward5 -= (np.sum(total_cost5) - 2.00e6)
        
    if np.sum(total_cost6) > 1.28e6 :
        reward6 -= (np.sum(total_cost6) - 1.28e6)
        
    if np.sum(total_cost7) > 2.56e6:
        reward7 -= (np.sum(total_cost7) - 2.56e6)
        
    if np.sum(total_cost8) > 3.53e7:
        reward8 -= (np.sum(total_cost8) - 3.53e7) 

    
    total_cost = np.sum(total_cost, axis=1)

    if current_time+1 == 12:
        done = True
    
    t_p, p_p, u_p, a_no_art_p, a_art_vls_p,vls_p = \
        calculate_proportions(data_array, num_jur, number_of_risk_groups, unaware_index, aware_no_care_index, ART_VLS_index, VLS_index)

        
    return next_state,reward1,reward2,reward3,reward4,reward5,reward6,reward7,reward8,done,total_inf,diagnosis_rate_risk,dropout_rate_risk, u_p, a_art_vls_p,vls_p,prep_rate, total_cost

In [None]:
print("============================================================================================")

env_name = "HIV Jurisdiction"
has_continuous_action_space = True
max_ep_len = 1000           # max timesteps in one episode
action_std = 0.05           

total_test_episodes = 10    # total num of testing episodes

K_epochs = 20               # update policy for K epochs
eps_clip = 0.2              # clip parameter for PPO
gamma_ = 0.99                # discount factor

lr_actor = 0.0003           # learning rate for actor
lr_critic = 0.0003           # learning rate for critic

#####################################################

# state space dimension
state_dim = 15

# action space dimension
if has_continuous_action_space:
    action_dim = 9
else:
    action_dim = env.action_space.n

# initialize a PPO agent
ppo_agent1 = PPO(state_dim, action_dim, lr_actor, lr_critic, gamma_, K_epochs, eps_clip, has_continuous_action_space, action_std)
ppo_agent2 = PPO(state_dim, action_dim, lr_actor, lr_critic, gamma_, K_epochs, eps_clip, has_continuous_action_space, action_std)
ppo_agent3 = PPO(state_dim, action_dim, lr_actor, lr_critic, gamma_, K_epochs, eps_clip, has_continuous_action_space, action_std)
ppo_agent4 = PPO(state_dim, action_dim, lr_actor, lr_critic, gamma_, K_epochs, eps_clip, has_continuous_action_space, action_std)
ppo_agent5 = PPO(state_dim, action_dim, lr_actor, lr_critic, gamma_, K_epochs, eps_clip, has_continuous_action_space, action_std)
ppo_agent6 = PPO(state_dim, action_dim, lr_actor, lr_critic, gamma_, K_epochs, eps_clip, has_continuous_action_space, action_std)
ppo_agent7 = PPO(state_dim, action_dim, lr_actor, lr_critic, gamma_, K_epochs, eps_clip, has_continuous_action_space, action_std)
ppo_agent8 = PPO(state_dim, action_dim, lr_actor, lr_critic, gamma_, K_epochs, eps_clip, has_continuous_action_space, action_std)

# preTrained weights directory

random_seed = 10           #### set this to load a particular checkpoint trained on random seed
if random_seed:
#     print("--------------------------------------------------------------------------------------------")
#     print("setting random seed to ", random_seed)
    torch.manual_seed(random_seed)
#     env.seed(random_seed)
    np.random.seed(random_seed)


run_num_pretrained = 0      #### set this to load a particular checkpoint num

directory = "PPO_preTrained" + '/' + env_name + '/' 
# directory = "PPO_preTrained" + '/' + env_name + '/'
# checkpoint_path = directory + "PPO_{}_{}_{}.pth".format(env_name, random_seed, run_num_pretrained)
checkpoint_path1 = directory + 'Cluster1' + "PPO_{}_{}_{}.pth".format(env_name, random_seed, run_num_pretrained)
checkpoint_path2 = directory + 'Cluster2' + "PPO_{}_{}_{}.pth".format(env_name, random_seed, run_num_pretrained)
checkpoint_path3 = directory + 'Cluster3' + "PPO_{}_{}_{}.pth".format(env_name, random_seed, run_num_pretrained)
checkpoint_path4 = directory + 'Cluster4' + "PPO_{}_{}_{}.pth".format(env_name, random_seed, run_num_pretrained)
checkpoint_path5 = directory + 'Cluster5' + "PPO_{}_{}_{}.pth".format(env_name, random_seed, run_num_pretrained)
checkpoint_path6 = directory + 'Cluster6' + "PPO_{}_{}_{}.pth".format(env_name, random_seed, run_num_pretrained)
checkpoint_path7 = directory + 'Cluster7' + "PPO_{}_{}_{}.pth".format(env_name, random_seed, run_num_pretrained)
checkpoint_path8 = directory + 'Cluster8' + "PPO_{}_{}_{}.pth".format(env_name, random_seed, run_num_pretrained)

print("loading network from : " + checkpoint_path1)
ppo_agent1.load(checkpoint_path1)

print("loading network from : " + checkpoint_path2)
ppo_agent2.load(checkpoint_path2)

print("loading network from : " + checkpoint_path3)
ppo_agent3.load(checkpoint_path3)

print("loading network from : " + checkpoint_path4)
ppo_agent4.load(checkpoint_path4)

print("loading network from : " + checkpoint_path5)
ppo_agent5.load(checkpoint_path5)

print("loading network from : " + checkpoint_path6)
ppo_agent6.load(checkpoint_path6)

print("loading network from : " + checkpoint_path7)
ppo_agent7.load(checkpoint_path7)

print("loading network from : " + checkpoint_path8)
ppo_agent8.load(checkpoint_path8)

print("--------------------------------------------------------------------------------------------")

test_running_reward = 0
reward_list = []

diag_list = []
drop_list = []
prep_list = []
inf_list = []
cost_list = []
unaware_list = []
art_vls_list = []
vls_list = []

episode_reward_list1 = []
episode_reward_list2 = []
episode_reward_list3 = []
episode_reward_list4 = []
episode_reward_list5 = []
episode_reward_list6 = []
episode_reward_list7 = []
episode_reward_list8 = []


for ep in range(1):
    ep_reward = 0
    
    state = initial_state(data_array_cluster,prep_values)

    current_ep_reward1 = 0
    current_ep_reward2 = 0
    current_ep_reward3 = 0
    current_ep_reward4 = 0
    current_ep_reward5 = 0
    current_ep_reward6 = 0
    current_ep_reward7 = 0
    current_ep_reward8 = 0


    for t in range(1, max_ep_len+1):
        
        action1 = ppo_agent1.select_action(state[1].flatten())
        action2 = ppo_agent2.select_action(state[2].flatten())
        action3 = ppo_agent3.select_action(state[3].flatten())
        action4 = ppo_agent4.select_action(state[4].flatten())
        action5 = ppo_agent5.select_action(state[5].flatten())
        action6 = ppo_agent6.select_action(state[6].flatten())
        action7 = ppo_agent7.select_action(state[7].flatten())
        action8 = ppo_agent8.select_action(state[8].flatten())

#         print(change_action_range(action1))
        state,reward1,reward2,reward3,reward4,reward5,reward6,reward7,reward8,done,total_inf,diag_rate,drop_rate,un_prop,art_prop,vls_prop,prep_rate,total_cost =test_step(state, action1, action2, action3, action4, action5, action6, action7, action8)
        
        diag_list.append(diag_rate)
        drop_list.append(drop_rate)
        prep_list.append(prep_rate)
        inf_list.append(total_inf)
        cost_list.append(total_cost)        
        unaware_list.append(un_prop)
        art_vls_list.append(art_prop)
        vls_list.append(vls_prop)

        current_ep_reward1 += reward1
        current_ep_reward2 += reward2
        current_ep_reward3 += reward3
        current_ep_reward4 += reward4
        current_ep_reward5 += reward5
        current_ep_reward6 += reward6
        current_ep_reward7 += reward7
        current_ep_reward8 += reward8

#         ltc_risk += ltc_increment

        if done:
            break

    # clear buffer

    ppo_agent1.buffer.clear()
    ppo_agent2.buffer.clear()
    ppo_agent3.buffer.clear()
    ppo_agent4.buffer.clear()
    ppo_agent5.buffer.clear()
    ppo_agent6.buffer.clear()
    ppo_agent7.buffer.clear()
    ppo_agent8.buffer.clear()

#     reward_list.append(ep_reward)

    episode_reward_list1.append(current_ep_reward1)
    episode_reward_list2.append(current_ep_reward2)
    episode_reward_list3.append(current_ep_reward3)
    episode_reward_list4.append(current_ep_reward4)
    episode_reward_list5.append(current_ep_reward5)
    episode_reward_list6.append(current_ep_reward6)
    episode_reward_list7.append(current_ep_reward7)
    episode_reward_list8.append(current_ep_reward8)

    test_running_reward +=  current_ep_reward1
    print('Episode: {} \t\t Reward: {}'.format(ep, round(current_ep_reward1, 2)))
    print('Episode: {} \t\t Reward: {}'.format(ep, round(current_ep_reward2, 2)))
    print('Episode: {} \t\t Reward: {}'.format(ep, round(current_ep_reward3, 2)))
    print('Episode: {} \t\t Reward: {}'.format(ep, round(current_ep_reward4, 2)))
    print('Episode: {} \t\t Reward: {}'.format(ep, round(current_ep_reward5, 2)))
    print('Episode: {} \t\t Reward: {}'.format(ep, round(current_ep_reward6, 2)))
    print('Episode: {} \t\t Reward: {}'.format(ep, round(current_ep_reward7, 2)))
    print('Episode: {} \t\t Reward: {}'.format(ep, round(current_ep_reward8, 2)))


print("============================================================================================")

In [None]:
new_inf_dict = {}

for i in range(8):
    lis1 = []
    lis2 = []
    lis3 = []
    
    for j in range(12):
        lis1.append(inf_list[j][i][0][0])
        lis2.append(inf_list[j][i][1][0])
        lis3.append(inf_list[j][i][2][0])
                    
    new_inf_dict[jur_name[i]]=[lis1,lis2,lis3]

In [None]:
total_inf_dict = {}

for jur in jur_name:
    a = new_inf_dict[jur][0]
    b = new_inf_dict[jur][1]
    c = new_inf_dict[jur][2]
    
    lis = list(np.array(a)+np.array(b)+np.array(c))
    
    total_inf_dict[jur] = lis