In [1]:
import torch
import pandas as pd
import os
import json
from torch.utils.data import DataLoader
import pickle
# Specify the working directory
os.chdir('/Users/david/Desktop/FinetuneEmbed')

In [2]:
# load gene ncbi descriptions
with open("./data/gene_text/hs_ncbi_gene_text.json", "r") as file:
    gene_descriptions = json.load(file)

## Long- vs short- range TFs
The input data used here are downloaded from Chen et al. (2020) (link: https://www-nature-com.stanford.idm.oclc.org/articles/s41467-020-16106-x).

In [3]:
long_short_range_tf = pd.read_csv('./data/long_vs_shortTF/41467_2020_16106_MOESM4_ESM.csv')
long_range_tf_gene = list(long_short_range_tf[long_short_range_tf['assignment']=='long-range TF']\
                                ['Unnamed: 0'])
long_range_tf_gene = list(set(long_range_tf_gene) & set(gene_descriptions.keys())) # find the intersected genes

short_range_tf_gene = list(long_short_range_tf[long_short_range_tf['assignment']=='short-range TF']\
                                ['Unnamed: 0'])
short_range_tf_gene = list(set(short_range_tf_gene) & set(gene_descriptions.keys())) # find the intersected genes

In [4]:
from mod.mod import TextDataset
from sklearn.model_selection import train_test_split
tf_genes = long_range_tf_gene + short_range_tf_gene
labels = [1] * len(long_range_tf_gene) + [0] * len(short_range_tf_gene) # 1 for long-range TF, 0 for short-range TF
# Split into train and test sets
genes_train, genes_test, labels_train, labels_test = train_test_split(tf_genes, labels, test_size=0.4, stratify=labels, random_state=44)

desc_train = [gene_descriptions[gene] for gene in genes_train]
desc_test = [gene_descriptions[gene] for gene in genes_test]

In [5]:
# Create datasets and dataloaders
train_dataset = TextDataset(desc_train, labels_train)
val_dataset = TextDataset(desc_test, labels_test)
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=2)

# Save the data
train_to_save = {'genes':genes_train, 'dataloader':train_loader, 'labels':labels_train}
val_to_save = {'genes':genes_test, 'dataloader':val_loader, 'labels':labels_test}

In [6]:
# Save as a pickle file
with open("./data/long_vs_shortTF/train_data.pkl", "wb") as f:
    pickle.dump(train_to_save, f)
with open("./data/long_vs_shortTF/test_data.pkl", "wb") as f:
    pickle.dump(val_to_save, f)

## Dosage sensitive vs insensitive TFs

In [7]:
# link_file = "https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/raw/main/example_input_files/gene_classification/dosage_sensitive_tfs/dosage_sens_tf_labels.csv"
with open(f"./data/DosageSensitivity/example_input_files_gene_classification_dosage_sensitive_tfs_dosage_sensitivity_TFs.pickle", "rb") as fp:
    dosage_tfs = pickle.load(fp)
sensitive = dosage_tfs["Dosage-sensitive TFs"]
insensitive = dosage_tfs["Dosage-insensitive TFs"]

In [8]:
import mygene
# convert gene id to gene symbols
mg = mygene.MyGeneInfo()
sensitive_query = mg.querymany(sensitive, species='human')
in_sensitive_query = mg.querymany(insensitive, species='human')
sensitive_gene_name = [x['symbol'] for x in sensitive_query]
insensitive_gene_name = [x['symbol'] for x in in_sensitive_query if 'symbol' in x]

1 input query terms found no hit:	['ENSG00000215271']


In [9]:
sensitive_gene = list(set(sensitive_gene_name) & set(gene_descriptions.keys())) # find the intersected genes
insensitive_gene = list(set(insensitive_gene_name) & set(gene_descriptions.keys())) # find the intersected genes

In [10]:
from mod.mod import TextDataset
from sklearn.model_selection import train_test_split
genes = sensitive_gene + insensitive_gene
labels = [1] * len(sensitive_gene) + [0] * len(insensitive_gene) # 1 for sensitive, 0 for insensitive
# Split into train and test sets
genes_train, genes_test, labels_train, labels_test = train_test_split(genes, labels, test_size=0.4, stratify=labels, random_state=44)

desc_train = [gene_descriptions[gene] for gene in genes_train]
desc_test = [gene_descriptions[gene] for gene in genes_test]

In [11]:
# Create datasets and dataloaders
train_dataset = TextDataset(desc_train, labels_train)
val_dataset = TextDataset(desc_test, labels_test)
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=2)

# Save the data
train_to_save = {'genes':genes_train, 'dataloader':train_loader, 'labels':labels_train}
val_to_save = {'genes':genes_test, 'dataloader':val_loader, 'labels':labels_test}

In [12]:
# Save as a pickle file
with open("./data/DosageSensitivity/train_data.pkl", "wb") as f:
    pickle.dump(train_to_save, f)
with open("./data/DosageSensitivity/test_data.pkl", "wb") as f:
    pickle.dump(val_to_save, f)

## Methylation state prediction
The csv files are downloaded from https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/

#### Bivalent vs. lys4

In [13]:
with open(f"./data/MethylationState/example_input_files_gene_classification_bivalent_promoters_bivalent_vs_lys4_only.pickle", "rb") as fp:
    bivalent_vs_lys4 = pickle.load(fp)

In [14]:
bivalent_gene_labels = bivalent_vs_lys4['bivalent']
lysine_gene_labels = bivalent_vs_lys4['lys4_only']

In [15]:
import mygene
# convert gene id to gene symbols
mg = mygene.MyGeneInfo()
bivalent_query = mg.querymany(bivalent_gene_labels, species='human')
lysine_query = mg.querymany(lysine_gene_labels, species='human')
bivalent_gene_name = [x.get('symbol', '') for x in bivalent_query]
lysine_gene_name = [x.get('symbol', '') for x in lysine_query if 'symbol' in x]

10 input query terms found dup hits:	[('ENSG00000007372', 2), ('ENSG00000110693', 2), ('ENSG00000117707', 2), ('ENSG00000120093', 2), ('E
2 input query terms found dup hits:	[('ENSG00000196628', 2), ('ENSG00000198728', 2)]


In [16]:
bivalent_gene = list(set(bivalent_gene_name) & set(gene_descriptions.keys())) # find the intersected genes
lysine_gene = list(set(lysine_gene_name) & set(gene_descriptions.keys())) # find the intersected genes

In [17]:
from mod.mod import TextDataset
from sklearn.model_selection import train_test_split
genes = bivalent_gene + lysine_gene
labels = [1] * len(bivalent_gene) + [0] * len(lysine_gene) # 1 for bivalent_gene, 0 for lysine_gene
# Split into train and test sets
genes_train, genes_test, labels_train, labels_test = train_test_split(genes, labels, test_size=0.4, stratify=labels, random_state=44)

desc_train = [gene_descriptions[gene] for gene in genes_train]
desc_test = [gene_descriptions[gene] for gene in genes_test]

In [18]:
# Create datasets and dataloaders
train_dataset = TextDataset(desc_train, labels_train)
val_dataset = TextDataset(desc_test, labels_test)
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=2)

# Save the data
train_to_save = {'genes':genes_train, 'dataloader':train_loader, 'labels':labels_train}
val_to_save = {'genes':genes_test, 'dataloader':val_loader, 'labels':labels_test}

In [19]:
# Save as a pickle file
with open("./data/MethylationState/bivalent_vs_lys4/train_data.pkl", "wb") as f:
    pickle.dump(train_to_save, f)
with open("./data/MethylationState/bivalent_vs_lys4/test_data.pkl", "wb") as f:
    pickle.dump(val_to_save, f)

#### Bivalent vs. no methyl

In [20]:
with open(f"./data/MethylationState/example_input_files_gene_classification_bivalent_promoters_bivalent_vs_no_methyl.pickle", "rb") as fp:
    bivalent_vs_no_methyl = pickle.load(fp)

In [21]:
bivalent_gene_labels = bivalent_vs_no_methyl['bivalent']
no_methylation_gene_labels = bivalent_vs_no_methyl['no_methylation']

In [22]:
import mygene
# convert gene id to gene symbols
mg = mygene.MyGeneInfo()
bivalent_query = mg.querymany(bivalent_gene_labels, species='human')
no_methylation_query = mg.querymany(no_methylation_gene_labels, species='human')
bivalent_gene_name = [x.get('symbol', '') for x in bivalent_query]
no_methylation_gene_name = [x.get('symbol', '') for x in no_methylation_query if 'symbol' in x]

10 input query terms found dup hits:	[('ENSG00000007372', 2), ('ENSG00000110693', 2), ('ENSG00000117707', 2), ('ENSG00000120093', 2), ('E
2 input query terms found dup hits:	[('ENSG00000147488', 2), ('ENSG00000151322', 2)]


In [23]:
bivalent_gene = list(set(bivalent_gene_name) & set(gene_descriptions.keys())) # find the intersected genes
no_methylation_gene = list(set(no_methylation_gene_name) & set(gene_descriptions.keys())) # find the intersected genes

In [24]:
from mod.mod import TextDataset
from sklearn.model_selection import train_test_split
genes = bivalent_gene + no_methylation_gene
labels = [1] * len(bivalent_gene) + [0] * len(no_methylation_gene) # 1 for bivalent_gene, 0 for no_methylation_gene
# Split into train and test sets
genes_train, genes_test, labels_train, labels_test = train_test_split(genes, labels, test_size=0.4, stratify=labels, random_state=44)

desc_train = [gene_descriptions[gene] for gene in genes_train]
desc_test = [gene_descriptions[gene] for gene in genes_test]

In [25]:
# Create datasets and dataloaders
train_dataset = TextDataset(desc_train, labels_train)
val_dataset = TextDataset(desc_test, labels_test)
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=2)

# Save the data
train_to_save = {'genes':genes_train, 'dataloader':train_loader, 'labels':labels_train}
val_to_save = {'genes':genes_test, 'dataloader':val_loader, 'labels':labels_test}

In [26]:
# Save as a pickle file
with open("./data/MethylationState/bivalent_vs_no_methyl/train_data.pkl", "wb") as f:
    pickle.dump(train_to_save, f)
with open("./data/MethylationState/bivalent_vs_no_methyl/test_data.pkl", "wb") as f:
    pickle.dump(val_to_save, f)