In [None]:
# Core libraries
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW

# Transformers and PEFT
from transformers import AutoTokenizer, AutoModel, BertConfig
from peft import LoraConfig, get_peft_model

# Data processing and visualization
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error

# Utilities
import gc
from tqdm.auto import tqdm

In [2]:
config = BertConfig.from_pretrained("zhihan1996/DNABERT-2-117M")
tokenizer = AutoTokenizer.from_pretrained("zhihan1996/DNABERT-2-117M", trust_remote_code=True)
model = AutoModel.from_config(config)

'(MaxRetryError('HTTPSConnectionPool(host=\'huggingface.co\', port=443): Max retries exceeded with url: /zhihan1996/DNABERT-2-117M/resolve/main/config.json (Caused by NameResolutionError("<urllib3.connection.HTTPSConnection object at 0x7f7ce2cfedb0>: Failed to resolve \'huggingface.co\' ([Errno -3] Temporary failure in name resolution)"))'), '(Request ID: 60fad852-f480-42dd-ba82-ff3b39c0645c)')' thrown while requesting HEAD https://huggingface.co/zhihan1996/DNABERT-2-117M/resolve/main/config.json
Retrying in 1s [Retry 1/5].


In [4]:
df = pd.read_csv('lossers/data/sequence-cleaner.tsv', sep='\t')
df.head()

Unnamed: 0,genus,species,sequence,identifier,is_complete,full_name,length,genus_label,species_label
0,Kitasatospora,hibisci,TTCACGGAGAGTTTGATCCTGGCTCAGGACGAACGCTGGCGGCGTG...,NR_200017.1,complete sequence,Kitasatospora_hibisci,1517,99,905
1,Peterkaempfera,podocarpi,TTCACGGAGAGTTTGATCCTGGCTCAGGACGAACGCTGGCGGCGTG...,NR_200001.1,complete sequence,Peterkaempfera_podocarpi,1516,166,1570
2,Streptomyces,citrinus,AGAGTTTGATCCTGGCTCAGGACGAACGCTGGCGGCGTGCTTAACA...,NR_199987.1,partial sequence,Streptomyces_citrinus,1489,205,468
3,Dickeya,ananatis,AAATTGAAGAGTTTGATCATGGCTCAGATTGAACGCTGGCGGCAGG...,NR_199979.1,complete sequence,Dickeya_ananatis,1542,63,130
4,Microbacterium,wangruii,AGAGTTTGATCATGGCTCAGGATGAACGCTGGCGGCGTGCTTAACA...,NR_199966.1,partial sequence,Microbacterium_wangruii,1487,130,2148


In [9]:
# Configure compute device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"âœ“ Using device: {device}")
if torch.cuda.is_available():
    print(f"  GPU: {torch.cuda.get_device_name(0)}")
    print(f"  Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
model.to(device)

âœ“ Using device: cuda
  GPU: NVIDIA A100-SXM4-40GB
  Memory: 42.29 GB


BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(4096, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSdpaSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)

In [6]:
# Calculate total parameters
total_params = sum(p.numel() for p in model.parameters())
print(f"âœ“ Model loaded successfully!")
print(f"  Total parameters: {total_params/1e6:.1f} million")
print(f"  Hidden size: {model.config.hidden_size}")
print(f"  Number of layers: {model.config.num_hidden_layers}")

âœ“ Model loaded successfully!
  Total parameters: 89.2 million
  Hidden size: 768
  Number of layers: 12


In [7]:
print("âœ“ Tokenizer loaded!")
print(f"  Vocabulary size: {len(tokenizer)}")
print("\nðŸ“– Tokenizer vocabulary:")
print(tokenizer.get_vocab())

âœ“ Tokenizer loaded!
  Vocabulary size: 4096

ðŸ“– Tokenizer vocabulary:
{'GTGAAAA': 1033, 'TCAGTTTT': 2800, 'TGAGTA': 724, 'CTTTAA': 598, 'TCATACA': 3490, 'GATTTGA': 3829, 'GTG': 30, 'CTATCTA': 3570, 'CGCAAA': 1223, 'CTTTTTTA': 3034, 'GGACTTA': 3753, 'GGGCTG': 905, 'GGAGATA': 3171, 'GAAGTCA': 1841, 'GCACCTT': 2582, 'TCATTATT': 3892, 'TGAAGTA': 2923, 'CCTTCTCC': 3281, 'TACGA': 1599, 'TATGTGA': 3988, 'CCTCTCTT': 3670, 'GCGAGA': 1158, 'TAATTAAAA': 2868, 'TCAGGAA': 1676, 'GGTGAA': 958, 'CTTCCTG': 1329, 'TTAGTG': 1529, 'GAAATGTG': 3756, 'TACTGTT': 3800, 'TGACATT': 1475, 'CCTGAAA': 1131, 'TCAGAAAA': 2696, 'TTTTATTA': 1849, 'CACAAAAA': 3137, 'GTAGGAA': 2499, 'TCAGGTG': 2334, 'TCCAAA': 388, 'AAATTA': 3090, 'TAAATTA': 673, 'CTTTTAA': 629, 'TCCTCTC': 3031, 'TACTGAAA': 3761, 'GCTTCAA': 3530, 'GTTTCA': 366, 'CATTCATT': 1962, 'GTACCA': 1222, 'TACAGTA': 1324, 'GTTCTTTT': 2843, 'GTAA': 68, 'TGAGTGA': 1596, 'GCGAGC': 3435, 'GTCTCGA': 3423, 'CTACTA': 869, 'TCTGGG': 2607, 'GCAGA': 390, 'CAAATTAA': 287

In [27]:
sequence = df['sequence'][1]
sequence

'TTCACGGAGAGTTTGATCCTGGCTCAGGACGAACGCTGGCGGCGTGCTTAACACATGCAAGTCGAACGGTGAAGCCCTTCGGGGTGGATCAGTGGCGAACGGGTGAGTAACACGTGGGCAATCTGCCCTGCACTCTGGGACAACACCGGGAAACCGGTGCTAATACCGGATATGACGCACTCCTGCATGGGGGTGCGTGGAAAGCTCCGGCGGTGCAGGATGAGCCCGCGGCCTATCAGCTTGTTGGTGGGGTGATGGCCTACCAAGGCGACGACGGGTAGCCGGCCTGAGAGGGCGACCGGCCACACTGGGACTGAGACACGGCCCAGACTCCTACGGGAGGCAGCAGTGGGGAATATTGCACAATGGGCGAAAGCCTGATGCAGCGACGCCGCGTGAGGGATGACGGCCTTCGGGTTGTAAACCTCTTTCAGCAGGGAAGAAGCGCAAGTGACGGTACCTGCAGAAGAAGCACCGGCTAACTACGTGCCAGCAGCCGCGGTAATACGTAGGGTGCGAGCGTTGTCCGGAATTATTGGGCGTAAAGAGCTCGTAGGCGGCCTGTCGCGTCGGATGTGAAAGCCCGGGGCTTAACCCCGGGTCTGCATTCGATACGGGCAGGCTAGAGTTCGGTAGGGGAGATCGGAATTCCTGGTGTAGCGGTGAAATGCGCAGATATCAGGAGGAACACCGGTGGCGAAGGCGGATCTCTGGGCCGATACTGACGCTGAGGAGCGAAAGCGTGGGGAGCGAACAGGATTAGATACCCTGGTAGTCCACGCCGTAAACGTTGGGAACTAGGTGTGGGCGACATTCCACGTCGTCCGTGCCGCAGCTAACGCATTAAGTTCCCCGCCTGGGGAGTACGGCCGCAAGGCTAAAACTCAAAGGAATTGACGGGGGCCCGCACAAGCGGCGGAGCATGTGGCTTAATTCGACGCAACGCGAAGAACCTTACCAAGGCTTGACATACGCCGGAAAACCGTAGAGATACGGTCC

In [32]:
NSAMPLES = df.shape[0] // 3
SEED = 42
dataset = df.copy()


# Create random splits
rng = np.random.default_rng(SEED)
random_idxs = rng.choice(len(dataset), len(dataset), replace=False)

train_df = dataset.iloc[random_idxs[:NSAMPLES]]
val_df = dataset.iloc[random_idxs[NSAMPLES:2*NSAMPLES]]
test_df = dataset.iloc[random_idxs[2*NSAMPLES:3*NSAMPLES]]


In [36]:
from sklearn.model_selection import train_test_split

# Stratified train/val/test split preserving class balance

# Choose label to stratify by (change to 'genus_label' if you prefer)
label_col = 'genus_label'

# Desired proportions
test_prop = 0.2
val_prop = 0.2  # of the original dataset

# First split off test set
train_val_df, test_df = train_test_split(
    dataset,
    test_size=test_prop,
    stratify=dataset[label_col],
    random_state=SEED,
)

# Then split train and validation from the remaining portion
# val_prop relative to original -> compute relative val fraction on train_val_df
val_relative = val_prop / (1.0 - test_prop)
train_df, val_df = train_test_split(
    train_val_df,
    test_size=val_relative,
    stratify=train_val_df[label_col],
    random_state=SEED,
)


# Inspect resulting class distributions
print("Split sizes:", len(train_df), len(val_df), len(test_df))
print("Train class distribution (example):")
print(train_df[label_col].value_counts().head(10))

Split sizes: 4067 1356 1356
Train class distribution (example):
genus_label
205    606
146    118
153    113
217     88
176     88
204     77
54      68
56      54
67      53
27      50
Name: count, dtype: int64


In [37]:
train_df['genus_label'].nunique(), test_df['genus_label'].nunique(), val_df['genus_label'].nunique()

(222, 222, 222)

In [43]:
# Custom PyTorch Dataset for protein sequences
class SequenceDataset(Dataset):
    """
    Dataset class that tokenizes protein sequences on-the-fly.
    """
    def __init__(self, df, tokenizer, max_length=512):
        self.df = df.reset_index(drop=True)
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        sequence = row["sequence"]
        label = torch.tensor(row["genus_label"], dtype=torch.float32)

        # Tokenize sequence
        inputs = self.tokenizer(
            sequence,
            return_tensors="pt",
            padding="max_length",
            max_length=self.max_length,
            truncation=True,
        )
        # Remove batch dimension
        inputs = {k: v.squeeze(0) for k, v in inputs.items()}
        return inputs, label

print("âœ“ SequenceDataset class defined")

âœ“ SequenceDataset class defined


In [44]:
# Create dataset instances
train_dataset = SequenceDataset(train_df, tokenizer)
val_dataset = SequenceDataset(val_df, tokenizer)
test_dataset = SequenceDataset(test_df, tokenizer)

print("âœ“ Datasets created:")
print(f"  Train: {len(train_dataset):,} samples")
print(f"  Val: {len(val_dataset):,} samples")
print(f"  Test: {len(test_dataset):,} samples")

âœ“ Datasets created:
  Train: 4,067 samples
  Val: 1,356 samples
  Test: 1,356 samples


In [48]:
# Configuration
BATCH_SIZE = 512
NUM_WORKERS = 12

print("ðŸ”„ Generating embeddings with frozen model...")
print("   This may take a few minutes...\n")

# Generate embeddings for training and test sets
train_embeddings = []
test_embeddings = []
val_embeddings = []

with torch.no_grad():
    # Training embeddings
    print("ðŸ“Š Processing training set...")
    for batch in tqdm(DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)):
        inputs, labels = batch
        inputs = {k: v.to(device) for k, v in inputs.items()}
        outputs = model(**inputs)
        train_embeddings.append(outputs.pooler_output.cpu().numpy())
    train_embeddings = np.vstack(train_embeddings)
    print(f"âœ“ Training embeddings: {train_embeddings.shape}\n")

    # Test embeddings
    print("ðŸ“Š Processing validation set...")
    for batch in tqdm(DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)):
        inputs, labels = batch
        inputs = {k: v.to(device) for k, v in inputs.items()}
        outputs = model(**inputs)
        val_embeddings.append(outputs.pooler_output.cpu().numpy())
    val_embeddings = np.vstack(val_embeddings)
    print(f"âœ“ Validation embeddings: {val_embeddings.shape}\n")

    # Test embeddings
    print("ðŸ“Š Processing test set...")
    for batch in tqdm(DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)):
        inputs, labels = batch
        inputs = {k: v.to(device) for k, v in inputs.items()}
        outputs = model(**inputs)
        test_embeddings.append(outputs.pooler_output.cpu().numpy())
    test_embeddings = np.vstack(test_embeddings)
    print(f"âœ“ Test embeddings: {test_embeddings.shape}\n")

print("âœ“ All embeddings generated!")

ðŸ”„ Generating embeddings with frozen model...
   This may take a few minutes...

ðŸ“Š Processing training set...


  0%|          | 0/8 [00:00<?, ?it/s]huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling pa

âœ“ Training embeddings: (4067, 768)

ðŸ“Š Processing validation set...


  0%|          | 0/3 [00:00<?, ?it/s]huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENI

âœ“ Validation embeddings: (1356, 768)

ðŸ“Š Processing test set...


  0%|          | 0/3 [00:00<?, ?it/s]huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling pa

âœ“ Test embeddings: (1356, 768)

âœ“ All embeddings generated!



