# Train CNN Classifier on human_ocr_ensembl dataset

The dataset comes from the [Genomic Benchmarks](https://github.com/ML-Bioinfo-CEITEC/genomic_benchmarks). Best reaults achieved are reported in these [tables](https://github.com/ML-Bioinfo-CEITEC/genomic_benchmarks/tree/main/experiments)

In [1]:
!pip install genomic-benchmarks optuna Dataset

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
from genomic_benchmarks.data_check import info
import optuna

[0m

In [2]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

## Get dataset

In [3]:
info("human_enhancers_cohn", 0)

Dataset `human_enhancers_cohn` has 2 classes: negative, positive.

All lengths of genomic intervals equals 500.

Totally 27791 sequences have been found, 20843 for training and 6948 for testing.


Unnamed: 0,train,test
negative,10422,3474
positive,10421,3474


In [4]:
dataset = load_dataset("katarinagresova/Genomic_Benchmarks_human_enhancers_cohn")

In [5]:
dataset['train'][0]

{'seq': 'TGGTGGTACTTGTCAGGACTTGGAGCAGCAGGTGCAAGATTTAGTGGGTTGGTTTTAGAATATCTGCTTGGAAAGTGGAAAAACTCAATGGATCATCTAGACTTTGGAATTTATCTCCTTCCCCACTTCTCCACTCCCCCAACAACAACAACAACAACAATGACAACAAAAACACCTGGAATAAACAGGTCATACAACGAGGTAGTTGATAGAATAATGTACTTTCCTTTCAGGCACCCCTTGGAGGAGGCAGATTCTGCCCTTTAAGCTGAATCTGCCTTTCCTGCATTTCCTGAAACTCCTGCATTTCCTGAAATCTTCCTGTATTTTCCTGAAATTTCCTGCCATTCCTGAAACTTTAAGGTAACTGTGTCATTAAAGGAAGGAGAGAAGGGAAGTATTAGGACTGCAGATTTGGGGTGCATGATCAGCCTGGCTCTGAGCTTGCAGACTCCCAGAGTCAGGGAAGGGAGGAGCCACCAGCAACCTTGTGGCTTACT',
 'label': 0}

## Encode and split dataset

In [6]:
def one_hot_encode(sequence, max_length=500):
    one_hot = torch.zeros((4, max_length), dtype=torch.float32)
    
    mapping = {'A': 0, 'C': 1, 'G': 2, 'T': 3}
    
    for i, nucleotide in enumerate(sequence[:max_length]):
        if nucleotide in mapping:
            one_hot[mapping[nucleotide], i] = 1.0

    return one_hot
    
class DNADataset(Dataset):
    def __init__(self, data):
        self.dataset = data
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        seq = self.dataset[idx]['seq']
        label = self.dataset[idx]['label']
        encoded_seq = one_hot_encode(seq)
        return encoded_seq, label

In [7]:
ds = dataset["train"].with_format("torch")
ds = DNADataset(ds)

train_size = int(0.8 * len(ds))
val_size = len(ds) - train_size

train_ds, val_ds = torch.utils.data.random_split(ds, [train_size, val_size])

train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)

val_loader = DataLoader(val_ds, batch_size=32, shuffle=False)

## Define model

In [8]:
class DNAClassifierCNN(nn.Module):
    def __init__(self, kernel_size=5, dropout_rate=0.3):
        super(DNAClassifierCNN, self).__init__()
        self.conv1 = nn.Conv1d(in_channels=4, out_channels=16, kernel_size=kernel_size, stride=1)
        self.pool = nn.MaxPool1d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv1d(in_channels=16, out_channels=32, kernel_size=kernel_size, stride=1)

        self.relu = nn.LeakyReLU()
        
        self.dropout = nn.Dropout(dropout_rate)
        
        self.fc1 = nn.Linear(self.count_flatten_size(), 64)
        self.fc2 = nn.Linear(64, 1)
        self.sigmoid = nn.Sigmoid()

    def count_flatten_size(self):
        dummy_input = torch.zeros(1, 4, 500)
        dummy_output = self.pool(self.conv2(self.pool(self.conv1(dummy_input))))
        return dummy_output.view(-1).size(0)
        
    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = x.reshape(x.size(0), -1) 
        
        x = self.dropout(x)
        
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        x = self.sigmoid(x)
        return x


In [9]:
# Training loop
def train_model(model, train_loader, optimizer, criterion):
    model.train()
    for batch in train_loader:
        inputs, labels = batch
        labels = labels.float().to(DEVICE)
        optimizer.zero_grad()
        
        outputs = model(inputs.to(DEVICE))
        loss = criterion(outputs.squeeze(), labels)
        loss.backward()
        optimizer.step()

In [10]:
def evaluate_model(model, test_loader, criterion):
    model.eval()
    total_loss = 0
    correct = 0
    with torch.no_grad():
        for batch in test_loader:
            inputs, labels = batch
            labels = labels.float().to(DEVICE)
            
            outputs = model(inputs.to(DEVICE))
            loss = criterion(outputs.squeeze(), labels)
            total_loss += loss.item()
            preds = (outputs.squeeze() > 0.5).float()
            correct += (preds == labels).sum().item()
    
    avg_loss = total_loss / len(test_loader)
    accuracy = correct / len(test_loader.dataset)
    return avg_loss, accuracy

In [22]:
def evaluation_loop(model, epochs, lr, optimizer, train_loader, val_loader, criterion):
    for epoch in range(epochs):

        model.train()
        train_loss = 0.0
        for batch in train_loader:
            inputs, labels = batch
            labels = labels.float().to(DEVICE)
            
            optimizer.zero_grad()  
            outputs = model(inputs.to(DEVICE))  
            loss = criterion(outputs.squeeze(), labels)  
            loss.backward()  
            optimizer.step()  
            
            train_loss += loss.item()  
        
        model.eval()  
        val_loss = 0.0
        correct = 0
        with torch.no_grad():
            for batch in val_loader:
                inputs, labels = batch
                labels = labels.float().to(DEVICE)
                
                outputs = model(inputs.to(DEVICE))  
                loss = criterion(outputs.squeeze(), labels) 
                val_loss += loss.item()
                
                preds = (outputs.squeeze() > 0.5).float() 
                correct += (preds == labels).sum().item()  
        
        avg_train_loss = train_loss / len(train_loader)
        avg_val_loss = val_loss / len(val_loader)
        val_accuracy = correct / len(val_loader.dataset)
        
        print(
            f"Epoch {epoch+1}/{epochs} - "
            f"Train Loss: {avg_train_loss:.4f}, "
            f"Val Loss: {avg_val_loss:.4f}, "
            f"Val Accuracy: {val_accuracy:.2%}"
        )
    
    return avg_val_loss, val_accuracy


## Perform training

In [28]:
from torch.utils.data import DataLoader, random_split
import torch

ds = dataset["train"].with_format("torch")
ds = DNADataset(ds)

train_size = int(0.8 * len(ds))
val_size = len(ds) - train_size

train_ds, val_ds = random_split(ds, [train_size, val_size])

train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)

val_loader = DataLoader(val_ds, batch_size=32, shuffle=False)

model = DNAClassifierCNN().to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001) 
criterion = torch.nn.BCELoss() 

avg_loss, accuracy = evaluation_loop(
    model=model,
    epochs=5,
    lr=0.001,
    optimizer=optimizer,
    train_loader=train_loader,
    val_loader=val_loader,
    criterion=criterion
)

print(f"Final Validation Loss: {avg_loss:.4f}, Accuracy: {accuracy:.2%}")


Epoch 1/5 - Train Loss: 0.6135, Val Loss: 0.5910, Val Accuracy: 67.79%
Epoch 2/5 - Train Loss: 0.5927, Val Loss: 0.5891, Val Accuracy: 67.98%
Epoch 3/5 - Train Loss: 0.5790, Val Loss: 0.5847, Val Accuracy: 68.89%
Epoch 4/5 - Train Loss: 0.5595, Val Loss: 0.5786, Val Accuracy: 69.63%
Epoch 5/5 - Train Loss: 0.5354, Val Loss: 0.6074, Val Accuracy: 66.39%
Final Validation Loss: 0.6074, Accuracy: 66.39%


## Hyperparam optimization

Let's try to optimize the learning rate, number of training epochs and size of the convolution kernel

In [29]:
def objective(trial):
    lr = trial.suggest_float('learning_rate', 0.00001, 0.01)
    epochs = trial.suggest_int('epochs', 5, 10)
    kernel_size = trial.suggest_int('kernel_size', 3, 5)
    batch_size = trial.suggest_int('batch_size', 32, 128)
    dropout_rate = trial.suggest_float('dropout_rate', 0.1, 0.5)
    weight_decay = trial.suggest_float('weight_decay', 1e-5, 1e-2)
    optimizer_type = trial.suggest_categorical('optimizer', ['AdamW', 'RMSprop'])

    print(f"LR: {lr}, Epochs: {epochs}, Kernel size: {kernel_size}, Batch size: {batch_size}, Dropout rate: {dropout_rate}, Weight decay: {weight_decay}, Optimizer: {optimizer_type}")
    
    # Initialize model
    model = DNAClassifierCNN(kernel_size=kernel_size, dropout_rate=dropout_rate).to(DEVICE)

    # Select optimizer based on trial
    if optimizer_type == 'AdamW':
        optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    else:  # RMSprop
        optimizer = optim.RMSprop(model.parameters(), lr=lr, weight_decay=weight_decay)
    
    # Evaluate the model using evaluation_loop
    _, acc = evaluation_loop(model, epochs, lr, optimizer, train_loader, val_loader, nn.BCELoss())
    return acc  # Return the accuracy to be optimized by Optuna


In [30]:
study = optuna.create_study(direction='maximize')
study.optimize(objective, n_trials=4)

[I 2024-11-19 13:33:17,200] A new study created in memory with name: no-name-316221a0-210e-4c56-8baf-fc180dd25f4f


LR: 0.005892030384720968, Epochs: 9, Kernel size: 4, Batch size: 55, Dropout rate: 0.25733416864910474, Weight decay: 0.0063130231635391155, Optimizer: RMSprop
Epoch 1/9 - Train Loss: 0.7626, Val Loss: 0.6067, Val Accuracy: 66.95%
Epoch 2/9 - Train Loss: 0.6152, Val Loss: 0.6811, Val Accuracy: 57.83%
Epoch 3/9 - Train Loss: 0.6127, Val Loss: 0.5988, Val Accuracy: 67.26%
Epoch 4/9 - Train Loss: 0.6077, Val Loss: 0.5939, Val Accuracy: 67.14%
Epoch 5/9 - Train Loss: 0.6032, Val Loss: 0.5892, Val Accuracy: 68.31%
Epoch 6/9 - Train Loss: 0.5965, Val Loss: 0.5802, Val Accuracy: 69.23%
Epoch 7/9 - Train Loss: 0.5970, Val Loss: 0.6615, Val Accuracy: 62.94%
Epoch 8/9 - Train Loss: 0.5960, Val Loss: 0.6534, Val Accuracy: 62.94%


[I 2024-11-19 13:44:26,425] Trial 0 finished with value: 0.5325017989925641 and parameters: {'learning_rate': 0.005892030384720968, 'epochs': 9, 'kernel_size': 4, 'batch_size': 55, 'dropout_rate': 0.25733416864910474, 'weight_decay': 0.0063130231635391155, 'optimizer': 'RMSprop'}. Best is trial 0 with value: 0.5325017989925641.


Epoch 9/9 - Train Loss: 0.5937, Val Loss: 0.8622, Val Accuracy: 53.25%
LR: 0.008400862231711404, Epochs: 8, Kernel size: 3, Batch size: 89, Dropout rate: 0.10933790352794137, Weight decay: 0.0019893386317083425, Optimizer: RMSprop
Epoch 1/8 - Train Loss: 1.2581, Val Loss: 0.6185, Val Accuracy: 65.68%
Epoch 2/8 - Train Loss: 0.6159, Val Loss: 0.6831, Val Accuracy: 60.28%
Epoch 3/8 - Train Loss: 0.6107, Val Loss: 0.6084, Val Accuracy: 66.06%
Epoch 4/8 - Train Loss: 0.6089, Val Loss: 0.5916, Val Accuracy: 68.12%
Epoch 5/8 - Train Loss: 0.6062, Val Loss: 0.6299, Val Accuracy: 63.73%
Epoch 6/8 - Train Loss: 0.6110, Val Loss: 0.6811, Val Accuracy: 58.93%
Epoch 7/8 - Train Loss: 0.6030, Val Loss: 0.6186, Val Accuracy: 66.85%


[I 2024-11-19 13:54:21,724] Trial 1 finished with value: 0.6665867114415928 and parameters: {'learning_rate': 0.008400862231711404, 'epochs': 8, 'kernel_size': 3, 'batch_size': 89, 'dropout_rate': 0.10933790352794137, 'weight_decay': 0.0019893386317083425, 'optimizer': 'RMSprop'}. Best is trial 1 with value: 0.6665867114415928.


Epoch 8/8 - Train Loss: 0.6094, Val Loss: 0.6134, Val Accuracy: 66.66%
LR: 0.006141293521589701, Epochs: 6, Kernel size: 3, Batch size: 37, Dropout rate: 0.4858476580077141, Weight decay: 0.008534651415073886, Optimizer: AdamW
Epoch 1/6 - Train Loss: 0.6215, Val Loss: 0.6229, Val Accuracy: 66.15%
Epoch 2/6 - Train Loss: 0.6002, Val Loss: 0.6261, Val Accuracy: 65.56%
Epoch 3/6 - Train Loss: 0.5964, Val Loss: 0.5934, Val Accuracy: 68.31%
Epoch 4/6 - Train Loss: 0.5895, Val Loss: 0.5825, Val Accuracy: 68.27%
Epoch 5/6 - Train Loss: 0.5872, Val Loss: 0.5836, Val Accuracy: 70.18%


[I 2024-11-19 14:01:28,610] Trial 2 finished with value: 0.6946509954425522 and parameters: {'learning_rate': 0.006141293521589701, 'epochs': 6, 'kernel_size': 3, 'batch_size': 37, 'dropout_rate': 0.4858476580077141, 'weight_decay': 0.008534651415073886, 'optimizer': 'AdamW'}. Best is trial 2 with value: 0.6946509954425522.


Epoch 6/6 - Train Loss: 0.5827, Val Loss: 0.5823, Val Accuracy: 69.47%
LR: 0.0007016479903258781, Epochs: 10, Kernel size: 5, Batch size: 112, Dropout rate: 0.2115460523297264, Weight decay: 0.006574428090540692, Optimizer: RMSprop
Epoch 1/10 - Train Loss: 0.6195, Val Loss: 0.6039, Val Accuracy: 66.51%
Epoch 2/10 - Train Loss: 0.6043, Val Loss: 0.6030, Val Accuracy: 67.02%
Epoch 3/10 - Train Loss: 0.6015, Val Loss: 0.6078, Val Accuracy: 66.59%
Epoch 4/10 - Train Loss: 0.6010, Val Loss: 0.6039, Val Accuracy: 66.68%
Epoch 5/10 - Train Loss: 0.5990, Val Loss: 0.6005, Val Accuracy: 67.09%
Epoch 6/10 - Train Loss: 0.5980, Val Loss: 0.6230, Val Accuracy: 64.64%
Epoch 7/10 - Train Loss: 0.5963, Val Loss: 0.6828, Val Accuracy: 58.60%
Epoch 8/10 - Train Loss: 0.5957, Val Loss: 0.6311, Val Accuracy: 63.83%
Epoch 9/10 - Train Loss: 0.5947, Val Loss: 0.6024, Val Accuracy: 67.98%


[I 2024-11-19 14:14:54,755] Trial 3 finished with value: 0.5919884864475894 and parameters: {'learning_rate': 0.0007016479903258781, 'epochs': 10, 'kernel_size': 5, 'batch_size': 112, 'dropout_rate': 0.2115460523297264, 'weight_decay': 0.006574428090540692, 'optimizer': 'RMSprop'}. Best is trial 2 with value: 0.6946509954425522.


Epoch 10/10 - Train Loss: 0.5927, Val Loss: 0.7199, Val Accuracy: 59.20%


In [31]:
print(f"Best hyperparameters: {study.best_params}")
print(f"Best value (validation AU PRC): {study.best_value}")

Best hyperparameters: {'learning_rate': 0.006141293521589701, 'epochs': 6, 'kernel_size': 3, 'batch_size': 37, 'dropout_rate': 0.4858476580077141, 'weight_decay': 0.008534651415073886, 'optimizer': 'AdamW'}
Best value (validation AU PRC): 0.6946509954425522


In [34]:
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader
from torch import optim

if not hasattr(study, 'best_params'):
    raise ValueError("study object or best_params not found")

best_model = DNAClassifierCNN(
    kernel_size=study.best_params['kernel_size'],
    dropout_rate=study.best_params['dropout_rate']
).to(DEVICE)

optimizer = optim.AdamW(
    best_model.parameters(),
    lr=study.best_params['learning_rate'],
    weight_decay=study.best_params['weight_decay']
)

criterion = nn.BCELoss()

if 'train' not in dataset or 'test' not in dataset:
    raise ValueError("Dataset must have 'train' and 'test' splits")

if 'val' not in dataset:
    print("No 'val' split found. Creating one from the 'train' split.")
    train_data = dataset['train']
    train_size = int(0.8 * len(train_data))  # 80% for training
    val_size = len(train_data) - train_size  # 20% for validation
    train_ds, val_ds = torch.utils.data.random_split(train_data, [train_size, val_size])
else:
    train_ds = dataset['train']
    val_ds = dataset['val']

train_loader = DataLoader(DNADataset(train_ds), batch_size=32, shuffle=True)
val_loader = DataLoader(DNADataset(val_ds), batch_size=32, shuffle=False)

avg_loss, accuracy = evaluation_loop(
    best_model,
    epochs=study.best_params['epochs'],
    lr=study.best_params['learning_rate'],
    optimizer=optimizer,
    train_loader=train_loader,
    val_loader=val_loader,
    criterion=criterion
)

print(f"Validation Loss: {avg_loss}, Validation Accuracy: {accuracy}")

test_loader = DataLoader(DNADataset(dataset["test"]), batch_size=32, shuffle=False)

test_loss, test_accuracy = test_model(best_model, test_loader, criterion)

print(f"Test Loss: {test_loss}, Test Accuracy: {test_accuracy}")


No 'val' split found. Creating one from the 'train' split.
Epoch 1/6 - Train Loss: 0.6278, Val Loss: 0.6046, Val Accuracy: 66.80%
Epoch 2/6 - Train Loss: 0.6004, Val Loss: 0.5936, Val Accuracy: 68.10%
Epoch 3/6 - Train Loss: 0.5887, Val Loss: 0.5797, Val Accuracy: 68.89%
Epoch 4/6 - Train Loss: 0.5813, Val Loss: 0.5761, Val Accuracy: 69.08%
Epoch 5/6 - Train Loss: 0.5807, Val Loss: 0.5966, Val Accuracy: 67.98%
Epoch 6/6 - Train Loss: 0.5813, Val Loss: 0.5973, Val Accuracy: 67.19%
Validation Loss: 0.5973212175241863, Validation Accuracy: 0.671863756296474
Test Loss: 0.6014408715274355, Test Accuracy: 0.6672423719055843
