<a href="https://colab.research.google.com/github/MartinekV/DL-for-bio-course/blob/master/04_DNA_tasks_ADVANCED.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Optional: Turn on GPU runtime

## Setup

In [1]:
!pip install -q genomic-benchmarks
!pip install torchmetrics -q
!pip install -q pytorch-lightning

## Text preprocessing

In [2]:
import torch
example_seq = 'ACCCTGCCAACACGGGACTTTAC'
vocab = {'A':0,'C':1,'T':2,'G':3}

In [None]:
numericalized = [vocab[c] for c in example_seq]
numericalized

In [None]:
numericalized_tensor = torch.tensor(numericalized)
ohe_seq = torch.nn.functional.one_hot(numericalized_tensor, num_classes=4)
ohe_seq

## Data preparation and exploration

In [None]:
from genomic_benchmarks.dataset_getters.pytorch_datasets import DemoCodingVsIntergenomicSeqs, HumanNontataPromoters

train_dset =  DemoCodingVsIntergenomicSeqs('train') 
# train_dset =  HumanNontataPromoters('train') 


In [None]:
train_dset[0]

In [None]:
len(train_dset)

In [None]:
from collections import Counter

lens = [len(x) for x,y in train_dset]
labels = [y for x,y in train_dset]

print(Counter(lens))
print(Counter(labels))

## Pytorch dataset

In [9]:
from torch.utils.data import Dataset, DataLoader
import torch

class MyDataset(Dataset):
    def __init__(self, raw_dataset):
      self.raw_dataset = raw_dataset
      self.vocab = {'N':0,'A':1,'C':2,'T':3,'G':4}
      
    def __len__(self):
      return len(self.raw_dataset)

    def __getitem__(self, idx):
      sequence, label = self.raw_dataset[idx]

      numericalized = [self.vocab[c] for c in sequence]
      numericalized_tensor = torch.tensor(numericalized)
      ohe_seq = torch.nn.functional.one_hot(numericalized_tensor, num_classes=len(self.vocab.keys()))
      x = ohe_seq.permute(1,0) #turning lengthx5 into 5xlength for CNN

      y = torch.tensor(label)

      return x.float(), y

dset = MyDataset(train_dset)
train_loader = DataLoader(dset, batch_size=32, shuffle=True) #without shuffle train acc = 100% showcase

In [10]:
x_batch, y_batch = next(iter(train_loader))

In [None]:
x_batch.size()

In [None]:
y_batch

## Model and training logic

In [13]:
import pytorch_lightning as pl
import torch.nn as nn
import torch.nn.functional as F
from torchmetrics import Accuracy

class CNN_PL(pl.LightningModule):
    def __init__(self, num_classes):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(in_channels=5, out_channels=16, kernel_size=3),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=3),
            nn.Flatten(),
            nn.LazyLinear(num_classes),
            nn.Softmax(dim=-1),
        )
        self.accuracy = Accuracy(task='multiclass', num_classes=num_classes)

    def forward(self, x):
        return self.net(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        pred = self(x)
        loss = F.cross_entropy(pred, y)
        self.log('train_loss', loss, prog_bar=True)

        accuracy = self.accuracy(pred, y)
        self.log('train_accuracy', accuracy, prog_bar=True)

        return loss

    def test_step(self, batch, batch_idx):
        x,y = batch
        pred = self(x)
        metrics = {'accuracy':self.accuracy(pred, y)}
        self.log_dict(metrics)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.0001)


## Training and evaluation

In [None]:
model = CNN_PL(num_classes=2)

# Optional GPU acceleration accelerator='gpu' in Trainer
trainer = pl.Trainer(max_epochs=1, accelerator='gpu')
trainer.fit(model, train_loader)

In [None]:
trainer.test(dataloaders=train_loader)

## Testing

In [None]:
test_dataset_original = DemoCodingVsIntergenomicSeqs('test')
# test_dataset_original =  HumanNontataPromoters('test') 
test_dset = MyDataset(test_dataset_original)
test_loader = DataLoader(test_dset, batch_size=32)

trainer.test(dataloaders=test_loader)

## Interpretability

In [17]:
!pip install -q captum

In [None]:
import random
vocab = {'N':0,'A':1,'C':2,'T':3,'G':4}

baseline_len = 200
# baseline = torch.tensor([random.choices(list(vocab.values()), k=baseline_len)])
baseline = torch.tensor([random.choices(([1,2,3,4]), k=baseline_len)])
# baseline = torch.tensor([[0]*baseline_len])
ohe_baseline = torch.nn.functional.one_hot(baseline, num_classes=len(vocab.keys())).permute(0,2,1).float()

datapoint_index = 20000
x,y = test_dset[datapoint_index]
x = x.unsqueeze(0)
x.requires_grad_()
sequence = test_dataset_original[datapoint_index][0]
sequence

In [None]:
from captum.attr import IntegratedGradients
import random

# Initialize Integrated Gradients
ig = IntegratedGradients(model)

# Compute attributions
attributions = ig.attribute(x, ohe_baseline, target=1, n_steps=500)
attributions = attributions.squeeze() 
# Print attributions
print('Attributions:', attributions.size())


In [None]:
from IPython.display import display, HTML
import numpy as np
# Convert attributions to numpy if it's a tensor
if torch.is_tensor(attributions):
    attributions = attributions.detach().numpy()

# Normalize attributions for better color coding
attributions = (attributions - attributions.min()) / (attributions.max() - attributions.min())

# Prepare the text
html_text = ""

# Loop over the sequence
for i,nucleotide in enumerate(sequence):
    idx = vocab[nucleotide]

    color_intensity = attributions[idx, i]
    # color_intensity = np.max(np.abs(attributions[idx,:]), axis=-1)

    # Define green and red intensities based on the color intensity
    green_intensity = int(color_intensity * 255)
    red_intensity = 255 - green_intensity

    # Generate the colored nucleotide
    colored_nucleotide = f'<span style="color: rgb({red_intensity}, {green_intensity}, 0);">{nucleotide}</span>'

    # Add the colored nucleotide to the html text
    html_text += colored_nucleotide

# Display the html text
display(HTML(html_text))
print('label',y.item(), 'pred', model(x).tolist())


## Helpers

In [21]:
#Optional helper for variable length datasets
def pad_truncate(sequence, max_len):
    if len(sequence) > max_len:
      sequence = sequence[:max_len]
    # If sequence is shorter than sequence_length, pad it with 'N'
    elif len(sequence) < max_len:
      sequence += 'N' * (max_len - len(sequence))
    return sequence