In [1]:
import torch
from torch import nn
from enformer_pytorch import from_pretrained, Enformer
from enformer_pytorch.finetune import HeadAdapterWrapper
import numpy as np
import random
from scipy.stats import pearsonr
import matplotlib.pyplot as plt 
import matplotlib.cm as cm
import seaborn as sns
import pandas as pd
from torch.utils.data import TensorDataset, DataLoader
from sklearn.model_selection import train_test_split
import torchmetrics
from torchmetrics.regression import PearsonCorrCoef

# Enformer Fine-Tuning

In this notebook, we are going to attempt to fine-tune the Enformer model and optimize the parameters needed. Here is a list of parameters we are going to test: 

**1)** Use just the TCGA data for fine-tuning (already confirmed that it works better. See PCA).\
**2)** Ensure data loader works properly and we get the sequences we want. \
**3)** Normalize gene counts and accordingly remove the soft max layer at the end of the model. Also means using MSE as loss function. \
**4)** Take the left-most tensor instead of averaging the middle two tensors together. Better metric. 

**NOTE**: this notebook can be adjusted based on different parameters - ideas include predicting raw vs. normalized counts, max-pooling/optimizing which tensor to pick, learning rate change, downsampling - feel free to edit this notebook accordingly. 

**DOWNSAMPLE TRAINING DATA NOW**: we will use only 1500 genes/sequences in order to save time for our training. The goal is just to see if we CAN overfit on new training data. We can comment out the downsampling step below to fine-tune on all of the genes.  

### Filter for just TCGA data

In [2]:
%%html
<style>
table {float:left}
</style>

| Cohort | Donors | Samples |
|:------:|:---:|:---:|
| TCGA-COAD | 459 | 522 |
| TCGA-READ | 168 | 173 |
| HCMI-CMDC | 66 | 138 |
| **TOTAL** | **693** | **833** | 

In [3]:
# Load gene counts and metadata in
gene_counts = pd.read_csv('/pollard/home/aravi1/CRC_TCGA_HCMI_data/crc_gene_counts.tsv', sep='\t')
metadata = pd.read_csv('/pollard/home/aravi1/CRC_TCGA_HCMI_data/crc_metadata.tsv', sep='\t')

In [4]:
# Rename count columns to each biosample ID 
biosample_id = metadata['File.Name'].str.split('.').str[0]
gene_counts.columns.values[3:] = biosample_id 
metadata['id'] = biosample_id

In [5]:
gene_counts.head()

Unnamed: 0,gene_id,gene_name,gene_type,90c9f8cd-4c8c-4f07-af2f-e17db69bd561,0f35c851-1cb8-4f75-a661-eae9111b7362,ff11a9e3-d32b-431b-9ebb-c5a3d9eb0e4f,8ee55a63-4e87-4c00-8012-4c87efdcb7ed,5c3c4b79-0682-4f19-96aa-071316a354d4,05167d53-0b47-4131-bfa4-450b236b9fd5,a49e0bfd-c6f5-4fc5-9eb5-0eeab117124f,...,e356598b-7611-492a-98ed-ef2ec1b77b7a,a3380ca6-6c65-4543-9bdf-957b26c1daaf,c5f2e898-f42a-44a2-9b3f-8af491c99857,47661ed9-5c0e-442d-b072-3da8b14fab02,9ed0c331-aa64-4928-83d1-7fb6be9b0d24,5472fa65-a8e2-4593-abd8-e241d1bdec84,c98732eb-48c4-4554-bdf6-0e9a0a9273ec,9da01737-be8f-4af5-9f8f-47bb892b6339,dfb9c45f-cb52-4e36-b663-1a9e8c7c0b47,ad30e5e1-182b-4758-8e70-53e0bcf78072
0,ENSG00000000003.15,TSPAN6,protein_coding,102.6828,180.1091,196.979,178.5252,200.7512,36.3378,103.4066,...,129.5642,167.1648,138.24,122.9009,99.9213,329.6668,88.3904,84.0686,168.7268,187.4724
1,ENSG00000000005.6,TNMD,protein_coding,0.5909,4.2922,7.044,2.069,1.9239,1.4758,0.2529,...,1.5335,0.2034,0.346,15.011,0.5999,4.1491,0.0,4.6429,2.1025,5.4324
2,ENSG00000000419.13,DPM1,protein_coding,189.5382,237.603,259.5645,190.5531,156.6911,16.1225,134.7653,...,209.5608,175.5828,287.7883,74.4915,253.1841,269.691,127.441,106.8103,207.7205,234.1476
3,ENSG00000000457.14,SCYL3,protein_coding,15.8945,6.3773,6.4526,9.2412,5.0334,2.0466,7.9265,...,14.239,8.1337,13.0432,2.5129,9.182,5.2358,13.4162,7.1353,11.1024,9.5908
4,ENSG00000000460.17,C1orf112,protein_coding,14.0779,5.9787,6.6893,5.0715,4.6298,2.9191,10.1571,...,14.2176,6.0589,7.1255,1.9394,8.0648,6.3916,9.4148,2.7023,14.1215,5.8139


In [6]:
tcga_samples = pd.concat([pd.Series(['gene_id', 'gene_name', 'gene_type']), metadata.id[metadata['Project'].str.contains('TCGA')]])
gene_counts = gene_counts.loc[:, gene_counts.columns.isin(tcga_samples)]

Now, our gene count matrix only has TCGA samples. 

### Normalization/median calculation + Filter

First, let's filter our count matrix to only the genes that were originally used to train Enformer. This way, we can maintain the same training/test splits that the original model used as to prevent bias.

At this step, we can take two paths: 
- **Normalization:** for Enformer to predict normalized counts and to log-transform/normalize our target data, set this to True. If  normalization is False, Enformer resorts to default - predicting raw counts. When normalization=False, target data=raw counts. 
- **Filter Zero Genes:** to train Enformer on genes with non-zero TPM counts, set to True 



In [7]:
# Change normalization parameter to true if you want Enformer to predict normalized TPM counts 
# rather than raw TPM counts 

# Set filter_zero_genes to True if we want to train model with only genes that have non-zero TPM expression
normalization = False
filter_zero_genes = False

In [8]:
filtered_counts = gene_counts[gene_counts['gene_name'].duplicated() == False]


# NORMALIZATION HERE - normalization means taking ln(TPM) - adding pseudocount 
if (normalization == True): 
    filtered_counts.iloc[:, 3:] = np.log(filtered_counts.iloc[:, 3:] + 0.00001)

print(filtered_counts['gene_name'].nunique())
print(filtered_counts.shape)

59427
(59427, 698)


In [9]:
# Calculate median across samples here 
filtered_counts['median'] = filtered_counts.iloc[:,3:].median(axis=1)
filtered_counts = filtered_counts[['gene_id', 'gene_name', 'gene_type', 'median']]

if (filter_zero_genes == True):
    filtered_counts = filtered_counts[filtered_counts['median'] < -7.5]
    
print(filtered_counts.shape)

(59427, 4)


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  filtered_counts['median'] = filtered_counts.iloc[:,3:].median(axis=1)


In [10]:
### INPUTS
# training sequences
# validation sequences
# test sequences 

TSS_centered_genes = pd.read_csv('Enformer_genomic_regions_TSSCenteredGenes_FixedOverlapRemoval.csv')
print(filtered_counts[filtered_counts['gene_name'].isin(TSS_centered_genes['gene_name'])].shape)
filtered_counts = filtered_counts[filtered_counts['gene_name'].isin(TSS_centered_genes['gene_name'])]

(17539, 4)


### Downsampling/Training-Test-Validation Split

In [11]:
TSS_centered_genes = TSS_centered_genes[TSS_centered_genes['gene_name'].isin(filtered_counts['gene_name'])]

# obtain training, validation and test genes here
training_genes, test_genes, validation_genes = TSS_centered_genes['gene_name'][TSS_centered_genes['set'] == "train"], TSS_centered_genes['gene_name'][TSS_centered_genes['set'] == "test"], TSS_centered_genes['gene_name'][TSS_centered_genes['set'] == "valid"]

# downsample training genes here - comment this out to train on all of the genes 
training_genes = training_genes.sample(n=1500, random_state=42)

### INPUTS
# training sequences
# validation sequences
# test sequences 

### shorten sequences from 196,608 bp to 49,152 as interval length. 
TSS_centered_genes['starts'] = (TSS_centered_genes['starts'] + ((196608 / 8) * 3)).astype(int)
TSS_centered_genes['ends'] = (TSS_centered_genes['ends'] - ((196608 / 8) * 3)).astype(int)
assert (TSS_centered_genes['ends'] - TSS_centered_genes['starts'] == 49152).all()
assert (TSS_centered_genes['ends'] - TSS_centered_genes['gene_start'].astype(int) == 24576).all()

training_sequences = TSS_centered_genes[TSS_centered_genes['gene_name'].isin(training_genes)]
validation_sequences = TSS_centered_genes[TSS_centered_genes['gene_name'].isin(validation_genes)]
test_sequences = TSS_centered_genes[TSS_centered_genes['gene_name'].isin(test_genes)]

In [12]:
len(test_genes)

1800

In [13]:
### OUTPUTS
# training tracks 
# validation tracks 
# testing tracks 

training_tracks = filtered_counts[filtered_counts['gene_name'].isin(training_genes)]
validation_tracks = filtered_counts[filtered_counts['gene_name'].isin(validation_genes)]
test_tracks = filtered_counts[filtered_counts['gene_name'].isin(test_genes)]

In [15]:
### Write input sequences to bed files for GenomeIntervalDataset to load to fine-tune

training_sequences = pd.merge(training_sequences, training_tracks, on="gene_name", how="inner")
validation_sequences = pd.merge(validation_sequences, validation_tracks, on="gene_name", how="inner")
test_sequences = pd.merge(test_sequences, test_tracks, on="gene_name", how="inner")

training_sequences.to_csv("training_sequences.bed", sep='\t', header=None, index=None)
validation_sequences.to_csv("validation_sequences.bed", sep='\t', header=None, index=None)
test_sequences.to_csv("test_sequences.bed", sep='\t', header=None, index=None)

### Data Loader/Batch Training-Test Split Samples

Next, let's ensure that our data loader works properly and we are getting the sequences that we want. Our sequence length that we will use is 49,152 bp (196,608 / 4). This allows us to effectively batch our sequences. 

In [None]:
#### DEFINE TARGET VARIABLES 
target_length = 384
TSS_tensor_pos1, TSS_tensor_pos2 = (target_length / 2) - 1, (target_length / 2)
batch_size = 4

In [16]:
import random

np.random.seed(150)

import torch
from enformer_pytorch import from_pretrained
from enformer_pytorch.finetune import HeadAdapterWrapper
from scipy.stats import pearsonr

import torch
import polars as pl
from enformer_pytorch import Enformer, GenomeIntervalDataset

filter_train = lambda df: df.filter(pl.col('column_4') == 'train')
filter_test = lambda df: df.filter(pl.col('column_4') == 'test')
filter_valid = lambda df: df.filter(pl.col('column_4') == 'valid')

# shift_augs by -64 to 64 try again 
training_ds = GenomeIntervalDataset(
    bed_file = 'training_sequences.bed',                       # bed file - columns 0, 1, 2 must be <chromosome>, <start position>, <end position>
    fasta_file = 'hg38.fa',                        # path to fasta file
    filter_df_fn = filter_train,                        # filter dataframe function
    return_seq_indices = True,                          # return nucleotide indices (ACGTN) or one hot encodings
    # shift_augs = (-64, 64),                               # random shift augmentations from -2 to +2 basepairs
    context_length = 49_152,
    # this can be longer than the interval designated in the .bed file,
    # in which case it will take care of lengthening the interval on either sides
    # as well as proper padding if at the end of the chromosomes
)

test_ds = GenomeIntervalDataset(
    bed_file = 'test_sequences.bed',                       
    fasta_file = 'hg38.fa',
    filter_df_fn = filter_test,
    return_seq_indices = True,                         
    context_length = 49_152,

)

validation_ds = GenomeIntervalDataset(
    bed_file = 'validation_sequences.bed',                       
    fasta_file = 'hg38.fa',
    filter_df_fn = filter_valid,
    return_seq_indices = True,                       
    context_length = 49_152,
    
)

In [17]:
# Create a Dataset class to easily join input sequences and target TPM counts together. 

import torch
from torch.utils.data import Dataset, DataLoader

class MyDataset(Dataset):
    def __init__(self, sequences, targets):
        self.sequences = sequences
        self.targets = targets

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        sequence = ((self.sequences[idx]) if len(self.sequences[idx]) > 0 else torch.zeros((1,)).clone().detach()).cuda() # Handling empty lists
        target = torch.tensor(self.targets[idx]).cuda()
        return sequence, target

In [18]:
training = MyDataset(training_ds, training_sequences['median'])
validation = MyDataset(validation_ds, validation_sequences['median'])
test = MyDataset(test_ds, test_sequences['median'])

In [19]:
# Create DataLoader instances
torch.manual_seed(42)
np.random.seed(150)

train_loader = DataLoader(training, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(validation, batch_size=batch_size)
test_loader = DataLoader(test, batch_size=batch_size)

## Model Architecture 

In [20]:
# Define Enformer revised/fine-tuned model here. 
from torch.nn import Sequential 

# 1,536 * 2 - pointwise convolutional
# compute target length based on tensor shape 

class EnformerFineTuning(nn.Module):
    def __init__(
        self,
        enformer,
        num_tracks,
        post_transformer_embed = False, # whether to take the embeddings from right after the transformer, instead of after the final pointwise convolutional - this would add another layernorm
        ):
        super().__init__()
        assert isinstance(enformer, Enformer)
        enformer_hidden_dim = enformer.dim * (2 if not post_transformer_embed else 1)
        
        self.enformer = enformer.cuda()
        # predicting normalized counts - no softplus activation layer 
        if (normalization == True):
            self.to_gene_counts = Sequential(
                nn.Linear(enformer_hidden_dim, num_tracks, bias=True).cuda(),
            )
        # predicting raw counts - add softplus activation layer 
        else:
            self.to_gene_counts = Sequential(
                nn.Linear(enformer_hidden_dim, num_tracks, bias=True).cuda(),
                nn.Softplus(beta=1, threshold=20).cuda()
            )
    
    def forward(
        self,
        seq,
        target = None,
        freeze_enformer = False,
        finetune_enformer_ln_only = False,
        finetune_last_n_layers_only = None,
    ):
        enformer_kwargs = dict()

        # enformer_kwargs = dict(target_length = 1)
        
        # returning only the embeddings here 
        embeddings = self.enformer(seq, return_only_embeddings=True)
        # print(embeddings.size())
        
        # batch_dim, seq_dim, feature_dim = embeddings.size()
        
        # compute center tensor by floor dividing - convention is just to take the left 
        
        # take center positions of embedding - we take left tensor in the middle 
        TSS_tensor1 = embeddings[:, int(TSS_tensor_pos1)]
        
        # convert embedding to gene count value 
        preds = self.to_gene_counts(TSS_tensor1)

        return preds
    
    def _log(self, t, eps = 1e-20):
        return torch.log(t.clamp(min = eps))
    
    # LOSS FUNCTION for Enformer 
    def poisson_loss(self, pred, target):
        return (pred - target * self._log(pred)).mean()

## Fine-tuning: Training Step 

Finally, we will train and fine-tune our model in this step. After creating our data loader and preprocessing our target counts, we use a Poisson loss function for predicting raw counts + MSE loss for predicting log transformed counts. 

While the model converges in less than 3 iterations usually, we can control the iterations with num_iterations. We use an ADAM optimizer with a learning rate of 0.00005. 

In [21]:
# Load the pre-trained Enformer model
# pytorch lightning - does everything for you 
enformer = from_pretrained('EleutherAI/enformer-official-rough', target_length=target_length, use_tf_gamma=False).cuda()
torch.manual_seed(42)

device = torch.device("cuda:0")

# Instantiate the HeadAdapterWrapper
model = EnformerFineTuning(enformer=enformer, num_tracks=1, post_transformer_embed=False).cuda()
model = model.to(device)
model.train()

# Number of training iterations
num_iterations = 5

### Define optimizer here 
# Define optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.00005)

### loss function 
criterion = torch.nn.MSELoss()

# training variables 
training_targets=[]
training_outputs=[]
training_loss=[]

# validation inputs
validation_targets=[]
validation_outputs=[]
validation_loss=[]

pearson = PearsonCorrCoef().to("cuda:0")

# gradients=[]
# Training loop
for i in range(num_iterations):
    model.train()

    j = 0
    training_loss_epoch = []
    outputs=[]
    targets=[]
    for seq, target in train_loader:
        # Forward pass
        
        target = torch.reshape(target, [target.size()[0], 1])
        
        optimizer.zero_grad()
        
        output = model(seq)
        
        if (normalization == True):
            loss = criterion(output.float(), target.float())
        else: 
            loss = model.poisson_loss(output, target)
        
        training_loss_epoch.append(loss.item())

        outputs.append(output)
        targets.append(target)

        # Backward pass
        loss.backward()
        
        if (j % 100 == 0):
            print(j)

        j += 1
        optimizer.step()
        loss = loss.detach()
    
    training_outputs.append(outputs)
    training_targets.append(targets)
    training_loss.append(training_loss_epoch)
    average_loss = np.mean(np.array(training_loss_epoch))
        
    # NORMALIZATION
    if (normalization == True):
        average_accuracy = pearson(torch.cat(outputs), torch.cat(targets))
    else:
        average_accuracy = pearson(torch.log(torch.cat(outputs) + 0.00001), torch.log(torch.cat(targets) + 0.00001))
    
    print(f"Epoch {i+1}, Loss: {round(average_loss, 5)}, Accuracy: {round(average_accuracy.item(), 5)}")
    
    model.eval()

    k = 0
    val_loss_epoch = []
    val_outputs=[]
    val_targets=[]
    with torch.no_grad():
        for val_seq, val_target in val_loader:
            
            val_target = torch.reshape(val_target, [val_target.size()[0], 1])
            
            val_output = model(val_seq)
            
            if (normalization == True): 
                val_loss = criterion(val_output.float(), val_target.float())
            else: 
                val_loss = model.poisson_loss(val_output, val_target)
            
            
            val_loss_epoch.append(val_loss.item())
            
            val_outputs.append(val_output)
            val_targets.append(val_target)
            
            if (k % 100 == 0):
                print(k)
                
            k += 1

        
        validation_targets.append(val_targets)
        validation_outputs.append(val_outputs)
        validation_loss.append(val_loss_epoch)
        
        validation_loss_epoch = np.mean(np.array(val_loss_epoch))
        
        ### NORMALIZATION
        if (normalization == True):
            validation_accuracy = pearson(torch.cat(val_outputs), torch.cat(val_targets))
        else: 
            validation_accuracy = pearson(torch.log(torch.cat(val_outputs) + 0.00001), torch.log(torch.cat(val_targets) + 0.00001))

        print(f"Epoch {i+1}, Validation Loss: {validation_loss_epoch}, Validation Accuracy: {round(validation_accuracy.item(), 5)}")
        

0
100
200
300
Epoch 1, Loss: -62.16759, Accuracy: 0.43704
0
100
200
300
Epoch 1, Validation Loss: -64.62198406904345, Validation Accuracy: 0.61948
0
100
200
300
Epoch 2, Loss: -109.68409, Accuracy: 0.61692
0
100
200
300
Epoch 2, Validation Loss: -102.7616125268387, Validation Accuracy: 0.65311
0
100
200
300
Epoch 3, Loss: -122.40199, Accuracy: 0.65467
0
100
200
300
Epoch 3, Validation Loss: -103.62525866625177, Validation Accuracy: 0.66587
0
100
200
300
Epoch 4, Loss: -132.52928, Accuracy: 0.68618
0
100
200
300
Epoch 4, Validation Loss: -107.57043206828367, Validation Accuracy: 0.64918
0
100
200
300
Epoch 5, Loss: -142.23221, Accuracy: 0.68424
0
100
200
300
Epoch 5, Validation Loss: -112.09365465196903, Validation Accuracy: 0.62735


## Fine-tuning: evaluate on test data

In [59]:
# Assuming you have a test_loader similar to train_loader and val_loader

test_targets = []
test_outputs = []
test_loss = []

model.eval()  # Set the model to evaluation mode

with torch.no_grad():
    l = 0
    for test_seq, test_target in test_loader:
        test_target = test_target.reshape([test_target.size()[0], 1])
        test_output = model(test_seq)
        if (normalization == True):
            test_loss_batch = criterion(test_output.float(), test_target.float())
        else: 
            test_loss_batch = model.poisson_loss(test_output, test_target).item()

        test_targets.append(test_target)
        test_outputs.append(test_output)
        test_loss.append(test_loss_batch)
        
        if (l % 100 == 0):
            print(l)
        
        l += 1

if (normalization == True):
    test_accuracy = pearsonr(torch.cat(test_outputs).cpu().detach().numpy().flatten(), torch.cat(test_targets).cpu().detach().numpy().flatten())
    test_loss_mean = np.mean(np.array([tensor.cpu().numpy() for tensor in test_loss]))
else: 
    test_loss_mean = np.mean(np.array(test_loss))
    test_accuracy = pearsonr(torch.log(torch.cat(test_outputs) + 0.00001).cpu().detach().numpy().flatten(), torch.log(torch.cat(test_targets) + 0.00001).cpu().detach().numpy().flatten())
    
print(f"Test Loss: {test_loss_mean}, Test Accuracy: {round(test_accuracy[0], 5)}")

0
Test Loss: 0.45113492012023926, Test Accuracy: nan


  test_accuracy = pearsonr(torch.cat(test_outputs).cpu().detach().numpy().flatten(), torch.cat(test_targets).cpu().detach().numpy().flatten())


### Save fine-tuned model and all training/test/validation outputs for evaluation. 

In [60]:
import pickle as pkl

file_path = '240506_enformer_raw/enformer_finetuned.pkl'

with open(file_path, 'wb') as f:
    pkl.dump(model, f)

In [61]:
# training variables 
torch.save(training_targets, '240506_enformer_raw/training_targets.pt')
torch.save(training_outputs, '240506_enformer_raw/training_outputs.pt')
torch.save(training_loss, '240506_enformer_raw/training_loss.pt')

# validation inputs
torch.save(validation_targets, '240506_enformer_raw/validation_targets.pt')
torch.save(validation_outputs, '240506_enformer_raw/validation_outputs.pt')
torch.save(validation_loss, '240506_enformer_raw/validation_loss.pt')

torch.save(test_targets, '240506_enformer_raw/test_targets.pt')
torch.save(test_outputs, '240506_enformer_raw/test_outputs.pt')
torch.save(test_loss, '240506_enformer_raw/test_loss.pt')