In [None]:
import sys
from loguru import logger

import io
import os
import time
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.spatial.distance import pdist, squareform
import matplotlib.animation as animation

from pyeed import Pyeed
from pyeed.analysis.mutation_detection import MutationDetection
from pyeed.analysis.embedding_analysis import EmbeddingTool
from pyeed.analysis.standard_numbering import StandardNumberingTool
from pyeed.embeddings.processor import get_processor
from pyeed.embeddings import free_memory

logger.remove()
level = logger.add(sys.stderr, level="DEBUG")

In [2]:
et = EmbeddingTool()

In [None]:
uri = "bolt://129.69.129.130:7688"
user = "neo4j"
password = "12345678"

eedb = Pyeed(uri, user=user, password=password)

eedb.db.wipe_database(date="2025-05-30")
eedb.db.initialize_db_constraints(user, password)

In [None]:
# read in the pandas dataframe
df = pd.read_csv('/home/nab/Niklas/TEM-lactamase/data/002_combined_data/TEM_lactamase.csv', sep=';')
print(df.head())
#   protein_name phenotype    protein_id protein_id_database
# 0        TEM-1        2b      AAP20891          AAP20891.1
# here the names and ids are given

ids = df['protein_id_database'].dropna().tolist()
names = df.loc[~df['protein_id_database'].isna(), 'protein_name'].tolist()
print(f"IDs: {ids}")
print(f"Names: {names}")
eedb.fetch_from_primary_db(ids, db="ncbi_protein")
eedb.fetch_dna_entries_for_proteins()
eedb.create_coding_sequences_regions()

In [5]:
offset_signal = 0

In [6]:
query_cypher = """
MATCH (p:Protein {accession_id: $accession_id})
RETURN p.sequence
"""
sequences = []
for accession_id in ids:
    sequence = eedb.db.execute_read(query_cypher, {"accession_id": accession_id})[0]['p.sequence'][offset_signal:]
    sequences.append(sequence)

In [7]:
model_name_list = ['prot_t5_xl_bfd', "esmc_600m", "esmc_300m", "facebook/esm2_t33_650M_UR50D", 'prot_t5_xl_uniref50', 'facebook/esm2_t36_3B_UR50D', 'facebook/esm2_t12_35M_UR50D', 'facebook/esm2_t6_8M_UR50D', 'facebook/esm2_t30_150M_UR50D']

In [None]:
# these sequences are now supposed to be embedded with all models in model_name_list
embeddings_all_sequences_last_layer = {}
processor = get_processor()

import torch

for model_name in model_name_list:
    print(f"Embedding sequences with {model_name}...")

    # get the model
    model = processor.get_or_create_model(model_name=model_name, device=torch.device('cuda:1'))

    # Initialize lists to store embeddings for all sequences
    embeddings_all_sequences_last_layer[model_name] = []
    
    # Calculate embeddings for each sequence
    for seq in sequences:
        last_layer_emb = model.get_single_embedding_last_hidden_state(sequence=seq, )
        embeddings_all_sequences_last_layer[model_name].append(last_layer_emb)

    # Properly clean up the model
    processor.remove_model(model_name)
    del model
    time.sleep(10)  # Reduced sleep time since we're doing proper cleanup

# Final cleanup
processor.cleanup()

In [None]:
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
import numpy as np
import os

# Create output directory if it doesn't exist
output_dir = "/home/nab/Niklas/TEM-lactamase/data/001_results/010_Model_Comparisons"
os.makedirs(output_dir, exist_ok=True)

# Create PCA plots for each model's last layer embeddings
for model_name in model_name_list:
    print(f"Creating PCA plot for {model_name}...")
    
    # Get embeddings for this model and mean pool over residues
    model_embeddings = embeddings_all_sequences_last_layer[model_name]
    mean_pooled_embeddings = []
    
    for emb in model_embeddings:
        # Mean pool over the sequence length dimension
        mean_pooled = np.mean(emb, axis=0)
        mean_pooled_embeddings.append(mean_pooled)
    
    # Convert to numpy array for PCA
    embeddings_array = np.array(mean_pooled_embeddings)
    
    # Apply PCA
    pca = PCA(n_components=2)
    pca_result = pca.fit_transform(embeddings_array)
    
    # Create the plot
    plt.figure(figsize=(10, 8))
    
    # Get phenotypes for coloring from df (only for rows where protein_id_database is not NA)
    phenotypes = df.loc[~df['protein_id_database'].isna(), 'phenotype'].tolist()
    
    # Create scatter plot with phenotypes as colors
    unique_phenotypes = list(set(phenotypes))
    colors = plt.cm.tab10(np.linspace(0, 1, len(unique_phenotypes)))
    
    for i, phenotype in enumerate(unique_phenotypes):
        mask = [p == phenotype for p in phenotypes]
        plt.scatter(pca_result[mask, 0], pca_result[mask, 1], 
                   c=[colors[i]], label=phenotype, alpha=0.7, s=50)
    
    # Add protein names as labels (from the names list)
    # for i, name in enumerate(names):
     #    plt.annotate(name, (pca_result[i, 0], pca_result[i, 1]), 
     #                xytext=(5, 5), textcoords='offset points', fontsize=8, alpha=0.8)
    
    plt.xlabel(f'PC1 (Variance Explained: {pca.explained_variance_ratio_[0]:.2%})')
    plt.ylabel(f'PC2 (Variance Explained: {pca.explained_variance_ratio_[1]:.2%})')
    plt.title(f'PCA of Mean-Pooled Embeddings - {model_name}')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    
    # Save the plot
    safe_model_name = model_name.replace('/', '_').replace(':', '_')
    plt.savefig(os.path.join(output_dir, f'pca_embeddings_{safe_model_name}.png'), 
                dpi=300, bbox_inches='tight')
    plt.show()
