In [1]:
import torch
from torch.utils.data import Dataset, DataLoader, Sampler
import random
import torch
from torch.utils.data import Dataset, DataLoader
from pymongo import MongoClient
from functools import reduce
from collections import defaultdict
import proteusAI as pai
import os
from typing import Union
from pathlib import Path
import esm

  alphabet = torch.load(os.path.join(Path(__file__).parent, "alphabet.pt"))


# Functions required(made by johnny):

In [2]:
def esm_compute(seqs: list, names: list=None, model: Union[str, torch.nn.Module]="esm1v", rep_layer: int=33, device=None):
    """
    Compute the of esm_tools models for a list of sequences.
 
    Args:
        seqs (list): protein sequences either as str or biotite.sequence.ProteinSequence.
        names (list, default None): list of names/labels for protein sequences.
            If None sequences will be named seq1, seq2, ...
        model (str, torch.nn.Module): choose either esm2, esm1v or a pretrained model object.
        rep_layer (int): choose representation layer. Default 33.
        device (str): Choose hardware for computation. Default 'None' for autoselection
                          other options are 'cpu' and 'cuda'.
 
    Returns: representations (list) of sequence representation, batch lens and batch labels
 
    Example:
        seqs = ["AGAVCTGAKLI", "AGHRFLIKLKI"]
        results, batch_lens, batch_labels = esm_compute(seqs)
    """
    # detect device
    if device == None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    else:
        device = torch.device(device)
 
    # on M1 if mps available
    #if device == torch.device(type='cpu'):
    #    device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
 
    # load model
    if isinstance(model, str):
        if model == "esm2":
            model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
        elif model == "esm1v":
            model, alphabet = esm.pretrained.esm1v_t33_650M_UR90S()
        else:
            raise ValueError(f"{model} is not a valid model")
    elif isinstance(model, torch.nn.Module):
        alphabet = torch.load(os.path.join(Path(__file__).parent, "alphabet.pt"))
    else:
        raise TypeError("Model should be either a string or a torch.nn.Module object")
 
 
    batch_converter = alphabet.get_batch_converter()
    model.eval()
    model.to(device)
 
    if names == None:
        names = names = [f'seq{i}' for i in range(len(seqs))]
 
    data = list(zip(names, seqs))
 
    # check datatype of sequences - str or biotite
    if all(isinstance(x[0], str) and isinstance(x[1], str) for x in data):
        pass  # all elements are strings
    else:
        data = [(x[0], str(x[1])) for x in data]
 
    batch_labels, batch_strs, batch_tokens = batch_converter(data)
    batch_lens = (batch_tokens != alphabet.padding_idx).sum(1)
 
    # Extract per-residue representations (on CPU)
    with torch.no_grad():
        results = model(batch_tokens.to(device), repr_layers=[rep_layer], return_contacts=True)
 
    return results, batch_lens, batch_labels, alphabet

# Get representations
def get_seq_rep(results, batch_lens):
    """
    Get sequence representations from esm_compute
    """
    token_representations = results["representations"][33]
 
    # Generate per-sequence representations via averaging
    sequence_representations = []
    for i, tokens_len in enumerate(batch_lens):
        sequence_representations.append(token_representations[i, 1: tokens_len - 1].mean(0))
 
    return sequence_representations
 
def get_logits(results):
    """
    Get logits from esm_compute
    """
    logits = results["logits"]
    return logits

# Dataloader for with mongo DB


In [62]:
# Connects to the internal db and reads the example within it.
def connect_db():
    # Connect to MongoDB with connection string
    string_path = "mongodb://127.0.0.1:27017/?directConnection=true&serverSelectionTimeoutMS=2000&appName=mongosh+2.3.3"
    
    client = MongoClient(string_path)

    # Create db
    db = client['proteins']

    # Access collection
    collection = db['uniref_test']

    return client, db, collection

# Selects correct data from the db retrieved data:
def get_taxon_sequence_data(collection):

    # Define the query filter to exclude documents where sequence.length > 1024
    query = {
        "sequence.length": {"$lte": 1024}
    }
    projection = {
    "primaryAccession": 1,    
    "organism.taxonId": 1,   # Include taxonId from the organism field
    "sequence.value": 1,      # Include value from the sequence field
    "sequence.length": 1      # Include length from the sequence field
    }
    documents =  collection.find(query, projection)

    minimal_documents = [] # Initialize new empty dictionary

    for doc in documents:
    # Create a new dictionary with only the desired properties
        new_obj = {
            "primaryAccession": doc.get("primaryAccession",{}),
            "taxonId": doc.get("organism", {}).get("taxonId"),
            "value": doc.get("sequence", {}).get("value"),
            "length": doc.get("sequence", {}).get("length")
        }
        minimal_documents.append(new_obj)


    return minimal_documents


def get_taxon_sequence_data(collection):

    documents = collection['results']

    minimal_documents = [] # Initialize new empty dictionary

    for doc in documents:
    # Create a new dictionary with only the desired properties
        new_obj = {
            "primaryAccession": doc.get("primaryAccession",{}),
            "taxonId": doc.get("organism", {}).get("taxonId"),
            "value": doc.get("sequence", {}).get("value"),
            "length": doc.get("sequence", {}).get("length")
        }
        minimal_documents.append(new_obj)


    return minimal_documents
    

In [6]:
client, db, collection = connect_db()

data = get_taxon_sequence_data(collection=collection)

In [11]:
data[0]

{'primaryAccession': 'O13437',
 'taxonId': 5477,
 'value': 'MKIVLVLYDAGKHAADEEKLYGCTENKLGIANWLKDQGHELITTSDKEGETSELDKHIPDADIIITTPFHPAYITKERLDKAKNLKLVVVAGVGSDHIDLDYINQTGKKISVLEVTGSNVVSVAEHVVMTMLVLVRNFVPAHEQIINHDWEVAAIAKDAYDIEGKTIATIGAGRIGYRVLERLLPFNPKELLYYDYQALPKEAEEKVGARRVENIEELVAQADIVTVNAPLHAGTKGLINKELLSKFKKGAWLVNTARGAICVAEDVAAALESGQLRGYGGDVWFPQPAPKDHPWRDMRNKYGAGNAMTPHYSGTTLDAQTRYAEGTKNILESFFTGKFDYRPQDIILLNGEYVTKAYGKHDKK',
 'length': 364}

In [46]:
class ProteinDataset(Dataset):

    def __init__(self, data):

        self.data = data
        #data[' length'] = data[' length'].str.replace(r'\^\^<.*?>', '', regex=True).astype(int)
        #data[' taxon'] = data[' taxon'].str.extract(r'h.*/(\d+)/?$')[0].astype(int)
    def __len__(self):

        return len(self.data)
    
    def __getitem__(self, index):
        if isinstance(index, list):
            return [self.data[i] for i in index]
        
        item = self.data[index]

        # Extract relevant fields
        sequence = item['value']  # The protein sequence
        length = item['length']   # Length of the sequence
        taxon_id = item['taxonId']  # Taxon ID
        primary_accession = item['primaryAccession']  # Primary accession

        return {
            'sequence': sequence,  # Return sequence as a string (you could modify this later)
            'length': length,
            'taxon_id': taxon_id,
            'primary_accession': primary_accession
        }

## Dummy sampler
class CustomSampler(Sampler):

    def __init__(self, data_source : Dataset, shuffle = True):

        self.data_source = data_source
        self.indices = list(range(len(data_source)))
        self.shuffle = shuffle

    def __iter__(self):
        
        if self.shuffle:
                random.shuffle(self.indices)  # Shuffle the indices if needed
            
        # Return the indices one by one (you could also apply custom logic here)
        for idx in self.indices:
            yield idx
            
    def __len__(self):
        """
        Return the number of items in the dataset.
        """
        return len(self.data_source)


class TaxonIdSampler(Sampler):
    def __init__(self, dataset: Dataset, batch_size, length_bin_size=5, shuffle=True):
        self.dataset = dataset
        self.batch_size = batch_size
        self.length_bin_size = length_bin_size
        self.shuffle = shuffle

        # Group sample indices by taxonId
        self.taxon_length_bins = defaultdict(lambda: defaultdict(list))

        for idx, sample in enumerate(dataset):
            taxon_id = sample['taxon_id']
            sequence_length = sample['length']
            length_bin = (sequence_length // length_bin_size) * length_bin_size  # integer division to know in which bucket the sequence is

            # Ensure that length_bin is properly initialized
            self.taxon_length_bins[taxon_id][length_bin].append(idx)

        '''
        structure of self.taxon_length_bins:

        {
            taxon_id_1: {
                length_bin_1: [sample_idx_1, sample_idx_2, ...],
                length_bin_2: [sample_idx_3, sample_idx_4, ...],
                ...
            },
            taxon_id_2: {
                length_bin_3: [sample_idx_5, sample_idx_6, ...],
                ...
            },
        }
        '''
        
        # Prepare batches based on taxon groups
        self.batches = []

        for taxon, length_bins in self.taxon_length_bins.items():
            for length_bin, indices in length_bins.items():
                if self.shuffle:
                    random.shuffle(indices)  # Shuffle the indices if needed
                for i in range(0, len(indices), batch_size):
                    self.batches.append(indices[i:i + batch_size])

        # Shuffle the batches if needed
        if self.shuffle:
            random.shuffle(self.batches)

    def __iter__(self):
        for batch in self.batches:
            yield batch

    def __len__(self):
        return len(self.batches)

def dict_collate_fn(batch):
    # Return the batch as-is (a list of dictionaries)
    return batch

In [31]:
datasetdb = ProteinDataset(data)
#samplerdb = CustomSampler(datasetdb, shuffle=True)
samplerdb = TaxonIdSampler(datasetdb, batch_size = 5 , shuffle=True)
#dataloaderdb = DataLoader(dataset = datasetdb ,batch_size = 5, sampler = samplerdb)
dataloaderdb = DataLoader(dataset = datasetdb , sampler = samplerdb, collate_fn=dict_collate_fn)

In [32]:
for batch in dataloaderdb:
    print(batch)

[[{'primaryAccession': 'A6ZN46', 'taxonId': 307796, 'value': 'MSKGKVLLVLYEGGKHAEEQEKLLGCIENELGIRNFIEEQGYELVTTIDKDPEPTSTVDRELKDAEIVITTPFFPAYISRNRIAEAPNLKLCVTAGVGSDHVDLEAANERKITVTEVTGSNVVSVAEHVMATILVLIRNYNGGHQQAINGEWDIAGVAKNEYDLEDKIISTVGAGRIGYRVLERLVAFNPKKLLYYDYQELPAEAINRLNEASKLFNGRGDIVQRVEKLEDMVAQSDVVTINCPLHKDSRGLFNKKLISHMKDGAYLVNTARGAICVAEDVAEAVKSGKLAGYGGDVWDKQPAPKDHPWRTMDNKDHVGNAMTVHISGTSLDAQKRYAQGVKNILNSYFSKKFDYRPQDIIVQNGSYATRAYGQKK', 'length': 376}]]
[[{'primaryAccession': 'P33160', 'taxonId': 33067, 'value': 'MAKVLCVLYDDPVDGYPKTYARDDLPKIDHYPGGQTLPTPKAIDFTPGQLLGSVSGELGLRKYLESNGHTLVVTSDKDGPDSVFERELVDADVVISQPFWPAYLTPERIAKAKNLKLALTAGIGSDHVDLQSAIDRNVTVAEVTYCNSISVAEHVVMMILSLVRNYLPSHEWARKGGWNIADCVSHAYDLEAMHVGTVAAGRIGLAVLRRLAPFDVHLHYTDRHRLPESVEKELNLTWHATREDMYPVCDVVTLNCPLHPETEHMINDETLKLFKRGAYIVNTARGKLCDRDAVARALESGRLAGYAGDVWFPQPAPKDHPWRTMPYNGMTPHISGTTLTAQARYAAGTREILECFFEGRPIRDEYLIVQGGALAGTGAHSYSKGNATGGSEEAAKFKKAV', 'length': 401}]]
[[{'primaryAccession': 'G0SGU4', 'taxonId': 759272, 'value':

# Dataloader for .csv

In [35]:
import csv
import pandas as pd

# with open('../data/raw/uniref100_sample_n20.csv', newline='') as csvfile:
#     csvreader = csv.reader(csvfile)
#     for row in csvreader:
#         print(row)  # Each row is a list of values

df = pd.read_csv('../data/raw/uniref100_sample_n20.csv')
df.head()


Unnamed: 0,clusterid,kingdomid,proteinid,length,sequence,taxon
0,UniRef100_Q181B5,10239,A0A125YE16,92^^<http://www.w3.org/2001/XMLSchema#int>,MKITDVRVRKLTEEGKMKCIVSITFDNLFVVHDIKVIEGHNGLFIA...,http://purl.uniprot.org/taxonomy/1263196
1,UniRef100_Q181B5,10239,A0A125YE16,92^^<http://www.w3.org/2001/XMLSchema#int>,MKITDVRVRKLTEEGKMKCIVSITFDNLFVVHDIKVIEGHNGLFIA...,http://purl.uniprot.org/taxonomy/1263198
2,UniRef100_Q181B5,10239,A0A125YE16,92^^<http://www.w3.org/2001/XMLSchema#int>,MKITDVRVRKLTEEGKMKCIVSITFDNLFVVHDIKVIEGHNGLFIA...,http://purl.uniprot.org/taxonomy/1266660
3,UniRef100_Q181B5,10239,A0A125YE16,92^^<http://www.w3.org/2001/XMLSchema#int>,MKITDVRVRKLTEEGKMKCIVSITFDNLFVVHDIKVIEGHNGLFIA...,http://purl.uniprot.org/taxonomy/1269277
4,UniRef100_Q181B5,10239,A0A125YE16,92^^<http://www.w3.org/2001/XMLSchema#int>,MKITDVRVRKLTEEGKMKCIVSITFDNLFVVHDIKVIEGHNGLFIA...,http://purl.uniprot.org/taxonomy/1269278


In [36]:
class ProteinDataset(Dataset):

    def __init__(self, data):

        self.data = data
        data[' length'] = data[' length'].str.replace(r'\^\^<.*?>', '', regex=True).astype(int)
        data[' taxon'] = data[' taxon'].str.extract(r'h.*/(\d+)/?$')[0].astype(int)
    def __len__(self):

        return len(self.data)
    
    def __getitem__(self, index):

        
        item = self.data.iloc[index]

        # Extract relevant fields
        sequence = item[' sequence']  # The protein sequence
        length = item[' length']   # Length of the sequence
        taxon_id = item[' kingdomid']  # Taxon ID
        primary_accession = item[' proteinid']  # Primary accession

        return {
            'sequence': sequence,  # Return sequence as a string (you could modify this later)
            'length': length,
            'taxon_id': taxon_id,
            'primary_accession': primary_accession
        }
    

In [None]:


#df = pd.read_csv('../../data_uniprot/cleaned_file.csv')
#df = pd.read_csv('../data/raw/uniref100_sample_n20.csv')

#df[' length'] = df[' length'].str.replace(r'\^\^<.*?>', '', regex=True).astype(int)
#df[' taxon'] = df[' taxon'].str.extract(r'h.*/(\d+)/?$')[0].astype(int)
#df[' length'] = df[' length'].str.replace(r'\^\^<.*?>', '', regex=True)#.astype(int)
#item = df.iloc[1]
#df.head()
#len(df)
#print(item[" length"])
# TO do:
# Remove strings from length
# Convert numbers string into  int.
#invalid_values = df[~df[' length'].str.replace(r'\^\^<.*?>', '', regex=True).str.isnumeric()]
#print(invalid_values)
#print(df)

In [31]:
header_row = df.iloc[0]
repeated_headers = (df == header_row).all(axis=1)
df_cleaned = df[~repeated_headers].reset_index(drop=True)
print(len(df_cleaned))

35848003


In [30]:
import numpy as np

# Get the header row
header_row = df.iloc[0].values

# Find rows that match the header row
repeated_headers = np.all(df.values == header_row, axis=1)

# Filter out the repeated headers
df_cleaned = df.loc[~repeated_headers].reset_index(drop=True)
print(len(df_cleaned))

35848003


In [36]:
#print(invalid_values)
header_row = df.iloc[0]  # assuming the first row is the actual header
#repeated_headers = (df == header_row).all(axis=1)

# Filter out the repeated headers
#df_cleaned = df[~repeated_headers].reset_index(drop=True)
len(df)
len(df_cleaned)

df.iloc[10000]
df[~df[' length'].str.replace(r'\^\^<.*?>', '', regex=True).str.isnumeric()]

#df.columns = df.columns.str.strip() 

Unnamed: 0,clusterid,kingdomid,proteinid,length,sequence,taxon
1000,clusterid,kingdomid,proteinid,length,sequence,taxon
2001,clusterid,kingdomid,proteinid,length,sequence,taxon
3002,clusterid,kingdomid,proteinid,length,sequence,taxon
4003,clusterid,kingdomid,proteinid,length,sequence,taxon


In [None]:
class ProteinDataset(Dataset):

    def __init__(self, data):

        self.data = data
        data[' length'] = data[' length'].str.replace(r'\^\^<.*?>', '', regex=True).astype(int)
        data[' taxon'] = data[' taxon'].str.extract(r'h.*/(\d+)/?$')[0].astype(int)
    def __len__(self):

        return len(self.data)
    
    def __getitem__(self, index):

        
        item = self.data.iloc[index]

        # Extract relevant fields
        sequence = item[' sequence']  # The protein sequence
        length = item[' length']   # Length of the sequence
        taxon_id = item[' taxon']  # Taxon ID
        primary_accession = item[' proteinid']  # Primary accession

        return {
            'sequence': sequence,  # Return sequence as a string (you could modify this later)
            'length': length,
            'taxon_id': taxon_id,
            'primary_accession': primary_accession
        }
    
class TaxonIdSampler(Sampler):
    def __init__(self, dataset: Dataset, batch_size, length_bin_size=5, shuffle=True):
        self.dataset = dataset
        self.batch_size = batch_size
        self.length_bin_size = length_bin_size
        self.shuffle = shuffle

        # Group sample indices by taxonId
        self.taxon_length_bins = defaultdict(lambda: defaultdict(list))

        for idx, sample in enumerate(dataset):
            taxon_id = sample['taxon_id']
            sequence_length = sample['length']
            length_bin = (sequence_length // length_bin_size) * length_bin_size  # integer division to know in which bucket the sequence is

            # Ensure that length_bin is properly initialized
            self.taxon_length_bins[taxon_id][length_bin].append(idx)

        '''
        structure of self.taxon_length_bins:

        {
            taxon_id_1: {
                length_bin_1: [sample_idx_1, sample_idx_2, ...],
                length_bin_2: [sample_idx_3, sample_idx_4, ...],
                ...
            },
            taxon_id_2: {
                length_bin_3: [sample_idx_5, sample_idx_6, ...],
                ...
            },
        }
        '''
        
        # Prepare batches based on taxon groups
        self.batches = []

        for taxon, length_bins in self.taxon_length_bins.items():
            for length_bin, indices in length_bins.items():
                if self.shuffle:
                    random.shuffle(indices)  # Shuffle the indices if needed
                for i in range(0, len(indices), batch_size):
                    self.batches.append(indices[i:i + batch_size])

        # Shuffle the batches if needed
        if self.shuffle:
            random.shuffle(self.batches)

    def __iter__(self):
        for batch in self.batches:
            yield batch

    def __len__(self):
        return len(self.batches)
    

In [None]:


#dataset = ProteinDataset(data = df)
df[' length']

#dataset[0]
#df[' length'].unique()




#sample = dataset[0]

#print(dataset[0])
# print("Original Sequence:", sample['sequence'])
# print("Masked Sequence:", sample['masked_sequence'])
# print("Mask:", sample['mask'])
# print("Length:", sample['length'])
# print("Taxon ID:", sample['taxon_id'])
# print("Primary Accession:", sample['primary_accession'])

In [38]:
datasetcsv = ProteinDataset(data = df)
samplercsv = TaxonIdSampler(dataset = datasetcsv, batch_size = 5, shuffle = True)
dataloadercsv = DataLoader(dataset = datasetcsv, batch_sampler=samplercsv)

In [39]:
# d1 = ProteinDataset(data = df)
# sampler = CustomSampler(d1, shuffle=True)
# dataloader = DataLoader(d1, batch_size=5, sampler=sampler)

# Iterate over the DataLoader
for batch in dataloadercsv:
    print(batch)  

{'sequence': ['MKITDVRVRKLTEEGKMKCIVSITFDNLFVVHDIKVIEGHNGLFIAMPSRKVGEGNFRDIAHPINAEMRQVLEDAVLQAYHEALVQWEVAAE', 'MKITDVRVRKLTEEGKMKCIVSITFDNLFVVHDIKVIEGHNGLFIAMPSRKVGEGNFRDIAHPINAEMRQVLEDAVLQAYHEALVQWEVAAE', 'MKITDVRVRKLTEEGKMKCIVSITFDNLFVVHDIKVIEGHNGLFIAMPSRKVGEGNFRDIAHPINAEMRQVLEDAVLQAYHEALVQWEVAAE', 'MKITDVRVRKLTEEGKMKCIVSITFDNLFVVHDIKVIEGHNGLFIAMPSRKVGEGNFRDIAHPINAEMRQVLEDAVLQAYHEALVQWEVAAE', 'MKITDVRVRKLTEEGKMKCIVSITFDNLFVVHDIKVIEGHNGLFIAMPSRKVGEGNFRDIAHPINAEMRQVLEDAVLQAYHEALVQWEVAAE'], 'length': tensor([92, 92, 92, 92, 92]), 'taxon_id': tensor([10239, 10239, 10239, 10239, 10239]), 'primary_accession': ['A0A125YE16', 'A0A125YE16', 'A0A125YE16', 'A0A125YE16', 'A0A125YE16']}
{'sequence': ['MKITDVRVRKLTEEGKMKCIVSITFDNLFVVHDIKVIEGHNGLFIAMPSRKVGEGNFRDIAHPINAEMRQVLEDAVLQAYHEALVQWEVAAE', 'MKITDVRVRKLTEEGKMKCIVSITFDNLFVVHDIKVIEGHNGLFIAMPSRKVGEGNFRDIAHPINAEMRQVLEDAVLQAYHEALVQWEVAAE', 'MKITDVRVRKLTEEGKMKCIVSITFDNLFVVHDIKVIEGHNGLFIAMPSRKVGEGNFRDIAHPINAEMRQVLEDAVLQAYHEALVQWEVAAE', 'MKITDVRVRKLT

In [41]:
class TaxonIdSampler(Sampler):
    def __init__(self, dataset: Dataset, batch_size, length_bin_size=5, shuffle=True):
        self.dataset = dataset
        self.batch_size = batch_size
        self.length_bin_size = length_bin_size
        self.shuffle = shuffle

        # Group sample indices by taxonId
        self.taxon_length_bins = defaultdict(lambda: defaultdict(list))

        for idx, sample in enumerate(dataset):
            taxon_id = sample['taxon_id']
            sequence_length = sample['length']
            length_bin = (sequence_length // length_bin_size) * length_bin_size  # integer division to know in which bucket the sequence is

            # Ensure that length_bin is properly initialized
            self.taxon_length_bins[taxon_id][length_bin].append(idx)

        '''
        structure of self.taxon_length_bins:

        {
            taxon_id_1: {
                length_bin_1: [sample_idx_1, sample_idx_2, ...],
                length_bin_2: [sample_idx_3, sample_idx_4, ...],
                ...
            },
            taxon_id_2: {
                length_bin_3: [sample_idx_5, sample_idx_6, ...],
                ...
            },
        }
        '''
        
        # Prepare batches based on taxon groups
        self.batches = []

        for taxon, length_bins in self.taxon_length_bins.items():
            for length_bin, indices in length_bins.items():
                if self.shuffle:
                    random.shuffle(indices)  # Shuffle the indices if needed
                for i in range(0, len(indices), batch_size):
                    self.batches.append(indices[i:i + batch_size])

        # Shuffle the batches if needed
        if self.shuffle:
            random.shuffle(self.batches)

    def __iter__(self):
        for batch in self.batches:
            yield batch

    def __len__(self):
        return len(self.batches)

In [59]:
import random
import torch
from collections import defaultdict
from torch.utils.data import Sampler

class TaxonIdSampler(Sampler):
    def __init__(self, dataset, batch_size, length_bin_size=5, mask_percentage=0.15, shuffle=True, seed=42):
        self.dataset = dataset
        self.batch_size = batch_size
        self.length_bin_size = length_bin_size
        self.mask_percentage = mask_percentage
        self.shuffle = shuffle
        self.seed = seed

        # Group sample indices by taxonId and length bin
        self.taxon_length_bins = defaultdict(lambda: defaultdict(list))
        for idx, sample in enumerate(dataset):
            taxon_id = sample['taxon_id']
            sequence_length = sample['length']
            length_bin = (sequence_length // length_bin_size) * length_bin_size
            self.taxon_length_bins[taxon_id][length_bin].append(idx)

        # Prepare batches based on taxon groups
        self.batches = []
        for taxon, length_bins in self.taxon_length_bins.items():
            for length_bin, indices in length_bins.items():
                if self.shuffle:
                    random.shuffle(indices)
                for i in range(0, len(indices), batch_size):
                    self.batches.append(indices[i:i + batch_size])

        if self.shuffle:
            random.shuffle(self.batches)

    def __iter__(self):
        random.seed(self.seed)  # Ensure reproducibility of masking per epoch
        for batch_indices in self.batches:
            batch_sequences = [self.dataset[idx]['sequence'] for idx in batch_indices]
            max_len = max(len(seq) for seq in batch_sequences)

            # Mask and pad each sequence in the batch
            masked_sequences = []
            masks = []
            for idx in batch_indices:
                sequence_tensor = torch.tensor([ord(c) for c in self.dataset[idx]['sequence']])
                padded_sequence, mask = self.mask_and_pad(sequence_tensor, max_len)
                masked_sequences.append(padded_sequence)
                masks.append(mask)

            # Yield the masked and padded batch
            yield {
                'masked_sequences': torch.stack(masked_sequences),
                'masks': torch.stack(masks),
                'lengths': torch.tensor([len(self.dataset[idx]['sequence']) for idx in batch_indices]),
                'taxon_id': torch.tensor([self.dataset[idx]['taxon_id'] for idx in batch_indices])
            }

    def mask_and_pad(self, sequence, max_len):
        """Apply masking and pad the sequence to max_len."""
        num_mask = int(self.mask_percentage * len(sequence))
        mask = torch.zeros(max_len, dtype=torch.bool)
        masked_sequence = torch.full((max_len,), 0, dtype=sequence.dtype)  # Assuming '0' is the mask token
        
        # Copy sequence to masked_sequence
        masked_sequence[:len(sequence)] = sequence
        mask_indices = random.sample(range(len(sequence)), num_mask)

        for idx in mask_indices:
            prob = random.random()
            if prob < 0.8:
                masked_sequence[idx] = 0  # Mask token
                mask[idx] = True
            elif prob < 0.9:
                masked_sequence[idx] = random.randint(1, 20)  # Random integer mutation
            # 10% remains unchanged

        return masked_sequence, mask

    def __len__(self):
        return len(self.batches)

In [73]:
def dict_collate_fn(batch):
    # Return the batch as-is (a list of dictionaries)
    return batch


import torch
import random

def batch_masking_collate_fn(batch, mask_ratio=0.15, mask_prob=0.8, mutate_prob=0.1, unchanged_prob=0.1, seed=42):
    random.seed(seed)
    torch.manual_seed(seed)

    # Extract sequences and determine the max length in this batch
    sequences = [item['sequence'] for item in batch]
    max_len = max(len(seq) for seq in sequences)

    masked_batch = []
    for item in batch:
        sequence = item['sequence']
        seq_len = len(sequence)
        
        # Decide which positions to mask
        num_masked = int(mask_ratio * seq_len)
        mask_indices = random.sample(range(seq_len), num_masked)
        
        masked_sequence = list(sequence)  # Convert to list for mutability
        
        for idx in mask_indices:
            choice = random.choices(
                ['mask', 'mutate', 'unchanged'], 
                weights=[mask_prob, mutate_prob, unchanged_prob]
            )[0]
            
            if choice == 'mask':
                masked_sequence[idx] = '[MASK]'  # Use [MASK] token, or any placeholder for masked
            elif choice == 'mutate':
                masked_sequence[idx] = random.choice('ACDEFGHIKLMNPQRSTVWY')  # Random amino acid (mutation)
            # 'unchanged' leaves it as is

        # Update item with masked sequence
        item['masked_sequence'] = ''.join(masked_sequence)

    # Pad sequences and masked sequences to max_len
    for item in batch:
        sequence = item['masked_sequence']
        padding_len = max_len - len(sequence)
        padded_sequence = sequence + ' ' * padding_len  # Use a suitable padding token here
        item['padded_sequence'] = padded_sequence
    
    # Return the batch in a way compatible with DataLoader
    return {
        'padded_sequences': torch.tensor([list(item['padded_sequence']) for item in batch]),
        'taxon_id': [item['taxon_id'] for item in batch],
        'primary_accession': [item['primary_accession'] for item in batch]
    }


In [43]:
# Initialize the TaxonBatchSampler with your dataset
batch_size = 50
dataset = ProteinDataset(data = df)
sampler = TaxonIdSampler(dataset, batch_size=batch_size, shuffle=True)

dataloader = DataLoader(dataset, batch_sampler=sampler)

# Iterate over the DataLoader
#for batch in dataloader:
 #   print(batch['sequence'])

In [76]:
def get_seq_rep(results, batch_lens):
    """
    Get sequence representations from esm_compute
    """
    token_representations = results["representations"][33]
 
    # Generate per-sequence representations via averaging
    sequence_representations = []
    for i, tokens_len in enumerate(batch_lens):
        sequence_representations.append(token_representations[i, 1: tokens_len - 1].mean(0))
 
    return sequence_representations
 
 
def get_logits(results):
    """
    Get logits from esm_compute
    """
    logits = results["logits"]
    return logits

In [98]:
seqs = ["AGAVCTGAKLI", "AGHRFLIKLKI"]
results, batch_lens, batch_labels, alphabet = esm_compute(seqs)



### First logits computation

In [100]:
LOGITS = []
REPRESENTATIOS = []
for batch in dataloader:
    #try:
        sequences = batch['sequence']
        names = batch['primary_accession']

        results, batch_lens, batch_labels, alphabet = esm_compute(seqs = sequences, names = names )
        
        sequence_representations = get_seq_rep(results, batch_lens)

        for i, x in enumerate(sequence_representations):
            torch.save(x, f"../demo/demo_results/{names[i]}.pt")

   # except:
        pass
    # Print or process the data
    # print("Sequences:", seqs)
    # print("Names:", names)

In [105]:
tensors = [torch.load('reps/'+f) for f in os.listdir('reps') if f.endswith('.pt')]



  tensors = [torch.load('reps/'+f) for f in os.listdir('reps') if f.endswith('.pt')]


In [66]:
import pandas as pd
from pandas import json_normalize
import json

with open('../data/processed/uniref100_test.json', 'r') as file :
    json_data = json.load(file)


# jsondb = json_data['results']
# jsondb[23]

jsondb = pd.read_json(json.dumps(json_data))

jsondb = get_taxon_sequence_data(jsondb)

datasetjson = ProteinDataset(jsondb)
sampler = TaxonIdSampler(batch_size= 5, dataset=datasetjson, shuffle= True)
dataloader = DataLoader(dataset = datasetjson, sampler = sampler)

for batch in dataloader:
    print(batch)

[{'primaryAccession': ['Q17335'], 'taxonId': tensor([6239]), 'value': ['MSSTAGQVINCKAAVAWSAKAPLSIETIQVAPPKAHEVRVKILYTAVCHTDAYTLDGHDPEGLFPVVLGHEGSGIVESVGEGVTGFAPGDHVVPLYVPQCKECEYCKNPKTNLCQKIRISQGNGFMPDGSSRFTCNGKQLFHFMGCSTFSEYTVVADISLCKVNPEAPLEKVSLLGCGISTGYGAVLNTCKVEEGSTVAVWGLGAVGLAVIMGAKAAGAKKIVGIDLIESKFESAKFFGATECINPKSVELPEGKSFQAWLVEQFDGGFDYTFECIGNVHTMRQALEAAHKGWGVSCIIGVAGAGQEIATRPFQLVTGRTWKGTAFGGWKSVESVPRLVDDYMNKKLLIDEFITHRWNIDDINTAFDVLHKGESLRSVLAFEKI'], 'length': tensor([384])}]
[{'primaryAccession': ['C8ZHD6'], 'taxonId': tensor([643680]), 'value': ['MSKGKVLLVLYEGGKHAEEQEKLLGCIENELGIRNFIEEQGYELVTTIDKDPEPTSTVDRELKDAEIVITTPFFPAYISRNRIAEAPNLKLCVTAGVGSDHVDLEAANERKITVTEVTGSNVVSVAEHVMATILVLIRNYNGGHQQAINGEWDIAGVAKNEYDLEDKIISTVGAGRIGYRVLERLVAFNPKKLLYYDYQELPAEAINRLNEASKLFNGRGDIVQRVEKLEDMVAQSDVVTINCPLHKDSRGLFNKKLISHMKDGAYLVNTARGAICVAEDVAEAVKSGKLAGYGGDVWDKQPAPKDHPWRTMDNKDHVGNAMTVHISGTSLDAQKRYAQGVKNILNSYFSKKFDYRPQDIIVQNGSYATRAYGQKK'], 'length': tensor([376])}]
[{'primaryAccession': ['Q03134'], 