In [None]:
#Install packages here, please note that this notebook has been optimized for Google Colabs
!pip install datasets torch einops
!pip install transformers>=4.28 # Ensure transformers version is at least 4.28


In [None]:
#Import everything here
import pandas as pd
import random
import datasets
from datasets import Dataset
from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments
import torch
from tqdm import tqdm
from transformers import AutoTokenizer, BertModel, AutoConfig
from torch.utils.data import Dataset
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score
import numpy as np
import os
from transformers import BertForSequenceClassification, BertTokenizer, TrainingArguments, Trainer
from datasets import load_dataset, load_metric
from datasets import Dataset
from collections import Counter
from torch.utils.data import Dataset

In [None]:
# Load DNABERT model and tokenizer
model_name = "zhihan1996/DNA_bert_6"
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertForSequenceClassification.from_pretrained(model_name, num_labels=2)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Load E. coli dataset
df = pd.read_csv('/content/drive/MyDrive/fullseq_microbigge_ecoli.csv')

In [None]:
df['Type'].value_counts(normalize=True)

In [None]:
df2 = df.head(1500)

In [None]:
df2['Type'].value_counts(normalize=True)

In [None]:
# Map labels to integers: AMR -> 1, anything else -> 0
df2['Label'] = df2['Type'].apply(lambda x: 1 if x == 'AMR' else 0)

In [None]:
df2.head()

In [None]:

def kmer_shift(sequence, k=6):
    """
    Shifts the k-mers of a sequence by one position.
    """
    shifted_sequences = []
    for i in range(1, k):
        shifted_sequence = sequence[i:] + sequence[:i]
        shifted_sequences.append(shifted_sequence)
    return shifted_sequences

In [None]:
def random_mutation(sequence, num_mutations=1):
    """
    Introduces random mutations in the sequence.
    """
    sequence = list(sequence)
    for _ in range(num_mutations):
        pos = random.randint(0, len(sequence) - 1)
        sequence[pos] = random.choice(['a', 'c', 'g', 't'])
    return ''.join(sequence)

In [None]:
def reverse_complement(sequence):
    """
    Generates the reverse complement of a DNA sequence.
    """
    complement = {'a': 't', 't': 'a', 'c': 'g', 'g': 'c'}
    return ''.join(complement[base] for base in reversed(sequence))

In [None]:

def augment_data(sequences, labels, k=6, num_mutations=1):
    """
    Augments the data by k-mer shifting and introducing random mutations.
    """
    augmented_sequences = []
    augmented_labels = []

    for seq, label in zip(sequences, labels):
        augmented_sequences.append(seq)  # Add original sequence
        augmented_labels.append(label)

        # K-mer shifting
        shifted_seqs = kmer_shift(seq, k)
        augmented_sequences.extend(shifted_seqs)
        augmented_labels.extend([label] * len(shifted_seqs))

        # Introduce mutations (single point mutations)
        for _ in range(num_mutations):
            mutated_seq = list(seq)  # Convert to list for mutability
            mutation_index = random.randint(0, len(seq) - 1)
            valid_nucleotides = "acgt".replace(seq[mutation_index], "")
            mutated_seq[mutation_index] = random.choice(valid_nucleotides)
            augmented_sequences.append("".join(mutated_seq))
            augmented_labels.append(label)

    return augmented_sequences, augmented_labels


In [None]:
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    accuracy = accuracy_score(labels, predictions)
    f1 = f1_score(labels, predictions, average='weighted')  # Adjust 'average' as needed
    return {"accuracy": accuracy, "f1": f1}

In [None]:
# Function to load, clean, split, and augment data
def load_clean_split_augment_data(df, test_size=0.2):

    # Drop rows with missing sequences
    df.dropna(subset=['full_sequence'], inplace=True)

    # Train-test split
    train_data, test_data = train_test_split(df, test_size=test_size, random_state=42, stratify=df['Label'])

    # Augment training data
    augmented_sequences, augmented_labels = augment_data(train_data['full_sequence'].tolist(), train_data['Label'].tolist())

    # Create new DataFrame for augmented training data
    augmented_train_data = pd.DataFrame({'full_sequence': augmented_sequences, 'Label': augmented_labels})

    # Convert to Hugging Face Dataset
    train_dataset = Dataset.from_pandas(augmented_train_data)
    test_dataset = Dataset.from_pandas(test_data)

    return train_dataset, test_dataset


In [None]:

# Load, clean, split, and augment dataset
train_dataset, test_dataset = load_clean_split_augment_data(df2)

In [None]:
train_dataset.shape

In [None]:
test_dataset.shape

In [None]:
output_dir = './results'
os.makedirs(output_dir, exist_ok=True)


In [None]:
# Apply the tokenizer function to the datasets
train_encodings = tokenizer(train_dataset['full_sequence'], truncation=True, padding=True)
test_encodings = tokenizer(test_dataset['full_sequence'], truncation=True, padding=True)

In [None]:
train_labels = [example['Label'] for example in train_dataset]
test_labels = [example['Label'] for example in test_dataset]

In [None]:
# Define the training arguments
training_args = TrainingArguments(
    output_dir='./results',
    evaluation_strategy='epoch',
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=3,
    weight_decay=0.01,
)

In [None]:
class SimpleDataset(Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: val[idx].clone().detach() for key, val in self.encodings.items()}  # Use clone().detach()
        item['labels'] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)

In [None]:
# Wrap the tokenized data in Dataset objects (using datasets.Dataset)
train_dataset = datasets.Dataset.from_dict({**train_encodings, 'labels': train_labels})
train_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'token_type_ids', 'labels'])

val_dataset = datasets.Dataset.from_dict({**test_encodings, 'labels': test_labels})
val_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'token_type_ids', 'labels'])


In [None]:
# Initialize the Trainer with Dataset objects
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,  # Use Dataset object
    eval_dataset=val_dataset,    # Use Dataset object
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

In [None]:
# Evaluate the model
eval_result = trainer.evaluate()
print(f"Evaluation result: {eval_result}")