In [None]:
const genes = readlines("/home/kwat/github/kraft/notebook/genes.txt")[2:20000]

In [None]:
const scores = fill(1 / length(genes), length(genes))

In [None]:
const gmt_file_path = "/home/kwat/garden/data/gene_set/msigdb_v6.2/h.all.v6.2.symbols.gmt"


function read_gmt_file_path(gmt_file_path::String)
    
    gene_set_dict = Dict{String, Array{String, 1}}()
    
    for line in readlines(gmt_file_path)
        
        line_split = String.(split(
            line,
            '\t',
        ))
        
        gene_set_dict[line_split[1]] = line_split[3:end]
        
    end
    
    gene_set_dict
    
end


const gene_set_dict = read_gmt_file_path(gmt_file_path)

const gene_set_genes = sort(gene_set_dict[sort(collect(keys(gene_set_dict)))[1]])

In [None]:
using BenchmarkTools

In [None]:
function make_hits(
    elements::Array{String, 1},
    elements_to_find::Array{String, 1},
)
    
    n = length(elements)
    
    hits = Array{Int64, 1}(
        undef,
        n,
    )
    
    elements_to_find_ = Dict(e=>nothing for e in elements_to_find)
    
    @inbounds @fastmath @simd for i in 1:n

        if haskey(
            elements_to_find_,
            elements[i],
        )
            
            hit = 1

        else
            
            hit = 0

        end
        
        hits[i] = hit

    end
    
    hits

end


const hits = make_hits(
    genes,
    gene_set_genes,
)

@benchmark make_hits(
    genes,
    gene_set_genes,
)

In [None]:
function sum_hit_scores(
    scores::Array{Float64, 1},
    hits::Array{Int64, 1},
)
    
    sum_ = 0.0
    
    @inbounds @fastmath @simd for i in 1:length(scores)
        
        if hits[i] == 1
        
            sum_ += abs(scores[i])
            
        end
        
    end
    
    sum_
    
end 


@benchmark sum_hit_scores(
    scores,
    hits,
)

In [None]:
function sum_hits(hits::Array{Int64, 1})
    
    sum_ = 0
    
    @inbounds @fastmath @simd for i in 1:length(hits)
        
         sum_ += hits[i]
        
    end
    
    sum_
    
end 


@benchmark sum_hits(hits)

In [None]:
function compute_gene_set_enrichment(
    genes::Array{String, 1},
    scores::Array{Float64, 1},
    gene_set_genes::Array{String, 1};
    statistic::String="ks",
    hits::Union{Nothing, Array{Int64, 1}}=nothing,
)
    
    n = length(scores)
    
    cumulative_sum = Array{Float64, 1}(
        undef,
        n,
    )
    
    if hits === nothing
        
        hits = make_hits(
            genes,
            gene_set_genes,
        )
        
    end
    
    hit_scores_sum = sum_hit_scores(
        scores,
        hits,
    )
    
    d_down = -1 / (n - sum_hits(hits))
    
    value = 0.0
    
    auc = 0.0
    
    min_ = 0.0
    
    max_ = 0.0
    
    @inbounds @fastmath @simd for i in 1:n
        
        if hits[i] == 1
            
            d_value = abs(scores[i]) / hit_scores_sum
            
        else
            
            d_value = d_down
            
        end
        
        value += d_value
        
        cumulative_sum[i] = value
        
        auc += value
        
        if value < min_
            
            min_ = value
            
        elseif max_ < value
            
            max_ = value
            
        end
            
    end
    
    if statistic == "auc"
            
        gsea_score = auc
        
    elseif statistic == "ks"
        
        if abs(min_) < abs(max_)
            
            gsea_score = max_
            
        else
            
            gsea_score = min_
            
        end
        
    end
    
    gsea_score
    
end


const statistic = "ks"

println(compute_gene_set_enrichment(
    genes,
    scores,
    gene_set_genes,
    statistic=statistic,
    hits=hits,
))


@benchmark compute_gene_set_enrichment(
    genes,
    scores,
    gene_set_genes;
    statistic=statistic,
    hits=hits,
)

In [None]:
using PyCall

In [None]:
pd = pyimport("pandas")

kraft = pyimport("kraft")

In [None]:
gene_score = pd.Series(scores, index=genes)

In [None]:
println(kraft.run_single_sample_gsea(
    gene_score,
    gene_set_genes,
    hit=hits,
    statistic=statistic,
    plot=false,
))

@benchmark kraft.run_single_sample_gsea(
    gene_score,
    gene_set_genes,
    hit=hits,
    statistic=statistic,
    plot=false,
)