In [1]:
import torch
import pandas as pd
import os
import json
from torch.utils.data import DataLoader
import pickle
# Specify the working directory
os.chdir('/Users/david/Desktop/FinetuneEmbed')

# load gene ncbi descriptions
with open("./data/gene_text/hs_ncbi_gene_text.json", "r") as file:
    gene_descriptions = json.load(file)

## Long- vs short- range TFs
The input data used here are downloaded from Chen et al. (2020) (link: https://www-nature-com.stanford.idm.oclc.org/articles/s41467-020-16106-x).

In [None]:
# long_short_range_tf = pd.read_csv('./data/long_vs_shortTF/41467_2020_16106_MOESM4_ESM.csv')
# long_range_tf_gene = list(long_short_range_tf[long_short_range_tf['assignment']=='long-range TF']\
#                                 ['Unnamed: 0'])
# long_range_tf_gene = list(set(long_range_tf_gene) & set(gene_descriptions.keys())) # find the intersected genes

# short_range_tf_gene = list(long_short_range_tf[long_short_range_tf['assignment']=='short-range TF']\
#                                 ['Unnamed: 0'])
# short_range_tf_gene = list(set(short_range_tf_gene) & set(gene_descriptions.keys())) # find the intersected genes

The input data used here are downloaded from the geneformer paper Hugging Face website (link: https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/example_input_files/gene_classification).

In [2]:
with open("./data/long_vs_shortTF/example_input_files_gene_classification_tf_regulatory_range_tf_regulatory_range.pickle", "rb") as f:
    check_data = pickle.load(f)

long_range_tf_gene = check_data['long_range']
short_range_tf_gene = check_data['short_range']

import mygene
# convert gene id to gene symbols
mg = mygene.MyGeneInfo()
long_range_query = mg.querymany(long_range_tf_gene, species='human')
short_range_query = mg.querymany(short_range_tf_gene, species='human')
long_range_gene_name = [x['symbol'] for x in long_range_query]
short_range_gene_name = [x['symbol'] for x in short_range_query if 'symbol' in x]

long_range_tf_gene = list(set(long_range_gene_name) & set(gene_descriptions.keys())) # find the intersected genes
short_range_tf_gene = list(set(short_range_gene_name) & set(gene_descriptions.keys())) # find the intersected genes

2 input query terms found no hit:	['ENSG00000269603', 'ENSG00000267841']


In [3]:
from sklearn.model_selection import train_test_split
tf_genes = long_range_tf_gene + short_range_tf_gene
labels = [1] * len(long_range_tf_gene) + [0] * len(short_range_tf_gene) # 1 for long-range TF, 0 for short-range TF
# Split into train and test sets
genes_train, genes_test, labels_train, labels_test = train_test_split(tf_genes, labels, test_size=0.2, stratify=labels, random_state=7)

desc_train = [gene_descriptions[gene] for gene in genes_train]
desc_test = [gene_descriptions[gene] for gene in genes_test]

# Save the data
train_to_save = {'genes':genes_train, 'desc':desc_train, 'labels':labels_train}
val_to_save = {'genes':genes_test, 'desc':desc_test, 'labels':labels_test}
# Save as a pickle file
with open("./data/long_vs_shortTF/train_data.pkl", "wb") as f:
    pickle.dump(train_to_save, f)
with open("./data/long_vs_shortTF/test_data.pkl", "wb") as f:
    pickle.dump(val_to_save, f)