In [None]:
import torch
import esm
import numpy as np
from Bio import SeqIO
from typing import List, Tuple
from sklearn.manifold import TSNE
from sklearn.cluster import KMeans
import random

class ProteinDiversifier:
    def __init__(self, model_name: str = "esm2_t33_650M_UR50D"):
        """
        Initialize the ProteinDiversifier with a specific ESM model.
        
        Args:
            model_name: Name of the ESM model to use
        """
        # Load the ESM model and alphabet
        self.model, self.alphabet = esm.pretrained.load_model_and_alphabet(model_name)
        self.model.eval()  # Set to evaluation mode
        
    def get_embedding(self, sequence: str) -> torch.Tensor:
        """
        Generate ESM embedding for a protein sequence.
        
        Args:
            sequence: Amino acid sequence
            
        Returns:
            Embedding tensor
        """
        # Prepare batch
        batch_converter = self.alphabet.get_batch_converter()
        data = [("protein1", sequence)]
        batch_labels, batch_strs, batch_tokens = batch_converter(data)
        
        # Generate embedding
        with torch.no_grad():
            results = self.model(batch_tokens, repr_layers=[33], return_contacts=False)
        token_embeddings = results["representations"][33]
        
        # Return per-residue embeddings
        return token_embeddings[0, 1:len(sequence)+1]  # Remove cls token and padding
    
    def generate_mutations(self, sequence: str, 
                         num_mutations: int = 5,
                         similarity_threshold: float = 0.8) -> List[str]:
        """
        Generate diverse mutations of a protein sequence.
        
        Args:
            sequence: Original protein sequence
            num_mutations: Number of mutations to generate
            similarity_threshold: Minimum cosine similarity threshold
            
        Returns:
            List of mutated sequences
        """
        # Get embedding for original sequence
        orig_embedding = self.get_embedding(sequence)
        
        # Define amino acids to sample from (excluding rare/special AAs)
        amino_acids = "ACDEFGHIKLMNPQRSTVWY"
        
        mutations = []
        for _ in range(num_mutations):
            # Randomly select position to mutate
            pos = random.randint(0, len(sequence)-1)
            
            # Try different amino acids
            best_mutation = None
            best_similarity = float('inf')
            
            for aa in amino_acids:
                if aa != sequence[pos]:
                    # Create mutated sequence
                    mutated_seq = sequence[:pos] + aa + sequence[pos+1:]
                    
                    # Get embedding for mutated sequence
                    mut_embedding = self.get_embedding(mutated_seq)
                    
                    # Calculate similarity
                    similarity = torch.cosine_similarity(
                        orig_embedding.mean(dim=0).unsqueeze(0),
                        mut_embedding.mean(dim=0).unsqueeze(0)
                    ).item()
                    
                    # Update best mutation if this one is more diverse while maintaining function
                    if similarity < best_similarity and similarity > similarity_threshold:
                        best_mutation = mutated_seq
                        best_similarity = similarity
            
            if best_mutation:
                mutations.append(best_mutation)
        
        return mutations
    
    def analyze_diversity(self, sequences: List[str], 
                         n_clusters: int = 3) -> Tuple[np.ndarray, List[int]]:
        """
        Analyze the diversity of a set of sequences using dimensionality reduction
        and clustering.
        
        Args:
            sequences: List of protein sequences
            n_clusters: Number of clusters for KMeans
            
        Returns:
            Tuple of (TSNE coordinates, cluster assignments)
        """
        # Generate embeddings for all sequences
        embeddings = []
        for seq in sequences:
            emb = self.get_embedding(seq)
            embeddings.append(emb.mean(dim=0).numpy())
        
        # Convert to numpy array
        embeddings_array = np.array(embeddings)
        
        # Perform t-SNE for visualization
        tsne = TSNE(n_components=2, random_state=42)
        tsne_coords = tsne.fit_transform(embeddings_array)
        
        # Perform clustering
        kmeans = KMeans(n_clusters=n_clusters, random_state=42)
        clusters = kmeans.fit_predict(embeddings_array)
        
        return tsne_coords, clusters

# Example usage
def main():
    # Example sequence (replace with your sequence of interest)
    sequence = "MVKVGVNGFGRIGRLVTRAAFNSGKVDIVAINDPFIDLNYMVYMFQYDSTHGKFHGTVKAENGKLVINGNPITIFQERDPSKIKWGDAGAEYVVESTGVFTTMEKAGAHLQGGAKRVIISAPSADAPMFVMGVNHEKYDNSLKIISNASCTTNCLAPLAKVIHDNFGIVEGLMTTVHAITATQKTVDGPSGKLWRDGRGALQNIIPASTGAAKAVGKVIPELDGKLTGMAFRVPTANVSVVDLTCRLEKPAKYDDIKKVVKQASEGPLKGILGYTEHQVVSSDFNSDTHSSTFDAGAGIALNDHFVKLISWYDNEFGYSNRVVDLMAHMASKE"
    
    # Initialize diversifier
    diversifier = ProteinDiversifier()
    
    # Generate mutations
    mutated_sequences = diversifier.generate_mutations(sequence, num_mutations=5)
    
    # Analyze diversity
    all_sequences = [sequence] + mutated_sequences
    tsne_coords, clusters = diversifier.analyze_diversity(all_sequences)
    
    print("Original sequence:", sequence)
    print("\nGenerated mutations:")
    for i, mut_seq in enumerate(mutated_sequences, 1):
        print(f"Mutation {i}:", mut_seq)
    
    print("\nCluster assignments:", clusters)


main()