In [11]:
import torch
from torch.utils.data import Dataset, DataLoader, Sampler
import random
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader
from pymongo import MongoClient
from functools import reduce
from collections import defaultdict
#import proteusAI as pai
import os
from typing import Union
from pathlib import Path
import esm
#For knowledge distillation
from sklearn.metrics.pairwise import cosine_similarity
import torch.nn.functional as functional

#### Defining esm_compute with smaller ESM-2 model

In [2]:
def esm2_compute(seqs: list, names: list=None, model: Union[str, torch.nn.Module]="student", rep_layer: int=33, device=None):
    """
    Compute the of esm_tools models for a list of sequences.
 
    Args:
        seqs (list): protein sequences either as str or biotite.sequence.ProteinSequence.
        names (list, default None): list of names/labels for protein sequences.
            If None sequences will be named seq1, seq2, ...
        model (str, torch.nn.Module): choose either esm2, esm1v or a pretrained model object.
        rep_layer (int): choose representation layer. Default 33.
        device (str): Choose hardware for computation. Default 'None' for autoselection
                          other options are 'cpu' and 'cuda'.
 
    Returns: representations (list) of sequence representation, batch lens and batch labels
 
    Example:
        seqs = ["AGAVCTGAKLI", "AGHRFLIKLKI"]
        results, batch_lens, batch_labels = esm_compute(seqs)
    """
    # detect device
    if device == None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    else:
        device = torch.device(device)
 
    # on M1 if mps available
    #if device == torch.device(type='cpu'):
    #    device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
 
    # load model
    if isinstance(model, str):
        if model == "student":
            model, alphabet = esm.pretrained.esm2_t6_8M_UR50D()
        elif model == "teacher":
            model, alphabet = esm.pretrained.esm2_t12_35M_UR50D()
        else:
            raise ValueError(f"{model} is not a valid model")
    elif isinstance(model, torch.nn.Module):
        alphabet = torch.load(os.path.join(Path(__file__).parent, "alphabet.pt"))
    else:
        raise TypeError("Model should be either a string or a torch.nn.Module object")
 
 
    batch_converter = alphabet.get_batch_converter()
    model.eval()
    model.to(device)
 
    if names == None:
        names = names = [f'seq{i}' for i in range(len(seqs))]
 
    data = list(zip(names, seqs))
 
    # check datatype of sequences - str or biotite
    if all(isinstance(x[0], str) and isinstance(x[1], str) for x in data):
        pass  # all elements are strings
    else:
        data = [(x[0], str(x[1])) for x in data]
 
    batch_labels, batch_strs, batch_tokens = batch_converter(data)
    batch_lens = (batch_tokens != alphabet.padding_idx).sum(1)
 
    # Extract per-residue representations (on CPU)
    with torch.no_grad():
        results = model(batch_tokens.to(device), repr_layers=[rep_layer], return_contacts=True)
 
    return results, batch_lens, batch_labels, alphabet

#### Rest of the functions

In [3]:
def get_taxon_sequence_data(collection):

    # Define the query filter to exclude documents where sequence.length > 1024
    query = {
        "sequence.length": {"$lte": 1024}
    }
    projection = {
    "primaryAccession": 1,    
    "organism.taxonId": 1,   # Include taxonId from the organism field
    "sequence.value": 1,      # Include value from the sequence field
    "sequence.length": 1      # Include length from the sequence field
    }
    documents =  collection.find(query, projection)

    minimal_documents = [] # Initialize new empty dictionary

    for doc in documents:
    # Create a new dictionary with only the desired properties
        new_obj = {
            "primaryAccession": doc.get("primaryAccession",{}),
            "taxonId": doc.get("organism", {}).get("taxonId"),
            "value": doc.get("sequence", {}).get("value"),
            "length": doc.get("sequence", {}).get("length")
        }
        minimal_documents.append(new_obj)


    return minimal_documents

def get_seq_rep(results, batch_lens, rep_layer = 33):
    """
    Get sequence representations from esm_compute
    """
    token_representations = results["representations"][rep_layer]
 
    # Generate per-sequence representations via averaging
    sequence_representations = []
    for i, tokens_len in enumerate(batch_lens):
        sequence_representations.append(token_representations[i, 1: tokens_len - 1].mean(0))
 
    return sequence_representations
 
 
def get_logits(results):
    """
    Get logits from esm_compute
    """
    logits = results["logits"]
    return logits

#### Dummy test with student model

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
seqs = ["AGAVCTGAKLI", "AGHRFLIKLKI"]
results, batch_lens, batch_labels, alphabet = esm2_compute(seqs, rep_layer=0)

#### Extract representations and logits from results

In [6]:
representations_student = get_seq_rep(results, batch_lens, rep_layer = 0)
logits_student = get_logits(results)

#### Using teacher model

In [7]:
results, batch_lens, batch_labels, alphabet = esm2_compute(seqs, rep_layer=0, model = "teacher")
representations_teacher = get_seq_rep(results, batch_lens, rep_layer = 0)
logits_teacher = get_logits(results)

#### Function for addressing differences in dimensionality between kernels

In [8]:
def pad_to_match(teacher_kernel, student_kernel):
    """
    Just a precaution function. It assures that tokens embeddings in both teacher and student
    representations have the same shape. This will apply zero-padding the kernel with less dimensions.
    """
    rows = max(teacher_kernel.shape[0], student_kernel.shape[0])
    cols = max(teacher_kernel.shape[1], student_kernel.shape[1])
    new_teacher_kernel = functional.pad(teacher_kernel, (0, cols - teacher_kernel.shape[1], 
                                                            0, rows - teacher_kernel.shape[0]))
    new_student_kernel = functional.pad(student_kernel, (0, cols - student_kernel.shape[1], 
                                                            0, rows - student_kernel.shape[0]))
    return new_teacher_kernel, new_student_kernel

#### Distillation loss functions

In [9]:
mse_loss = nn.MSELoss()

def kernel_similarity_matrix(kernel):
    """
    Calculates the cosine similarity between each pair of token embeddings on the kernel
    """
    return cosine_similarity(kernel.cpu().detach().numpy())

def kernel_mse_alignment_loss(teacher_kernel, student_kernel):
    """
    Calculates the MSE kernel alignment loss between teacher and student
    """
    teacher_matrix = torch.tensor(kernel_similarity_matrix(teacher_kernel))
    student_matrix = torch.tensor(kernel_similarity_matrix(student_kernel))

    if teacher_matrix.shape != student_matrix.shape:
        teacher_matrix, student_matrix = pad_to_match(teacher_matrix, student_matrix)

    return mse_loss(teacher_matrix, student_matrix)

def logits_mse_loss(teacher_logits, student_logits):
    """
    Calculates the MSE loss between teacher and student logits
    """
    return mse_loss(teacher_logits, student_logits)

def distillation_loss(teacher_rep, teacher_logits, student_rep, student_logits, weight_rep = 1, weight_logits = 1):
    """
    Calculates the combined MSE loss between kernel alignment and logits
    """
    alignment_loss = kernel_mse_alignment_loss(teacher_rep, student_rep)
    logits_loss = logits_mse_loss(teacher_logits, student_logits)
    return weight_rep * alignment_loss + weight_logits * logits_loss

#### Testing KD loss computations

In [10]:
student_logits = logits_student
teacher_logits = logits_teacher

student_rep = torch.stack(representations_student)
teacher_rep = torch.stack(representations_teacher)



kd_loss = distillation_loss(teacher_rep, teacher_logits, student_rep, student_logits)

print("Knowledge distillation loss (dummy data): {}".format(kd_loss))

Knowledge distillation loss (dummy data): 0.6618165969848633
