In [1]:
import torch
from torch import nn
from enformer_pytorch import from_pretrained, Enformer
from enformer_pytorch.finetune import HeadAdapterWrapper
import numpy as np
import random
from scipy.stats import pearsonr
import matplotlib.pyplot as plt 
import matplotlib.cm as cm
import seaborn as sns
import pandas as pd
from torch.utils.data import TensorDataset, DataLoader
from sklearn.model_selection import train_test_split
import torchmetrics
from torchmetrics.regression import PearsonCorrCoef
import pickle as pkl

# Evaluation on different cancer types

In this model, we are going to evaluate our fine-tuned model on cancer samples across different tissues. The framework here will be to check how well model performance correlates with tissue/cancer similarity.  

In [2]:
# Load gene counts and metadata in
all_cell_type_counts = pd.read_csv("/pollard/home/aravi1/TCGA_matrices/all_cell_types/all_cell_type_counts.tsv.gz", sep='\t')
all_cell_type_metadata = pd.read_csv("/pollard/home/aravi1/TCGA_matrices/all_cell_types/all_cell_type_metadata.tsv", sep='\t')

In [3]:
%%html
<style>
table {float:left}
</style>

| Cohort | Samples | Cancer Type | 
|:------:|:---:|:-----------|
| COAD-READ | 695 | Adenocarcinoma |
| BLCA | 431 | Carcinoma | 
| KIRC | 614 | Carcinoma |
| LGG | 534 | Glioma |
| LUAD | 600 | Adenocarcinoma | 
| LUSC | 553 | Sarcoma |
| PRAD | 554 | Adenocarcinoma |
| **TOTAL** | **3981** |

In [4]:
gene_names = pd.read_csv("crc_gene_counts.tsv", sep="\t")
gene_names = gene_names['gene_name']

all_cell_type_counts = pd.concat([gene_names, all_cell_type_counts.iloc[:, 1:]], axis=1)

In [5]:
TSS_centered_genes = pd.read_csv('Enformer_genomic_regions_TSSCenteredGenes_FixedOverlapRemoval.csv')
TSS_centered_genes = TSS_centered_genes[TSS_centered_genes['gene_name'].isin(all_cell_type_counts['gene_name'])]

# obtain training, validation and test genes here
test_genes = TSS_centered_genes['gene_name'][TSS_centered_genes['set'] == "test"]


### INPUTS
# validation sequences
# test sequences 

### shorten sequences from 196,608 bp to 49,152 as interval length. 
TSS_centered_genes['starts'] = (TSS_centered_genes['starts'] + ((196608 / 8) * 3)).astype(int)
TSS_centered_genes['ends'] = (TSS_centered_genes['ends'] - ((196608 / 8) * 3)).astype(int)
assert (TSS_centered_genes['ends'] - TSS_centered_genes['starts'] == 49152).all()
assert (TSS_centered_genes['ends'] - TSS_centered_genes['gene_start'].astype(int) == 24576).all()

test_sequences = TSS_centered_genes[TSS_centered_genes['gene_name'].isin(test_genes)]
test_sequences = test_sequences[test_sequences['gene_name'].isin(all_cell_type_counts['gene_name'])]

In [6]:
normalization=False
batch_size = 4

In [7]:
all_cell_type_metadata

Unnamed: 0.1,Unnamed: 0,Project,sample_id
0,0,TCGA-COAD,90c9f8cd-4c8c-4f07-af2f-e17db69bd561
1,1,TCGA-COAD,0f35c851-1cb8-4f75-a661-eae9111b7362
2,2,TCGA-READ,ff11a9e3-d32b-431b-9ebb-c5a3d9eb0e4f
3,3,TCGA-COAD,8ee55a63-4e87-4c00-8012-4c87efdcb7ed
4,4,TCGA-COAD,5c3c4b79-0682-4f19-96aa-071316a354d4
...,...,...,...
3976,549,TCGA-PRAD,1f3cabd7-c794-4083-a010-bc828e82608f
3977,550,TCGA-PRAD,9493a088-7cc8-4d7f-ad8b-946b9b9c4a2b
3978,551,TCGA-PRAD,fad685af-19be-4f40-9dd2-34ed966ffeae
3979,552,TCGA-PRAD,ac8abf7f-05eb-40d8-b8a8-5a2e2febfcbd


# Data Loader Here

In [8]:
import random
import torch
from enformer_pytorch import from_pretrained, Enformer, GenomeIntervalDataset
from enformer_pytorch.finetune import HeadAdapterWrapper
from scipy.stats import pearsonr
import polars as pl

np.random.seed(150)

filter_test = lambda df: df.filter(pl.col('column_4') == 'test')
test_ds = GenomeIntervalDataset(
    bed_file = 'test_sequences.bed',                       
    fasta_file = 'hg38.fa',
    filter_df_fn = filter_test,
    return_seq_indices = True,                         
    context_length = 49_152
)

In [9]:
import torch
from torch.utils.data import Dataset, DataLoader

class MyDataset(Dataset):
    def __init__(self, sequences, targets, genes, cell_type):
        self.sequences = sequences
        self.targets = targets
        self.gene = genes 
        self.cell_type = cell_type

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        sequence = ((self.sequences[idx]) if len(self.sequences[idx]) > 0 else torch.zeros((1,)).clone().detach()).cuda() # Handling empty lists
        target = torch.tensor(self.targets[idx]).cuda()
        gene = self.gene[idx]
        cell_type = self.cell_type
        return sequence, target, gene, cell_type

In [10]:
#### DEFINE TARGET VARIABLES 
target_length = 384
TSS_tensor_pos1, TSS_tensor_pos2 = (target_length / 2) - 1, (target_length / 2)
batch_size = 4

# Define Enformer revised/fine-tuned model here. 
from torch.nn import Sequential 

# 1,536 * 2 - pointwise convolutional
# compute target length based on tensor shape 

class EnformerFineTuning(nn.Module):
    def __init__(
        self,
        enformer,
        num_tracks,
        post_transformer_embed = False, # whether to take the embeddings from right after the transformer, instead of after the final pointwise convolutional - this would add another layernorm
        ):
        super().__init__()
        assert isinstance(enformer, Enformer)
        enformer_hidden_dim = enformer.dim * (2 if not post_transformer_embed else 1)
        
        self.enformer = enformer.cuda()

        if (normalization == True):
            self.to_gene_counts = Sequential(
                nn.Linear(enformer_hidden_dim, num_tracks, bias=True).cuda(),
                # nn.Softplus(beta=1, threshold=20).cuda()
            )
        else:
            self.to_gene_counts = Sequential(
                nn.Linear(enformer_hidden_dim, num_tracks, bias=True).cuda(),
                nn.Softplus(beta=1, threshold=20).cuda()
            )
    
    def forward(
        self,
        seq,
        target = None,
        freeze_enformer = False,
        finetune_enformer_ln_only = False,
        finetune_last_n_layers_only = None,
    ):
        enformer_kwargs = dict()

        # enformer_kwargs = dict(target_length = 1)
        
        # returning only the embeddings here 
        embeddings = self.enformer(seq, return_only_embeddings=True)
        # print(embeddings.size())
        
        # batch_dim, seq_dim, feature_dim = embeddings.size()
        
        # compute center tensor by floor dividing - convention is just to take the left 
        
        # take center positions of embedding and average the two 
        TSS_tensor1 = embeddings[:, int(TSS_tensor_pos1)]
        # TSS_tensor2 = embeddings[:, int(TSS_tensor_pos2)]
        # mean_embeddings_TSS = torch.add(TSS_tensor1, TSS_tensor2) / 2
        
        # convert mean embedding to gene count value 
        preds = self.to_gene_counts(TSS_tensor1)

        return preds
    
    def _log(self, t, eps = 1e-20):
        return torch.log(t.clamp(min = eps))
    
    # LOSS FUNCTION for Enformer 
    def poisson_loss(self, pred, target):
        return (pred - target * self._log(pred)).mean()

In [11]:
file_path = "/pollard/home/aravi1/240506_enformer_raw/enformer_finetuned.pkl"
with open(file_path, 'rb') as f:
    model = pkl.load(f)

In [13]:
# Create DataLoader instances
torch.manual_seed(42)
np.random.seed(150)

def eval_per_cell_type(cell_type, test_sequences, test_ds):
    
    test_targets = []
    test_outputs = []
    test_loss = []
    test_genes = []
    
    print(cell_type)

    # Create DataLoader instances
    torch.manual_seed(42)
    np.random.seed(150)

    cell_type_samples = all_cell_type_metadata['sample_id'][all_cell_type_metadata['Project'] == cell_type]
    cell_type_counts = all_cell_type_counts.loc[:, all_cell_type_counts.columns.isin(cell_type_samples)]
    cell_type_counts['median'] = cell_type_counts.median(axis=0)

    cell_type_counts['median'] = cell_type_counts.median(axis=1)
    cell_type_counts = pd.concat([all_cell_type_counts['gene_name'], cell_type_counts], axis=1)
    cell_type_counts = cell_type_counts[cell_type_counts['gene_name'].isin(test_sequences['gene_name'])]
    cell_type_counts = cell_type_counts.drop_duplicates(subset=['gene_name'], keep='first')
    
    cell_type_counts['gene_name'] = pd.Categorical(cell_type_counts['gene_name'], test_sequences['gene_name'])
    cell_type_counts_revised = cell_type_counts.sort_values("gene_name")
    
    cell_type_counts_revised = cell_type_counts_revised.reset_index(drop=True)

    test = MyDataset(test_ds, cell_type_counts_revised['median'], cell_type_counts_revised['gene_name'], cell_type)
    test_loader = DataLoader(test, batch_size=batch_size, shuffle=False)
    
    model.eval()  # Set the model to evaluation mode
    
    with torch.no_grad():
        l = 0
        for test_seq, test_target, test_gene, cancer_type in test_loader:
            test_target = test_target.reshape([test_target.size()[0], 1])
            test_output = model(test_seq)
            if (normalization == True):
                test_loss_batch = criterion(test_output.float(), test_target.float())
            else: 
                test_loss_batch = model.poisson_loss(test_output, test_target).item()

            test_genes.append(test_gene)
            test_targets.append(test_target)
            test_outputs.append(test_output)
            test_loss.append(test_loss_batch)

            if (l % 100 == 0):
                print(l)

            l += 1

    if (normalization == True):
        test_accuracy = pearsonr(torch.cat(test_outputs).cpu().detach().numpy().flatten(), torch.cat(test_targets).cpu().detach().numpy().flatten())
        test_loss_mean = np.mean(np.array([tensor.cpu().numpy() for tensor in test_loss]))
    else: 
        test_loss_mean = np.mean(np.array(test_loss))
        test_accuracy = pearsonr(torch.log(torch.cat(test_outputs) + 0.00001).cpu().detach().numpy().flatten(), torch.log(torch.cat(test_targets) + 0.00001).cpu().detach().numpy().flatten())

    print(f"Test Loss: {test_loss_mean}, Test Accuracy: {round(test_accuracy[0], 5)}")
    
    return test_targets, test_outputs, test_loss, test_genes
    
    # return test, cell_type_counts_revised 

# test, decoy = eval_per_cell_type("TCGA-COAD")
# i = 0

# COAD_targets, COAD_outputs, COAD_loss, COAD_genes = eval_per_cell_type("TCGA-COAD")
#READ_targets, READ_outputs, READ_loss, READ_genes = eval_per_cell_type("TCGA-READ")
#BLCA_targets, BLCA_outputs, BLCA_loss, BLCA_genes = eval_per_cell_type("TCGA-BLCA")
#KIRC_targets, KIRC_outputs, KIRC_loss, KIRC_genes = eval_per_cell_type("TCGA-KIRC")
#LGG_targets, LGG_outputs, LGG_loss, LGG_genes = eval_per_cell_type("TCGA-LGG")
#LUAD_targets, LUAD_outputs, LUAD_loss, LUAD_genes = eval_per_cell_type("TCGA-LUAD")
#LUSC_targets, LUSC_outputs, LUSC_loss, LUSC_genes = eval_per_cell_type("TCGA-LUSC")
#PRAD_targets, PRAD_outputs, PRAD_loss, PRAD_genes = eval_per_cell_type("TCGA-PRAD")


TypeError: eval_per_cell_type() missing 2 required positional arguments: 'test_sequences' and 'test_ds'

# Evaluate Model Performance on CRC-relevant genes

In [14]:
colon_relevant_genes = ['MUC2', 'KRT20', 'CDX2', 'ALPI', 'SSTR2', 'SLC26A3', 'FABP6', 
                        'GUCA2B', 'TFF3', 'LGR5', 'CHGA', 'CALB1']
print(len(colon_relevant_genes))
colon_relevant_file = TSS_centered_genes[TSS_centered_genes['gene_name'].isin(colon_relevant_genes)]

12


In [15]:
colon_relevant_file.to_csv("colon_sequences.bed", sep='\t', header=None, index=None)

In [16]:
crc_ds = GenomeIntervalDataset(
    bed_file = 'colon_sequences.bed',                       
    fasta_file = 'hg38.fa',
    return_seq_indices = True,                         
    context_length = 49_152
)

In [None]:
COAD_targets, COAD_outputs, COAD_loss, COAD_genes = eval_per_cell_type("TCGA-COAD", colon_relevant_file, crc_ds)
READ_targets, READ_outputs, READ_loss, READ_genes = eval_per_cell_type("TCGA-READ", colon_relevant_file, crc_ds)
BLCA_targets, BLCA_outputs, BLCA_loss, BLCA_genes = eval_per_cell_type("TCGA-BLCA", colon_relevant_file, crc_ds)
KIRC_targets, KIRC_outputs, KIRC_loss, KIRC_genes = eval_per_cell_type("TCGA-KIRC", colon_relevant_file, crc_ds)
LGG_targets, LGG_outputs, LGG_loss, LGG_genes = eval_per_cell_type("TCGA-LGG", colon_relevant_file, crc_ds)
LUAD_targets, LUAD_outputs, LUAD_loss, LUAD_genes = eval_per_cell_type("TCGA-LUAD", colon_relevant_file, crc_ds)
LUSC_targets, LUSC_outputs, LUSC_loss, LUSC_genes = eval_per_cell_type("TCGA-LUSC", colon_relevant_file, crc_ds)
PRAD_targets, PRAD_outputs, PRAD_loss, PRAD_genes = eval_per_cell_type("TCGA-PRAD", colon_relevant_file, crc_ds)

TCGA-COAD


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  cell_type_counts['median'] = cell_type_counts.median(axis=0)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  cell_type_counts['median'] = cell_type_counts.median(axis=1)


In [26]:
outputs = torch.cat(COAD_outputs).cpu().numpy().flatten()
BLCA_targets = torch.cat(BLCA_targets).cpu().numpy().flatten()
COAD_targets = torch.cat(COAD_targets).cpu().numpy().flatten()
KIRC_targets = torch.cat(KIRC_targets).cpu().numpy().flatten()
LGG_targets = torch.cat(LGG_targets).cpu().numpy().flatten()
LUAD_targets = torch.cat(LUAD_targets).cpu().numpy().flatten()
LUSC_targets = torch.cat(LUSC_targets).cpu().numpy().flatten()
PRAD_targets = torch.cat(PRAD_targets).cpu().numpy().flatten()
READ_targets = torch.cat(READ_targets).cpu().numpy().flatten()

output_target_table = pd.concat([pd.Series(outputs), pd.Series(BLCA_targets), pd.Series(COAD_targets), pd.Series(KIRC_targets), 
                                 pd.Series(LGG_targets), pd.Series(LUAD_targets), pd.Series(LUSC_targets), pd.Series(PRAD_targets), pd.Series(READ_targets)], axis=1)

In [31]:
output_target_table.columns = ['output', 'BLCA', 'COAD', 'KIRC', 
                               'LGG', 'LUAD', 'LUSC', 'PRAD', 'READ']

output_target_table.index = CRC_bed_file['gene_name']


In [34]:
output_target_table.to_csv('/pollard/home/aravi1/TCGA_matrices/TCGA_output_target.tsv', sep='\t')