In [None]:
pip install umap-learn

In [None]:
!pip install git+https://huggingface.co/ctheodoris/Geneformer

In [2]:
import pandas as pd
import numpy as np
import ast
from tqdm import tqdm
import json

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

In [4]:
from google.colab import drive
drive.mount('/content/drive/')

Mounted at /content/drive/


In [5]:
def trunc(text):
    return text[:-4]

In [6]:
ase_rows = pd.read_feather("/content/drive/MyDrive/ase_data/ase_rows_only_with_p_localized.feather")

In [7]:
output_list = pd.read_csv('/content/drive/MyDrive/ase_data/TCGA_localized_PCa_subtypes.txt', sep= "\t")

In [8]:
output_list = output_list.rename(columns={"Unnamed: 0": "caseID"})

In [9]:
file = open("/content/drive/MyDrive/ase_data/ASE_localized_recurr10perc.txt", "r")
content = file.read()
file.close()

ase_localized_recurr10perc = content.split("\n")
ase_localized_recurr10perc.pop()

''

In [10]:
import pickle
with open("/content/drive/MyDrive/ase_data/geneformer_token_dictionary.pkl", "rb") as f:
  gene_token_dict = pickle.load(f)

# Assuming 'ase_rows' is your DataFrame and 'genes' is a Series from this DataFrame
# Your 'parse' function
def parse(gene):
    return gene.split("_")[1].split(".")[0]

# Apply the 'parse' function to each gene in the 'genes' column of 'ase_rows' to create 'ensembl_genes'
ase_rows['ensembl_genes'] = ase_rows['gene'].apply(parse)

# Initialize an empty set to hold genes not covered by the gene_token_dict
excluded_genes = set()

# Check each gene in 'ensembl_genes' against the gene_token_dict to find excluded genes
for gene in set(ase_rows['ensembl_genes']):
    if gene not in gene_token_dict:
        excluded_genes.add(gene)

# Filter 'ase_rows' to exclude rows with genes that are not covered by gene_token_dict
ase_rows_filtered = ase_rows[~ase_rows['ensembl_genes'].isin(excluded_genes)]

In [None]:
ase_rows

In [None]:
ase_rows_filtered

In [11]:
ase_rows_filtered['caseID'] = ase_rows_filtered['caseID'].apply(trunc)
ase_rows_filtered = ase_rows_filtered[['caseID', 'genename', 'foldc_adjust_new', 'ensembl_genes']]
ase_rows_filtered

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  ase_rows_filtered['caseID'] = ase_rows_filtered['caseID'].apply(trunc)


Unnamed: 0,caseID,genename,foldc_adjust_new,ensembl_genes
0,TCGA-ZG-A9NI,ZYX,1.129678,ENSG00000159840
1,TCGA-ZG-A9NI,ZWINT,2.144539,ENSG00000122952
2,TCGA-ZG-A9NI,ZRANB2,1.230258,ENSG00000132485
3,TCGA-ZG-A9NI,ZNFX1,1.110871,ENSG00000124201
4,TCGA-ZG-A9NI,ZNF84,1.879898,ENSG00000198040
...,...,...,...,...
1493003,TCGA-2A-A8VO,AARS,1.026190,ENSG00000090861
1493004,TCGA-2A-A8VO,AAK1,1.946546,ENSG00000115977
1493005,TCGA-2A-A8VO,AACS,1.160942,ENSG00000081760
1493006,TCGA-2A-A8VO,A4GALT,0.689941,ENSG00000128274


In [12]:
# Convert boolean columns to integers (True to 1, False to 0)
binary_columns = ['SPOP', 'FOXA1', 'MSI', 'TMPRSS2--ERG fusion', 'PTEN deletion']
output_list[binary_columns] = output_list[binary_columns].astype(int)

# Extract binary vectors per row (caseID)
binary_vectors = output_list[binary_columns].values.tolist()

# Optionally, create a dictionary with caseID as keys and binary vectors as values
caseID_vectors = dict(zip(output_list['caseID'], binary_vectors))

In [None]:
caseID_vectors

In [13]:
# Group by 'caseID' and aggregate 'genename' and 'fold_c_value' into lists
aggregated_ase_rows = ase_rows_filtered.groupby('caseID').agg({
    'ensembl_genes': lambda x: list(x),
    'foldc_adjust_new': lambda x: list(x)
}).reset_index()

# Now, aggregated_df contains one row per caseID,
# with 'genename' and 'fold_c_value' columns containing lists of values.

aggregated_ase_rows

Unnamed: 0,caseID,ensembl_genes,foldc_adjust_new
0,TCGA-2A-A8VO,"[ENSG00000091436, ENSG00000131848, ENSG0000019...","[0.18942017743179512, 0.6828771359250264, 1.23..."
1,TCGA-2A-A8VT,"[ENSG00000159840, ENSG00000174442, ENSG0000016...","[0.6213034048266602, 1.6194680584463998, 1.443..."
2,TCGA-2A-A8VV,"[ENSG00000074755, ENSG00000159840, ENSG0000015...","[1.09120612144389, 1.1662320926778185, 0.83460..."
3,TCGA-2A-A8VX,"[ENSG00000091436, ENSG00000159840, ENSG0000012...","[3.0323215295822448, 0.6323230667045006, 0.769..."
4,TCGA-2A-A8W1,"[ENSG00000074755, ENSG00000162378, ENSG0000007...","[0.9524830479450136, 1.8592842071080753, 0.944..."
...,...,...,...
490,TCGA-ZG-A9M4,"[ENSG00000074755, ENSG00000162415, ENSG0000014...","[0.6113333830379262, 0.516567508774503, 0.4334..."
491,TCGA-ZG-A9MC,"[ENSG00000036549, ENSG00000070476, ENSG0000017...","[3.2966079478793766, 0.34299176406443294, 1.88..."
492,TCGA-ZG-A9N3,"[ENSG00000074755, ENSG00000132003, ENSG0000013...","[0.5179727310287839, 0.22299939514922978, 6.39..."
493,TCGA-ZG-A9ND,"[ENSG00000074755, ENSG00000070476, ENSG0000012...","[4.0927262289111255, 1.515322580745505, 1.5944..."


In [14]:
aggregated_ase_rows['caseID'] = aggregated_ase_rows['caseID'].astype(str)

# Apply the mapping
aggregated_ase_rows['target'] = aggregated_ase_rows['caseID'].map(caseID_vectors)

# Now, aggregated_df has a new column 'target' with binary vectors for each caseID
aggregated_ase_rows

Unnamed: 0,caseID,ensembl_genes,foldc_adjust_new,target
0,TCGA-2A-A8VO,"[ENSG00000091436, ENSG00000131848, ENSG0000019...","[0.18942017743179512, 0.6828771359250264, 1.23...","[0, 0, 0, 0, 0]"
1,TCGA-2A-A8VT,"[ENSG00000159840, ENSG00000174442, ENSG0000016...","[0.6213034048266602, 1.6194680584463998, 1.443...","[0, 0, 0, 1, 0]"
2,TCGA-2A-A8VV,"[ENSG00000074755, ENSG00000159840, ENSG0000015...","[1.09120612144389, 1.1662320926778185, 0.83460...","[0, 0, 0, 1, 0]"
3,TCGA-2A-A8VX,"[ENSG00000091436, ENSG00000159840, ENSG0000012...","[3.0323215295822448, 0.6323230667045006, 0.769...","[0, 0, 0, 0, 0]"
4,TCGA-2A-A8W1,"[ENSG00000074755, ENSG00000162378, ENSG0000007...","[0.9524830479450136, 1.8592842071080753, 0.944...","[0, 0, 0, 0, 0]"
...,...,...,...,...
490,TCGA-ZG-A9M4,"[ENSG00000074755, ENSG00000162415, ENSG0000014...","[0.6113333830379262, 0.516567508774503, 0.4334...","[0, 0, 0, 0, 0]"
491,TCGA-ZG-A9MC,"[ENSG00000036549, ENSG00000070476, ENSG0000017...","[3.2966079478793766, 0.34299176406443294, 1.88...","[0, 0, 0, 1, 1]"
492,TCGA-ZG-A9N3,"[ENSG00000074755, ENSG00000132003, ENSG0000013...","[0.5179727310287839, 0.22299939514922978, 6.39...","[0, 0, 0, 1, 1]"
493,TCGA-ZG-A9ND,"[ENSG00000074755, ENSG00000070476, ENSG0000012...","[4.0927262289111255, 1.515322580745505, 1.5944...","[0, 0, 1, 0, 0]"


In [15]:
def rank_genes(gene_vector, gene_tokens):
    """
    Rank gene expression vector.
    """
    # sort by median-scaled gene values
    sorted_indices = np.argsort(-gene_vector)
    return gene_tokens[sorted_indices]


def tokenize_cell(gene_vector, gene_tokens):
    """
    Convert normalized gene expression vector to tokenized rank value encoding.
    """
    # create array of gene vector with token indices
    # mask undetected genes
    nonzero_mask = np.nonzero(gene_vector)[0]
    # rank by median-scaled gene values
    return rank_genes(gene_vector[nonzero_mask], gene_tokens[nonzero_mask])

In [16]:
# Step 1: Convert ensembl_genes to numerical indices
aggregated_ase_rows['gene_indices'] = aggregated_ase_rows['ensembl_genes'].apply(lambda genes: [gene_token_dict[gene] for gene in genes if gene in gene_token_dict])

In [17]:
def compare_gene_lengths(df):
    # Initialize an empty list to store the results of comparisons
    discrepancies = []

    # Iterate through each row in the DataFrame
    for index, row in df.iterrows():
        # Get the lengths of gene_indices and ensembl_genes for the current row
        len_gene_indices = len(row['gene_indices'])
        len_ensembl_genes = len(row['ensembl_genes'])

        if len_gene_indices != len_ensembl_genes:
            discrepancies.append((index, len_ensembl_genes, len_gene_indices))
            print(f"Row {index}: Length of ensembl_genes is {len_ensembl_genes}, but length of gene_indices is {len_gene_indices}.")

    return discrepancies

discrepancies = compare_gene_lengths(aggregated_ase_rows)

# check if there were any discrepancies
if discrepancies:
    print("There are discrepancies in the lengths.")
else:
    print("All lengths match.")


All lengths match.


In [18]:
# Step 2: Apply tokenize_cell function to each row
aggregated_ase_rows['tokenized_genes'] = aggregated_ase_rows.apply(lambda row: tokenize_cell(np.array(row['foldc_adjust_new']), np.array(row['gene_indices'])), axis=1)

#Collate and Dataset

In [19]:
from torch.utils.data import Dataset
import pandas as pd
import numpy as np

class ASEDataset(Dataset):
    def __init__(self, dataframe):
        """
        Initializes the dataset.
        :param dataframe: A pandas DataFrame containing the data.
        """
        self.dataframe = dataframe

    def __len__(self):
        """
        Returns the size of the dataset.
        """
        return len(self.dataframe)

    def __getitem__(self, idx):
        """
        Retrieves a single data point from the dataset.
        """
        row = self.dataframe.iloc[idx]
        gene_tokens = np.array(row['tokenized_genes'])
        fold_change = np.array(row['foldc_adjust_new'])
        target = np.array(row['target'], dtype=np.float32)

        # Return gene_tokens and target as the data dictionary
        return {'gene_tokens': gene_tokens, 'fold_change': fold_change, 'target': target}

In [20]:
import torch

def collate_fn(batch, max_len=512, pad_token=0):
    batch_inputs, batch_targets = [], []

    for item in batch:
        # Extract gene_tokens and target from the current item
        gene_tokens = item['gene_tokens']
        target = item['target']

        # Truncate the gene_tokens to max_len if necessary
        if len(gene_tokens) > max_len:
            gene_tokens = gene_tokens[:max_len]

        # Pad gene_tokens to ensure they are all the same length
        if len(gene_tokens) < max_len:
            padding_length = max_len - len(gene_tokens)
            gene_tokens.extend([pad_token] * padding_length)

        # Append the processed gene_tokens and target to our batch lists
        batch_inputs.append(gene_tokens)
        batch_targets.append(target)

    # Convert lists to PyTorch tensors for model processing
    batch_inputs_tensor = torch.tensor(batch_inputs, dtype=torch.long)
    batch_targets_tensor = torch.tensor(batch_targets, dtype=torch.float)  # Use float for multilabel classification

    return batch_inputs_tensor, batch_targets_tensor


In [21]:
from sklearn.model_selection import train_test_split

# Perform the split
train_df, test_df = train_test_split(aggregated_ase_rows, test_size=0.2, random_state=42)

# Create training and testing datasets
train_dataset = ASEDataset(train_df)
test_dataset = ASEDataset(test_df)

In [22]:
from torch.utils.data import DataLoader

# Define batch size and collate_fn (as previously described)
batch_size = 32

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

# Model

In [23]:
# Load model directly
from transformers import AutoTokenizer, AutoModelForMaskedLM

# tokenizer = AutoTokenizer.from_pretrained("ctheodoris/Geneformer")
base_model = AutoModelForMaskedLM.from_pretrained("ctheodoris/Geneformer")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/565 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/41.2M [00:00<?, ?B/s]

In [24]:
import torch
from torch import nn
from transformers import BertModel

class ASEGeneformer(nn.Module):
    def __init__(self, base_model, num_labels):
        super(ASEGeneformer, self).__init__()
        self.num_labels = num_labels

        # Initialize the BERT model
        self.base_model = base_model.bert

        # Classification head
        self.classification_head = nn.Linear(256, num_labels)

        # Activation function to get probabilities for each label
        self.activation = nn.Sigmoid()

    def forward(self, input_ids, attention_mask=None, token_type_ids=None):
        # Get the outputs from BERT model
        outputs = self.base_model(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)

        # Extract the last hidden state of the [CLS] token
        cls_output = outputs.last_hidden_state[:, 0, :]

        # Pass through the classification head
        logits = self.classification_head(cls_output)

        # Apply activation function to get probabilities
        probs = self.activation(logits)

        return probs


In [25]:
ase_model = ASEGeneformer(base_model, 5)

In [26]:
ase_model

ASEGeneformer(
  (base_model): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(25426, 256, padding_idx=0)
      (position_embeddings): Embedding(2048, 256)
      (token_type_embeddings): Embedding(2, 256)
      (LayerNorm): LayerNorm((256,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.02, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-5): 6 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=256, out_features=256, bias=True)
              (key): Linear(in_features=256, out_features=256, bias=True)
              (value): Linear(in_features=256, out_features=256, bias=True)
              (dropout): Dropout(p=0.02, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=256, out_features=256, bias=True)
              (LayerNorm): LayerNorm((256,), eps=1e-12, elementw

#Training

In [27]:
from sklearn.metrics import f1_score, precision_score, recall_score, hamming_loss
import torch

def evaluate_metrics(y_true, y_pred):
    """
    Computes evaluation metrics for multilabel classification.
    """
    # Threshold predictions to convert probabilities to binary output
    y_pred = torch.sigmoid(y_pred) > 0.5
    y_pred = y_pred.cpu().numpy()
    y_true = y_true.cpu().numpy()

    metrics = {
        'F1 Score (Macro)': f1_score(y_true, y_pred, average='macro'),
        'Precision (Macro)': precision_score(y_true, y_pred, average='macro'),
        'Recall (Macro)': recall_score(y_true, y_pred, average='macro'),
        'Hamming Loss': hamming_loss(y_true, y_pred)
    }

    return metrics

def train(model, train_loader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    for batch in train_loader:
        inputs, targets = batch
        inputs, targets = inputs.to(device), targets.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    return total_loss / len(train_loader)

def evaluate(model, test_loader, criterion, device):
    model.eval()
    total_loss = 0
    all_preds = []
    all_targets = []

    with torch.no_grad():
        for batch in test_loader:
            inputs, targets = batch
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            total_loss += loss.item()
            all_preds.append(outputs)
            all_targets.append(targets)

    all_preds = torch.cat(all_preds, dim=0)
    all_targets = torch.cat(all_targets, dim=0)

    metrics = evaluate_metrics(all_targets, all_preds)

    return total_loss / len(test_loader), metrics

In [29]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ase_model.to(device)
optimizer = torch.optim.Adam(ase_model.parameters(), lr=1e-5)
criterion = torch.nn.BCELoss()

# Training and evaluation loop
num_epochs = 10
for epoch in range(num_epochs):
    train_loss = train(ase_model, train_loader, optimizer, criterion, device)
    test_loss, metrics = evaluate(ase_model, test_loader, criterion, device)

    print(f"Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}")
    for metric_name, metric_value in metrics.items():
        print(f"{metric_name}: {metric_value:.4f}")
        print("")


  _warn_prf(average, modifier, msg_start, len(result))


Epoch 1, Train Loss: 0.3600, Test Loss: 0.2980
F1 Score (Macro): 0.1748

Precision (Macro): 0.1152

Recall (Macro): 0.6000

Hamming Loss: 0.8848



  _warn_prf(average, modifier, msg_start, len(result))


Epoch 2, Train Loss: 0.3067, Test Loss: 0.2657
F1 Score (Macro): 0.1748

Precision (Macro): 0.1152

Recall (Macro): 0.6000

Hamming Loss: 0.8848



  _warn_prf(average, modifier, msg_start, len(result))


Epoch 3, Train Loss: 0.2849, Test Loss: 0.2495
F1 Score (Macro): 0.1748

Precision (Macro): 0.1152

Recall (Macro): 0.6000

Hamming Loss: 0.8848



  _warn_prf(average, modifier, msg_start, len(result))


Epoch 4, Train Loss: 0.2721, Test Loss: 0.2420
F1 Score (Macro): 0.1748

Precision (Macro): 0.1152

Recall (Macro): 0.6000

Hamming Loss: 0.8848



  _warn_prf(average, modifier, msg_start, len(result))


Epoch 5, Train Loss: 0.2669, Test Loss: 0.2347
F1 Score (Macro): 0.1748

Precision (Macro): 0.1152

Recall (Macro): 0.6000

Hamming Loss: 0.8848



  _warn_prf(average, modifier, msg_start, len(result))


Epoch 6, Train Loss: 0.2618, Test Loss: 0.2343
F1 Score (Macro): 0.1748

Precision (Macro): 0.1152

Recall (Macro): 0.6000

Hamming Loss: 0.8848



  _warn_prf(average, modifier, msg_start, len(result))


Epoch 7, Train Loss: 0.2562, Test Loss: 0.2311
F1 Score (Macro): 0.1748

Precision (Macro): 0.1152

Recall (Macro): 0.6000

Hamming Loss: 0.8848



  _warn_prf(average, modifier, msg_start, len(result))


Epoch 8, Train Loss: 0.2590, Test Loss: 0.2268
F1 Score (Macro): 0.1748

Precision (Macro): 0.1152

Recall (Macro): 0.6000

Hamming Loss: 0.8848



  _warn_prf(average, modifier, msg_start, len(result))


Epoch 9, Train Loss: 0.2572, Test Loss: 0.2310
F1 Score (Macro): 0.1748

Precision (Macro): 0.1152

Recall (Macro): 0.6000

Hamming Loss: 0.8848

Epoch 10, Train Loss: 0.2565, Test Loss: 0.2261
F1 Score (Macro): 0.1748

Precision (Macro): 0.1152

Recall (Macro): 0.6000

Hamming Loss: 0.8848



  _warn_prf(average, modifier, msg_start, len(result))
