In [None]:
import pandas as pd
import numpy as np
import anndata
import os
import sys
from plotnine import *

sys.path.append('/code/decima/src/decima/')
import preprocess

from grelu.data.preprocess import filter_chromosomes

## Paths

In [None]:
save_dir="/gstore/data/resbioai/grelu/decima/20240823"
matrix_file = os.path.join(save_dir, "aggregated.h5ad")

## Load matrix

In [None]:
%%time
ad = anndata.read_h5ad(matrix_file)
print(ad.shape)

## Format .var

In [None]:
ad.var = ad.var[['chrom', 'start', 'end','strand','gene_name','gene_type', 'frac_nan',
                 'mean_counts', 'n_tracks']]

In [None]:
ad.var['gene_start'] = ad.var.start.tolist()
ad.var['gene_end'] = ad.var.end.tolist()
ad.var['gene_length'] = ad.var.end - ad.var.start

## Filter chromosomes

In [None]:
ad = filter_chromosomes(ad, "autosomesX")

## Make intervals

In [None]:
%%time

ad = preprocess.var_to_intervals(ad.copy(), chr_end_pad = 10000, genome="hg38")
print(ad.shape)
print(ad.var.start.min())

## Drop intervals with too many Ns

In [None]:
%%time
ad.var["frac_N"] = ad.var.apply(lambda row: preprocess.get_frac_N(row), axis=1)

In [None]:
print(ad.shape)
ad = ad[:, ad.var.frac_N < 0.4]
print(ad.shape)

## How many intervals don't contain the gene end?

In [None]:
(ad.var.gene_mask_end == 524288).sum()

## Visualize number of upstream and downstream bases

In [None]:
ad.var.loc[:, 'Upstream bases'] = ad.var.gene_mask_start
ad.var.loc[:, 'Downstream bases'] = 524288 - ad.var.gene_mask_end 

In [None]:
(
    ggplot(ad.var, aes(x='Upstream bases')) 
    + geom_histogram(fill='white', color='black', bins=50)
    + theme_classic() + theme(figure_size=(4, 2)) + ylab('Count')
    + scale_y_log10(labels = label_value) 
    + xlab("Number of bases upstream of TSS")
)

In [None]:
(
    ggplot(ad.var, aes(x='Downstream bases')) 
    + geom_histogram(fill='white', color='black', bins=50)
    + theme_classic() + theme(figure_size=(4, 2)) + ylab('Count')
    +xlab("Number of bases downstream of gene")
)

## Save filtered anndata

In [None]:
ad.write_h5ad(matrix_file)
#ad = anndata.read_h5ad(out_file)