### import from loss functions


In [1]:
import torch
import torch.nn as nn
#For knowledge distillation
from sklearn.metrics.pairwise import cosine_similarity
import torch.nn.functional as functional

mse_loss = nn.MSELoss()

def pad_to_match(teacher_kernel, student_kernel):
    """
    Just a precaution function. It assures that tokens embeddings in both teacher and student
    representations have the same shape. This will apply zero-padding the kernel with less dimensions.
    """
    rows = max(teacher_kernel.shape[0], student_kernel.shape[0])
    cols = max(teacher_kernel.shape[1], student_kernel.shape[1])
    new_teacher_kernel = functional.pad(teacher_kernel, (0, cols - teacher_kernel.shape[1], 
                                                            0, rows - teacher_kernel.shape[0]))
    new_student_kernel = functional.pad(student_kernel, (0, cols - student_kernel.shape[1], 
                                                            0, rows - student_kernel.shape[0]))
    return new_teacher_kernel, new_student_kernel


def kernel_similarity_matrix(kernel):
    """
    Calculates the cosine similarity between each pair of token embeddings on the kernel
    """
    return cosine_similarity(kernel.cpu().detach().numpy())

def kernel_mse_alignment_loss(teacher_kernel, student_kernel):
    """
    Calculates the MSE kernel alignment loss between teacher and student
    """
    teacher_matrix = torch.tensor(kernel_similarity_matrix(teacher_kernel))
    student_matrix = torch.tensor(kernel_similarity_matrix(student_kernel))

    if teacher_matrix.shape != student_matrix.shape:
        teacher_matrix, student_matrix = pad_to_match(teacher_matrix, student_matrix)

    return mse_loss(teacher_matrix, student_matrix)

def logits_mse_loss(teacher_logits, student_logits):
    """
    Calculates the MSE loss between teacher and student logits
    """
    return mse_loss(teacher_logits, student_logits)


class DistillationLoss(nn.Module):
    def __init__(self, weight_rep=1.0, weight_logits=1.0):
        super(DistillationLoss, self).__init__()
        self.weight_rep = weight_rep
        self.weight_logits = weight_logits

    def forward(self, teacher_rep, teacher_logits, student_rep, student_logits):

        alignment_loss = kernel_mse_alignment_loss(teacher_rep, student_rep)
        logits_loss = logits_mse_loss(teacher_logits, student_logits)
        return self.weight_rep * alignment_loss + self.weight_logits * logits_loss


### import from proteus

In [8]:
def get_seq_rep(results, batch_lens):
    """
    Get sequence representations from esm_compute
    """
    print(len(results["representations"]))
    token_representations = results["representations"][33]
 
    # Generate per-sequence representations via averaging
    sequence_representations = []
    for i, tokens_len in enumerate(batch_lens):
        sequence_representations.append(token_representations[i, 1: tokens_len - 1].mean(0))
 
    return sequence_representations
 

def get_logits(results):
    """
    Get logits from esm_compute
    """
    logits = results["logits"]
    return logits

### import from dataloader

In [3]:
import torch
from torch.utils.data import Dataset, DataLoader, Sampler
import random
import torch
from torch.utils.data import Dataset, DataLoader
from collections import defaultdict
import esm


def dict_collate_fn(batch):
# Return the batch as-is (a list of dictionaries)
    return batch


def connect_db():
    # Connect to MongoDB with connection string
    string_path = "mongodb://127.0.0.1:27017/?directConnection=true&serverSelectionTimeoutMS=2000&appName=mongosh+2.3.3"
    
    client = MongoClient(string_path)

    # Create db
    db = client['proteins']

    # Access collection
    collection = db['uniref_test']

    return client, db, collection


def get_taxon_sequence_data(collection):

    # Define the query filter to exclude documents where sequence.length > 1024
    query = {
        "sequence.length": {"$lte": 1024}
    }
    projection = {
    "primaryAccession": 1,    
    "organism.taxonId": 1,   # Include taxonId from the organism field
    "sequence.value": 1,      # Include value from the sequence field
    "sequence.length": 1      # Include length from the sequence field
    }
    documents =  collection.find(query, projection)

    minimal_documents = [] # Initialize new empty dictionary

    for doc in documents:
    # Create a new dictionary with only the desired properties
        new_obj = {
            "primaryAccession": doc.get("primaryAccession",{}),
            "taxonId": doc.get("organism", {}).get("taxonId"),
            "value": doc.get("sequence", {}).get("value"),
            "length": doc.get("sequence", {}).get("length")
        }
        minimal_documents.append(new_obj)

    return minimal_documents


class ProteinDataset(Dataset):

    def __init__(self, data):

        self.data = data

    def __len__(self):

        return len(self.data)

    def __getitem__(self, index):

        item = self.data[index]

        # Extract relevant fields
        sequence = item['value']  # The protein sequence
        length = item['length']   # Length of the sequence
        taxon_id = item['taxonId']  # Taxon ID
        primary_accession = item['primaryAccession']  # Primary accession

        return {
            'sequence': sequence,  # Return sequence as a string (you could modify this later)
            'length': length,
            'taxon_id': taxon_id,
            'primary_accession': primary_accession
        }


class TaxonIdSampler(Sampler):
    def __init__(self, dataset: Dataset, batch_size, length_bin_size=5, shuffle=True):
        self.dataset = dataset
        self.batch_size = batch_size
        self.length_bin_size = length_bin_size
        self.shuffle = shuffle

        # Group sample indices by taxonId
        self.taxon_length_bins = defaultdict(lambda: defaultdict(list))

        for idx, sample in enumerate(dataset):
            taxon_id = sample['taxon_id']
            sequence_length = sample['length']
            length_bin = (sequence_length // length_bin_size) * length_bin_size  # integer division to know in which bucket the sequence is

            # Ensure that length_bin is properly initialized
            self.taxon_length_bins[taxon_id][length_bin].append(idx)

        '''
        structure of self.taxon_length_bins:

        {
            taxon_id_1: {
                length_bin_1: [sample_idx_1, sample_idx_2, ...],
                length_bin_2: [sample_idx_3, sample_idx_4, ...],
                ...
            },
            taxon_id_2: {
                length_bin_3: [sample_idx_5, sample_idx_6, ...],
                ...
            },
        }
        '''
        
        # Prepare batches based on taxon groups
        self.batches = []

        for taxon, length_bins in self.taxon_length_bins.items():
            for length_bin, indices in length_bins.items():
                if self.shuffle:
                    random.shuffle(indices)  # Shuffle the indices if needed
                for i in range(0, len(indices), batch_size):
                    self.batches.append(indices[i:i + batch_size])

        # Shuffle the batches if needed
        if self.shuffle:
            random.shuffle(self.batches)

    def __iter__(self):
        for batch in self.batches:
            yield batch

    def __len__(self):
        return len(self.batches)

#### prepare batches and models

In [4]:
batch_size = 1
num_epochs = 1
learning_rate = 1e-4
weight_rep = 0.5
weight_logits = 0.5

checkpoints = True
cp_dir = "checkpoints"
cp_freq = 200

# get data
#_, _, collection = connect_db()
#dataset = ProteinDataset(get_taxon_sequence_data(collection))

############################ TEMP FOR TESTING
import json
import pandas as pd

def get_taxon_sequence_data2(collection):

    documents = collection['results']

    minimal_documents = [] # Initialize new empty dictionary

    for doc in documents:
    # Create a new dictionary with only the desired properties
        new_obj = {
            "primaryAccession": doc.get("primaryAccession",{}),
            "taxonId": doc.get("organism", {}).get("taxonId"),
            "value": doc.get("sequence", {}).get("value"),
            "length": doc.get("sequence", {}).get("length")
        }
        minimal_documents.append(new_obj)
    return minimal_documents

with open('../data/processed/uniref100_test.json', 'r') as file :
    json_data = json.load(file)
collection = pd.read_json(json.dumps(json_data))
dataset = ProteinDataset(get_taxon_sequence_data2(collection))
############################################

sampler = TaxonIdSampler(dataset, batch_size=batch_size, shuffle=True)
dataloader = DataLoader(dataset, batch_sampler=sampler, collate_fn=dict_collate_fn)

# load models
teacher_model, alphabet = esm.pretrained.esm2_t12_35M_UR50D()
student_model, _ = esm.pretrained.esm2_t6_8M_UR50D()

# initialize batch converter
batch_converter = alphabet.get_batch_converter()

# train only student
teacher_model.eval()
student_model.train()

# Detect device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Available device: ", device)
teacher_model.to(device)
student_model.to(device)

# define optimizer and loss
optimizer = torch.optim.Adam(student_model.parameters(), lr=learning_rate)
distillation_loss = DistillationLoss(weight_rep=1.0, weight_logits=1.0)

  collection = pd.read_json(json.dumps(json_data))
Downloading: "https://dl.fbaipublicfiles.com/fair-esm/models/esm2_t12_35M_UR50D.pt" to C:\Users\Kacper/.cache\torch\hub\checkpoints\esm2_t12_35M_UR50D.pt
Downloading: "https://dl.fbaipublicfiles.com/fair-esm/regression/esm2_t12_35M_UR50D-contact-regression.pt" to C:\Users\Kacper/.cache\torch\hub\checkpoints\esm2_t12_35M_UR50D-contact-regression.pt
Downloading: "https://dl.fbaipublicfiles.com/fair-esm/models/esm2_t6_8M_UR50D.pt" to C:\Users\Kacper/.cache\torch\hub\checkpoints\esm2_t6_8M_UR50D.pt
Downloading: "https://dl.fbaipublicfiles.com/fair-esm/regression/esm2_t6_8M_UR50D-contact-regression.pt" to C:\Users\Kacper/.cache\torch\hub\checkpoints\esm2_t6_8M_UR50D-contact-regression.pt


Available device:  cpu


### training loop

In [9]:
# training loop
for epoch in range(num_epochs):

    for batch in dataloader:

        # extract sequences and names from the batch
        sequences = [item['sequence'] for item in batch]
        names = [item['primary_accession'] for item in batch]

        # prepare data for batch conversion
        if names is None:
            names = [f'seq{i}' for i in range(len(sequences))]
        data = list(zip(names, sequences))

        # check datatype of sequences - str or biotite
        if all(isinstance(x[0], str) and isinstance(x[1], str) for x in data):
            pass  # all elements are strings
        else:
            data = [(x[0], str(x[1])) for x in data]

        # convert data to batch tensors
        batch_labels, batch_strs, batch_tokens = batch_converter(data)
        batch_lens = (batch_tokens != alphabet.padding_idx).sum(1)
        batch_tokens = batch_tokens.to(device)

        # zero the gradients
        optimizer.zero_grad()

        # forward pass - teacher
        with torch.no_grad():
            teacher_res = teacher_model(batch_tokens, repr_layers=[33], return_contacts=False)
            teacher_logits = get_logits(teacher_res)
            teacher_reps = get_seq_rep(teacher_res, batch_lens)

        # forward pass - student
        student_res = student_model(batch_tokens, repr_layers=[6], return_contacts=False)
        student_logits = get_logits(student_res)
        student_reps = get_seq_rep(student_res, batch_lens)

        # compute loss and backprop
        loss = distillation_loss(teacher_reps, teacher_logits, student_reps, student_logits)
        loss.backward()
        optimizer.step()

    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item()}')

    # tensorflow-like checkpoints
    if checkpoints:
        if (epoch + 1) % cp_freq == 0:
            path = f'cp_epoch_{epoch+1}.pt'
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': student_model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss.item(),
            }, path)
            print(f'Checkpoint saved: {path}')


0


KeyError: 33