# Figure S4: Alternative model with age-dependent burst frequency & fixed replication timing

In [None]:
include("../analysis/mESC/load_analysis.jl")
include("../analysis/mESC/filter_prior.jl")

srcpath = normpath(srcpath*"../age-dependent/")
fitpath = normpath(datapath*"fits_age-dependent/")
include(srcpath*"dists.jl")
include(srcpath*"mle.jl")

# fix the replication point to be in the middle of the S phase
θᵣ = θ_G1_S + (θ_S_G2M - θ_G1_S)/2

fits_main = load(fitpath*"fits_main.jld2", "fits_main");

In [29]:
fits_alt2 = load(fitpath*"fits_alt_2_age-dep_bf.jld2", "fits_alt2");

Perform a series of filtering steps for the alternative age-dependent model and the main-text model, and recover an overlapping set of genes

In [30]:
# --- Step 1 ---

# Remove all genes for which the predicted mean expression given by alternative age-dependent in either G1 or G2/M 
# cell cycle phase is negatively correlated with the cell age θ. Usually indicative of a bad model fit (due to larger 
# deviations from the ratio of 2 between the cells in θ_f and in θ_i).

inds = Vector{Bool}(undef, ngenes)
for i in 1:ngenes
    mG1 = mean.(Ref(fits_alt2[i]), θs_G1, Ref(T_cycle), Ref(decay_rates[i]), Ref(θᵣ), Ref(θ_G1_S), Ref(θ_S_G2M))
    mG2M = mean.(Ref(fits_alt2[i]), θs_G2M, Ref(T_cycle), Ref(decay_rates[i]), Ref(θᵣ), Ref(θ_G1_S), Ref(θ_S_G2M))
    inds[i] = (cor(mG1, θs_G1) .> 0) .&& (cor(mG2M, θs_G2M) .> 0)
end

inds1 = findall(inds);

In [31]:
# -- Step 2 ---
# Remove genes that have a clearly bad fit -- indicated by a negative R^2 value either for the mean or variance over the entire cell cycle

function compute_rsq_mean_alt2(ind::Int, m=fits_alt2[ind])
    counts_G1 = [xG1[ind][_inds] for _inds in inds_θs_G1]
    counts_S = [xS[ind][_inds] for _inds in inds_θs_S]
    counts_G2M = [xG2M[ind][_inds] for _inds in inds_θs_G2M]

    yG1 = mean.(counts_G1)
    yS = mean.(counts_S)
    yG2M = mean.(counts_G2M)
    y = vcat(yG1, yS, yG2M)

    ms = mean.(Ref(m), θs, Ref(T_cycle), Ref(decay_rates[ind]), Ref(θᵣ), Ref(θ_G1_S), Ref(θ_S_G2M))
    r2 = 1 - sum((y .- ms).^2) / sum((y.- mean(y)).^2)
    r2
end

function compute_rsq_var_alt2(ind::Int, m=fits_alt2[ind])
    counts_G1 = [xG1[ind][_inds] for _inds in inds_θs_G1]
    counts_S = [xS[ind][_inds] for _inds in inds_θs_S]
    counts_G2M = [xG2M[ind][_inds] for _inds in inds_θs_G2M]

    yG1 = var.(counts_G1)
    yS = var.(counts_S)
    yG2M = var.(counts_G2M)
    y = vcat(yG1, yS, yG2M)
    
    vars = var.(Ref(m), θs, Ref(T_cycle), Ref(decay_rates[ind]), Ref(θᵣ), Ref(θ_G1_S), Ref(θ_S_G2M))
    r2 = 1 - sum((y .- vars).^2) / sum((y.- mean(y)).^2)
    r2
end

mean_rsqs_alt2 = compute_rsq_mean_alt2.(1:ngenes)
var_rsqs_alt2 = compute_rsq_var_alt2.(1:ngenes)

inds2 = findall(mean_rsqs_alt2 .> 0 .&& var_rsqs_alt2 .> 0);

In [None]:
# --- Step 3 ---
# Remove genes for which the alternative model estimates a burst size tending to the lower parameter bound (leads to unrealistic burst parameter ratio estimates)

b1s = [m.b₁ for m in fits_alt2]
b2s = [m.b₂ for m in fits_alt2]
inds3 = intersect(findall(b1s .> 0.002), findall(b2s .> 0.002))

fits_alt2 = fits_alt2[intersect(inds1, inds2, inds3)]
gene_names_alt2 = gene_names[intersect(inds1, inds2, inds3)]
println("$(length(gene_names_alt2)) genes left after filtering for alt1 model.");

In [None]:
# perform filtering for the main model
include("../analysis/mESC/filter_post.jl")

In [None]:
inds = findall(in(gene_names_alt2), gene_names)
_inds = findall(in(gene_names), gene_names_alt2)
fits_alt2 = fits_alt2[_inds]

xG1 = xG1[inds]
xS = xS[inds]
xG2M = xG2M[inds]

counts_spliced = counts_spliced[inds]
gene_names = gene_names[inds]
decay_rates = decay_rates[inds]

G1_th_ind_fits = G1_th_ind_fits[inds]
G2M_th_ind_fits = G2M_th_ind_fits[inds]
fits_main = fits_main[inds]

ndiff = ngenes - length(inds)
ngenes = length(inds)

println("Removed $ndiff genes that did not overlap between the filtered gene sets.")
println("$ngenes genes left remaining.");

In [35]:
burst_freqs_G1_alt2 = get_burst_frequency_G1.(fits_alt2, Ref(thetaG1))
burst_freqs_G2M_alt2 = get_burst_frequency_G2M.(fits_alt2, Ref(thetaG2M))
burst_sizes_G1_alt2 = get_burst_size_G1.(fits_alt2)
burst_sizes_G2M_alt2 = get_burst_size_G2M.(fits_alt2);

In [36]:
burst_freqs_G1_main = get_burst_frequency_G1.(fits_main)
burst_freqs_G2M_main = get_burst_frequency_G2M.(fits_main)
burst_sizes_G1_main = get_burst_size_G1.(fits_main, Ref(thetaG1))
burst_sizes_G2M_main = get_burst_size_G2M.(fits_main, Ref(thetaG2M));

In [37]:
ratio_f_main = burst_freqs_G2M_main ./ burst_freqs_G1_main
ratio_b_main = burst_sizes_G2M_main ./ burst_sizes_G1_main
ratio_f_alt2 = burst_freqs_G2M_alt2 ./ burst_freqs_G1_alt2
ratio_b_alt2 = burst_sizes_G2M_alt2 ./ burst_sizes_G1_alt2;

In [None]:
@show median(ratio_f_main)
@show quantile(ratio_f_main, 0.25)
@show quantile(ratio_f_main, 0.75);

In [None]:
@show median(ratio_f_alt2)
@show quantile(ratio_f_alt2, 0.25)
@show quantile(ratio_f_alt2, 0.75);

In [None]:
@show median(ratio_b_main)
@show quantile(ratio_b_main, 0.25)
@show quantile(ratio_b_main, 0.75);

In [None]:
@show median(ratio_b_alt2)
@show quantile(ratio_b_alt2, 0.25)
@show quantile(ratio_b_alt2, 0.75);

In [42]:
x = ratio_f_alt2
y = ratio_b_alt2
f = Figure(size = (size_pt[1]*1.1, size_pt[2]*1.2), figure_padding = 1)

ga = f[1, 1] = GridLayout()
axtop = Axis(ga[1, 1], 
             leftspinevisible = false,
             rightspinevisible = false,
             bottomspinevisible = false,
             topspinevisible = false)
axmain = Axis(ga[2, 1], xlabel = "", ylabel = "",
              yminorticks = IntervalsBetween(2),
              yminorticksvisible = true,
              yminorticksize = 1.5,
              yminortickwidth = 0.7,
              xticksmirrored = true,
              yticksmirrored = true,
              rightspinecolor = (c1, 1),
              topspinecolor = (c2, 1))
axright = Axis(ga[2, 2],
               leftspinevisible = false,
               rightspinevisible = false,
               bottomspinevisible = false,
               topspinevisible = false)

linkyaxes!(axmain, axright)
linkxaxes!(axmain, axtop)

hidedecorations!(axtop, grid = false)
hidedecorations!(axright, grid = false)
scatter!(axmain, x, y, color=(:gray, 0.4), markersize=2)
vlines!(axmain, 1, color=(:black, 0.4), linestyle=:dash)
hlines!(axmain, 1, color=(:black, 0.4), linestyle=:dash)
xlims!(axmain, low = 0, high = 2.0)

density!(axtop, x, color=(c2), npoints=1000)
density!(axtop, ratio_f_main, color=(c2, 0.2), npoints=1000, strokewidth=0.1)
hlines!(axtop, 0, color=(:black, 0.3), linewidth=0.3)
boxplot!(axtop, fill(0.0, length(x)), x, orientation=:horizontal, strokewidth = 0.7, 
         width=0.7, whiskerwidth=0, show_outliers=false, color=(c2, 0))
ylims!(axtop, low=-0.4, high=3.2)

density!(axright, y, direction = :y, color=(c1), npoints=1000)
density!(axright, ratio_b_main, direction = :y, color=(c1, 0.2), npoints=1000)
vlines!(axright, 0, color=(:black, 0.3), linewidth=0.3)
boxplot!(axright, fill(0.0, length(y)), y, strokewidth = 0.7, 
         width=0.5, whiskerwidth=0, show_outliers=false, color=(c1, 0))
ylims!(axright, low=0, high=3.0)
xlims!(axright, low=-0.4, high=3.2)

colgap!(ga, 2)
rowgap!(ga, 2)
colsize!(ga, 2, Relative(1.2/3))
rowsize!(ga, 1, Relative(1.2/3))

In [None]:
f

In [44]:
cb = colorant"#40b2dd"
cv3 = colorant"#57C6FF"
cv2 = colorant"#9881FD"
cv1 = c1;

In [None]:
function plot_mean_fit(f::GridPosition, ind::Int)
    m_main = fits_main[ind]
    m_alt = fits_alt2[ind]
    
    counts_G1 = [xG1[ind][_inds] for _inds in inds_θs_G1]
    counts_S = [xS[ind][_inds] for _inds in inds_θs_S]
    counts_G2M = [xG2M[ind][_inds] for _inds in inds_θs_G2M]

    yG1 = mean.(counts_G1)
    yS = mean.(counts_S)
    yG2M = mean.(counts_G2M)

    ax = Axis(f, xlabel="", ylabel="Mean", xticks=(0:0.2:1))
    vlines!(ax, θ_G1_S, linewidth=0.8, color=(:black, 0.2))
    vlines!(ax, θ_S_G2M, linewidth=0.8, color=(:black, 0.2))
    vlines!(ax, θᵣ, linewidth=0.8, color=(cv2, 0.6))

    y = vcat(yG1, yS, yG2M)
    lines!(ax, θs, y, color=(cb, 0.3), linewidth=0.3)
    scatter!(ax, θs, y, color=(cb, 0.7), markersize=3.0, strokecolor=(c1, 0.85), strokewidth=0.2)
    lines!(ax, θs, mean.(Ref(m_main), θs, Ref(T_cycle), Ref(decay_rates[ind]), Ref(θᵣ), Ref(θ_G1_S), Ref(θ_S_G2M)), color=(:gray, 0.7), linewidth=1.8)
    lines!(ax, θs, mean.(Ref(m_alt), θs, Ref(T_cycle), Ref(decay_rates[ind]), Ref(θᵣ), Ref(θ_G1_S), Ref(θ_S_G2M)), color=(c3, 0.7), linewidth=1.8)
    
    xlims!(ax, low=-0.02, high=1.01)

    ax 
end

In [None]:
function plot_var_fit(f::GridPosition, ind::Int)
    m_main = fits_main[ind]
    m_alt = fits_alt2[ind]

    counts_G1 = [xG1[ind][_inds] for _inds in inds_θs_G1]
    counts_S = [xS[ind][_inds] for _inds in inds_θs_S]
    counts_G2M = [xG2M[ind][_inds] for _inds in inds_θs_G2M]

    yG1 = var.(counts_G1)
    yS = var.(counts_S)
    yG2M = var.(counts_G2M)
 
    ax = Axis(f, xlabel="Cell age θ", ylabel="Variance", xticks=(0:0.2:1))
    vlines!(ax, θ_G1_S, linewidth=0.8, color=(:black, 0.2))
    vlines!(ax, θ_S_G2M, linewidth=0.8, color=(:black, 0.2))
    vlines!(ax, θᵣ, linewidth=0.8, color=(cv2, 0.6))
        
    y = vcat(yG1, yS, yG2M)
    lines!(ax, θs, y, color=(cb, 0.3), linewidth=0.3)
    scatter!(ax, θs, y, color=(cb, 0.7), markersize=3.0, strokecolor=(c1, 0.85), strokewidth=0.2)
    lines!(ax, θs, var.(Ref(m_main), θs, Ref(T_cycle), Ref(decay_rates[ind]), Ref(θᵣ), Ref(θ_G1_S), Ref(θ_S_G2M)), color=(:gray, 0.7), linewidth=1.8)
    lines!(ax, θs, var.(Ref(m_alt), θs, Ref(T_cycle), Ref(decay_rates[ind]), Ref(θᵣ), Ref(θ_G1_S), Ref(θ_S_G2M)), color=(c3, 0.7), linewidth=1.8)
    
    xlims!(ax, low=-0.02, high=1.01)
    ax 
end

In [None]:
f = Figure(size = (size_pt[1]*3.0, size_pt[2]*1.4), figure_padding = 1, fontsize=7)
ga = f[1,1] = GridLayout()

ind = findfirst(gene_names .== "Uqcrh")
println(gene_names[ind])
ax11 = plot_mean_fit(ga[1,1], ind)
ax11.title = gene_names[ind]
ax11.titlefont = "Arial"
ax11.xlabel = ""
ax11.xticksvisible = false
ax11.xticklabelsvisible = false
vlines!(ax11, θs_S[20], linewidth=0.8, color=(:red, 0.2))

ax21 = plot_var_fit(ga[2,1], ind)
ax21.xlabel = ""

ind = findfirst(gene_names .== "Lsm8")
ax12 = plot_mean_fit(ga[1,2], ind)
println(gene_names[ind])
ax12.title = gene_names[ind]
ax12.titlefont = "Arial"
ax12.xlabel = ""
ax12.ylabel = ""
ax12.xticksvisible = false
ax12.xticklabelsvisible = false
vlines!(ax12, θs_S[5], linewidth=0.8, color=(:red, 0.2))

ax22 = plot_var_fit(ga[2,2], ind)
ax22.ylabel = ""

ind = findfirst(gene_names .== "Vim")
println(gene_names[ind])
ax13 = plot_mean_fit(ga[1,3], ind)
ax13.title = gene_names[ind]
ax13.titlefont = "Arial"
ax13.xlabel = ""
ax13.ylabel = ""
ax13.yticks = 10:10:30
ax13.xticksvisible = false
ax13.xticklabelsvisible = false
vlines!(ax13, θs_G1[8], linewidth=0.8, color=(:red, 0.2))
vlines!(ax13, θs_G2M[14], linewidth=0.8, color=(:red, 0.2))

ax23 = plot_var_fit(ga[2,3], ind)
ax23.xlabel = ""
ax23.ylabel = ""
ax23.yticks = 50:100:250

colgap!(ga, 10)
rowgap!(ga, 6)

f