In [18]:
# Core libraries
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW

# Transformers and PEFT
from transformers import AutoTokenizer, AutoModel, BertConfig
from peft import LoraConfig, get_peft_model

# Data processing and visualization
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error

# Utilities
import gc
from tqdm.auto import tqdm 

# Set style for prettier plots
sns.set_style("whitegrid")
plt.rcParams['figure.dpi'] = 100

print("âœ“ All libraries imported successfully!")

âœ“ All libraries imported successfully!


In [19]:
config = BertConfig.from_pretrained("zhihan1996/DNABERT-2-117M")
tokenizer = AutoTokenizer.from_pretrained("zhihan1996/DNABERT-2-117M", trust_remote_code=True)
model = AutoModel.from_config(config)

In [20]:
# Read in the new sequence data and run head to see the overall data structure
data_in = pd.read_csv("data/sequence-wide.tsv", sep='\t') 
data_in.head

<bound method NDFrame.head of                       genus       species  \
0               Alitibacter   langaaensis   
1               Alitibacter   langaaensis   
2               Roseovarius     maritimus   
3               Roseovarius        roseus   
4           Planosporangium      spinosum   
...                     ...           ...   
27727     Thermoclostridium  stercorarium   
27728           Clostridium      isatidis   
27729         Couchioplanes     caeruleus   
27730             Halomonas     koreensis   
27731  Pseudoflavonifractor   phocaeensis   

                                                sequence   identifier  \
0      ATTGAAGAGTTTGATCATGGCTCAGATTGAACGCTGGCGGCAGGCT...  NR_118751.1   
1      ATTGAACGCTGGCGGCAGGCTTAACACATGCAAGTCGAACGGTAAC...  NR_042885.1   
2      CAACTTGAGAGTTTGATCCTGGCTCAGAACGAACGCTGGCGGCAGG...  NR_200035.1   
3      CAACTTGAGAGTTTGATCCTGGCTCAGAACGAACGCTGGCGGCAGG...  NR_200034.1   
4      TTGTTGGAGAGTTTGATCCTGGCTCAGGACGAACGCTGGCGGCGTG...  NR_200

In [21]:
# Check the columns
print("Column 1:")
print(data_in.iloc[:, 0])
print("Column 2:")
print(data_in.iloc[:, 1])
print("Column 3:")
print(data_in.iloc[:, 2])

Column 1:
0                 Alitibacter
1                 Alitibacter
2                 Roseovarius
3                 Roseovarius
4             Planosporangium
                 ...         
27727       Thermoclostridium
27728             Clostridium
27729           Couchioplanes
27730               Halomonas
27731    Pseudoflavonifractor
Name: genus, Length: 27732, dtype: object
Column 2:
0         langaaensis
1         langaaensis
2           maritimus
3              roseus
4            spinosum
             ...     
27727    stercorarium
27728        isatidis
27729       caeruleus
27730       koreensis
27731     phocaeensis
Name: species, Length: 27732, dtype: object
Column 3:
0        ATTGAAGAGTTTGATCATGGCTCAGATTGAACGCTGGCGGCAGGCT...
1        ATTGAACGCTGGCGGCAGGCTTAACACATGCAAGTCGAACGGTAAC...
2        CAACTTGAGAGTTTGATCCTGGCTCAGAACGAACGCTGGCGGCAGG...
3        CAACTTGAGAGTTTGATCCTGGCTCAGAACGAACGCTGGCGGCAGG...
4        TTGTTGGAGAGTTTGATCCTGGCTCAGGACGAACGCTGGCGGCGTG...
                 

In [22]:
data_in.sequence

0        ATTGAAGAGTTTGATCATGGCTCAGATTGAACGCTGGCGGCAGGCT...
1        ATTGAACGCTGGCGGCAGGCTTAACACATGCAAGTCGAACGGTAAC...
2        CAACTTGAGAGTTTGATCCTGGCTCAGAACGAACGCTGGCGGCAGG...
3        CAACTTGAGAGTTTGATCCTGGCTCAGAACGAACGCTGGCGGCAGG...
4        TTGTTGGAGAGTTTGATCCTGGCTCAGGACGAACGCTGGCGGCGTG...
                               ...                        
27727    TGATCCTGGCTCAGGACGAACGCTGGCGGCGTGCCTAACACATGCA...
27728    GGCGTGCNTAACACATGCAAGTCGAGCGAGGTGATTTCNTTCGGGA...
27729    CGCTGGCGGCGTGCTTAACACATGCAAGTCGAGCGGAAAGGCCCCT...
27730    ACGATGGGAGCTTGCTCCCAGGCGTCGAGCGGCGGACGGGTGAGTA...
27731    AGAGTTTGATCCTGGCTCAGGATGAACGCTGGCGGCGTRCTTAACA...
Name: sequence, Length: 27732, dtype: object

In [23]:
# Convert the new sequences to input tokens, view the first element
inputs = tokenizer(data_in.sequence.to_list())["input_ids"]
print(inputs[0])

[1, 2061, 25, 222, 23, 224, 143, 3411, 403, 247, 53, 150, 527, 2759, 2834, 2734, 724, 873, 81, 118, 2470, 30, 708, 72, 61, 679, 29, 200, 88, 2894, 71, 117, 639, 72, 1478, 315, 137, 2787, 825, 1826, 966, 189, 1235, 45, 229, 4079, 314, 1340, 835, 427, 138, 316, 99, 120, 2139, 76, 36, 987, 75, 315, 8, 0, 41, 199, 0, 0, 5, 778, 460, 632, 59, 100, 72, 26, 0, 1212, 1527, 71, 148, 281, 0, 9, 3558, 238, 92, 635, 59, 111, 556, 2787, 135, 52, 259, 64, 72, 31, 120, 469, 2816, 50, 2638, 166, 29, 135, 1262, 31, 141, 17, 495, 1170, 317, 32, 443, 79, 78, 30, 619, 36, 247, 137, 1517, 2810, 19, 153, 1826, 20, 277, 1080, 332, 159, 15, 583, 458, 61, 783, 18, 486, 17, 540, 29, 200, 14, 183, 22, 236, 168, 37, 282, 3453, 71, 7, 0, 3386, 34, 123, 315, 103, 265, 194, 50, 42, 534, 171, 259, 166, 112, 2394, 200, 106, 59, 118, 1952, 409, 577, 117, 124, 1832, 0, 113, 205, 35, 553, 403, 38, 499, 16, 605, 788, 212, 8, 0, 0, 9, 3382, 169, 194, 233, 368, 38, 2785, 1149, 282, 1435, 66, 101, 39, 386, 8, 0, 9, 10, 846, 

In [24]:
# Set all random seeds for reproducibility
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)
np.random.seed(42)

print("âœ“ Random seeds set for reproducibility")

âœ“ Random seeds set for reproducibility


In [26]:
# Helper function to count trainable parameters
def print_trainable_parameters(model):
    """Count and display trainable vs total parameters."""
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    
    trainable_pct = 100 * trainable_params / all_param
    print(f"  Trainable params: {trainable_params:,}")
    print(f"  Total params: {all_param:,}")
    print(f"  Trainable: {trainable_pct:.2f}%")

# Configure LoRA
lora_config = LoraConfig(
    r=8,                              # Rank (lower = fewer params)
    lora_alpha=32,                    # Scaling factor
    target_modules=["query", "value"], # Adapt attention layers
    lora_dropout=0.1,                 # Regularization
    bias="none",                      # Don't adapt bias terms
)

# Apply LoRA to the model
print("ðŸ”„ Applying LoRA adapters to model...\n")
ft_model = get_peft_model(model, lora_config)

print("ðŸ“Š Model Parameters Comparison:")
print("-" * 60)
print("Before LoRA (original model):")
print_trainable_parameters(model)
print("\nAfter LoRA (adapted model):")
print_trainable_parameters(ft_model)
print("-" * 60)
print("\nâœ“ LoRA adapters applied successfully!")

ðŸ”„ Applying LoRA adapters to model...

ðŸ“Š Model Parameters Comparison:
------------------------------------------------------------
Before LoRA (original model):
  Trainable params: 294,912
  Total params: 89,481,984
  Trainable: 0.33%

After LoRA (adapted model):
  Trainable params: 294,912
  Total params: 89,481,984
  Trainable: 0.33%
------------------------------------------------------------

âœ“ LoRA adapters applied successfully!


In [30]:
# Create the train-val-test splits
# Configuration
prop_train = 0.8
prop_val = 0.1
prop_test = 0.1
SEED = 42

# Create random splits
rng = np.random.default_rng(SEED)
random_idxs = rng.permutation(len(data_in))

# Calculate split sizes
n_total = len(data_in)
n_train = int(prop_train * n_total)
n_val = int(prop_val * n_total)

train_df = data_in.iloc[random_idxs[:n_train]]
val_df = data_in.iloc[random_idxs[n_train:n_train + n_val]]
test_df = data_in.iloc[random_idxs[n_train + n_val:]]

In [31]:
# Custom PyTorch Dataset for protein sequences
class SequenceDataset(Dataset):
    """
    Dataset class that tokenizes protein sequences on-the-fly.
    """
    def __init__(self, df, tokenizer, max_length=512):
        self.df = df.reset_index(drop=True)
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        sequence = row["sequence"]
        label = torch.tensor(row["species"], dtype=torch.float32)

        # Tokenize sequence
        inputs = self.tokenizer(
            sequence,
            padding="max_length",
            max_length=self.max_length,
            truncation=True,
        )
        # Remove batch dimension
        inputs = {k: v.squeeze(0) for k, v in inputs.items()}
        return inputs, label

print("âœ“ SequenceDataset class defined")

âœ“ SequenceDataset class defined


In [32]:
# Create dataset instances
train_dataset = SequenceDataset(train_df, tokenizer)
val_dataset = SequenceDataset(val_df, tokenizer)
test_dataset = SequenceDataset(test_df, tokenizer)

print("âœ“ Datasets created:")
print(f"  Train: {len(train_dataset):,} samples")
print(f"  Val: {len(val_dataset):,} samples")
print(f"  Test: {len(test_dataset):,} samples")

âœ“ Datasets created:
  Train: 22,185 samples
  Val: 2,773 samples
  Test: 2,774 samples


In [None]:
# Training configuration
BATCH_SIZE = 128
NUM_EPOCHS = 10
LEARNING_RATE = 1e-3

# Create custom model with regression head
class ClassificationModel(nn.Module):
    """Combines LoRA-adapted DNA BERT model with classification head."""
    def __init__(self, base_model):
        super().__init__()
        self.base_model = model  # LoRA-wrapped model
        self.regressor = nn.Linear(base_model.config.hidden_size, 1)

    def forward(self, **inputs):
        outputs = self.base_model(**inputs)
        pooled_output = outputs.pooler_output
        return self.regressor(pooled_output).squeeze(-1)

# Instantiate the complete model
regression_model = RegressionModel(ft_model).to(device)

print("ðŸŽ¯ Complete Model Architecture:")
print("-" * 60)
print_trainable_parameters(regression_model)
print("-" * 60)

# Create data loaders
tokenizer = AutoTokenizer.from_pretrained(MODEL_CHECKPOINT)
train_dataset = SequenceDataset(train_df, tokenizer)
val_dataset = SequenceDataset(val_df, tokenizer)
test_dataset = SequenceDataset(test_df, tokenizer)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

# Setup optimizer and loss
optimizer = AdamW(filter(lambda p: p.requires_grad, regression_model.parameters()), lr=LEARNING_RATE)
loss_fn = nn.MSELoss()

print("\nâœ“ Model and data loaders ready for training!")