In [None]:
using BenchmarkTools
using Plotly
using Random
using StatsBase

In [None]:
Random.seed!(20121020)

;

In [None]:
style = Style(layout = Layout(autosize = false))

use_style!(style)

;

In [None]:
function cumulate_sum_reverse(a::Vector)

    return reverse(cumsum(Iterators.reverse(a)))

end

;

In [None]:
element_ = string.(Array('A':'Z'))

n_element = length(element_)

element_score_ = randn(n_element)

;

In [None]:
set_element_ = sample(element_, 3; replace = false)

set_element_ = string.(collect("KWAT"))

;

In [None]:
sort_index_ = sortperm(element_score_)

element_score_ = element_score_[sort_index_]

# element_ = element_[sort_index_]

xaxis = attr(title = "Element", tickvals = 0:n_element, ticktext = element_)

plot([scatter(y = element_score_)], Layout(yaxis_title = "Element Score", xaxis = xaxis))

## Is

In [None]:
set_element_to_ = Dict(set_element => nothing for set_element in set_element_)

is_h = [Float64(haskey(set_element_to_, element)) for element in element_]

is_m = 1 .- is_h

plot(
    [scatter(name = "Is Hit", y = is_h), scatter(name = "Is Miss", y = is_m)],
    Layout(xaxis = xaxis),
)

In [None]:
is_h_p = is_h / sum(is_h)

is_h_p_cr = cumsum(is_h_p)

is_h_p_cl = cumulate_sum_reverse(is_h_p)

plot(
    [
        scatter(name = "P( Is Hit )", y = is_h_p),
        scatter(name = "CR( P( Is Hit ) )", y = is_h_p_cr),
        scatter(name = "CL( P( Is Hit ) )", y = is_h_p_cl),
    ],
    Layout(xaxis = xaxis),
)

In [None]:
is_m_p = is_m / sum(is_m)

is_m_p_cr = cumsum(is_m_p)

is_m_p_cl = cumulate_sum_reverse(is_m_p)

plot(
    [
        scatter(name = "P( Is Miss )", y = is_m_p),
        scatter(name = "CR( P( Is Miss ) )", y = is_m_p_cr),
        scatter(name = "CL( P( Is Miss ) )", y = is_m_p_cl),
    ],
    Layout(xaxis = xaxis),
)

## Amplitutde

In [None]:
a = abs.(element_score_)

plot([scatter(y = a)], Layout(yaxis_title = "Amplitude", xaxis = xaxis))

In [None]:
a_p = a / sum(a)

a_p_cr = cumsum(a_p)

a_p_cl = cumulate_sum_reverse(a_p)

plot(
    [
        scatter(name = "P( Amplitude )", y = a_p),
        scatter(name = "CR( P( Amplitude ) )", y = a_p_cr),
        scatter(name = "CL( P( Amplitude ) )", y = a_p_cl),
    ],
    Layout(xaxis = xaxis),
)

In [None]:
a_h = is_h .* a

a_h_p = a_h / sum(a_h)

a_h_p_cr = cumsum(a_h_p)

a_h_p_cl = cumulate_sum_reverse(a_h_p)

plot(
    [
        scatter(name = "P( Amplitude Hit )", y = a_h_p),
        scatter(name = "CR( P( Amplitude Hit ) )", y = a_h_p_cr),
        scatter(name = "CL( P( Amplitude Hit ) )", y = a_h_p_cl),
    ],
    Layout(xaxis = xaxis),
)

In [None]:
a_m = is_m .* a

a_m_p = a_m / sum(a_m)

a_m_p_cr = cumsum(a_m_p)

a_m_p_cl = cumulate_sum_reverse(a_m_p)

plot(
    [
        scatter(name = "P( Amplitude Miss )", y = a_m_p),
        scatter(name = "CR( P( Amplitude Miss ) )", y = a_m_p_cr),
        scatter(name = "CL( P( Amplitude Miss ) )", y = a_m_p_cl),
    ],
    Layout(xaxis = xaxis),
)

## KS

In [None]:
# KS

s = is_h_p_cl - is_m_p_cl

plot([scatter(y = s)], Layout(title = "KS", xaxis = xaxis))

## JSD

In [None]:
# JSD

s = JSD(a_h_p_cl, a_m_p_cl, a_p_cl) - JSD(a_h_p_cr, a_m_p_cr, a_p_cr)

## Score

In [None]:
score = sum(s) / length(s)

## High level interface

In [None]:
element_ = readlines("genes.txt")[2:end]

n_element = length(element_)

element_value_ = randn(n_element)

element_x_sample = DataFrame(
    Symbol("Element") => element_,
    Symbol("Sample Constant") => fill(1 / n_element, n_element),
    Symbol("Sample Normal") => element_value_,
    Symbol("Sample Normal x 10") => element_value_ * 10,
)

In [None]:
set_element_ = sample(element_, 100, replace = false)

set_to_element_ = read_gmt("h.all.v6.2.symbols.gmt")

In [None]:
is_in_ = check_is_in(element_, set_element_)

score_set(element_, element_value_, is_in_)

In [None]:
score_set(element_, element_value_, set_element_)

In [None]:
plot_set_enrichment(
    element_,
    element_value_,
    set_element_;
    title1_text = "Title",
    title2_text = "Description or any other text go here",
    element_value_name = "Element<br>Value<br>Metric",
)

In [None]:
benchmark_result = @benchmark score_set(element_, element_value_, set_element_)

@printf "%.2f ms / set" minimum(benchmark_result.times) / 1e6 / length(set_element_)

benchmark_result

In [None]:
gene_x_sample_file_path = ""

gmt_file_path_ = [
    "h.all.v6.2.symbols.gmt",
    "c1.all.v6.2.symbols.gmt",
    "c2.all.v6.2.symbols.gmt",
    "c3.all.v6.2.symbols.gmt",
    "c5.all.v6.2.symbols.gmt",
    "c6.all.v6.2.symbols.gmt",
    "c7.all.v6.2.symbols.gmt",
]

directory_path = "output/"

gene_set_keyword_ = [
    "VANTVEER_BREAST_CANCER_ESR1",
    "DOANE_BREAST_CANCER_ESR1",
    "YANG_BREAST_CANCER_ESR1",
    "CHARAFE_BREAST_CANCER_LUMINAL_VS_BASAL",
    "AIGNER_ZEB1_TARGETS",
    "SANSOM_APC_TARGETS",
    "BCAT_GDS748",
    "BCAT.100_UP.V1",
    "PID_WNT_SIGNALING_PATHWAY",
    "LIU_CDX2_TARGETS",
    "KEGG_WNT_SIGNALING_PATHWAY",
];

gene_set_x_sample_tsv_file_path = Kraft.gsea(
    gene_x_sample_file_path,
    gmt_file_path_,
    directory_path;
    sample_normalization_method = "-0-",
    gene_set_keyword_ = gene_set_keyword_,
)

gene_set_x_sample = CSV.read(gene_set_x_sample_tsv_file_path)