# Figure S2: Validation of age-dependent model inference

In [None]:
include("../analysis/mESC/load_analysis.jl")
include("../analysis/mESC/filter_prior.jl")

srcpath = normpath(srcpath*"../age-dependent/")
fitpath = datapath*"fits_age-dependent/"
include(srcpath*"dists.jl")
include(srcpath*"mle.jl")

# fix the replication point to be in the middle of the S phase
θᵣ = θ_G1_S + (θ_S_G2M - θ_G1_S)/2

fits_main = load(fitpath*"fits_main.jld2", "fits_main")
include("../analysis/mESC/filter_post.jl")
datapath = normpath(datapath*"fits_age-dependent/");

# Supplementary Figure S2: synthetic validation

In [102]:
burst_freqs_G1_th_dep = get_burst_frequency_G1.(fits_main)
burst_freqs_G2M_th_dep = get_burst_frequency_G2M.(fits_main)
burst_sizes_G1_th_dep = get_burst_size_G1.(fits_main, Ref(thetaG1))
burst_sizes_G2M_th_dep = get_burst_size_G2M.(fits_main, Ref(thetaG2M))
ratio_f_th_dep = burst_freqs_G2M_th_dep ./ burst_freqs_G1_th_dep
ratio_b_th_dep = burst_sizes_G2M_th_dep ./ burst_sizes_G1_th_dep;

In [103]:
@load datapath*"synthetic_data_gene_inds.jld2" inds
@load datapath*"synthetic_data_fits.jld2" all_fits;

In [None]:
n_examples = length(inds)
n_samples = length(all_fits[1])

In [None]:
true_ps = hcat((collect(params(m)) for m in fits_main[inds])...)

In [106]:
ys_median = Vector{Vector{Float64}}(undef, n_examples)
ys_Q1 = similar(ys_median)
ys_Q3 = similar(ys_median)

for i in 1:n_examples
    refit_ps = hcat((collect(params(m)) for m in all_fits[i])...)
    ys_median[i] = median.(eachrow(refit_ps))
    ys_Q1[i] = quantile.(eachrow(refit_ps), 0.25)
    ys_Q3[i] = quantile.(eachrow(refit_ps), 0.75)
end

In [None]:
function plot_param(ax::Axis, i::Int)
    x = true_ps[i, :]
    y = [_y[i] for _y in ys_median]
    y_Q1 = [_y[i] for _y in ys_Q1]
    y_Q3 = [_y[i] for _y in ys_Q3]
    y_err_Q1 = y .- y_Q1
    y_err_Q3 = y_Q3 .- y

    xi = min(minimum(x), minimum(y_Q1))
    xi = xi < 0 ? xi*1.1 : xi*0.9
    xf = max(maximum(x), maximum(y_Q3))
    xf = xf < 0 ? xf*0.9 : xf*1.1

    errorbars!(ax, x, y, y_err_Q1, y_err_Q3, whiskerwidth=3, linewidth=0.6, color=(c1, 0.4), direction=:y)
    scatter!(ax, x, y, markersize=3, color=(c1, 0.7))
    hlines!(ax, 0, color=(:black, 0.2), linewidth=0.5)
    lines!(ax, min(xi, 0):0.01:xf, min(xi, 0):0.01:xf, color=(:black, 0.4), linestyle=:dash, linewidth=0.5)
    xlims!(ax, xi, xf)
    ylims!(ax, xi, xf)

    ax
end

In [None]:
# identify fits that have failed (inference error)
x = [findall([any(iszero.(params(m))) for m in fits]) for fits in all_fits]
rinds = findall(!isempty(_x) for _x in x)
rinds

In [None]:
findall(get_burst_frequency_G1.(all_fits[rinds[1]]) .== 0) 

In [None]:
all_fits[136][66] # all zeros indicate that inference failed

In [111]:
# remove the single erroneous fit (has little impact on the interquartile range statistics)
deleteat!(all_fits[136], 66);

In [None]:
f = Figure(size = (size_pt[1]*3.0, size_pt[2]*3.0), figure_padding = 1)
ga = GridLayout(f[1,1])

# Parameter f
ax11 = plot_param(Axis(ga[1,1], xlabel="True f₁", ylabel="Observed f₁"), 1)
xlims!(ax11, low=0, high=3.4); ylims!(ax11, low=0, high=3.4)
ax12 = plot_param(Axis(ga[1,2], xlabel="True f₂", ylabel="Observed f₂"), 2)#
xlims!(ax12, low=0, high=1.5); ylims!(ax12, low=0, high=1.5)
ax13 = plot_param(Axis(ga[1,3], xlabel="True ρ₁", ylabel="Observed ρ₁"), 3)
xlims!(ax13, low=0, high=6.5); ylims!(ax13, low=0, high=6.5)
ax14 = plot_param(Axis(ga[1,4], xlabel="True ρ₂", ylabel="Observed ρ₂"), 4)
xlims!(ax14, low=0, high=4); ylims!(ax14, low=0, high=4)
# Parameter β
ax21 = plot_param(Axis(ga[2,1], xlabel="True β₁", ylabel="Observed β₁"), 5)
ax22 = plot_param(Axis(ga[2,2], xlabel="True β₂", ylabel="Observed β₂"), 6)
ax23 = plot_param(Axis(ga[2,3], xlabel="True β₃", ylabel="Observed β₃"), 7)
xlims!(ax23, low=-3, high=15); ylims!(ax23, low=-3, high=15)
ax24 = plot_param(Axis(ga[2,4], xlabel="True β₄", ylabel="Observed β₄"), 8)
xlims!(ax24, low=-3, high=11); ylims!(ax24, low=-3, high=11)

rowgap!(ga, 7)
colgap!(ga, 8)

f

In [113]:
vec_ratios_f = Vector{Vector{Float64}}(undef, n_examples)

for i in 1:n_examples
    f1s = [get_burst_frequency_G1(m) for m in all_fits[i]]
    f2s = [get_burst_frequency_G2M(m) for m in all_fits[i]]
    vec_ratios_f[i] = f2s ./ f1s 
end

vec_ratios_b = Vector{Vector{Float64}}(undef, n_examples)

for i in 1:n_examples
    b1s = [get_burst_size_G1(m, thetaG1) for m in all_fits[i]]
    b2s = [get_burst_size_G2M(m, thetaG2M) for m in all_fits[i]]
    vec_ratios_b[i] = b2s ./ b1s 
end

In [None]:
ax31 = Axis(ga[3,2], xlabel="True Qf", ylabel="Observed Qf")

x = ratio_f_th_dep[inds] 
y = median.(vec_ratios_f)
y_Q1 = quantile.(vec_ratios_f, 0.25)
y_Q3 = quantile.(vec_ratios_f, 0.75)
y_err_Q1 = y .- y_Q1
y_err_Q3 = y_Q3 .- y
xi = 0; xf = 3.2

errorbars!(ax31, x, y, y_err_Q1, y_err_Q3, whiskerwidth=3, linewidth=0.6, color=(c1, 0.4), direction=:y)
scatter!(ax31, x, y, markersize=3, color=(c1, 0.7))
lines!(ax31, xi:0.001:xf, xi:0.001:xf, color=(:black, 0.4), linestyle=:dash, linewidth=0.5)
xlims!(ax31, 0, xf)
ylims!(ax31, 0, xf)

f

In [None]:
ax32 = Axis(ga[3,3], xlabel="True Qb", ylabel="Observed Qb")

x = ratio_b_th_dep[inds] 
y = median.(vec_ratios_b)
y_Q1 = quantile.(vec_ratios_b, 0.25)
y_Q3 = quantile.(vec_ratios_b, 0.75)
y_err_Q1 = y .- y_Q1
y_err_Q3 = y_Q3 .- y
xi = 0; xf = 3.8

errorbars!(ax32, x, y, y_err_Q1, y_err_Q3, whiskerwidth=3, linewidth=0.6, color=(c1, 0.4), direction=:y)
scatter!(ax32, x, y, markersize=3, color=(c1, 0.7))
lines!(ax32, xi:0.001:xf, xi:0.001:xf, color=(:black, 0.4), linestyle=:dash, linewidth=0.5)
xlims!(ax32, 0, xf)
ylims!(ax32, 0, xf)

f

In [None]:
rowgap!(ga, 7)
f