
# DNA Sequence Classification with BERT

In this notebook, we will classify DNA sequences using a BERT-based model. We will load the data, preprocess it into k-mers, and train the model on the dataset.


In [1]:
# Import necessary libraries
import pandas as pd
from Bio import SeqIO
import gzip
import ftplib
import os
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
import numpy as np


## 1. Function to Download DNA Sequences from FTP

This section defines a function to download DNA sequences from provided FTP links.


In [2]:
def download_fasta(ftp_url):
    """Downloads a FASTA file from the given FTP URL."""
    try:
        # Parse the FTP URL
        url_parts = ftp_url.split("/")
        host = url_parts[2]
        filepath = "/".join(url_parts[3:])
        
        # Connect to the FTP server
        ftp = ftplib.FTP(host)
        ftp.login()  # Login anonymously

        # Download the file
        local_filename = os.path.basename(filepath)
        with open(local_filename, 'wb') as f:
            ftp.retrbinary(f"RETR {filepath}", f.write)

        ftp.quit()
        return local_filename
    except Exception as e:
        print(f"Failed to download {ftp_url}: {e}")
        return None


## 2. Load the Metadata File

In this section, we load the DNA sequence metadata from a tab-delimited text file and display its contents.


In [3]:
# Read the metadata from the text file
input_file = 'E-MTAB-5530.sdrf.txt'  # Path to your input text file
df = pd.read_csv(input_file, sep="\t")

# Display the first few rows of the DataFrame
df.head()


Unnamed: 0,Source Name,Comment[ENA_SAMPLE],Comment[BioSD_SAMPLE],Comment[experiment],Characteristics[organism],Characteristics[strain],Characteristics[developmental stage],Characteristics[genotype],Characteristics[organism part],Characteristics[cell type],...,Comment[ENA_EXPERIMENT],Scan Name,Comment[SUBMITTED_FILE_NAME],Comment[ENA_RUN],Comment[FASTQ_URI],Comment[SPOT_LENGTH],Comment[READ_INDEX_1_BASE_COORD],Factor Value[genotype],Factor Value[organism part],Factor Value[reporter fluorescence]
0,CD41_2_H5,ERS1610822,SAMEA103921654,Original,Danio rerio,Tubingen Long Fin,adult,Tg(cd41:EGFP),testis,mixed cell types,...,ERX1948556,SLX-10876.N721_S522.C93G6ANXX.s_1.r_1.fq.gz,SLX-10876.N721_S522.C93G6ANXX.s_1.r_1.fq.gz,ERR1888058,ftp://ftp.sra.ebi.ac.uk/vol1/fastq/ERR188/008/...,250,126,Tg(cd41:EGFP),testis,


## 3. Save Metadata to CSV

We will save the loaded metadata DataFrame to a CSV file.


In [4]:
# Save the DataFrame df as a CSV file
output_file = 'output.csv'  # Specify your desired output filename
df.to_csv(output_file, index=False)  # Save without row index


## 4. Download and Process DNA Sequences

For each row in the metadata, we download the DNA sequences using the provided FTP links.


In [5]:
# Prepare to open the CSV file for writing
output_file = "dna_sequences.csv"  # Output file name
with open(output_file, 'w') as out_csv:
    # Write header for the CSV file
    out_csv.write("seq,organism_part\n")

    # Iterate through the rows of the DataFrame
    for index, row in df.iterrows():
        ftp_url_1 = row['Comment[FASTQ_URI]']
        ftp_url_2 = ftp_url_1.replace("_1.", "_2.")  # Handle paired-end read
        
        print('ftp_url_2:', ftp_url_2)

        # Download DNA sequences from the FTP URLs
        for ftp_url in [ftp_url_1, ftp_url_2]:
            local_file = download_fasta(ftp_url)
            
            if local_file:
                # Read the downloaded FASTQ file
                with gzip.open(local_file, 'rt') as f:
                    for record in SeqIO.parse(f, "fastq"):
                        # Print the sequence
                        print(f"Read DNA Sequence: {record.seq}")

                        # Write to CSV immediately
                        out_csv.write(f"{record.seq},{row['Characteristics[organism part]']}\n")
                        
                # Clean up the local file
                os.remove(local_file)

print(f"DNA sequences and organism parts saved to {output_file}.")


ftp_url_2: ftp://ftp.sra.ebi.ac.uk/vol1/fastq/ERR188/008/ERR1888058/ERR1888058_2.fastq.gz
Read DNA Sequence: CAGATGGAGTTCATGTTGCAGAATCATGTCTTCCCTCTGTTCCGCAGTGAGCTGGGCTACATGAGAGCGAGGGCCTGCTGGGTACTTCATTACTTCTGTGAGGTTAAGTTCAAGAATGACCAGAA
Read DNA Sequence: NTCCTCGTCTTTGCATCTCTGCTGTCTGGTTCTCCACGATGGCCGCAGATGAAGCTCTCTGTTCGGACGCTCCTCCAGCGCGACCGTGGCTCCTGAACTCTGGTGAATCTCGTGTAATCATTCTC
Read DNA Sequence: TTGCTGATTCTTATTGGATTTCTGATATTTCCTATTGTTATGGGCTATTACATCTCTAAGGAATTGGTGAAGTAAAATGGTGAAGCTTATGAACTGTCTCTTATACACATCTCCGAGCCCACGAG
Read DNA Sequence: GCCAGCGGACGCGTGGTTCTCAATGATATAGACCTGTGCCACTGTCTCTTATACACATCTCCGAGCCCACGAGACTACGCTGCATCTCGTATGCCGTCTTCTGCTTGAAAAAAAAAAACCCCACC
Read DNA Sequence: CCACTACATCCAGGATGGCCTCGCAGTGTTCTCTGTAGAGCGTCTGAAGGGCTTTGACATCCTCTGGGCCGACGCCATCCACTGCCTCTCCAAGATCCAGTTCAACAAACTCCGGCAGAGCTCTG
Read DNA Sequence: AGATAGGGGTGCGTCCAACCTCCAGACCAGATGCCAGTGTTGCAATTCGGCCCATTACTGTACGGTCTCCAGTGCTGATGACGATGCCGCGAGCAGTGCCTTCAACACAGTTGGTTGAGAAGAAA
Read DNA Sequence: AGGTCATCATGAGCTCTGTCA

KeyboardInterrupt: 

## 5. Define the Dataset and DataLoader

In this section, we define a custom dataset and dataloader for training the model.


In [6]:
class SequenceDataset(Dataset):
    """Custom dataset for loading DNA sequences and labels."""
    
    def __init__(self, csv_file):
        self.data = pd.read_csv(csv_file)
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        sequence = self.data.iloc[idx, 0]  # DNA sequence
        label = self.data.iloc[idx, 1]      # Organism part
        return sequence, label

# Initialize the dataset and dataloader
dataset = SequenceDataset(output_file)
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)


## 6. Define the Model



In [7]:
# Define the Dataset class for DNA sequences
class DNASequenceDataset(Dataset):
    """
    Dataset class for handling DNA sequences for classification.
    """
    def __init__(self, sequences, labels, tokenizer, k):
        self.sequences = sequences
        self.labels = labels
        self.tokenizer = tokenizer
        self.k = k

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        """
        Get item from dataset by index.
        """
        sequence = self.sequences[idx]
        # Tokenize the sequence into k-mers
        kmers = [sequence[i:i+self.k] for i in range(len(sequence) - self.k + 1)]
        kmers_str = " ".join(kmers)
        encoded = self.tokenizer(kmers_str, padding='max_length', truncation=True, max_length=512, return_tensors='pt')
        label = torch.tensor(self.labels[idx], dtype=torch.long)
        return encoded['input_ids'].squeeze(), encoded['attention_mask'].squeeze(), label


In [8]:
# Load and preprocess the data
def load_data(file_path):
    """
    Load data from a CSV file and preprocess it.
    """
    df = pd.read_csv(file_path)
    df = df.sample(frac=0.05, random_state=42)  # Randomly select 5% of the data
    le = LabelEncoder()
    df['label'] = le.fit_transform(df['label'])  # Encode labels
    return df, le


In [9]:
# Create DataLoader objects for training and validation datasets
def create_data_loaders(df, tokenizer, k, batch_size=8):
    """
    Create DataLoader objects for training and validation datasets.
    """
    train_df, val_df = train_test_split(df, test_size=0.2, random_state=42)
    train_dataset = DNASequenceDataset(train_df['seq'].values, train_df['label'].values, tokenizer, k)
    val_dataset = DNASequenceDataset(val_df['seq'].values, val_df['label'].values, tokenizer, k)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    
    return train_loader, val_loader


In [10]:
# Train the model
def train_model(model, train_loader, val_loader, num_epochs=3, learning_rate=2e-5):
    """
    Train the model with the specified training and validation data loaders.
    """
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=learning_rate)
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        
        for input_ids, attention_mask, labels in train_loader:
            input_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(input_ids, attention_mask=attention_mask)
            loss = criterion(outputs.logits, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        val_loss = 0
        model.eval()
        with torch.no_grad():
            for input_ids, attention_mask, labels in val_loader:
                input_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device)
                outputs = model(input_ids, attention_mask=attention_mask)
                loss = criterion(outputs.logits, labels)
                val_loss += loss.item()

        print(f'Epoch {epoch+1}/{num_epochs}, Train Loss: {total_loss/len(train_loader):.4f}, Val Loss: {val_loss/len(val_loader):.4f}')


In [None]:
# Load tokenizer
tokenizer = BertTokenizer.from_pretrained("zhihan1996/DNA_bert_6", do_lower_case=False)

# Load and preprocess data
df, label_encoder = load_data('updated_dna_sequences.csv')

# Create data loaders
train_loader, val_loader = create_data_loaders(df, tokenizer, k=6)

# Define the model
model = BertForSequenceClassification.from_pretrained("zhihan1996/DNA_bert_6", num_labels=len(label_encoder.classes_))

# Train the model
train_model(model, train_loader, val_loader, num_epochs=3)
