In [None]:
!pip install fair-esm
!pip install torch

In [1]:
import requests
from Bio import AlignIO, Phylo, SeqIO
from Bio.Phylo.TreeConstruction import DistanceCalculator, DistanceTreeConstructor
import subprocess
import tempfile
# from google.colab import files
import matplotlib.pyplot as plt
import os
from io import StringIO
import urllib.parse

In [2]:
import esm
import torch 

In [3]:
model, alphabet = esm.pretrained.esm2_t12_35M_UR50D()
model.eval()
batch_converter = alphabet.get_batch_converter()

In [4]:
def fetch_uniprot_sequences(keyword, max_sequences=50):
    """Download sequences from UniProt for a given protein family"""
    # Encode the query properly for UniProt API
    encoded_keyword = urllib.parse.quote(keyword)

    # UniProt REST API URL
    url = f"https://rest.uniprot.org/uniprotkb/search?query={encoded_keyword}&format=fasta&size={max_sequences}"

    print(f"Fetching from URL: {url}")
    response = requests.get(url)

    if response.status_code == 200:
        # Check if we got any sequences
        if len(response.text.strip()) == 0:
            raise ValueError("No sequences found in the response")

        # Save sequences to a FASTA file
        file_name = "output.fasta"
        # Get the current working directory
        current_dir = os.getcwd()
        file_path = os.path.join(current_dir, file_name)
        with open(file_path, mode='w') as fasta_file:
            fasta_file.write(response.text)
            return fasta_file.name
    else:
        print(f"HTTP Status Code: {response.status_code}")
        print(f"Response content: {response.text[:200]}")  # Print first 200 chars of response
        raise ValueError(f"Failed to fetch sequences: Status code {response.status_code}")

def obtain_fasta_file(keyword, title):
    # Download sequences
    print(f"Downloading sequences for {keyword}...")
    fasta_file = fetch_uniprot_sequences(keyword)
    # Save a copy of the original sequences
    original_sequences_file = f"{title.replace(' ', '_')}_sequences.fasta"
    with open(fasta_file, 'r') as source:
        with open(original_sequences_file, 'w') as dest:
            dest.write(source.read())

    # Print sequence information
    sequences = list(SeqIO.parse(fasta_file, "fasta"))
    print(f"\nNumber of sequences downloaded: {len(sequences)}")
    print("\nSequence IDs:")
    for seq in sequences:
        print(f"{seq.id}: {seq.description}")

    # Download the original sequences
    # files.download(original_sequences_file)

    print(f"\nOriginal sequences saved to: {original_sequences_file}")

    if len(sequences) < 3:
        print("Not enough sequences found for analysis")
        return
    
    return fasta_file

    # # Parse the FASTA file
    # with open(fasta_file, "r") as file:
    #     for record in SeqIO.parse(file, "fasta"):
    #         print(f"ID: {record.id}")
    #         print(f"Description: {record.description}")
    #         print(f"Sequence: {record.seq}")
    #         print("-" * 50)
    
protein_families = [
    ("protein_name:globin AND reviewed:true", "Globin Family Tree"),
    ("protein_name:insulin receptor kinase AND reviewed:true", "Insulin Receptor Kinase Tree"),
    ("gene:RAS AND reviewed:true", "Ras Family Tree")
]

# Let's try just one family first to make sure it works
keyword, title = protein_families[0] 

fasta_file = obtain_fasta_file(keyword, title)

sequences = []
with open(fasta_file, "r") as file:
    for record in SeqIO.parse(file, "fasta"):
        sequences.append((record.id, str(record.seq)))  # (sequence_id, sequence)

# Convert sequences to ESM2 input format
batch_labels, batch_strs, batch_tokens = batch_converter(sequences)

# Generate embeddings
with torch.no_grad():
    results = model(batch_tokens, repr_layers=[12])  # Layer 12 for final embeddings
    token_embeddings = results["representations"][12]  # (batch_size, seq_len, embedding_dim)

# Average pool the embeddings for a single sequence representation
embeddings = []
for i, (_, seq) in enumerate(sequences):
    seq_len = len(seq)
    seq_embedding = token_embeddings[i, 1:seq_len + 1].mean(0)  # Mean pool (exclude [CLS] and padding)
    embeddings.append(seq_embedding.numpy())  # Convert to numpy array

# Print or save the embeddings
count = 0
for i, embedding in enumerate(embeddings):
    print(f"Embedding for sequence {sequences[i][0]}: {embedding}")
    count += 1
print(count)

Downloading sequences for protein_name:globin AND reviewed:true...
Fetching from URL: https://rest.uniprot.org/uniprotkb/search?query=protein_name%3Aglobin%20AND%20reviewed%3Atrue&format=fasta&size=50

Number of sequences downloaded: 50

Sequence IDs:
sp|P02100|HBE_HUMAN: sp|P02100|HBE_HUMAN Hemoglobin subunit epsilon OS=Homo sapiens OX=9606 GN=HBE1 PE=1 SV=2
sp|P02042|HBD_HUMAN: sp|P02042|HBD_HUMAN Hemoglobin subunit delta OS=Homo sapiens OX=9606 GN=HBD PE=1 SV=2
sp|P02008|HBAZ_HUMAN: sp|P02008|HBAZ_HUMAN Hemoglobin subunit zeta OS=Homo sapiens OX=9606 GN=HBZ PE=1 SV=2
sp|P69905|HBA_HUMAN: sp|P69905|HBA_HUMAN Hemoglobin subunit alpha OS=Homo sapiens OX=9606 GN=HBA1 PE=1 SV=2
sp|P09105|HBAT_HUMAN: sp|P09105|HBAT_HUMAN Hemoglobin subunit theta-1 OS=Homo sapiens OX=9606 GN=HBQ1 PE=1 SV=2
sp|Q12800|TFCP2_HUMAN: sp|Q12800|TFCP2_HUMAN Alpha-globin transcription factor CP2 OS=Homo sapiens OX=9606 GN=TFCP2 PE=1 SV=2
sp|P68871|HBB_HUMAN: sp|P68871|HBB_HUMAN Hemoglobin subunit beta OS=Homo sapi