In [1]:
using DataFrames, Distributions, CSV, Turing, Optim, StatsBase, StatsFuns, Random

Random.seed!(1234)

# Read the vitamin D dataset
data = CSV.read("../../data/vitd/nhanes.csv", DataFrame)
data[!, :"RIAGENDR"] = data[!, :"RIAGENDR"] .- 1
#data[!, :"DPQ020"] = data[!, :"DPQ020"] .+ 1
data[!, :"DPQ020"] = data[!, :"DPQ020"]

# Look at the first few rows of the data
first(data, 5)

Row,BMXBMI,RIAGENDR,RIDAGEYR,LBXVIDMS,INDFMMPI,DPQ020,SMQ040
Unnamed: 0_level_1,Float64,Float64,Float64,Float64,Float64,Float64,Float64
1,27.8,0.0,62.0,76.1,4.14,0.0,3.0
2,30.8,0.0,53.0,56.5,0.0,0.0,1.0
3,28.8,0.0,78.0,87.5,1.81,0.0,3.0
4,28.0,0.0,22.0,47.2,2.98,0.0,2.0
5,27.6,0.0,46.0,44.5,1.73,0.0,1.0


In [2]:
#= Convert categorical data to integers =#

data[!, :"RIAGENDR"] = convert(Array{Int}, data[!, :"RIAGENDR"])
data[!, :"DPQ020"] = convert(Array{Int}, data[!, "DPQ020"])
data[!, :"SMQ040"] = convert(Array{Int}, data[!, :"SMQ040"]);

# Get variables
bmi = data[!, :"BMXBMI"]
gender = data[!, :"RIAGENDR"]
age = data[!, :"RIDAGEYR"]
vitd = data[!, :"LBXVIDMS"]
poverty = data[!, :"INDFMMPI"]
smoking = data[!, :"SMQ040"]
replace!(data.:"DPQ020", 0 => 0, 1 => 0, 2 => 1, 3 => 1)
depression = data[!, :"DPQ020"];

In [3]:
countmap(depression)

Dict{Int64, Int64} with 2 entries:
  0 => 1453
  1 => 175

In [4]:
poverty_array = convert(Vector{Float64}, poverty);

In [5]:
# Turing Model OPTIM

@model function depression_model(bmi, gender, age, vitd, poverty, depression, smoking)
    # Define the length 
    n = min(length(bmi), length(gender), length(age), length(vitd), length(poverty), length(depression), length(smoking))

    # Define the priors
    μ_bmi = 3.859643431931376
    β_age_bmi = -0.021245195124222155
    β_smoking_bmi = 0.2608648954109639
    σ_bmi = 3.598419252059555e-7
    α_age_1 = 31.76691197040743
    α_age_2 = 51.5380975623051
    α_vitd = 4.11474524081595
    β_smoking_vitd = -0.19495145263217273
    β_age_vitd = -0.13587552554192822
    β_gender_vitd = -0.25606847088071794
    β_bmi_vitd = 0.3318863296959625
    σ_vitd = 1.2299204116924695e-11
    α_depression = -0.008191236513563314
    β_vitd_depression = -0.109039388026058
    β_bmi_depression = 0.2817057374804185
    β_age_depression_1 = 0.14305384978470523
    β_age_depression_2 = 0.1940741920532467
    β_age_depression_3 = 0.011078412341716712
    β_poverty_depression = -0.03931861435928967
    β_gender_depression = 0.18639727014944052
    β_smoking_depression = -0.17446385090257444    
    
    for i in 1:n
        dist1_age = Normal(α_age_1, 7.2)
        dist2_age = Normal(α_age_2, 8.7)
        dist1_pov = Normal(1, 2)
        dist2_pov = Normal(5, 0.001)
        gender[i] ~ Bernoulli(0.4)
        age[i] ~ MixtureModel([dist1_age, dist2_age], [0.4, 0.6])
        smoking[i] ~ Categorical([0.35, 0.12, 0.53]) 
        poverty[i] ~ MixtureModel([dist1_pov, dist2_pov], [0.5, 0.5])

        bmi[i] ~ LogNormal(μ_bmi + β_age_bmi * age[i] + β_smoking_bmi * smoking[i], σ_bmi)

        vitd[i] ~ LogNormal(α_vitd + β_smoking_vitd * smoking[i] + β_age_vitd * age[i]
        + β_gender_vitd * gender[i] + β_bmi_vitd * bmi[i], σ_vitd)

        if age[i] < 30
            linear_predictor = α_depression + vitd[i] * β_vitd_depression +
            bmi[i] * β_bmi_depression + age[i] * β_age_depression_1 +
            poverty[i] * β_poverty_depression + gender[i] * β_gender_depression +
           smoking[i] * β_smoking_depression

        elseif age[i] >= 30 && age[i] < 50

            linear_predictor = α_depression + vitd[i] * β_vitd_depression +
            bmi[i] * β_bmi_depression + age[i] * β_age_depression_2 +
            poverty[i] * β_poverty_depression + gender[i] * β_gender_depression +
           smoking[i] * β_smoking_depression
        
        else

            linear_predictor = α_depression + vitd[i] * β_vitd_depression +
            bmi[i] * β_bmi_depression + age[i] * β_age_depression_3 +
            poverty[i] * β_poverty_depression + gender[i] * β_gender_depression +
           smoking[i] * β_smoking_depression
        end
        
        #depression[i] ~ Categorical(softmax([0, linear_predictor, linear_predictor + θ_1, linear_predictor + θ_2]))
        p_depression = logistic(linear_predictor)
        depression[i] ~ Bernoulli(p_depression)

    end
end


depression_model (generic function with 2 methods)

In [6]:
model_depr = depression_model(bmi, gender, age, vitd, poverty_array, repeat([missing], length(bmi)), smoking)
chain_depr = sample(model_depr, PG(10), MCMCThreads(), 1000, 6)

[32mSampling (6 threads)   0%|                              |  ETA: N/A[39m
[32mSampling (6 threads)  17%|█████                         |  ETA: 5:50:02[39m
[32mSampling (6 threads)  33%|██████████                    |  ETA: 2:20:20[39m
[32mSampling (6 threads)  50%|███████████████               |  ETA: 1:10:10[39m
[32mSampling (6 threads)  67%|████████████████████          |  ETA: 0:35:05[39m
[32mSampling (6 threads)  83%|█████████████████████████     |  ETA: 0:14:02[39m
[32mSampling (6 threads) 100%|██████████████████████████████| Time: 3:08:15[39m
[90mSampling (6 threads) 100%|██████████████████████████████| Time: 3:08:15[39m


Chains MCMC chain (1000×1630×6 Array{Float64, 3}):

Iterations        = 1:1:1000
Number of chains  = 6
Samples per chain = 1000
Wall duration     = 11283.11 seconds
Compute duration  = 32222.32 seconds
parameters        = depression[1], depression[2], depression[3], depression[4], depression[5], depression[6], depression[7], depression[8], depression[9], depression[10], depression[11], depression[12], depression[13], depression[14], depression[15], depression[16], depression[17], depression[18], depression[19], depression[20], depression[21], depression[22], depression[23], depression[24], depression[25], depression[26], depression[27], depression[28], depression[29], depression[30], depression[31], depression[32], depression[33], depression[34], depression[35], depression[36], depression[37], depression[38], depression[39], depression[40], depression[41], depression[42], depression[43], depression[44], depression[45], depression[46], depression[47], depression[48], depression[49], dep

In [41]:
depression_days = Array{Float64}(undef, length(depression))
depression_days_adjusted = Array{Int}(undef, length(depression))


for i in 1:length(depression)
    depression_days[i] = mean(chain_depr["depression[$i]"])
end

cutoff = quantile(depression_days, 0.88)

for i in 1:length(depression_days)
    item = depression_days[i]
    if item == cutoff
        item = 1
    else
        item = 0
    end
    depression_days_adjusted[i] = item
end

countmap(depression_days_adjusted)

Dict{Int64, Int64} with 2 entries:
  0 => 1324
  1 => 304

In [8]:
total = 0
for i in 1:length(depression)
    is_correct = depression_days_adjusted[i] == depression[i]
    total += is_correct
end
good_rat = total/ length(depression)

0.8925061425061425

In [9]:
intervention_data = filter(row -> (row."DPQ020" == 1) && row."LBXVIDMS" < quantile(data."LBXVIDMS", 0.2), data);

In [10]:
#= Convert categorical data to integers =#

intervention_data[!, :"RIAGENDR"] = convert(Array{Int}, intervention_data[!, :"RIAGENDR"])
intervention_data[!, :"DPQ020"] = convert(Array{Int}, intervention_data[!, :"DPQ020"])
intervention_data[!, :"SMQ040"] = convert(Array{Int}, intervention_data[!, :"SMQ040"])

# Get variables
bmi_interv = intervention_data[!, :"BMXBMI"]
gender_interv = intervention_data[!, :"RIAGENDR"]
age_interv = intervention_data[!, :"RIDAGEYR"]
vitd_interv = intervention_data[!, :"LBXVIDMS"]
poverty_interv = intervention_data[!, :"INDFMMPI"]
depression_interv = intervention_data[!, :"DPQ020"]
smoking_interv = intervention_data[!, :"SMQ040"]
poverty_array_interv = convert(Vector{Float64}, poverty_interv);

In [11]:
model_depr_interv = depression_model(bmi_interv, gender_interv, age_interv, repeat([quantile(vitd, 0.8)], length(bmi_interv)), poverty_array_interv, repeat([missing], length(bmi_interv)), smoking_interv)
chain_depr_interv = sample(model_depr_interv, PG(10), 1000)

[32mSampling   0%|                                          |  ETA: N/A[39m
[32mSampling   0%|▎                                         |  ETA: 0:05:06[39m
[32mSampling   1%|▍                                         |  ETA: 0:04:14[39m
[32mSampling   2%|▋                                         |  ETA: 0:03:56[39m
[32mSampling   2%|▉                                         |  ETA: 0:03:47[39m
[32mSampling   2%|█                                         |  ETA: 0:03:41[39m
[32mSampling   3%|█▎                                        |  ETA: 0:03:36[39m
[32mSampling   4%|█▌                                        |  ETA: 0:03:33[39m
[32mSampling   4%|█▋                                        |  ETA: 0:03:30[39m
[32mSampling   4%|█▉                                        |  ETA: 0:03:28[39m
[32mSampling   5%|██▏                                       |  ETA: 0:03:26[39m
[32mSampling   6%|██▎                                       |  ETA: 0:03:24[39m
[32mSampling   6%|█

Chains MCMC chain (1000×39×1 Array{Float64, 3}):

Log evidence      = -2.704418444328565e24
Iterations        = 1:1:1000
Number of chains  = 1
Samples per chain = 1000
Wall duration     = 207.41 seconds
Compute duration  = 207.41 seconds
parameters        = depression[1], depression[2], depression[3], depression[4], depression[5], depression[6], depression[7], depression[8], depression[9], depression[10], depression[11], depression[12], depression[13], depression[14], depression[15], depression[16], depression[17], depression[18], depression[19], depression[20], depression[21], depression[22], depression[23], depression[24], depression[25], depression[26], depression[27], depression[28], depression[29], depression[30], depression[31], depression[32], depression[33], depression[34], depression[35], depression[36], depression[37]
internals         = lp, logevidence

Summary Statistics
 [1m     parameters [0m [1m    mean [0m [1m     std [0m [1m naive_se [0m [1m    mcse [0m [1m 

In [15]:
first(data, 5)

Row,BMXBMI,RIAGENDR,RIDAGEYR,LBXVIDMS,INDFMMPI,DPQ020,SMQ040
Unnamed: 0_level_1,Float64,Int64,Float64,Float64,Float64,Int64,Int64
1,27.8,0,62.0,76.1,4.14,0,3
2,30.8,0,53.0,56.5,0.0,0,1
3,28.8,0,78.0,87.5,1.81,0,3
4,28.0,0,22.0,47.2,2.98,0,2
5,27.6,0,46.0,44.5,1.73,0,1


In [26]:
counterfact_data = filter(row -> (row."DPQ020" == 1) && row."LBXVIDMS" > quantile(data."LBXVIDMS", 0.8) && row."LBXVIDMS" < quantile(data."LBXVIDMS", 0.9), data)

Row,BMXBMI,RIAGENDR,RIDAGEYR,LBXVIDMS,INDFMMPI,DPQ020,SMQ040
Unnamed: 0_level_1,Float64,Int64,Float64,Float64,Float64,Int64,Int64
1,28.2,0,69.0,94.5,1.01,1,3
2,19.1,1,47.0,85.4,0.99,1,1
3,24.5,1,66.0,88.6,1.25,1,3
4,36.8,0,59.0,90.0,1.86,1,1
5,27.1,1,48.0,86.7,4.16,1,1
6,36.8,1,51.0,89.0,0.2,1,1
7,34.4,0,64.0,92.9,3.75,1,2
8,22.4,1,48.0,94.6,0.95,1,1
9,35.3,1,62.0,91.0,0.89,1,3
10,20.2,1,77.0,93.5,1.66,1,1


In [27]:
# Convert categorical data to integers
counterfact_data[!, :"RIAGENDR"] = convert(Array{Int}, counterfact_data[!, :"RIAGENDR"])
counterfact_data[!, :"DPQ020"] = convert(Array{Int}, counterfact_data[!, :"DPQ020"])
counterfact_data[!, :"SMQ040"] = convert(Array{Int}, counterfact_data[!, :"SMQ040"])

# Get variables for counterfactual analysis
bmi_counterfact = counterfact_data[!, :"BMXBMI"]
gender_counterfact = counterfact_data[!, :"RIAGENDR"]
age_counterfact = counterfact_data[!, :"RIDAGEYR"]
vitd_counterfact = counterfact_data[!, :"LBXVIDMS"]
poverty_counterfact = counterfact_data[!, :"INDFMMPI"]
depression_counterfact = counterfact_data[!, :"DPQ020"]
smoking_counterfact = counterfact_data[!, :"SMQ040"]
poverty_array_counterfact = convert(Vector{Float64}, poverty_counterfact);

In [28]:
counterfact_model_depr = depression_model(bmi_counterfact, gender_counterfact, age_counterfact,repeat([quantile(vitd, 0.95)], length(bmi_counterfact)), poverty_array_counterfact, repeat([missing], length(bmi)), smoking_counterfact)
counterfact_chain_depr = sample(counterfact_model_depr, PG(10), 1000)

[32mSampling   0%|                                          |  ETA: N/A[39m
[32mSampling   1%|▍                                         |  ETA: 0:01:31[39m
[32mSampling   2%|▉                                         |  ETA: 0:01:46[39m
[32mSampling   3%|█▎                                        |  ETA: 0:01:41[39m
[32mSampling   4%|█▋                                        |  ETA: 0:01:37[39m
[32mSampling   5%|██▏                                       |  ETA: 0:01:34[39m
[32mSampling   6%|██▌                                       |  ETA: 0:01:31[39m
[32mSampling   7%|███                                       |  ETA: 0:01:29[39m
[32mSampling   8%|███▍                                      |  ETA: 0:01:27[39m
[32mSampling   9%|███▊                                      |  ETA: 0:01:26[39m
[32mSampling  10%|████▎                                     |  ETA: 0:01:24[39m
[32mSampling  11%|████▋                                     |  ETA: 0:01:23[39m
[32mSampling  12%|█

Chains MCMC chain (100×17×1 Array{Float64, 3}):

Log evidence      = -6.980184644688326e23
Iterations        = 1:1:100
Number of chains  = 1
Samples per chain = 100
Wall duration     = 88.7 seconds
Compute duration  = 88.7 seconds
parameters        = depression[1], depression[2], depression[3], depression[4], depression[5], depression[6], depression[7], depression[8], depression[9], depression[10], depression[11], depression[12], depression[13], depression[14], depression[15]
internals         = lp, logevidence

Summary Statistics
 [1m     parameters [0m [1m    mean [0m [1m     std [0m [1m naive_se [0m [1m    mcse [0m [1m      ess [0m [1m    rhat[0m ⋯
 [90m         Symbol [0m [90m Float64 [0m [90m Float64 [0m [90m  Float64 [0m [90m Float64 [0m [90m  Float64 [0m [90m Float64[0m ⋯

   depression[1]    0.0500    0.2190     0.0219    0.0224   117.0884    0.9920 ⋯
   depression[2]    0.9200    0.2727     0.0273    0.0249   126.5885    0.9953 ⋯
   depression[3]   

In [93]:
interv_data_no_intervention_model = depression_model(bmi_interv, gender_interv, age_interv, vitd_interv, poverty_array_interv, repeat([missing], length(bmi_interv)), smoking_interv)
chain_depr_interv_no_interv = sample(interv_data_no_intervention_model, PG(10), 1000)

[32mSampling   0%|                                          |  ETA: N/A[39m
[32mSampling   0%|▎                                         |  ETA: 0:04:22[39m
[32mSampling   1%|▍                                         |  ETA: 0:04:12[39m
[32mSampling   2%|▋                                         |  ETA: 0:03:56[39m
[32mSampling   2%|▉                                         |  ETA: 0:03:48[39m
[32mSampling   2%|█                                         |  ETA: 0:04:09[39m
[32mSampling   3%|█▎                                        |  ETA: 0:04:04[39m
[32mSampling   4%|█▌                                        |  ETA: 0:04:01[39m
[32mSampling   4%|█▋                                        |  ETA: 0:04:10[39m
[32mSampling   4%|█▉                                        |  ETA: 0:04:12[39m
[32mSampling   5%|██▏                                       |  ETA: 0:04:06[39m
[32mSampling   6%|██▎                                       |  ETA: 0:04:00[39m
[32mSampling   6%|█

Chains MCMC chain (1000×39×1 Array{Float64, 3}):

Log evidence      = -3.809874428883141e24
Iterations        = 1:1:1000
Number of chains  = 1
Samples per chain = 1000
Wall duration     = 221.11 seconds
Compute duration  = 221.11 seconds
parameters        = depression[1], depression[2], depression[3], depression[4], depression[5], depression[6], depression[7], depression[8], depression[9], depression[10], depression[11], depression[12], depression[13], depression[14], depression[15], depression[16], depression[17], depression[18], depression[19], depression[20], depression[21], depression[22], depression[23], depression[24], depression[25], depression[26], depression[27], depression[28], depression[29], depression[30], depression[31], depression[32], depression[33], depression[34], depression[35], depression[36], depression[37]
internals         = lp, logevidence

Summary Statistics
 [1m     parameters [0m [1m    mean [0m [1m     std [0m [1m naive_se [0m [1m    mcse [0m [1m 

In [94]:
counterfact_data_no_interv_model = depression_model(bmi_counterfact, gender_counterfact, age_counterfact, vitd_counterfact, poverty_array_counterfact, repeat([missing], length(bmi)), smoking_counterfact)
counterfact_chain_depr_no_intev = sample(counterfact_data_no_interv_model, PG(10), 1000)

[32mSampling   0%|                                          |  ETA: N/A[39m
[32mSampling   0%|▎                                         |  ETA: 0:01:38[39m
[32mSampling   1%|▍                                         |  ETA: 0:01:43[39m
[32mSampling   2%|▋                                         |  ETA: 0:01:38[39m
[32mSampling   2%|▉                                         |  ETA: 0:01:35[39m
[32mSampling   2%|█                                         |  ETA: 0:01:33[39m
[32mSampling   3%|█▎                                        |  ETA: 0:01:32[39m
[32mSampling   4%|█▌                                        |  ETA: 0:01:31[39m
[32mSampling   4%|█▋                                        |  ETA: 0:01:30[39m
[32mSampling   4%|█▉                                        |  ETA: 0:01:30[39m
[32mSampling   5%|██▏                                       |  ETA: 0:01:31[39m
[32mSampling   6%|██▎                                       |  ETA: 0:01:31[39m
[32mSampling   6%|█

Chains MCMC chain (1000×17×1 Array{Float64, 3}):

Log evidence      = -7.330051420789005e23
Iterations        = 1:1:1000
Number of chains  = 1
Samples per chain = 1000
Wall duration     = 93.7 seconds
Compute duration  = 93.7 seconds
parameters        = depression[1], depression[2], depression[3], depression[4], depression[5], depression[6], depression[7], depression[8], depression[9], depression[10], depression[11], depression[12], depression[13], depression[14], depression[15]
internals         = lp, logevidence

Summary Statistics
 [1m     parameters [0m [1m    mean [0m [1m     std [0m [1m naive_se [0m [1m    mcse [0m [1m       ess [0m [1m    rha[0m ⋯
 [90m         Symbol [0m [90m Float64 [0m [90m Float64 [0m [90m  Float64 [0m [90m Float64 [0m [90m   Float64 [0m [90m Float6[0m ⋯

   depression[1]    0.0970    0.2961     0.0094    0.0080    792.5261    0.999 ⋯
   depression[2]    0.9920    0.0891     0.0028    0.0028    810.8932    0.999 ⋯
   depression[3]

In [95]:
depression_days = Array{Float64}(undef, length(depression_interv))
ates = Array{Float64}(undef, length(depression_interv))

for i in 1:length(depression_interv)
    depression_days[i] = mean(chain_depr_interv["depression[$i]"])
    og_depression = mean(chain_depr_interv_no_interv["depression[$i]"])
    ates[i] = depression_days[i] - og_depression
end

mean_ate = mean(ates)
std_ate = std(ates)
ci_lower_ate = mean_ate - 1.96 * std_ate
ci_upper_ate = mean_ate + 1.96 * std_ate

mean_ate, std_ate, ci_lower_ate, ci_upper_ate

(-0.2213513513513514, 0.2942094129674961, -0.7980018007676437, 0.35529909806494087)

In [96]:
size(depression_interv)

(37,)

In [97]:
ates = []

for i in 1:length(depression_interv)
    depression_sample = Array(chain_depr_interv["depression[$i]"])
    ates = append!(ates, depression_sample .- depression_interv[i])
end

mean_ate = mean(ates)
std_ate = std(ates)
ci_lower_ate = mean_ate - 1.96 * std_ate
ci_upper_ate = mean_ate + 1.96 * std_ate

mean_ate, std_ate, ci_lower_ate, ci_upper_ate

(-0.22667567567567568, 0.4186866985364378, -1.0473016048070938, 0.5939502534557424)

In [104]:
ates = Array{Float64}(undef, length(depression_counterfact))

for i in 1:length(depression_counterfact)
    depression_days = mean(counterfact_chain_depr["depression[$i]"])
    og_depression = mean(counterfact_chain_depr_no_intev["depression[$i]"])
    ates[i] = depression_days - og_depression
end

mean_ate = mean(ates)
std_ate = std(ates)
ci_lower_ate = mean_ate - 1.96 * std_ate
ci_upper_ate = mean_ate + 1.96 * std_ate

mean_ate, std_ate, ci_lower_ate, ci_upper_ate

(-0.214, 0.19437592443510074, -0.5949768118927974, 0.16697681189279742)

In [99]:
size(depression_counterfact)

(15,)

In [105]:
quantile(vitd, 0.8), quantile(vitd, 0.9), quantile(vitd, 0.95)

(84.5, 96.52999999999999, 108.0)