# Get the pesudobulk data from scooby

In [1]:
import os
import numpy as np
import pandas as pd
import scipy
import torch
from tqdm import tqdm
import polars as pl
from torch.utils.data import DataLoader
from enformer_pytorch.data import GenomeIntervalDataset
from scooby.data import onTheFlyPseudobulkDataset
from scooby.utils.utils import undo_squashed_scale, get_gene_slice_and_strand
from scooby.utils.transcriptome import Transcriptome

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
data_path = '/gstore/data/resbioai/grelu/decima/scooby/'  

## Read cell types used for training

In [3]:
cell_type_index = pd.read_parquet(os.path.join(data_path, 'training_data', 'scooby_training_data', 'celltype_fixed.pq'))
cell_type_index['size'] = cell_type_index['cellindex'].apply(lambda x: len(x))
cell_type_index['celltype'] = cell_type_index['celltype'].str.replace(' ', '_').replace(r"G/M_prog", "G+M_prog").replace("MK/E_prog", "MK+E_prog") #+ '_filtered'
cell_type_index = cell_type_index.sort_values('celltype')
cell_type_index.head()

Unnamed: 0,celltype,cellindex,size
4,B1_B,"[5, 9, 20, 32, 112, 128, 151, 265, 294, 360, 3...",1747
1,CD14+_Mono,"[1, 11, 13, 19, 30, 38, 49, 50, 51, 58, 62, 64...",10338
3,CD16+_Mono,"[4, 17, 94, 315, 329, 370, 698, 709, 928, 936,...",1762
7,CD4+_T_activated,"[8, 24, 28, 40, 45, 48, 55, 63, 68, 75, 76, 82...",5157
6,CD4+_T_naive,"[7, 44, 47, 54, 56, 59, 88, 116, 123, 132, 140...",4170


## Load the genes

In [4]:
gtf_file = os.path.join(data_path, "gencode.v32.annotation.sorted.gtf.gz")
transcriptome = Transcriptome(gtf_file)

In [12]:
fasta_file = os.path.join(data_path, 'training_data', "scooby_training_data", "genome_human.fa")
base_path = os.path.join(data_path,'training_data', 'scooby_training_data', 'pseudobulks')
genes_file = os.path.join(data_path,'training_data','scooby_training_data', 'train_val_test_gene_sequences.csv')

In [6]:
gene_coords = pd.read_table(genes_file, header=None)
print(len(gene_coords))
gene_coords.head()

25275


Unnamed: 0,0,1,2,3,4
0,chr1,663145,859753,AL669831.2,+
1,chr1,696101,892709,LINC01409,+
2,chr1,720299,916907,FAM87B,+
3,chr1,748365,944973,LINC01128,+
4,chr1,812100,1008708,AL645608.6,+


In [None]:
#import anndata
#genes = anndata.read_h5ad('/gstore/data/resbioai/grelu/decima/)

## Get the pseudobulk labels used for training

In [18]:
ds = GenomeIntervalDataset(
    bed_file = genes_file,
    fasta_file = fasta_file,
    filter_df_fn = lambda df: df.filter((pl.col('column_2') >=0)), 
    shift_augs = (0,0),
    return_augs = True,
    context_length = 524288,
)

In [19]:
dataset_targets = onTheFlyPseudobulkDataset(
    cell_types = cell_type_index['celltype'].values,
    ds = ds, 
    base_path = base_path,
)

In [20]:
dataset_target_loader = iter(DataLoader(dataset_targets, batch_size=1, shuffle = False))
print(len(dataset_target_loader))

all_targets = []
gene_names = []

for i in tqdm(range(len(dataset_target_loader))):
    gene_slice, strand = get_gene_slice_and_strand(transcriptome, dataset_targets.genome_ds.df[i, 'column_4'], dataset_targets.genome_ds.df[i, 'column_2'], span = False)
    targets = (next(dataset_target_loader)[2]).float().cuda()
    if len(gene_slice) == 0:
        continue

    if strand == '+':
        t = targets[0, gene_slice, ::2]
    elif strand == '-':
        t = targets[0, gene_slice, 1::2]

    all_targets.append(undo_squashed_scale(t, clip_soft=384).sum(axis=0).detach().clone().cpu().squeeze())
    gene_names.append(val_dataset_targets.genome_ds.df[i, 'column_4'])


25263


 30%|████████████████████                                              | 7661/25263 [10:05<24:17, 12.08it/s]IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



In [21]:
all_targets = torch.vstack(all_targets).clone().numpy(force=True)
all_targets.shape

(25187, 21)

In [22]:
torch.save(all_targets, "count_target_test_no_neighbor.pq")
pd.DataFrame(gene_names).to_parquet("gene_names.pq")