In [1]:
import torch
import pandas as pd
import os
import json
from torch.utils.data import DataLoader
import pickle
# Specify the working directory
os.chdir('/Users/david/Desktop/FinetuneEmbed')

In [2]:
# load gene ncbi descriptions
with open("./data/gene_text/hs_ncbi_gene_text.json", "r") as file:
    gene_descriptions = json.load(file)

## Long- vs short- range TFs
The input data used here are downloaded from Chen et al. (2020) (link: https://www-nature-com.stanford.idm.oclc.org/articles/s41467-020-16106-x).

In [None]:
# long_short_range_tf = pd.read_csv('./data/long_vs_shortTF/41467_2020_16106_MOESM4_ESM.csv')
# long_range_tf_gene = list(long_short_range_tf[long_short_range_tf['assignment']=='long-range TF']\
#                                 ['Unnamed: 0'])
# long_range_tf_gene = list(set(long_range_tf_gene) & set(gene_descriptions.keys())) # find the intersected genes

# short_range_tf_gene = list(long_short_range_tf[long_short_range_tf['assignment']=='short-range TF']\
#                                 ['Unnamed: 0'])
# short_range_tf_gene = list(set(short_range_tf_gene) & set(gene_descriptions.keys())) # find the intersected genes

The input data used here are downloaded from the geneformer paper Hugging Face website (link: https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/example_input_files/gene_classification).

In [3]:
with open("./data/long_vs_shortTF/example_input_files_gene_classification_tf_regulatory_range_tf_regulatory_range.pickle", "rb") as f:
    check_data = pickle.load(f)

In [4]:
long_range_tf_gene = check_data['long_range']
short_range_tf_gene = check_data['short_range']

import mygene
# convert gene id to gene symbols
mg = mygene.MyGeneInfo()
long_range_query = mg.querymany(long_range_tf_gene, species='human')
short_range_query = mg.querymany(short_range_tf_gene, species='human')
long_range_gene_name = [x['symbol'] for x in long_range_query]
short_range_gene_name = [x['symbol'] for x in short_range_query if 'symbol' in x]

long_range_tf_gene = list(set(long_range_gene_name) & set(gene_descriptions.keys())) # find the intersected genes
short_range_tf_gene = list(set(short_range_gene_name) & set(gene_descriptions.keys())) # find the intersected genes

2 input query terms found no hit:	['ENSG00000269603', 'ENSG00000267841']


In [5]:
from mod.mod import GeneDataset
from sklearn.model_selection import train_test_split
tf_genes = long_range_tf_gene + short_range_tf_gene
labels = [1] * len(long_range_tf_gene) + [0] * len(short_range_tf_gene) # 1 for long-range TF, 0 for short-range TF
# Split into train and test sets
genes_train, genes_test, labels_train, labels_test = train_test_split(tf_genes, labels, test_size=0.3, stratify=labels, random_state=7)

desc_train = [gene_descriptions[gene] for gene in genes_train]
desc_test = [gene_descriptions[gene] for gene in genes_test]

In [6]:
# Save the data
train_to_save = {'genes':genes_train, 'desc':desc_train, 'labels':labels_train}
val_to_save = {'genes':genes_test, 'desc':desc_test, 'labels':labels_test}
# Save as a pickle file
with open("./data/long_vs_shortTF/train_data.pkl", "wb") as f:
    pickle.dump(train_to_save, f)
with open("./data/long_vs_shortTF/test_data.pkl", "wb") as f:
    pickle.dump(val_to_save, f)

In [8]:
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
from sklearn.metrics import roc_auc_score
from torch.utils.data import Dataset, DataLoader
import numpy as np

# Define your dataset class
class TextDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length=128):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]
        
        encoding = self.tokenizer(
            text,
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )
        
        # Remove the batch dimension added by the tokenizer (squeeze the single dimension)
        encoding = {key: value.squeeze(0) for key, value in encoding.items()}
        encoding["label"] = torch.tensor(label, dtype=torch.long)

        return encoding

# Load model and tokenizer
model_name = "sentence-transformers/all-MiniLM-L6-v2"  # Choose an appropriate Sentence BERT model
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)

# Prepare datasets
train_texts, train_labels = desc_train, labels_train   # your training texts and labels
test_texts, test_labels = desc_test, labels_test    # your test texts and labels

train_dataset = TextDataset(train_texts, train_labels, tokenizer)
test_dataset = TextDataset(test_texts, test_labels, tokenizer)

  Referenced from: <9A4710B9-0DA3-36BB-9129-645F282E64B2> /Users/david/anaconda3/envs/myenv/lib/python3.10/site-packages/torchvision/image.so
  warn(
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at sentence-transformers/all-MiniLM-L6-v2 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [10]:
# Training arguments
training_args = TrainingArguments(
    output_dir="./results",
    evaluation_strategy="epoch",
    learning_rate=1e-3,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=10,
    weight_decay=0.01,
)

# Define the compute_metrics function for AUC
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    probs = torch.nn.functional.softmax(torch.tensor(logits), dim=1)[:, 1].numpy()  # Get probability of the positive class
    auc = roc_auc_score(labels, probs)
    return {"AUC": auc}

# Initialize the Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    compute_metrics=compute_metrics,
)

# Fine-tune the model
trainer.train()

# Evaluate on the test set
results = trainer.evaluate()
print("Test AUC:", results["eval_AUC"])



  0%|          | 0/180 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

{'eval_loss': 0.5751219391822815, 'eval_AUC': 0.5021367521367521, 'eval_runtime': 0.203, 'eval_samples_per_second': 172.452, 'eval_steps_per_second': 24.636, 'epoch': 1.0}


  0%|          | 0/5 [00:00<?, ?it/s]

{'eval_loss': 0.6117882132530212, 'eval_AUC': 0.5, 'eval_runtime': 0.1876, 'eval_samples_per_second': 186.529, 'eval_steps_per_second': 26.647, 'epoch': 2.0}


  0%|          | 0/5 [00:00<?, ?it/s]

{'eval_loss': 0.572949230670929, 'eval_AUC': 0.5598290598290598, 'eval_runtime': 0.1775, 'eval_samples_per_second': 197.193, 'eval_steps_per_second': 28.17, 'epoch': 3.0}


  0%|          | 0/5 [00:00<?, ?it/s]

{'eval_loss': 0.5788039565086365, 'eval_AUC': 0.5, 'eval_runtime': 0.1883, 'eval_samples_per_second': 185.875, 'eval_steps_per_second': 26.554, 'epoch': 4.0}


  0%|          | 0/5 [00:00<?, ?it/s]

{'eval_loss': 0.5742241740226746, 'eval_AUC': 0.5, 'eval_runtime': 0.1787, 'eval_samples_per_second': 195.888, 'eval_steps_per_second': 27.984, 'epoch': 5.0}


  0%|          | 0/5 [00:00<?, ?it/s]

{'eval_loss': 0.5736111402511597, 'eval_AUC': 0.4935897435897436, 'eval_runtime': 0.154, 'eval_samples_per_second': 227.283, 'eval_steps_per_second': 32.469, 'epoch': 6.0}


  0%|          | 0/5 [00:00<?, ?it/s]

{'eval_loss': 0.5719671249389648, 'eval_AUC': 0.7649572649572649, 'eval_runtime': 0.1953, 'eval_samples_per_second': 179.192, 'eval_steps_per_second': 25.599, 'epoch': 7.0}


  0%|          | 0/5 [00:00<?, ?it/s]

{'eval_loss': 0.576185941696167, 'eval_AUC': 0.8376068376068377, 'eval_runtime': 0.1818, 'eval_samples_per_second': 192.563, 'eval_steps_per_second': 27.509, 'epoch': 8.0}


  0%|          | 0/5 [00:00<?, ?it/s]

{'eval_loss': 0.5721469521522522, 'eval_AUC': 0.5, 'eval_runtime': 0.1979, 'eval_samples_per_second': 176.861, 'eval_steps_per_second': 25.266, 'epoch': 9.0}


  0%|          | 0/5 [00:00<?, ?it/s]

{'eval_loss': 0.5711102485656738, 'eval_AUC': 0.3589743589743589, 'eval_runtime': 0.174, 'eval_samples_per_second': 201.174, 'eval_steps_per_second': 28.739, 'epoch': 10.0}
{'train_runtime': 28.1621, 'train_samples_per_second': 49.357, 'train_steps_per_second': 6.392, 'train_loss': 0.5899030897352431, 'epoch': 10.0}


  0%|          | 0/5 [00:00<?, ?it/s]

Test AUC: 0.3589743589743589


In [11]:
len(desc_train_aug), len(genes_train_aug), len(labels_train_aug)

(556, 556, 556)

In [12]:
from mod.mod import collate_fn
# Create datasets and dataloaders
train_dataset = GeneDataset(genes_train_aug, desc_train_aug, labels_train_aug)
val_dataset = GeneDataset(genes_test, desc_test, labels_test)
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=2, collate_fn=collate_fn)

# Save the data
train_to_save = {'genes':genes_train_aug, 'dataloader':train_loader, 'labels':labels_train_aug}
val_to_save = {'genes':genes_test, 'dataloader':val_loader, 'labels':labels_test}

In [14]:
# Example usage
for genes, labels, descs in train_loader:
    print(genes)

ValueError: too many dimensions 'str'

In [12]:
# Save as a pickle file
with open("./data/long_vs_shortTF/train_data.pkl", "wb") as f:
    pickle.dump(train_to_save, f)
with open("./data/long_vs_shortTF/test_data.pkl", "wb") as f:
    pickle.dump(val_to_save, f)

## Dosage sensitive vs insensitive TFs

In [None]:
# link_file = "https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/raw/main/example_input_files/gene_classification/dosage_sensitive_tfs/dosage_sens_tf_labels.csv"
with open(f"./data/DosageSensitivity/example_input_files_gene_classification_dosage_sensitive_tfs_dosage_sensitivity_TFs.pickle", "rb") as fp:
    dosage_tfs = pickle.load(fp)
sensitive = dosage_tfs["Dosage-sensitive TFs"]
insensitive = dosage_tfs["Dosage-insensitive TFs"]

In [None]:
import mygene
# convert gene id to gene symbols
mg = mygene.MyGeneInfo()
sensitive_query = mg.querymany(sensitive, species='human')
in_sensitive_query = mg.querymany(insensitive, species='human')
sensitive_gene_name = [x['symbol'] for x in sensitive_query]
insensitive_gene_name = [x['symbol'] for x in in_sensitive_query if 'symbol' in x]

1 input query terms found no hit:	['ENSG00000215271']


In [None]:
sensitive_gene = list(set(sensitive_gene_name) & set(gene_descriptions.keys())) # find the intersected genes
insensitive_gene = list(set(insensitive_gene_name) & set(gene_descriptions.keys())) # find the intersected genes

In [None]:
from mod.mod import TextDataset
from sklearn.model_selection import train_test_split
genes = sensitive_gene + insensitive_gene
labels = [1] * len(sensitive_gene) + [0] * len(insensitive_gene) # 1 for sensitive, 0 for insensitive
# Split into train and test sets
genes_train, genes_test, labels_train, labels_test = train_test_split(genes, labels, test_size=0.2, stratify=labels, random_state=7)

desc_train = [gene_descriptions[gene] for gene in genes_train]
desc_test = [gene_descriptions[gene] for gene in genes_test]

In [None]:
# Create datasets and dataloaders
train_dataset = TextDataset(desc_train, labels_train)
val_dataset = TextDataset(desc_test, labels_test)
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=2)

# Save the data
train_to_save = {'genes':genes_train, 'dataloader':train_loader, 'labels':labels_train}
val_to_save = {'genes':genes_test, 'dataloader':val_loader, 'labels':labels_test}

In [None]:
# Save as a pickle file
with open("./data/DosageSensitivity/train_data.pkl", "wb") as f:
    pickle.dump(train_to_save, f)
with open("./data/DosageSensitivity/test_data.pkl", "wb") as f:
    pickle.dump(val_to_save, f)

## Methylation state prediction
The csv files are downloaded from https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/

#### Bivalent vs. lys4

In [None]:
with open(f"./data/MethylationState/example_input_files_gene_classification_bivalent_promoters_bivalent_vs_lys4_only.pickle", "rb") as fp:
    bivalent_vs_lys4 = pickle.load(fp)

In [None]:
bivalent_gene_labels = bivalent_vs_lys4['bivalent']
lysine_gene_labels = bivalent_vs_lys4['lys4_only']

In [None]:
import mygene
# convert gene id to gene symbols
mg = mygene.MyGeneInfo()
bivalent_query = mg.querymany(bivalent_gene_labels, species='human')
lysine_query = mg.querymany(lysine_gene_labels, species='human')
bivalent_gene_name = [x.get('symbol', '') for x in bivalent_query]
lysine_gene_name = [x.get('symbol', '') for x in lysine_query if 'symbol' in x]

10 input query terms found dup hits:	[('ENSG00000007372', 2), ('ENSG00000110693', 2), ('ENSG00000117707', 2), ('ENSG00000120093', 2), ('E
2 input query terms found dup hits:	[('ENSG00000196628', 2), ('ENSG00000198728', 2)]


In [None]:
bivalent_gene = list(set(bivalent_gene_name) & set(gene_descriptions.keys())) # find the intersected genes
lysine_gene = list(set(lysine_gene_name) & set(gene_descriptions.keys())) # find the intersected genes

In [None]:
from mod.mod import TextDataset
from sklearn.model_selection import train_test_split
genes = bivalent_gene + lysine_gene
labels = [1] * len(bivalent_gene) + [0] * len(lysine_gene) # 1 for bivalent_gene, 0 for lysine_gene
# Split into train and test sets
genes_train, genes_test, labels_train, labels_test = train_test_split(genes, labels, test_size=0.2, stratify=labels, random_state=7)

desc_train = [gene_descriptions[gene] for gene in genes_train]
desc_test = [gene_descriptions[gene] for gene in genes_test]

In [None]:
# Create datasets and dataloaders
train_dataset = TextDataset(desc_train, labels_train)
val_dataset = TextDataset(desc_test, labels_test)
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=2)

# Save the data
train_to_save = {'genes':genes_train, 'dataloader':train_loader, 'labels':labels_train}
val_to_save = {'genes':genes_test, 'dataloader':val_loader, 'labels':labels_test}

In [None]:
# Save as a pickle file
with open("./data/MethylationState/bivalent_vs_lys4/train_data.pkl", "wb") as f:
    pickle.dump(train_to_save, f)
with open("./data/MethylationState/bivalent_vs_lys4/test_data.pkl", "wb") as f:
    pickle.dump(val_to_save, f)

#### Bivalent vs. no methyl

In [None]:
with open(f"./data/MethylationState/example_input_files_gene_classification_bivalent_promoters_bivalent_vs_no_methyl.pickle", "rb") as fp:
    bivalent_vs_no_methyl = pickle.load(fp)

In [None]:
bivalent_gene_labels = bivalent_vs_no_methyl['bivalent']
no_methylation_gene_labels = bivalent_vs_no_methyl['no_methylation']

In [None]:
import mygene
# convert gene id to gene symbols
mg = mygene.MyGeneInfo()
bivalent_query = mg.querymany(bivalent_gene_labels, species='human')
no_methylation_query = mg.querymany(no_methylation_gene_labels, species='human')
bivalent_gene_name = [x.get('symbol', '') for x in bivalent_query]
no_methylation_gene_name = [x.get('symbol', '') for x in no_methylation_query if 'symbol' in x]

10 input query terms found dup hits:	[('ENSG00000007372', 2), ('ENSG00000110693', 2), ('ENSG00000117707', 2), ('ENSG00000120093', 2), ('E
2 input query terms found dup hits:	[('ENSG00000147488', 2), ('ENSG00000151322', 2)]


In [None]:
bivalent_gene = list(set(bivalent_gene_name) & set(gene_descriptions.keys())) # find the intersected genes
no_methylation_gene = list(set(no_methylation_gene_name) & set(gene_descriptions.keys())) # find the intersected genes

In [None]:
from mod.mod import TextDataset
from sklearn.model_selection import train_test_split
genes = bivalent_gene + no_methylation_gene
labels = [1] * len(bivalent_gene) + [0] * len(no_methylation_gene) # 1 for bivalent_gene, 0 for no_methylation_gene
# Split into train and test sets
genes_train, genes_test, labels_train, labels_test = train_test_split(genes, labels, test_size=0.2, stratify=labels, random_state=7)

desc_train = [gene_descriptions[gene] for gene in genes_train]
desc_test = [gene_descriptions[gene] for gene in genes_test]

In [None]:
# Create datasets and dataloaders
train_dataset = TextDataset(desc_train, labels_train)
val_dataset = TextDataset(desc_test, labels_test)
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=2)

# Save the data
train_to_save = {'genes':genes_train, 'dataloader':train_loader, 'labels':labels_train}
val_to_save = {'genes':genes_test, 'dataloader':val_loader, 'labels':labels_test}

In [None]:
# Save as a pickle file
with open("./data/MethylationState/bivalent_vs_no_methyl/train_data.pkl", "wb") as f:
    pickle.dump(train_to_save, f)
with open("./data/MethylationState/bivalent_vs_no_methyl/test_data.pkl", "wb") as f:
    pickle.dump(val_to_save, f)

In [15]:
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from typing import List, Dict

# Sample data
genes = ['gene1', 'gene2', 'gene3']  # List of genes
labels = [0, 1, 0]  # Corresponding labels for each gene
descriptions = [
    "This gene is involved in metabolic processes.",
    "This gene plays a role in cell division.",
    "This gene is associated with the immune response."
]  # Descriptions for each gene

# Custom Dataset class
class GeneDataset(Dataset):
    def __init__(self, genes: List[str], labels: List[int], descriptions: List[str]):
        self.genes = genes
        self.labels = labels
        self.descriptions = descriptions

    def __len__(self):
        return len(self.genes)

    def __getitem__(self, idx):
        gene = self.genes[idx]
        label = self.labels[idx]
        description = self.descriptions[idx]
        return {"gene": gene, "label": torch.tensor(label), "description": description}

# Custom collate function for variable-length descriptions
def collate_fn(batch: List[Dict]):
    genes = [item['gene'] for item in batch]
    labels = torch.stack([item['label'] for item in batch])
    
    # Tokenize descriptions at character level
    descriptions = [torch.tensor([ord(char) for char in item['description']], dtype=torch.long) for item in batch]
    
    # Pad descriptions to the same length
    descriptions_padded = pad_sequence(descriptions, batch_first=True, padding_value=0)
    
    return {"genes": genes, "labels": labels, "descriptions": descriptions_padded}

# Instantiate the dataset and data loader
gene_dataset = GeneDataset(genes, labels, descriptions)
gene_loader = DataLoader(gene_dataset, batch_size=2, shuffle=True, collate_fn=collate_fn)

# Example usage
for batch in gene_loader:
    print(batch)

{'genes': ['gene2', 'gene1'], 'labels': tensor([1, 0]), 'descriptions': tensor([[ 84, 104, 105, 115,  32, 103, 101, 110, 101,  32, 112, 108,  97, 121,
         115,  32,  97,  32, 114, 111, 108, 101,  32, 105, 110,  32,  99, 101,
         108, 108,  32, 100, 105, 118, 105, 115, 105, 111, 110,  46,   0,   0,
           0,   0,   0],
        [ 84, 104, 105, 115,  32, 103, 101, 110, 101,  32, 105, 115,  32, 105,
         110, 118, 111, 108, 118, 101, 100,  32, 105, 110,  32, 109, 101, 116,
          97,  98, 111, 108, 105,  99,  32, 112, 114, 111,  99, 101, 115, 115,
         101, 115,  46]])}
{'genes': ['gene3'], 'labels': tensor([0]), 'descriptions': tensor([[ 84, 104, 105, 115,  32, 103, 101, 110, 101,  32, 105, 115,  32,  97,
         115, 115, 111,  99, 105,  97, 116, 101, 100,  32, 119, 105, 116, 104,
          32, 116, 104, 101,  32, 105, 109, 109, 117, 110, 101,  32, 114, 101,
         115, 112, 111, 110, 115, 101,  46]])}
