# Cell 0: Install required packages


In [7]:
!pip install numpy scipy gensim scikit-learn transformers biopython torch



In [14]:
import numpy as np
from sklearn.neighbors import KDTree

def generate_query_subseq_embeddings(query, window_size, step_size, embed_func):
    subseqs = [query[i:i+window_size] for i in range(0, len(query)-window_size+1, step_size)]
    return [embed_func(subseq) for subseq in subseqs]

def find_seeds(query_embeddings, database_embeddings, threshold):
    db_ids, db_embeds = zip(*database_embeddings.items())
    db_embeds = np.array(db_embeds)
    if db_embeds.ndim == 1:
        db_embeds = db_embeds.reshape(-1, 1)
    tree = KDTree(db_embeds)
    seeds = []
    for i, q_embed in enumerate(query_embeddings):
        distances, indices = tree.query([q_embed], k=5)  # Get top 5 closest matches
        for dist, idx in zip(distances[0], indices[0]):
            if dist < threshold:
                seeds.append((i, db_ids[idx], dist))
    return seeds

def extend_seed(query, db_seq, seed_pos, gap_open=-10, gap_extend=-1):
    # Simple extension algorithm (can be improved)
    left_score, left_query, left_db = 0, "", ""
    right_score, right_query, right_db = 0, "", ""

    # Extend left
    i, j = seed_pos, 0
    while i > 0 and j > 0:
        if query[i-1] == db_seq[j-1]:
            left_score += 1
            left_query = query[i-1] + left_query
            left_db = db_seq[j-1] + left_db
            i -= 1
            j -= 1
        else:
            break

    # Extend right
    i, j = seed_pos, 0
    while i < len(query) and j < len(db_seq):
        if query[i] == db_seq[j]:
            right_score += 1
            right_query += query[i]
            right_db += db_seq[j]
            i += 1
            j += 1
        else:
            break

    return left_query + right_query, left_db + right_db, left_score + right_score

def evaluate_alignment(query_seq, db_seq, gap_open=-10, gap_extend=-1):
    # Simple scoring function (can be improved)
    score = sum(1 if q == d else -1 for q, d in zip(query_seq, db_seq))
    gaps = query_seq.count('-') + db_seq.count('-')
    score += gap_open * gaps + gap_extend * (len(query_seq) - len(db_seq.replace('-', '')))
    return score

def embedding_based_blast(query, database, database_embeddings, embed_func, window_size=3, step_size=1, threshold=0.1):
    query_embeddings = generate_query_subseq_embeddings(query, window_size, step_size, embed_func)
    seeds = find_seeds(query_embeddings, database_embeddings, threshold)

    alignments = []
    for seed in seeds:
        query_pos, db_id, dist = seed
        db_seq = database[db_id]
        extended_query, extended_db, score = extend_seed(query, db_seq, query_pos)
        alignment_score = evaluate_alignment(extended_query, extended_db)
        alignments.append((db_id, extended_query, extended_db, alignment_score))

    return sorted(alignments, key=lambda x: x[3], reverse=True)

In [11]:
import numpy as np
from transformers import BertModel, BertTokenizer
from gensim.models import Word2Vec
from Bio import SeqIO

def generate_protein_embeddings(sequence, model_type='ProtBERT'):
    if model_type == 'ProtBERT':
        tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False)
        model = BertModel.from_pretrained("Rostlab/prot_bert")
        inputs = tokenizer(sequence, return_tensors='pt')
        outputs = model(**inputs)
        embeddings = outputs.last_hidden_state.mean(dim=1).squeeze().detach().numpy()
    elif model_type == 'ProtVec':
        # Assuming ProtVec model is pre-trained and loaded
        protvec_model = Word2Vec.load('path_to_protvec_model')
        words = [sequence[i:i+3] for i in range(len(sequence)-2)]
        embeddings = np.mean([protvec_model.wv[word] for word in words if word in protvec_model.wv], axis=0)
    return embeddings

def generate_nucleotide_embeddings(sequence, model_type='DNA2Vec'):
    if model_type == 'DNA2Vec':
        # Assuming DNA2Vec model is pre-trained and loaded
        dna2vec_model = Word2Vec.load('path_to_dna2vec_model')
        words = [sequence[i:i+11] for i in range(len(sequence)-10)]
        embeddings = np.mean([dna2vec_model.wv[word] for word in words if word in dna2vec_model.wv], axis=0)
    return embeddings

def generate_database_embeddings(fasta_file, sequence_type='protein'):
    embeddings = {}
    for record in SeqIO.parse(fasta_file, "fasta"):
        if sequence_type == 'protein':
            embeddings[record.id] = generate_protein_embeddings(str(record.seq))
        else:
            embeddings[record.id] = generate_nucleotide_embeddings(str(record.seq))
    return embeddings

In [12]:
import time
from Bio import SeqIO
from Bio.Blast import NCBIWWW, NCBIXML

def run_ncbi_blast(query, database_file):
    start_time = time.time()
    result_handle = NCBIWWW.qblast("blastp", "nr", query, entrez_query=f'"{database_file}"[PACC]')
    blast_record = NCBIXML.read(result_handle)
    end_time = time.time()

    alignments = []
    for alignment in blast_record.alignments:
        for hsp in alignment.hsps:
            alignments.append((alignment.title, hsp.query, hsp.sbjct, hsp.score))

    return alignments, end_time - start_time

def run_embedding_blast(query, database_file):
    start_time = time.time()
    database = {record.id: str(record.seq) for record in SeqIO.parse(database_file, "fasta")}
    database_embeddings = generate_database_embeddings(database_file, "protein")
    results = embedding_based_blast(query, database, database_embeddings, generate_protein_embeddings)
    end_time = time.time()

    return results, end_time - start_time

def compare_results(ncbi_results, embedding_results):
    ncbi_hits = set(result[0] for result in ncbi_results)
    embedding_hits = set(result[0] for result in embedding_results)

    common_hits = ncbi_hits.intersection(embedding_hits)
    ncbi_only = ncbi_hits - embedding_hits
    embedding_only = embedding_hits - ncbi_hits

    print(f"Common hits: {len(common_hits)}")
    print(f"NCBI BLAST only: {len(ncbi_only)}")
    print(f"Embedding BLAST only: {len(embedding_only)}")

    return len(common_hits) / len(ncbi_hits) if ncbi_hits else 0

In [16]:
import os

def main():
    # Load the query sequence from an environment variable or use a default
    query = os.environ.get('QUERY_SEQUENCE', "MVLSEGEWQLVLHVWAKVEADVAGHGQDILIRLFKSHPETLEKFDRFKHLKTEAEMKASEDLKKHGVTVLTALGAILKKKGHHEAELKPLAQSHATKHKIPIKYLEFISEAIIHVLHSRHPGNFGADAQGAMNKALELFRKDIAAKYKELGYQG")

    # Use the sample database provided in the Docker image
    database_file = "data/sample_database.fasta"

    print("Running Embedding-based BLAST...")
    embedding_results, embedding_time = run_embedding_blast(query, database_file)
    print(f"Embedding-based BLAST completed in {embedding_time:.2f} seconds")

    print("\nRunning NCBI BLAST...")
    ncbi_results, ncbi_time = run_ncbi_blast(query, database_file)
    print(f"NCBI BLAST completed in {ncbi_time:.2f} seconds")

    print("\nComparing results...")
    sensitivity = compare_results(ncbi_results, embedding_results)
    print(f"Sensitivity (proportion of NCBI BLAST hits found): {sensitivity:.2f}")

    print(f"\nSpeed comparison: Embedding-based BLAST was {ncbi_time / embedding_time:.2f}x faster")

if __name__ == "__main__":
    main()

Running Embedding-based BLAST...
Embedding-based BLAST completed in 363.72 seconds

Running NCBI BLAST...
NCBI BLAST completed in 63.49 seconds

Comparing results...
Common hits: 0
NCBI BLAST only: 0
Embedding BLAST only: 5
Sensitivity (proportion of NCBI BLAST hits found): 0.00

Speed comparison: Embedding-based BLAST was 0.17x faster
