## DNABERT2 Fine-tuning to detect the location/region of the promoter

### Modules import

In [8]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel, AutoConfig
from sklearn.model_selection import train_test_split
import os
from transformers import AutoTokenizer, AutoModel, BertConfig, AutoModelForMaskedLM, AutoModelForTokenClassification
import pandas as pd



#### data loading

data class define

https://drive.usercontent.google.com/download?id=1GRtbzTe3UXYF1oW27ASNhYX3SZ16D7N2&export=download&authuser=0&confirm=t&uuid=9a91e4d5-dfac-4ed1-869c-52ff8525f085&at=AENtkXaXrQnIdKo74wE_zRA19WYK%3A1732141945972

#### Implement DNABERT2 fro promoter classifier

freeze the pretrained DNABERT2, just add a simple layer for binary classifier (1="it is promoter", 0="it is not a promoter")

## DNABERT2 + Fine-tuned to locate the promoter (Ignore all parts above, only use this section)

### Package import

In [9]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, BertConfig, AutoModel
import random
from sklearn.model_selection import train_test_split
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
df = pd.read_csv('combined_sequences_labels.csv')
print(df.head())
model_name = 'zhihan1996/DNABERT-2-117M'
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

                                            Sequence  \
0  CGGTCCTGGATCCCACCCGCGCTGGGCTCAGGGCCGCGGGTTCGGG...   
1  GGCCCTCCTGGACGGAGTGGACCTTATGAGACACTCCCTAGCTGAA...   
2  TATCCCATTCCCTGCAGCTTTCCCTGCCGCACAGGCGGCAGGGTTG...   
3  CTCCCGGGCTCAGGCAATCCTCCCGCCTCAGCCCACGGAGTAGATG...   
4  TGGCAGGTATGACAGACGTTTGATTCCCAACCTTCTCCGCTTTGGT...   

                                    One-Hot Encoding  
0  0000000000000000000000000000000000000000000000...  
1  0000000000000000000000000000000000000000000000...  
2  0000000000000000000000000000000000000000000000...  
3  0000000000000000000000000000000000000000000000...  
4  0000000000000000000000000000000000000000000000...  


In [10]:
df_test1 = pd.read_csv('mouse_data.csv')
df_test2= pd.read_csv('pig_data.csv')

### file/model read

### model/dataloader class define

#### model class

##### purpose

To predict whether each nucleotide in a DNA sequence is part of a promoter region (label 1) or not (label 0).

##### components

Pre-trained DNABERT2 Model: Provides contextual embeddings for DNA sequences.   
Multi-head Attention Layer: Focuses on relevant positions in the sequence.   
Classifier Layer: Makes token-level predictions.  

In [21]:
from transformers import BertModel, BertConfig

class PromoterDetectionModel(nn.Module):
    def __init__(self, bert_model_name="bert-base-uncased"):
        super(PromoterDetectionModel, self).__init__()
        # Load pre-trained BERT model
        self.bert = BertModel.from_pretrained(bert_model_name)
        self.attention = nn.MultiheadAttention(embed_dim=self.bert.config.hidden_size, num_heads=8)
        self.classifier = nn.Linear(self.bert.config.hidden_size, 2)  # Only output 1 or 0 for each token

    def forward(self, input_ids, attention_mask):
        # Fetch the output from pre-trained BERT
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)

        if isinstance(outputs, tuple):
            sequence_output = outputs[0]
        else:
            sequence_output = outputs.last_hidden_state  # (batch_size, seq_length, hidden_size)

        # Adjust the shape for MultiheadAttention
        sequence_output = sequence_output.permute(1, 0, 2)  # (seq_length, batch_size, hidden_size)

        # Apply attention
        attn_output, _ = self.attention(sequence_output, sequence_output, sequence_output)

        # Adjust the size back
        attn_output = attn_output.permute(1, 0, 2)  # (batch_size, seq_length, hidden_size)

        # Classifier
        logits = self.classifier(attn_output)  # (batch_size, seq_length, 2)
        return logits


In [11]:
import torch
import torch.nn as nn
from transformers import GPT2Model

class PromoterDetectionModel(nn.Module):
    def __init__(self, model_name="gpt2"):
        super(PromoterDetectionModel, self).__init__()
        # Load pre-trained GPT-2 model
        self.model = GPT2Model.from_pretrained(model_name)
        # Define multihead attention
        self.attention = nn.MultiheadAttention(embed_dim=self.model.config.hidden_size, num_heads=8)
        # Define a linear classifier
        self.classifier = nn.Linear(self.model.config.hidden_size, 2)

    def forward(self, input_ids, attention_mask):
        # Fetch the output from pre-trained GPT-2
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)

        if isinstance(outputs, tuple):
            sequence_output = outputs[0]  # Extract the last hidden state
        else:
            sequence_output = outputs.last_hidden_state

        # Adjust the shape for MultiheadAttention
        sequence_output = sequence_output.permute(1, 0, 2)  # (seq_length, batch_size, hidden_size)

        # Apply attention
        attn_output, _ = self.attention(sequence_output, sequence_output, sequence_output)

        # Adjust the size back
        attn_output = attn_output.permute(1, 0, 2)  # (batch_size, seq_length, hidden_size)

        # Classifier to predict 0 or 1 for each token
        logits = self.classifier(attn_output)  # (batch_size, seq_length, 2)
        return logits


dataloader class

In [22]:
class PromoterDataset(Dataset):
    def __init__(self, sequences, labels, tokenizer, max_length=512):
        self.sequences = sequences
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        seq = self.sequences[idx]
        label = self.labels[idx]

        # encode process
        encoded = self.tokenizer(seq,
                                 padding='max_length',
                                 truncation=True,
                                 max_length=self.max_length,
                                 return_tensors='pt')

        input_ids = encoded['input_ids'].squeeze(0)  # (max_length)
        attention_mask = encoded['attention_mask'].squeeze(0)  # (max_length)

        # transfer label into tensor and padding to max_length if needed
        label = [int(i) for i in label]
        if len(label) < self.max_length:
            label += [0] * (self.max_length - len(label))
        else:
            label = label[:self.max_length]
        label = torch.tensor(label, dtype=torch.long)  # (max_length)

        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'labels': label
        }

model initialize

In [23]:
model = PromoterDetectionModel()
model.to(device)

PromoterDetectionModel(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, el

data loader initizalize

In [28]:
sequences = df['Sequence'].tolist()[:]
labels = df['One-Hot Encoding'].tolist()[:]
labels = [list(label) for label in labels]

train_sequences, val_sequences, train_labels, val_labels = train_test_split(
    sequences, labels, test_size=0.2, random_state=42)

train_dataset = PromoterDataset(train_sequences, train_labels, tokenizer)
val_dataset = PromoterDataset(val_sequences, val_labels, tokenizer)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64)





test_sequences = df_test2['Sequence'].tolist()[:]
labels = df_test2['One-Hot Encoding'].tolist()[:]
labels = [list(label) for label in labels]

test_sequences, _, test_labels, aa = train_test_split(
    sequences, labels, test_size=0.001, random_state=42)

test_dataset = PromoterDataset(test_sequences, test_labels, tokenizer)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=True)


loss function

In [25]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3)


### Model fine-tuned (only trained the new layers over the pre-trained DNABERT2, freeze the DNABERT2)

In [26]:
epochs = 200

for epoch in range(epochs):
    model.train()
    total_loss = 0

    for batch in train_loader:
        optimizer.zero_grad()
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        logits = model(input_ids, attention_mask)

        # adjust the shape of logits
        logits = logits.view(-1, 2)  # (batch_size * seq_length, 2)
        labels_flat = labels.view(-1)  # (batch_size * seq_length)

        loss = criterion(logits, labels_flat)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader)
    print(f'Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}')

    # model eval
    model.eval()
    total_correct = 0
    total_count = 0

    with torch.no_grad():
        for batch in val_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            logits = model(input_ids, attention_mask)
            predictions = torch.argmax(logits, dim=-1)

            total_correct += (predictions == labels).sum().item()
            total_count += labels.numel()

    accuracy = total_correct / total_count
    print(f'Validation Accuracy: {accuracy:.4f}')

# model saved
torch.save(model.state_dict(), 'promoter_detection_model_bert.pth')

Epoch 1/7, Loss: 1.1085
Validation Accuracy: 0.5771
Epoch 2/7, Loss: 0.6840
Validation Accuracy: 0.5771
Epoch 3/7, Loss: 0.6854
Validation Accuracy: 0.5771
Epoch 4/7, Loss: 0.6815
Validation Accuracy: 0.5771
Epoch 5/7, Loss: 0.6812
Validation Accuracy: 0.5771
Epoch 6/7, Loss: 0.6814
Validation Accuracy: 0.5771
Epoch 7/7, Loss: 0.6828
Validation Accuracy: 0.5771


In [29]:
model.load_state_dict(torch.load('promoter_detection_model_bert.pth'))
model.eval()  # 设置为评估模式

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

total_correct = 0
total_count = 0

with torch.no_grad():
    for batch in test_loader:  # 确保使用的是test_loader
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        logits = model(input_ids, attention_mask)
        predictions = torch.argmax(logits, dim=-1)

        total_correct += (predictions == labels).sum().item()
        total_count += labels.numel()

accuracy = total_correct / total_count
print(f'Test Accuracy: {accuracy:.4f}')

  model.load_state_dict(torch.load('promoter_detection_model_bert.pth'))


Test Accuracy: 0.5849
