# RootCauseDiscovery (seperate the two datasets and try more patients with ground truth in the appendix as Ines suggested)

Here for a known "root cause gene", lets run a Lasso regression to filter down the number of features, before running Root Cause Discovery algorithm

+ After processing, we obtained 2 datasets
    1. "Interventional data": there are 32 samples for which we have ground truth data. Each row is a gene and each column is a sample. 
    2. "observational data": these are a few hundred samples for which we did not do any intervention. Each row is a gene and each column is a sample.
+ The main idea is that an intervention supposedly caused a disturbance in gene expression, so we can compare the data of intervention to baseline (i.e. observational data), and try to which gene is most differentially expressed

In [3]:
using RootCauseDiscovery
using DataFrames
using CSV
using DelimitedFiles
using JLD2
using Random
using ProgressMeter
using LinearAlgebra
BLAS.set_num_threads(1)

using Plots
gr(fmt=:png)

# source: Table 1 of https://genomemedicine.biomedcentral.com/articles/10.1186/s13073-022-01019-9
# patient_RNA_defects = Dict{String, String}(
#     # table S2
#     "R62943" => "AE,AS",
#     "R98254" => "AE",
#     "R86287" => "AS",
#     "R89912" => "AE,MAE",
#     "R19100" => "AE",
#     "R15264" => "AE,AS,Var",
#     "R36605" => "AE,AS",
#     "R61100" => "AE,AS",
#     "R77611" => "AE,MAE",
#     "R16472" => "AE,AS,Var",
#     "R51757" => "AE",
#     "R80346" => "AE,Var",
#     "R20754" => "AE",
#     "R25473" => "AE,Var",
#     "R28774" => "AE,AS,Var",
#     "R96820" => "AE,AS",
#     "R21147" => "AE",
#     "R64921" => "AE,AS",
#     "R52016" => "AE,AS,Var",
#     "R46723" => "AE,AS,Var",
#     "R58859" => "AS",
#     "R80184" => "AE,MAE",
#     "R59185" => "AE",
#     "R63087" => "AS",
#     "R44456" => "AE,AS",
#     "R33391" => "AS,Var",
#     "R66696" => "AS",
#     "R24289" => "AE,MAE",
#     "R98349" => "AS",
#     "R91273" => "AE",
#     "R60537" => "AE,Var",
#     "R70961" => "AS"
# )

# helper function to submit 1 job to run 1 command
function submit(command::String, ncores::Int, total_mem::Number, joblog_dir::String; 
        jobname="submit", waitfor=Int[], verbose=true)
    mem = round(Int, total_mem / ncores) # memory per core
    filename = "$jobname.sh"
    open(filename, "w") do io
        println(io, "#!/bin/bash")
        println(io, "#")
        println(io, "#SBATCH --job-name=$jobname")
        println(io, "#")
        println(io, "#SBATCH --time=48:00:00")
        println(io, "#SBATCH --cpus-per-task=$ncores")
        println(io, "#SBATCH --mem-per-cpu=$(mem)G")
        println(io, "#SBATCH --partition=candes,zihuai,normal,qsu,owners")
        println(io, "#SBATCH --output=$(joinpath(joblog_dir, "slurm-%j.out"))")
        println(io, "")
        println(io, "#save job info on joblog:")
        println(io, "echo \"Job \$JOB_ID started on:   \" `hostname -s`")
        println(io, "echo \"Job \$JOB_ID started on:   \" `date `")
        println(io, "")
        println(io, "# load the job environment:")
        println(io, "module load julia/1.10")
        println(io, "export JULIA_DEPOT_PATH=\"/home/groups/sabatti/.julia\"")
        println(io, "")
        println(io, "# run code")
        println(io, "echo \"$command\"")
        println(io, "$command")
        println(io, "")
        println(io, "#echo job info on joblog:")
        println(io, "echo \"Job \$JOB_ID ended on:   \" `hostname -s`")
        println(io, "echo \"Job \$JOB_ID ended on:   \" `date `")
        println(io, "#echo \" \"")
    end
    # submit job and capture job ID
    io = IOBuffer()
    if length(waitfor) != 0
        run(pipeline(`sbatch --dependency=afterok:$(join(waitfor, ':')) $filename`; stdout=io))
    else
        run(pipeline(`sbatch $filename`; stdout=io))
    end
    msg = String(take!(io))
    verbose && println(stdout, msg)
    jobid = parse(Int, strip(msg)[21:end])
    # clean up and return job ID
    close(io)
    rm(filename, force=true)
    return jobid
end

# helper function to submit 1 job to run multiple commands
function submit(commands::Vector{String}, ncores::Int, 
        total_mem::Number, joblog_dir::String; 
        jobname="submit", waitfor=Int[], verbose=true)
    mem = round(Int, total_mem / ncores) # memory per core
    filename = "$jobname.sh"
    open(filename, "w") do io
        println(io, "#!/bin/bash")
        println(io, "#")
        println(io, "#SBATCH --job-name=$jobname")
        println(io, "#")
        println(io, "#SBATCH --time=48:00:00")
        println(io, "#SBATCH --cpus-per-task=$ncores")
        println(io, "#SBATCH --mem-per-cpu=$(mem)G")
        println(io, "#SBATCH --partition=candes,zihuai,normal,qsu,owners")
        println(io, "#SBATCH --output=$(joinpath(joblog_dir, "slurm-%j.out"))")
        println(io, "")
        println(io, "#save job info on joblog:")
        println(io, "echo \"Job \$JOB_ID started on:   \" `hostname -s`")
        println(io, "echo \"Job \$JOB_ID started on:   \" `date `")
        println(io, "")
        println(io, "# load the job environment:")
        println(io, "module load julia/1.10")
        println(io, "export JULIA_DEPOT_PATH=\"/home/groups/sabatti/.julia\"")
        println(io, "")
        for command in commands
            println(io, "echo \"$command\"")
            println(io, "$command")
        end
        println(io, "")
        println(io, "#echo job info on joblog:")
        println(io, "echo \"Job \$JOB_ID ended on:   \" `hostname -s`")
        println(io, "echo \"Job \$JOB_ID ended on:   \" `date `")
        println(io, "#echo \" \"")
    end
    # submit job and capture job ID
    io = IOBuffer()
    if length(waitfor) != 0
        run(pipeline(`sbatch --dependency=afterok:$(join(waitfor, ':')) $filename`; stdout=io))
    else
        run(pipeline(`sbatch $filename`; stdout=io))
    end
    msg = String(take!(io))
    verbose && println(stdout, msg)
    jobid = parse(Int, strip(msg)[21:end])
    # clean up and return job ID
    close(io)
    rm(filename, force=true)
    return jobid
end

function make_plot(
        Z_score_rank::AbstractVector, 
        RCD_rank::AbstractVector, 
        patient_name::AbstractVector;
        plot1_ymax = 5500,
        plot2_ymin = -2500,
        plot2_ymax = 2500,
        title1 = "Rank of root cause (Z-score vs Cholesky-score)",
        title2 = "Z-score rank - Cholesky-score rank"
    )
    # sort by diff in Z scores
    Z_score_diff = Z_score_rank - RCD_rank
    perm = sortperm(Z_score_diff)

    #
    # plot 1
    #
    plt1 = scatter(
        1:length(patient_name), 
        Z_score_rank[perm], 
        ylabel="rank", 
        xticks=(1:length(patient_name), patient_name[perm]),
        xrotation = 50,
        label = "Z-score rank",
        title = title1,
        legend=:top,
        dpi=300,
        ylim=(0, plot1_ymax)
    )
    scatter!(plt1, 1:length(patient_name), RCD_rank[perm], label="Cholesky-score rank")

    #
    # plot 2
    #
    palettes = theme_palette(:auto)
    my_colors = []
    for i in eachindex(Z_score_diff)
        if isinf(Z_score_diff[i])
            push!(my_colors, :black)
            Z_score_diff[i] = plot2_ymin
        elseif Z_score_diff[i] > 0
            push!(my_colors, palettes[1])
        elseif Z_score_diff[i] == 0
            push!(my_colors, palettes[3])
        else
            push!(my_colors, palettes[2])
        end
    end

    # make plot
    plt2 = scatter(
        1:length(patient_name_success), 
        Z_score_diff[perm],
        title=title2, 
        ylabel="Rank difference",
        xticks=(1:length(patient_name), patient_name[perm]),
        xrotation = 50,
        ylim=(plot2_ymin, plot2_ymax),
        label = nothing,
        legend=:outerbottom,
        color=my_colors[perm],
        dpi=300,
    )
    hline!(plt2, [0], color=:black, linestyle=:dash, label=nothing)

    #
    # combine plots
    #
    return plot(plt1, plt2, layout = (2, 1), size=(800, 600))
end

make_plot (generic function with 1 method)

## Read processed data
First, read in data and log (base 2) transform them

In [66]:
function read_data(low_count, threshold, max_cor, concatenate, dir)
#     low_count = 10
#     threshold = 0.1
#     max_cor = 0.999
#     concatenate = "ns"
#     dir = "/scratch/users/bbchu/RootCauseDiscovery/8.26.2024/data"

    # read processed data
    isdir(dir) || mkpath(dir)
    trainsform_int_file = joinpath(dir, "transform_int_$concatenate")
    transform_obs_file = joinpath(dir, "transform_obs_$concatenate")
    ground_truth_file = joinpath(dir, "ground_truth_$concatenate")

    if !(isfile(trainsform_int_file) && isfile(transform_obs_file) && isfile(ground_truth_file))
        # process raw data & save
        transform_int, transform_obs, ground_truth = QC_gene_expression_data(;
            low_count = low_count, 
            threshold = threshold, 
            max_cor = max_cor, 
            concatenate = concatenate
        )
        CSV.write(trainsform_int_file, transform_int)
        CSV.write(transform_obs_file, transform_obs)
        CSV.write(ground_truth_file, ground_truth)
    else
        transform_int = CSV.read(trainsform_int_file, DataFrame)
        transform_obs = CSV.read(transform_obs_file, DataFrame)
        ground_truth = CSV.read(ground_truth_file, DataFrame)
    end

    # numeric matrices
    Xobs = transform_obs[:, 2:end] |> Matrix |> transpose
    Xint = transform_int[:, 2:end] |> Matrix |> transpose
    
    return Xobs, Xint, ground_truth
end
low_count = 10
threshold = 0.1
max_cor = 0.999
concatenate = "all"
dir = "/scratch/users/bbchu/RootCauseDiscovery/8.26.2024/data"
Xobs, Xint, ground_truth = read_data(low_count, threshold, max_cor, concatenate, dir)

([6.905019519912375 7.257078563799141 … -0.06477114998921524 5.752340009973989; 7.02426909562131 7.246401013060676 … 0.12757476399859816 5.769481834936712; … ; 6.839604518578518 7.29331359490887 … 2.7416948901919214 5.792997014900723; 6.634964406531482 7.155258338398599 … 1.9075765878191422 5.778777598727033], [7.16242663712475 7.523928621811798 … 0.8758080027262863 5.863970440970483; 7.282548459614407 7.484357868453626 … -0.04128210658791 6.128328625903546; … ; 6.791879945272701 7.149247489954807 … 2.2252196140464853 5.952250834124798; 6.954662707932615 7.206785478347909 … 2.43723143625253 5.900881832561235], [1m69×6 DataFrame
[1m Row │[1m Patient ID [1m Genetic diagnosis [1m gene_id         [1m patient column index in ⋯
     │[90m String7    [90m String15          [90m String15        [90m Int64                   ⋯
─────┼──────────────────────────────────────────────────────────────────────────
   1 │ R62943      C19orf70           ENSG00000174917                          ⋯

## Submit main jobs (1 job for each patient)

In [None]:
# put this in script /scratch/users/bbchu/RootCauseDiscovery/run_realData_RCD_ines.jl

using RootCauseDiscovery
using DataFrames
using CSV
using DelimitedFiles
using JLD2
using Random
using LinearAlgebra
BLAS.set_num_threads(1)

# inputs to the script
patient_id = ARGS[1]
y_idx_z_threshold = parse(Float64, ARGS[2])
outfile = ARGS[3]
low_count = parse(Int, ARGS[4]) # min number of gene count exceeding `threshold` is needed to prevent gene from being filtered out
threshold = parse(Float64, ARGS[5]) # for processing raw data
seed = parse(Int, ARGS[6])
nshuffles = parse(Int, ARGS[7])
method = ARGS[8] # "cv" or "nhalf"
max_cor = parse(Float64, ARGS[9])
concatenate = ARGS[10] # "all", "ns", or "ss"

# for testing
# patient_id = "R62943"
# y_idx_z_threshold = 1.5
# outfile = "fdsa"
# low_count = 10
# threshold = 0.1
# seed = 1111
# nshuffles = 1
# method = "cv"
# max_cor = 0.999
# concatenate = "all"

# process raw data
transform_int, transform_obs, _ = 
    QC_gene_expression_data(low_count = low_count, 
    threshold = threshold, max_cor = max_cor, 
    concatenate = concatenate);

# numeric matrices
Xobs = transform_obs[:, 2:end] |> Matrix |> transpose
Xint = transform_int[:, 2:end] |> Matrix |> transpose
i = findfirst(x -> x == patient_id, names(transform_int)[2:end])
if !isnothing(i)
    Xint_sample = Xint[i, :]

    # concat Xobs
    nint = size(Xint, 1)
    Xobs_full = vcat(Xobs, Xint[setdiff(1:nint, i), :])

    # run main alg
    Random.seed!(seed)
    @time root_cause_score = root_cause_discovery_high_dimensional(
        Xobs_full, Xint_sample, method, y_idx_z_threshold=y_idx_z_threshold,
        nshuffles=nshuffles
    );

    # save result
    writedlm(outfile, root_cause_score)
else
    println("patient $patient_id not found in interventional samples!")
end

println("Done!")

### Submit jobs for each patient

20 runs take ~300 sec, so if a patient requires 4000 tries, it'll take ~16.6h

In [9]:
julia_exe = "/scratch/users/bbchu/RootCauseDiscovery/run_realData_RCD_ines.jl"
dirs = [
    "/scratch/users/bbchu/RootCauseDiscovery/8.26.2024"
]
joblog_dirs = [joinpath(dirs[i], "joblogs") for i in eachindex(dirs)]
mkpath.(joblog_dirs)
seeds = [8262024]
low_counts = [10]
thresholds = [0.1]
nshuffles = 10
y_idx_z_threshold = 1.5
method = "cv"
max_cor = 0.999
concatenates = ["all", "ns", "ss"]

# patient IDs
patient_name = ["R62943", "R98254", "R89912", "R19100", "R15264", "R36605", "R61100", 
    "R77611", "R16472", "R51757", "R80346", "R20754", "R25473", "R28774", "R96820", 
    "R21147", "R64921", "R52016", "R46723", "R80184", "R59185", "R44456", "R24289", 
    "R91273", "R60537", "R72253", "R75000", "R91016", "R82353", "R34834", "R78764", 
    "R30367", "R76358", "R12128", "R45867", "R31640", "R95723", "R55237", "R18626", 
    "R34820", "R25912", "R11258", "R64046", "R54158", "R19907", "R27473", "R70186", 
    "R26710", "R30525", "R29620", "R64055", "R15748", "R66814", "R77365", "R21993", 
    "R42505", "R64948", "R21470"]

# for testing
# nshuffles = 2
# y_idx_z_threshold = 20
# patient_name = ["R62943", "R98254"]

for concat in concatenates
    for (dir, joblog_dir, seed, low_count, threshold) in zip(dirs, joblog_dirs, seeds, low_counts, thresholds)
        for patient in patient_name
            outdir = joinpath(dir, concat, method)
            isdir(outdir) || mkpath(outdir)

            outfile = joinpath(outdir, "$(patient).csv")
            if !isfile(outfile)
                cmd = "julia -t 24 $julia_exe $patient $y_idx_z_threshold $outfile $low_count $threshold $seed $nshuffles $method $max_cor $concat"
                submit(cmd, 24, 64, joblog_dir, jobname=patient)
            end
        end
    end
end

Submitted batch job 52100815

Submitted batch job 52100817

Submitted batch job 52100819

Submitted batch job 52100822

Submitted batch job 52100824

Submitted batch job 52100826

Submitted batch job 52100829

Submitted batch job 52100831

Submitted batch job 52100832

Submitted batch job 52100835

Submitted batch job 52100837

Submitted batch job 52100840

Submitted batch job 52100842

Submitted batch job 52100845

Submitted batch job 52100847

Submitted batch job 52100849

Submitted batch job 52100850

Submitted batch job 52100852

Submitted batch job 52100855

Submitted batch job 52100858

Submitted batch job 52100860

Submitted batch job 52100863

Submitted batch job 52100865

Submitted batch job 52100868

Submitted batch job 52100869

Submitted batch job 52100870

Submitted batch job 52100873

Submitted batch job 52100874

Submitted batch job 52100875

Submitted batch job 52100877

Submitted batch job 52100879

Submitted batch job 52100881

Submitted batch job 52100882

Submitted 

In [2]:
patient_name = ["R62943", "R98254", "R89912", "R19100", "R15264", "R36605", "R61100", 
    "R77611", "R16472", "R51757", "R80346", "R20754", "R25473", "R28774", "R96820", 
    "R21147", "R64921", "R52016", "R46723", "R80184", "R59185", "R44456", "R24289", 
    "R91273", "R60537", "R72253", "R75000", "R91016", "R82353", "R34834", "R78764", 
    "R30367", "R76358", "R12128", "R45867", "R31640", "R95723", "R55237", "R18626", 
    "R34820", "R25912", "R11258", "R64046", "R54158", "R19907", "R27473", "R70186", 
    "R26710", "R30525", "R29620", "R64055", "R15748", "R66814", "R77365", "R21993", 
    "R42505", "R64948", "R21470"]
length(patient_name)

58

# Plot result

### 8.26.2024 (`nshuffle=10`, new real data analysis based on Ines’ suggestion)

In [75]:
function aggregate_result(concatenate::String, dir;
    low_count = 10,
    threshold = 0.1,
    max_cor = 0.999,
    )
    cv_outdir = joinpath(dir, "$concatenate/cv")
    data_dir = joinpath(dir, "data")
    
    # import data
    Xobs, Xint, ground_truth = read_data(low_count, threshold, max_cor, concatenate, data_dir)
    
    # only consider patients that is AE
    idx = findall(isone, ground_truth[!, "is AE"])
    ground_truth_filtered = ground_truth[idx, :]
    patient_name_filtered = ground_truth_filtered[!, "Patient ID"]

    # things to return
    df = DataFrame(
        "Patient Name"=>String[],
        "Z-score rank"=>Float64[],
        "Cholesky-score rank"=>Float64[],
        "RC Z-score"=>Float64[],
        "RC Cholesky-score"=>Float64[],
    )
    z_all = Vector{Float64}[]

    # compare our method against z score method
    @showprogress for (i, id) in enumerate(patient_name_filtered)
        try
            # compute z-score for current sample
            Xint_sample = Xint[i, :]
            Xobs_full = vcat(Xobs, Xint[setdiff(1:size(Xint, 1), i), :])
            z = RootCauseDiscovery.zscore(Xobs_full, Xint_sample)
            push!(z_all, z)
            
            # compute rank of z-score
            patient_root_cause_idx = ground_truth_filtered[i, "root cause row index in genecounts"]
            patient_root_cause_zscore = z[patient_root_cause_idx]
            root_cause_zscore_rank = count(x -> x > patient_root_cause_zscore, z)

            # cv result
            file = joinpath(cv_outdir, "$id.csv")
            cholesky_score = readdlm(file)
            root_cause_cholesky_score = cholesky_score[patient_root_cause_idx]
            root_cause_chol_rank_cv = count(x -> x > root_cause_cholesky_score, cholesky_score)

            push!(df, [id, root_cause_zscore_rank, root_cause_chol_rank_cv, 
                           round(patient_root_cause_zscore, digits=3), 
                            round(root_cause_cholesky_score, digits=3)])
        catch e
            println(e)
            continue
        end
    end

    # @show length(cholesky_score)
    df = df[!, [1, 3, 2, 4, 5]]
    sort!(df, "Cholesky-score rank")
    return df, z_all
end
dir = "/scratch/users/bbchu/RootCauseDiscovery/8.26.2024"
df_ns, z_ns = aggregate_result("ns", dir)
df_ss, z_ss = aggregate_result("ss", dir)
df_all, z_all = aggregate_result("all", dir);

[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:02[39m
[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:04[39m
[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:06[39m


In [76]:
# add S2/S3/S4 columns
data_dir = RootCauseDiscovery.datadir()
genecounts = process_data(RootCauseDiscovery.datadir(), "all")
S2_file = "13073_2022_1019_MOESM1_ESM_S2.csv"
S3_file = "13073_2022_1019_MOESM1_ESM_S3.csv"
S4_file = "13073_2022_1019_MOESM1_ESM_S4.csv"
df2 = RootCauseDiscovery.process_one_root_cause_ground_truth_table(S2_file, genecounts, data_dir)
df3 = RootCauseDiscovery.process_one_root_cause_ground_truth_table(S3_file, genecounts, data_dir)
df4 = RootCauseDiscovery.process_one_root_cause_ground_truth_table(S4_file, genecounts, data_dir)
S2_samples = df2[!, "Patient ID"]
S3_samples = df3[!, "Patient ID"]
S4_samples = df4[!, "Patient ID"]
supp_col = ["" for _ in 1:size(df_all, 1)]
for (i, id) in enumerate(df_all[!, "Patient Name"])
    supp_col[i] = id ∈ S2_samples ? "S2" : id ∈ S3_samples ? "S3" : "S4"
end
df_all[!, "origin table"] = supp_col

@show df_all;

df_all = 58×6 DataFrame
 Row │ Patient Name  Cholesky-score rank  Z-score rank  RC Z-score  RC Cholesky-score  origin table
     │ String        Float64              Float64       Float64     Float64            String
─────┼──────────────────────────────────────────────────────────────────────────────────────────────
   1 │ R19100                        0.0           0.0      24.275              8.858  S2
   2 │ R61100                        0.0           1.0      33.222             17.558  S2
   3 │ R77611                        0.0           9.0       8.15              13.888  S2
   4 │ R16472                        0.0           0.0      39.138             12.179  S2
   5 │ R28774                        0.0           0.0     133.791             11.93   S2
   6 │ R96820                        0.0        1632.0       2.084             10.357  S2
   7 │ R64921                        0.0           2.0      17.351             14.071  S2
   8 │ R80184                        0.0          1

## Save result

In [81]:
# final ranking
CSV.write("/scratch/users/bbchu/RootCauseDiscovery/8.26.2024/result_all.csv", df_all)
CSV.write("/scratch/users/bbchu/RootCauseDiscovery/8.26.2024/result_ns.csv", df_ns)
CSV.write("/scratch/users/bbchu/RootCauseDiscovery/8.26.2024/result_ss.csv", df_ss)

"/scratch/users/bbchu/RootCauseDiscovery/8.26.2024/result_ss.csv"

Also save Z-scores

In [80]:
writedlm("/scratch/users/bbchu/RootCauseDiscovery/8.26.2024/zscores_all.csv", hcat(z_all...)')
writedlm("/scratch/users/bbchu/RootCauseDiscovery/8.26.2024/zscores_ns.csv", hcat(z_ns...)')
writedlm("/scratch/users/bbchu/RootCauseDiscovery/8.26.2024/zscores_ss.csv", hcat(z_ss...)')