In [3]:
import scanpy as sc
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
import os

import pickle
import requests
from sklearn.model_selection import train_test_split
import torch
import h5py
import ast
import scanpy as sc
import anndata as ad
from scipy.io import mmread

from scipy.sparse import csr_matrix
from gears import PertData


In [None]:
def get_status(gene_id):
    url = f"https://rest.ensembl.org/sequence/id/{gene_id}"

    params = {
        "type": "protein",
        "multiple_sequences": 1
    }
    headers = {
        "Content-Type": "text/x-fasta"
    }

    response = requests.get(url, params=params, headers=headers)
    return response.ok


In [None]:
def get_canonical_transcript_id(ensembl_gene_id: str) -> str:
    base_url = "https://rest.ensembl.org"
    headers_json = {"Content-Type": "application/json"}
    
    # Step 1: Get canonical transcript for the gene
    url = f"{base_url}/lookup/id/{ensembl_gene_id}?expand=1"
    r = requests.get(url, headers=headers_json)
    r.raise_for_status()
    gene_info = r.json()
    
    canonical_transcript_id = gene_info.get("canonical_transcript")
    #if canonical_transcript_id is None:
        #raise ValueError(f"No canonical transcript found for gene {ensembl_gene_id}")

    return canonical_transcript_id


In [None]:
def get_protein_translation_id(canonical_transcript_id: str) -> str:
    base_url = "https://rest.ensembl.org"
    headers_json = {"Content-Type": "application/json"}

    url = f"{base_url}/lookup/id/{canonical_transcript_id}?expand=1"
    r = requests.get(url, headers=headers_json)
    r.raise_for_status()
    tx_info = r.json()
    
    protein_id = tx_info.get("Translation", {}).get("id")
    #if protein_id is None:
        #raise ValueError(f"No protein translation found for transcript {canonical_transcript_id}")

    return protein_id


In [None]:
def get_protein_sequence(protein_id: str) -> str:
    base_url = "https://rest.ensembl.org"
    headers_json = {"Content-Type": "application/json"}

    url = f"{base_url}/sequence/id/{protein_id}?type=protein"
    r = requests.get(url, headers={"Content-Type": "text/plain"})
    r.raise_for_status()
    protein_seq = r.text.strip()
    
    return protein_seq

In [None]:
#Data processing for replogle RPE1 dataset

data_folder_rpe1 = "..data/gene_perturb_data/replogle_rpe1"

raw_data_path = os.path.join(data_folder_rpe1, "rpe1_raw_singlecell_01.h5ad")
replogle_rpe1 = sc.read_h5ad(raw_data_path)

In [None]:
gene_ensembl_dict = {}

used_genes = []
for i, gene in enumerate(replogle_rpe1.obs["gene"]):
    if gene not in used_genes:
        ensembl_id = replogle_rpe1.obs["gene_id"][i]
        gene_ensembl_dict[gene] = ensembl_id
        used_genes.append(gene)

In [None]:
dict_path_rpe1 = os.path.join(data_folder_rpe1, "gene_ensembl_dict.pkl")

with open(dict_path_rpe1, "wb") as f:
    pickle.dump(gene_ensembl_dict, f)

In [None]:
sc.pp.normalize_total(replogle_rpe1)
sc.pp.log1p(replogle_rpe1)
sc.pp.highly_variable_genes(replogle_rpe1,n_top_genes=5000, subset=True)

In [None]:
perturbed_genes = list(replogle_rpe1.obs["gene_id"].unique())
perturbed_genes.remove('non-targeting')
perturbed_genes.remove('nan')

In [None]:
#The successes are protein coding genes
#The failures are either not found in the database or are non-coding genes

successes = []
failures = []

for gene_id in tqdm(perturbed_genes):
    status = get_status(gene_id)
    if status == True:
        successes.append(gene_id)
    else:
        failures.append(gene_id)

In [None]:
aa_dict = {}

successful_genes = []
failure_genes = []

#Maybe determine successes first
for gene_id in tqdm(successes):
    protein_seq = None
    canon_transcript_id = get_canonical_transcript_id(gene_id)
    ct_id_new = canon_transcript_id.split(".")[0]
    
    if ct_id_new is not None:
        protein_id = get_protein_translation_id(ct_id_new)
        
    if protein_id is not None:
        protein_seq = get_protein_sequence(protein_id)

    if protein_seq is not None:
        successful_genes.append(gene_id)
        aa_dict[gene_id] = protein_seq
    else:
        failure_genes.append(gene_id)


In [None]:
csv_path_rpe1 = os.path.join(data_folder_rpe1, "perturbed_genes.csv")

aa_df = pd.DataFrame.from_dict(aa_dict, orient="index", columns=["aa_sequence"])
aa_df = aa_df.reset_index().rename(columns={"index": "gene_id"})
aa_df.to_csv(csv_path_rpe1, index=False)

In [None]:
filtered_rpe1 = replogle_rpe1.copy()
filtered_rpe1.X = csr_matrix(filtered_rpe1.X)
cell_type = 'rpe1'

In [None]:
filtered_rpe1.obs = filtered_rpe1.obs.rename(columns={'gene':'condition'})
filtered_rpe1.obs['condition'] = [c + '+ctrl' for c in filtered_rpe1.obs['condition']]
filtered_rpe1.obs['cell_type'] = cell_type
filtered_rpe1.obs = filtered_rpe1.obs.loc[:,['condition', 'cell_type']]

In [None]:
# Set condition names
mapper = {k:k for k in filtered_rpe1.obs['condition'].unique()}
mapper['non-targeting+ctrl'] = 'ctrl'
filtered_rpe1.obs['condition'] = filtered_rpe1.obs['condition'].map(mapper)

In [None]:
# Set cov_drug_dose_name names
filtered_rpe1.obs['cov_drug_dose_name'] = [cell_type+'_'+x+'_1+1' for x in filtered_rpe1.obs['condition']]
mapper = {k:k for k in filtered_rpe1.obs['cov_drug_dose_name'].unique()}
mapper['rpe1_ctrl_1+1'] = 'rpe1_ctrl_1'
filtered_rpe1.obs['cov_drug_dose_name'] = filtered_rpe1.obs['cov_drug_dose_name'].map(mapper)

In [None]:
from gears.data_utils import rank_genes_groups_by_cov

rank_genes_groups_by_cov(filtered_rpe1, groupby='cov_drug_dose_name', 
                         covariate='cell_type', control_group='ctrl_1', n_genes=20)

In [None]:
filtered_rpe1.X = csr_matrix(filtered_rpe1.X)
adata_rpe1 = filtered_rpe1

In [None]:
#cwd = os.getcwd()
pert_data = PertData("..data/gene_perturb_data") # specific saved folder
pert_data.new_data_process(dataset_name = data_folder_rpe1, adata = adata_rpe1) # specific dataset name and adata object


In [None]:
pert_data.load(data_path = data_folder_rpe1) # load the processed data, the path is saved folder + dataset_name
pert_data.prepare_split(split = 'simulation', seed = 1) # get data split with seed
pert_data.get_dataloader(batch_size = 32, test_batch_size = 128) # prepare data loader

In [None]:
#Data processing for replogle K562 dataset

data_folder_k562 = "..data/gene_perturb_data/replogle_k562"

raw_data_path = os.path.join(data_folder_k562, "K562_essential_raw_singlecell_01.h5ad")
replogle_k562 = sc.read_h5ad(raw_data_path)

In [None]:
gene_ensembl_dict = {}

used_genes = []
for i, gene in enumerate(replogle_k562.obs["gene"]):
    if gene not in used_genes:
        ensembl_id = replogle_k562.obs["gene_id"][i]
        gene_ensembl_dict[gene] = ensembl_id
        used_genes.append(gene)

In [None]:
dict_path_k562 = os.path.join(data_folder_k562, "gene_ensembl_dict.pkl")

with open(dict_path_k562, "wb") as f:
    pickle.dump(gene_ensembl_dict, f)

In [None]:
sc.pp.normalize_total(replogle_k562)
sc.pp.log1p(replogle_k562)

sc.pp.highly_variable_genes(replogle_k562,n_top_genes=5000, subset=True)

In [None]:
perturbed_genes = list(replogle_k562.obs["gene_id"].unique())
perturbed_genes.remove('non-targeting')

In [None]:
successes = []
failures = []

for gene_id in tqdm(perturbed_genes):
    status = get_status(gene_id)
    if status == True:
        successes.append(gene_id)
    else:
        failures.append(gene_id)

In [None]:
aa_dict = {}

successful_genes = []
failure_genes = []

#Maybe determine successes first
for gene_id in tqdm(successes):
    protein_seq = None
    canon_transcript_id = get_canonical_transcript_id(gene_id)
    ct_id_new = canon_transcript_id.split(".")[0]
    
    if ct_id_new is not None:
        protein_id = get_protein_translation_id(ct_id_new)
        
    if protein_id is not None:
        protein_seq = get_protein_sequence(protein_id)

    if protein_seq is not None:
        successful_genes.append(gene_id)
        aa_dict[gene_id] = protein_seq
    else:
        failure_genes.append(gene_id)


In [None]:
csv_path_k562 = os.path.join(data_folder_k562, "perturbed_genes.csv")

aa_df = pd.DataFrame.from_dict(aa_dict, orient="index", columns=["aa_sequence"])
aa_df = aa_df.reset_index().rename(columns={"index": "gene_id"})
aa_df.to_csv(csv_path_k562, index=False)

In [None]:
filtered_k562 = replogle_k562.copy()
filtered_k562.X = csr_matrix(filtered_k562.X)
cell_type = 'K562'

In [None]:
filtered_k562.obs = filtered_k562.obs.rename(columns={'gene':'condition'})
filtered_k562.obs['condition'] = [c + '+ctrl' for c in filtered_k562.obs['condition']]
filtered_k562.obs['cell_type'] = cell_type
filtered_k562.obs = filtered_k562.obs.loc[:,['condition', 'cell_type']]

In [None]:
# Set condition names
mapper = {k:k for k in filtered_k562.obs['condition'].unique()}
mapper['non-targeting+ctrl'] = 'ctrl'
filtered_k562.obs['condition'] = filtered_k562.obs['condition'].map(mapper)

In [None]:
# Set cov_drug_dose_name names
filtered_k562.obs['cov_drug_dose_name'] = [cell_type+'_'+x+'_1+1' for x in filtered_k562.obs['condition']]
mapper = {k:k for k in filtered_k562.obs['cov_drug_dose_name'].unique()}
mapper['K562_ctrl_1+1'] = 'K562_ctrl_1'
filtered_k562.obs['cov_drug_dose_name'] = filtered_k562.obs['cov_drug_dose_name'].map(mapper)

In [None]:
from gears.data_utils import rank_genes_groups_by_cov

rank_genes_groups_by_cov(filtered_k562, groupby='cov_drug_dose_name', 
                         covariate='cell_type', control_group='ctrl_1', n_genes=20)

In [None]:
filtered_k562.X = csr_matrix(filtered_k562.X)
adata_k562 = filtered_k562

In [None]:
pert_data = PertData("..data/gene_perturb_data") # specific saved folder
pert_data.new_data_process(dataset_name = data_folder_k562, adata = adata_k562) # specific dataset name and adata object

In [None]:
pert_data.load(data_path = data_folder_k562) # load the processed data, the path is saved folder + dataset_name
pert_data.prepare_split(split = 'simulation', seed = 1) # get data split with seed
pert_data.get_dataloader(batch_size = 32, test_batch_size = 128) # prepare data loader