## üì¶ Step 1: Setup and Imports

We'll use:
- **`transformers`**: Load pre-trained ESM models from Hugging Face
- **`peft`**: Apply LoRA adapters for parameter-efficient fine-tuning
- **`torch`**: Deep learning framework for training
- **`scikit-learn`**: Metrics and baseline models

In [1]:
import torch
from transformers import AutoTokenizer, AutoModel, BertConfig

# Data processing and visualization
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics.pairwise import cosine_similarity

print("‚úì All libraries imported successfully!")

‚úì All libraries imported successfully!


In [2]:
# Configure compute device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"‚úì Using device: {device}")
if torch.cuda.is_available():
    print(f"  GPU: {torch.cuda.get_device_name(0)}")
    print(f"  Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

‚úì Using device: cuda
  GPU: NVIDIA A100-SXM4-40GB
  Memory: 42.29 GB


## ü§ñ Step 2: Load Pre-trained DNABERT-2 Model

We'll use the DNABERT-2 model.

In [3]:
# Select model checkpoint (change this to experiment with different sizes)
MODEL_CHECKPOINT = "zhihan1996/DNABERT-2-117M"

config = BertConfig.from_pretrained(MODEL_CHECKPOINT)
# tokenizer = AutoTokenizer.from_pretrained(MODEL_CHECKPOINT, trust_remote_code=True)

print(f"üì• Loading model: {MODEL_CHECKPOINT}")
model = AutoModel.from_config(config)
model.eval()  # Set to evaluation mode (no training yet)

# Calculate total parameters
total_params = sum(p.numel() for p in model.parameters())
print(f"‚úì Model loaded successfully!")
print(f"  Total parameters: {total_params/1e6:.1f} million")
print(f"  Hidden size: {model.config.hidden_size}")
print(f"  Number of layers: {model.config.num_hidden_layers}")

üì• Loading model: zhihan1996/DNABERT-2-117M
‚úì Model loaded successfully!
  Total parameters: 89.2 million
  Hidden size: 768
  Number of layers: 12


In [4]:
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_CHECKPOINT, trust_remote_code=True)

print("‚úì Tokenizer loaded!")
print(f"  Vocabulary size: {len(tokenizer)}")
print("\nüìñ First 10 keys in Tokenizer vocabulary:")
# print(tokenizer.get_vocab())

for x in list(tokenizer.get_vocab())[0:10]:
    print (x)

‚úì Tokenizer loaded!
  Vocabulary size: 4096

üìñ First 10 keys in Tokenizer vocabulary:
TATTTA
GGTTATT
GTGATT
CACATTTT
GAATATA
CTTCAAA
CCAAGG
TATTTATTTT
GTCGTG
TCCCCAA


## Step 3: Load and Prepare Dataset

Load sequence data.

In [5]:
# Load the cleaned dataset
DATASET_NAME = "data/sequence-wide.tsv"

print(f"üìÇ Loading dataset: {DATASET_NAME}")
dataset = pd.read_csv(DATASET_NAME, sep='\t')

print(f"‚úì Dataset loaded successfully!")
print(f"  Total samples: {len(dataset):,}")
print(f"\nFirst few rows:")
dataset.head()

üìÇ Loading dataset: data/sequence-wide.tsv
‚úì Dataset loaded successfully!
  Total samples: 27,732

First few rows:


Unnamed: 0,genus,species,sequence,identifier,is_complete
0,Alitibacter,langaaensis,ATTGAAGAGTTTGATCATGGCTCAGATTGAACGCTGGCGGCAGGCT...,NR_118751.1,partial sequence
1,Alitibacter,langaaensis,ATTGAACGCTGGCGGCAGGCTTAACACATGCAAGTCGAACGGTAAC...,NR_042885.1,partial sequence
2,Roseovarius,maritimus,CAACTTGAGAGTTTGATCCTGGCTCAGAACGAACGCTGGCGGCAGG...,NR_200035.1,complete sequence
3,Roseovarius,roseus,CAACTTGAGAGTTTGATCCTGGCTCAGAACGAACGCTGGCGGCAGG...,NR_200034.1,complete sequence
4,Planosporangium,spinosum,TTGTTGGAGAGTTTGATCCTGGCTCAGGACGAACGCTGGCGGCGTG...,NR_200033.1,complete sequence


## Step 4: Tokenization Example

Let's tokenize a sample sequence to see the tokenizer in action.

In [6]:
# Sample protein sequence: Alitibacter langaaensis NR_118751.1
sequence = dataset["sequence"][0]

print("Sample sequence (first row of dataset):")
print(f"  Identifier: {dataset["identifier"][0]}")
print(f"  Length: {len(sequence)} amino acids")
print(f"  Sequence: {sequence[:50]}...")

Sample sequence (first row of dataset):
  Identifier: NR_118751.1
  Length: 1477 amino acids
  Sequence: ATTGAAGAGTTTGATCATGGCTCAGATTGAACGCTGGCGGCAGGCTTAAC...


In [7]:
# Tokenize the sequence
inputs = tokenizer(sequence, return_tensors="pt").to(device)

print("‚úì Sequence tokenized!")

‚úì Sequence tokenized!


In [8]:
# Examine tokenized output
print(f"Original sequence length: {len(sequence)} amino acids")
print(f"Tokenized input IDs shape: {inputs['input_ids'].shape}")
print(f"Token IDs: {inputs['input_ids'][0][:20]}...")  # Show first 20 tokens

print(f"\nüí° Why is tokenized length different?")
print(f"   The tokenizer adds special tokens like <cls> (start) and <eos> (end)!")

### DNABERT-2 also uses BPE tokenization, not single amino acids but groups of variable length

Original sequence length: 1477 amino acids
Tokenized input IDs shape: torch.Size([1, 335])
Token IDs: tensor([   1, 2061,   25,  222,   23,  224,  143, 3411,  403,  247,   53,  150,
         527, 2759, 2834, 2734,  724,  873,   81,  118], device='cuda:0')...

üí° Why is tokenized length different?
   The tokenizer adds special tokens like <cls> (start) and <eos> (end)!


## üéØ Step 5: Zero-Shot Prediction

The model was trained on masked language modeling (predicting missing amino acids), but the embeddings can be used **directly** for downstream tasks without additional training. This is called **zero-shot learning**.

**Use case**: Find proteins similar in function or structure based on embedding similarity.

In [9]:
# Reference database of known microbes with different genus/species
reference = {
    "Alitibacter langaaensis (NR_118751.1)": 
        dataset["sequence"][0],
    
    "Alitibacter langaaensis (NR_042885.1)": 
        dataset["sequence"][1],
    
    "Roseovarius maritimus (NR_200035.1)": 
        dataset["sequence"][2]
}

print(f"‚úì Reference database created with {len(reference)} sequences")

‚úì Reference database created with 3 sequences


In [10]:
# Query microbe - which reference microbe is it most similar to?
query = {
    "Roseovarius roseus (NR_200034.1)": 
        dataset["sequence"][3]
}

print(f"üîç Query protein: {list(query.keys())[0]}")

üîç Query protein: Roseovarius roseus (NR_200034.1)


In [11]:
# Generate embeddings for reference database
print("üîÑ Generating embeddings for reference database...")
reference_embeddings = {}
with torch.no_grad():
    for name, seq in reference.items():
        ref_inputs = tokenizer(seq, return_tensors="pt")["input_ids"]
        ref_outputs = model(ref_inputs)[0]
        reference_embeddings[name] = torch.mean(ref_outputs[0], dim=0).numpy().reshape(1,-1)
print(f"‚úì Generated {len(reference_embeddings)} reference embeddings")

# Generate embedding for query protein
query_name, query_sequence = next(iter(query.items()))
print(f"\nüîÑ Generating embedding for query: {query_name}...")
with torch.no_grad():
    query_inputs = tokenizer(query_sequence, return_tensors="pt")["input_ids"]
    query_outputs = model(query_inputs)[0]
    query_embedding = torch.mean(query_outputs[0], dim=0).numpy().reshape(1,-1)
print("‚úì Query embedding generated")

üîÑ Generating embeddings for reference database...
‚úì Generated 3 reference embeddings

üîÑ Generating embedding for query: Roseovarius roseus (NR_200034.1)...
‚úì Query embedding generated


In [12]:
# Calculate cosine similarity between query and each reference
similarities = {}
for name, ref_emb in reference_embeddings.items():
    similarity = cosine_similarity(query_embedding, ref_emb)[0][0]
    similarities[name] = similarity

# Find best match
best_match = max(similarities, key=similarities.get)

print("=" * 70)
print(f"üîç ZERO-SHOT SIMILARITY SEARCH RESULTS")
print("=" * 70)
print(f"\nQuery Microbe: {query_name}\n")
print("Similarity Scores (sorted by relevance):")
print("-" * 70)
for name, score in sorted(similarities.items(), key=lambda item: item[1], reverse=True):
    bar = "‚ñà" * int(score * 50)
    print(f"{name:45s} {score:.4f} {bar}")

print("\n" + "=" * 70)
print(f"üéØ Best Match: {best_match}")
print(f"   Similarity: {similarities[best_match]:.4f}")
print("=" * 70)

üîç ZERO-SHOT SIMILARITY SEARCH RESULTS

Query Microbe: Roseovarius roseus (NR_200034.1)

Similarity Scores (sorted by relevance):
----------------------------------------------------------------------
Roseovarius maritimus (NR_200035.1)           0.9998 ‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà
Alitibacter langaaensis (NR_042885.1)         0.9977 ‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà
Alitibacter langaaensis (NR_118751.1)         0.9977 ‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà

üéØ Best Match: Roseovarius maritimus (NR_200035.1)
   Similarity: 0.9998


In [13]:
similarities.items()

dict_items([('Alitibacter langaaensis (NR_118751.1)', np.float32(0.9976617)), ('Alitibacter langaaensis (NR_042885.1)', np.float32(0.997679)), ('Roseovarius maritimus (NR_200035.1)', np.float32(0.9997591))])