In [1]:
from pathlib import Path
from gears import PertData, GEARS
import scanpy as sc
import json
#import numpy as np
#import pandas as pd

# import gseapy as gp
# from tqdm import tqdm
# from pathlib import Path
# from Bio import Entrez, SeqIO
# import mygene
# import gzip
# from transformers import AutoTokenizer, AutoModel
# import torch

In [2]:
data_dir = Path('/Users/djemec/data/jepa/v0_3')
raw_dir = data_dir / 'raw_files'
pert_dir = data_dir / 'pert_embd'


#protein_dir = data_dir / 'uniprot'
#train_dir = data_dir / 'training'
#pt_dir = data_dir / 'pretraining'
splits= ['train','val','test']


n_genes = 8192 # 2**13
count_normalize_target = 1e4

**Breakdown of perturbation**

* Replogle K562 Ess - CRISPRi (Repression)
* Replogle K562 GW - CRISPRi (Repression)
* Replogle RPE1 - CRISPRi (Repression)
* Adamson - CRISPRi (Repression)
* Norman - CRISPRa (Activation) + Combinatorial
* SciPlex 3 - Drug Treatment

Because of this breakdown, for now we'll just use the CRISPRi datasets.   

In [3]:
datasets= {
    'k562e':raw_dir / 'k562e',
    'k562gw':raw_dir /'k562gw'/'replogle_k562_gw_expanded_8k.h5ad',
    'rep1e':raw_dir /'rep1e'/'perturb_processed.h5ad',
    'adamson':raw_dir /'adamson'/'perturb_processed.h5ad'
}
total_cells = 0

## Extract Anchor File - K562 essential
we'll build out 4 important variables from our key dataset.  k562 will be the GEARS dataset we test against and so we need to be extra careful with data leakage on this set. 
1. `all_genes` - all genes with expression we'll have in our model
2. `anchor_genes` - minimum set of genes based on k562-essential. We have to have these to be able to do validation
3. `all_perts`- all perturbations we've seen that we'll train on minus the forbidden perturbations
4. `forbidden_perturbations` - perturbations we cannot include from the k562 dataset for our validation. If we do, we'll "leak" for our validation

We'll also track how many cells we're training on in general.  note that this will include all the cells in test/train/val so we'll actually only train on a subset. 

In [4]:
pert_data = PertData(datasets['k562e']) 
pert_data.load(data_name='replogle_k562_essential')

Found local copy...
Found local copy...
Found local copy...
These perturbations are not in the GO graph and their perturbation can thus not be predicted
['C7orf26+ctrl' 'C14orf178+ctrl' 'RPS10-NUDT3+ctrl' 'SEM1+ctrl' 'FAU+ctrl']
Local copy of pyg dataset is detected. Loading...
Done!


In [5]:
pert_data.prepare_split(split='simulation', seed=1) 
adata = pert_data.adata 

Local copy of split is detected. Loading...
Simulation split test composition:
combo_seen0:0
combo_seen1:0
combo_seen2:0
unseen_single:272
Done!


here1


In [6]:
def clean_gears_name(name):
    # GEARS format is 'Gene+ctrl' -> We want 'Gene'
    if name.endswith('+ctrl'):
        return name.replace('+ctrl', '')
    if name.startswith('ctrl+'):
        return name.replace('ctrl+', '')
    elif name == 'ctrl':
        return 'control'
    return str.strip(name)

In [7]:
hold_out_perts = set(pert_data.set2conditions['test'])
hold_out_perts = set([clean_gears_name(i) for i in hold_out_perts])
len(hold_out_perts)

272

In [8]:
sc.pp.normalize_total(adata, target_sum=count_normalize_target)
sc.pp.log1p(adata)
sc.pp.highly_variable_genes(adata, n_top_genes=n_genes, subset=True)

In [9]:
anchor_genes = adata.var.gene_name.tolist()
all_genes = sorted(list(set(anchor_genes)))
len(anchor_genes), len(all_genes)

(5000, 4999)

In [10]:
all_perts = adata.obs['condition'].unique()
all_perts = [clean_gears_name(p) for p in all_perts if p != 'ctrl']
all_perts = sorted(list(set(all_perts)))
len(all_perts)

1087

In [11]:
k562e_total_cells = adata.n_obs
total_cells += k562e_total_cells
k562e_total_cells, total_cells

(162246, 162246)

## Add K562 genome wide
we'd previously already filtered down the K562 genome wide dataset based on forbidden perturbations and our target gene count so for this dataset it should be a matter of reading the file and updating our variables.  we'll be adding to:
1. `all_genes` until we reach our limit
2. `all_perts` since this is a k562 dataset we'll make sure to not add any forbidden perturbations

In [12]:
adata_gw = sc.read_h5ad(datasets['k562gw'], backed='r')

In [13]:
k652gw_total_cells = adata_gw.n_obs
total_cells += k652gw_total_cells
k652gw_total_cells, total_cells

(1943010, 2105256)

**Unify Genes** Now we'll merge the genes we extracted with the anchor_genes.  Since we preprocessed the genome-wide dataset we know that this will come up to exactly our target gene count.  in the future we'd have to be careful to only add up to the target with each dataset to not explode our model size. 

In [14]:
gw_genes = adata_gw.var_names.tolist()
gw_genes = sorted(list(set(gw_genes)))
len(gw_genes), gw_genes[:10]

(8192,
 ['A1BG',
  'AAAS',
  'AACS',
  'AAGAB',
  'AAK1',
  'AAMDC',
  'AAMP',
  'AAR2',
  'AARS',
  'AARS2'])

In [15]:
all_genes = set(all_genes).union(set(gw_genes))
len(all_genes), len(all_genes) <= n_genes

(8192, True)

**Unify Perturbations** 

Now we'll merge the perturbations we extracted with the previous perturbations.  We need to make sure to flag perturbations we don't want 

In [16]:
gw_perts = adata_gw.obs['perturbation'].unique()
gw_perts = [clean_gears_name(p) for p in gw_perts if p != 'control']
gw_perts = sorted(list(set(gw_perts)))
len(gw_perts)

9594

check that no forbidden perturbations leaked before merging the perturbations in. 

In [17]:
set(hold_out_perts) & set(gw_perts)

set()

In [18]:
all_perts = set(all_perts).union(set(gw_perts))
len(all_perts)

9869

In [19]:
f'total cells: {total_cells} | total genes {len(all_genes)} | total perturbations {len(all_perts)}'

'total cells: 2105256 | total genes 8192 | total perturbations 9869'

## rep1e Dataset

In [20]:
ds_adata = sc.read_h5ad(datasets['rep1e'], backed='r')

In [21]:
ds_total_cells = ds_adata.n_obs
total_cells += ds_total_cells
ds_total_cells, total_cells

(162733, 2267989)

In [22]:
ds_genes = ds_adata.var_names.tolist()
ds_genes = sorted(list(set(gw_genes)))
len(ds_genes), ds_genes[:10]

(8192,
 ['A1BG',
  'AAAS',
  'AACS',
  'AAGAB',
  'AAK1',
  'AAMDC',
  'AAMP',
  'AAR2',
  'AARS',
  'AARS2'])

In [23]:
if len(set(all_genes).union(set(ds_genes))) <= n_genes:
    all_genes = set(all_genes).union(set(ds_genes))
    print('including all genes')
else:
    skipped_genes = set(ds_genes) - set(all_genes)
    print(f'skipping {skipped_genes}')

including all genes


In [24]:
ds_perts = ds_adata.obs['condition'].unique()
ds_perts = [clean_gears_name(p) for p in ds_perts if p != 'ctrl' and p != 'control']
ds_perts = sorted(list(set(ds_perts)))
remaining_ds_perts = sorted(list(set(ds_perts) - set(hold_out_perts)))
len(ds_perts), len(remaining_ds_perts)

(1543, 1329)

In [25]:
all_perts = set(all_perts).union(set(remaining_ds_perts))
len(all_perts)

9869

In [26]:
f'total cells: {total_cells} | total genes {len(all_genes)} | total perturbations {len(all_perts)}'

'total cells: 2267989 | total genes 8192 | total perturbations 9869'

## adamson Dataset

In [27]:
ds_adata = sc.read_h5ad(datasets['adamson'], backed='r')

In [28]:
ds_total_cells = ds_adata.n_obs
total_cells += ds_total_cells
ds_total_cells, total_cells

(68603, 2336592)

In [29]:
ds_genes = ds_adata.var_names.tolist()
ds_genes = sorted(list(set(gw_genes)))
len(ds_genes), ds_genes[:10]

(8192,
 ['A1BG',
  'AAAS',
  'AACS',
  'AAGAB',
  'AAK1',
  'AAMDC',
  'AAMP',
  'AAR2',
  'AARS',
  'AARS2'])

In [30]:
if len(set(all_genes).union(set(ds_genes))) <= n_genes:
    all_genes = set(all_genes).union(set(ds_genes))
    print('including all genes')
else:
    skipped_genes = set(ds_genes) - set(all_genes)
    print(f'skipping {skipped_genes}')

including all genes


In [31]:
ds_perts = ds_adata.obs['condition'].unique()
ds_perts = [clean_gears_name(p) for p in ds_perts if p != 'ctrl' and p != 'control']
ds_perts = sorted(list(set(ds_perts)))
remaining_ds_perts = sorted(list(set(ds_perts) - set(hold_out_perts)))
len(ds_perts), len(remaining_ds_perts)

(86, 75)

In [32]:
all_perts = set(all_perts).union(set(remaining_ds_perts))
len(all_perts)

9875

In [33]:
f'total cells: {total_cells} | total genes {len(all_genes)} | total perturbations {len(all_perts)}'

'total cells: 2336592 | total genes 8192 | total perturbations 9875'

## Save Genes and perturbations

In [37]:
genes_and_perts = {
    'all_genes':list(all_genes),
    'anchor_genes': list(anchor_genes),
    'all_perts': list(all_perts),
    'hold_out_perts': list(hold_out_perts)
}

In [38]:
with open(raw_dir / 'gene_and_perts.json', 'w') as f:
    json.dump(genes_and_perts, f)