In [1]:
# import the libraries 
suppressPackageStartupMessages(library(dplyr))
suppressPackageStartupMessages(library(ggplot2))
suppressPackageStartupMessages(library(argparse))
suppressPackageStartupMessages(library(cowplot))

figure_theme_path <- file.path(
    "..","visulaization_utils", "figure_themes.R")
source(figure_theme_path)

“package ‘dplyr’ was built under R version 4.2.3”
“package ‘ggplot2’ was built under R version 4.2.3”
“package ‘argparse’ was built under R version 4.2.3”
“package ‘cowplot’ was built under R version 4.2.3”


In [None]:
# define command line arguments
parser <- ArgumentParser(description = "Visualize MLP results")
# add arguments
parser$add_argument('--cell_type', type='character', help='Cell type to visualize')
parser$add_argument('--model_name', type='character', help='Model name to visualize')
parser$add_argument('--selected_treatment_comparisons', type='character', help='Selected treatment comparisons to visualize')

# parse arguments from command line
args <- parser$parse_args()

# define cell type
celltype <- args$cell_type
model_name <- args$model_name
selected_treatment_comparisons <- args$selected_treatment_comparisons

In [2]:
celltype <- "PBMC"
model_name <- "LPS_100.000_DMSO_0.025_vs_DMSO_0.100_DMSO_0.025"
selected_treatment_comparisons <- "DMSO_0.100_DMSO_0.025 vs LPS_100.000_DMSO_0.025, DMSO_0.100_DMSO_0.025 vs Thapsigargin_10.000_DMSO_0.025, LPS_100.000_DMSO_0.025 vs Thapsigargin_10.000_DMSO_0.025"


In [3]:
output_file_plot <- file.path(
    "..","..","figures","Binary_Classification",model_name,celltype,"pr_curves_testing_plot.png"
)
output_file_legend <- file.path(
    "..","..","figures","Binary_Classification",model_name,celltype,"pr_curves_testing_legend.png"
)

In [4]:
results_dir <- file.path(
    "..","..","results","Binary_Classification",model_name,celltype
)
results_file <- file.path(
    results_dir,"testing_metrics.csv"
)

# Read in the results file
df <- read.csv(results_file)
head(df,3)

Unnamed: 0_level_0,treatments_tested,model,group,shuffled_data,PR_Threshold,Precision,Recall
Unnamed: 0_level_1,<chr>,<chr>,<chr>,<chr>,<dbl>,<dbl>,<dbl>
1,DMSO_0.100_DMSO_0.025 vs LPS_100.000_DMSO_0.025,LPS_100.000_DMSO_0.025_vs_DMSO_0.100_DMSO_0.025,test,False,0.0,0.4256538,1.0
2,DMSO_0.100_DMSO_0.025 vs LPS_100.000_DMSO_0.025,LPS_100.000_DMSO_0.025_vs_DMSO_0.100_DMSO_0.025,test,False,5.971659000000001e-39,0.4256675,0.9997867
3,DMSO_0.100_DMSO_0.025 vs LPS_100.000_DMSO_0.025,LPS_100.000_DMSO_0.025_vs_DMSO_0.100_DMSO_0.025,test,False,4.875545999999999e-38,0.425666,0.9997805


In [5]:
unique(df$treatments_tested)

In [6]:
selected_treatment_comparisons <- unlist(strsplit(selected_treatment_comparisons,", "))
selected_treatment_comparisons

In [7]:
# filter the df by selected treatment comparisons in the treatment column
tmp_df <- df[which(df$treatment %in% selected_treatment_comparisons),]
unique(tmp_df$treatments_tested)

In [8]:
head(tmp_df)
tmp_df$treatments_tested <- gsub(" vs ", "\n", tmp_df$treatments_tested)
head(tmp_df)

Unnamed: 0_level_0,treatments_tested,model,group,shuffled_data,PR_Threshold,Precision,Recall
Unnamed: 0_level_1,<chr>,<chr>,<chr>,<chr>,<dbl>,<dbl>,<dbl>
1,DMSO_0.100_DMSO_0.025 vs LPS_100.000_DMSO_0.025,LPS_100.000_DMSO_0.025_vs_DMSO_0.100_DMSO_0.025,test,False,0.0,0.4256538,1.0
2,DMSO_0.100_DMSO_0.025 vs LPS_100.000_DMSO_0.025,LPS_100.000_DMSO_0.025_vs_DMSO_0.100_DMSO_0.025,test,False,5.971659000000001e-39,0.4256675,0.9997867
3,DMSO_0.100_DMSO_0.025 vs LPS_100.000_DMSO_0.025,LPS_100.000_DMSO_0.025_vs_DMSO_0.100_DMSO_0.025,test,False,4.875545999999999e-38,0.425666,0.9997805
4,DMSO_0.100_DMSO_0.025 vs LPS_100.000_DMSO_0.025,LPS_100.000_DMSO_0.025_vs_DMSO_0.100_DMSO_0.025,test,False,5.259709999999999e-38,0.4256671,0.9997805
5,DMSO_0.100_DMSO_0.025 vs LPS_100.000_DMSO_0.025,LPS_100.000_DMSO_0.025_vs_DMSO_0.100_DMSO_0.025,test,False,1.3825489999999999e-37,0.4256683,0.9997805
6,DMSO_0.100_DMSO_0.025 vs LPS_100.000_DMSO_0.025,LPS_100.000_DMSO_0.025_vs_DMSO_0.100_DMSO_0.025,test,False,3.36168e-37,0.4256694,0.9997805


Unnamed: 0_level_0,treatments_tested,model,group,shuffled_data,PR_Threshold,Precision,Recall
Unnamed: 0_level_1,<chr>,<chr>,<chr>,<chr>,<dbl>,<dbl>,<dbl>
1,DMSO_0.100_DMSO_0.025 LPS_100.000_DMSO_0.025,LPS_100.000_DMSO_0.025_vs_DMSO_0.100_DMSO_0.025,test,False,0.0,0.4256538,1.0
2,DMSO_0.100_DMSO_0.025 LPS_100.000_DMSO_0.025,LPS_100.000_DMSO_0.025_vs_DMSO_0.100_DMSO_0.025,test,False,5.971659000000001e-39,0.4256675,0.9997867
3,DMSO_0.100_DMSO_0.025 LPS_100.000_DMSO_0.025,LPS_100.000_DMSO_0.025_vs_DMSO_0.100_DMSO_0.025,test,False,4.875545999999999e-38,0.425666,0.9997805
4,DMSO_0.100_DMSO_0.025 LPS_100.000_DMSO_0.025,LPS_100.000_DMSO_0.025_vs_DMSO_0.100_DMSO_0.025,test,False,5.259709999999999e-38,0.4256671,0.9997805
5,DMSO_0.100_DMSO_0.025 LPS_100.000_DMSO_0.025,LPS_100.000_DMSO_0.025_vs_DMSO_0.100_DMSO_0.025,test,False,1.3825489999999999e-37,0.4256683,0.9997805
6,DMSO_0.100_DMSO_0.025 LPS_100.000_DMSO_0.025,LPS_100.000_DMSO_0.025_vs_DMSO_0.100_DMSO_0.025,test,False,3.36168e-37,0.4256694,0.9997805


In [9]:
pr_curve_gg <- (
    ggplot(tmp_df, aes(x = Recall, y = Precision))
    + geom_line(aes(color = treatments_tested, linetype = shuffled_data))
    + theme_bw()
    + xlab("Recall")
    + ylab("Precision")

    + scale_linetype_manual(
        name = "Shuffled\ntraining\ndata",
        labels = shuffled_labels,
        values = shuffled_linetypes
    )

    + guides(
        color = guide_legend(order = 1),
        linetype = guide_legend(order = 2),
    )
    + coord_fixed()
    + figure_theme
    # Decrease spacing in legend
    + theme(
        legend.spacing.y = unit(0.1, "cm"),
        legend.box.spacing = unit(0.2, "cm"),
        legend.key.size = unit(2.5, "lines"),
        legend.key.width = unit(1, "lines")
    )
    + ggtitle(paste0("Precision-Recall Curve for ","\n", model_name, " model"))
)
# detach the legend 
legend <- get_legend(pr_curve_gg)
pr_curve_gg <- pr_curve_gg + theme(legend.position = "none")



ggsave(output_file_plot, pr_curve_gg, height = 5.5, width = 8.5, dpi = 500)
ggsave(output_file_legend, legend, height = 5.5, width = 8.5, dpi = 500)