# dataloaders

> Fill in a module description here

In [None]:
#| default_exp dataloaders

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
from ProtMamba_ssm.utils import AA_TO_ID
from ProtMamba_ssm.fim import NoFIM, SingleSpanFIM, MultipleSpanFIM
import pickle
import torch
import numpy as np
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from dataclasses import dataclass
from typing import Dict, Sequence


In [None]:
#| export

# Make dataset
class Uniclust30_Dataset(Dataset):
    """
        Dataset class used to import the Uniclust30 folders.
        If `filename` = "encoded_MSAs.pkl", it will load the full dataset.
        If `filename` = "encoded_MSAs_subset.pkl", it will load a small subset of the dataset.
        If `sample` = True, it will sample a random number of sequences from each cluster.
        If `sample` = False, it will load all the sequences from each cluster (and shuffle them).
        To limit the length of the MSAs, set `max_msa_len` to a positive integer.
        If `reverse` = True, it will reverse the sequences with probability 0.5 and move the last token to the front.
        If `scrambling_strategy` = "no-scramble", it will not scramble the sequences and simply concatenate them.
        If `scrambling_strategy` = "OpenAI", it will scramble the sequences using the OpenAI strategy.
        If `scrambling_strategy` = "inpaint", it will scramble the sequences using the inpaint strategy. In this case it will use
        `max_patches` patches and mask `mask_fraction` of the patches.
    """
    _FIM = {"no-scramble": NoFIM, "one_span": SingleSpanFIM, "multiple_span": MultipleSpanFIM}
    _POSIDS = {"none", "1d", "2d"}

    def __init__(self, filename="encoded_MSAs_train.pkl",
                 filepath="/nvme1/common/OpenProteinSet/",
                 sample=False,
                 max_msa_len=-1,
                 reverse=False,
                 seed=42,
                 troubleshoot=False,
                 fim_strategy="no-scramble",
                 max_patches=5,
                 mask_fraction=0.2,
                 always_mask=False,
                 max_position_embeddings=2048,
                 max_seq_position_embeddings=512,
                 add_position_ids="none", ):
        np.random.seed(seed)
        self.path = filepath
        # self.path_clusters = self.path + "OpenProteinSet_uniclust30-filtered/"
        if filename:
            self.dataset = pickle.load(open(self.path + filename, "rb"))
            self.cluster_names = list(self.dataset.keys())
        else:
            self.dataset = None
            self.cluster_names = []
        self.sample = sample
        self.max_msa_len = max_msa_len
        self.reverse = reverse
        self.fim_strategy = fim_strategy
        if fim_strategy in Uniclust30_Dataset._FIM:
            self.fim = Uniclust30_Dataset._FIM[fim_strategy](max_patches=max_patches,
                                                             mask_fraction=mask_fraction,
                                                             always_mask=always_mask,
                                                             add_position_ids=add_position_ids != "none",
                                                             troubleshoot=troubleshoot)
        else:
            raise ValueError(f'Fill in the middle stragy "{fim_strategy}" not recognized.')
        self.max_position_embeddings = max_position_embeddings
        self.max_seq_position_embeddings = max_seq_position_embeddings
        self.add_position_ids = add_position_ids

        self.troubleshoot = troubleshoot

    def __len__(self):
        return len(self.cluster_names)

    def __getitem__(self, idx):
        # get all the sequences in the cluster
        sequences = self.get_sequences(idx)
        # get total number of sequences in the cluster and choose how many to sample
        orig_num_sequences = len(self.get_index_start_of_sequences(sequences))
        num_sequences = np.random.randint(1, orig_num_sequences + 1) if self.sample else orig_num_sequences
        # sample the sequences
        sequences, position_ids = self.sample_sequences(sequences, num_sequences)
        # with probability 0.5, reverse the sequences and move the last token to the front
        sequences, position_ids = self.reverse_sequences(sequences, position_ids) if (
                self.reverse and np.random.rand() > 0.5) else sequences, position_ids
        # limit the length of the MSA
        sequences = sequences[:self.max_msa_len] if self.max_msa_len > 0 else sequences
        if self.add_position_ids != "none":
            position_ids = position_ids[:self.max_msa_len] if self.max_msa_len > 0 else position_ids
        # convert to tensor
        sequences = torch.asarray(sequences, dtype=torch.int64)
        position_ids = torch.asarray(position_ids, dtype=torch.int64).clamp(0,
                                                                            self.max_position_embeddings - 1) if self.add_position_ids!="none" else None

        if self.troubleshoot:
            print(
                f"Cluster {idx} has {orig_num_sequences} sequences, of which {num_sequences} sampled now. Total MSA length: {len(sequences)}")
        if self.add_position_ids == "1d":
            return dict(input_ids=sequences, position_ids=position_ids, labels=sequences)
        if self.add_position_ids == "2d":
            seq_position_ids = (sequences == AA_TO_ID["<cls>"]).int().cumsum(-1).clamp(0,
                                                                                       self.max_seq_position_embeddings - 1).contiguous()
            return dict(input_ids=sequences, position_ids=position_ids, seq_position_ids=seq_position_ids,
                        labels=sequences)
        return dict(input_ids=sequences, labels=sequences)

    def get_sequences(self, idx):
        """Get the sequences in the cluster with index `idx`."""
        cluster_name = self.cluster_names[idx]
        sequences = self.dataset[cluster_name]
        return sequences

    def get_index_start_of_sequences(self, sequences):
        """Get the positions of the start of each sequence in the cluster."""
        return np.where(sequences == 0)[0]

    def reverse_sequences(self, sequence, position_ids=None):
        """Reverse the sequences and move the last token to the front."""
        sequence = sequence[::-1]
        if position_ids is not None:
            position_ids = position_ids[::-1]
        return np.concatenate([sequence[-1:], sequence[:-1]]), np.concatenate(
            [position_ids[-1:], position_ids[:-1]]) if position_ids is not None else None

    def sample_sequences(self, sequences, num_sequences):
        """Sample `num_sequences` from the sequences in the cluster."""
        L = len(sequences)
        # get the indexes of the start of each sequence
        inds = self.get_index_start_of_sequences(sequences)
        # check that there are sequences in the cluster and that there are enough of them
        assert len(inds) > 0, "No sequences found in cluster."
        assert len(inds) >= num_sequences, "Not enough sequences in cluster."
        # sample n_sequences randomly from the sequences
        which_seqs = np.random.choice(np.arange(len(inds)), num_sequences, replace=False)
        # get the tuples of start and end indexes of the sequences
        tuples = [(inds[i], inds[i + 1]) if i < len(inds) - 1 else (inds[i], L) for i in which_seqs]
        if self.troubleshoot:
            print(f"Sampled sequences: {tuples}")
        # concatenate the sequences
        sequences, position_ids = self.fim.apply(sequences, tuples)
        return sequences, position_ids


def make_dataloader(dataset):
    """Basic function to make a dataloader.
    """
    dataloader = DataLoader(dataset)
    return dataloader

In [None]:
#| export

@dataclass
class DataCollatorForUniclust30Dataset(object):
    """
    Collate examples into a batch, and pad batch to the maximum sequence length.
    """

    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
        input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "input_ids")) 
        input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=AA_TO_ID["<pad>"])
        labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100)
        if "seq_position_ids" in instances[0] and "position_ids" in instances[0]:
            position_ids = torch.nn.utils.rnn.pad_sequence(
                [instance["position_ids"] for instance in instances],
                batch_first=True, padding_value=0)
            seq_position_ids = torch.nn.utils.rnn.pad_sequence(
                [instance["seq_position_ids"] for instance in instances],
                batch_first=True, padding_value=0)
            return dict(
                input_ids=input_ids,
                labels=labels,
                position_ids=position_ids,
                seq_position_ids=seq_position_ids,
                attention_mask=input_ids.ne(AA_TO_ID["<pad>"]),
            )

        if "position_ids" in instances[0]:
            position_ids = torch.nn.utils.rnn.pad_sequence(
                [instance["position_ids"] for instance in instances],
                batch_first=True, padding_value=0)
            return dict(
                input_ids=input_ids,
                labels=labels,
                position_ids=position_ids,
                attention_mask=input_ids.ne(AA_TO_ID["<pad>"]),
            )

        return dict(
            input_ids=input_ids,
            labels=labels,
            attention_mask=input_ids.ne(AA_TO_ID["<pad>"]),
        )

In [None]:
import random
from protmamba.utils import MASK_TO_ID, AA_TO_ID

import cProfile
import pstats
import io

def profile(func):
    def wrapper(*args, **kwargs):
        pr = cProfile.Profile()
        pr.enable()
        retval = func(*args, **kwargs)
        pr.disable()
        s = io.StringIO()
        sortby = 'cumulative'
        ps = pstats.Stats(pr, stream=s).sort_stats(sortby)
        ps.print_stats()
        print(s.getvalue())
        return retval
    return wrapper

class ConcatenateSequences:
    def __init__(self, 
                 max_patches=5,
                 mask_fraction=0.2,
                 scrambling_strategy="",
                 mask_tokens=MASK_TO_ID,
                 eos_token=AA_TO_ID["<eos>"],
                 troubleshoot=False):
        """
        This class is designed to concatenate sequences based on different scrambling strategies.
        It takes a list of sequences, tuples indicating the start and end indices of each sequence,
        an optional number of patches to sample, and a scrambling strategy as inputs.
        """
        self.troubleshoot = troubleshoot
        self.max_patches = max_patches
        self.scrambling_strategy = scrambling_strategy
        self.mask_fraction = mask_fraction
        self.mask_tokens = mask_tokens
        assert len(self.mask_tokens)>=self.max_patches, "Number of mask tokens must be bigger than max number of patches."
        self.eos_token = eos_token
        # self.eom_token = -3

    def concatenate(self, sequences, tuples):
        """
        This function concatenates the sequences based on the scrambling strategy.
        """
        if self.scrambling_strategy=="no-scramble":
            return np.concatenate([sequences[slice(t[0],t[1])] for t in tuples])
        # We could remove this, same as max_patches = 1
        if self.scrambling_strategy=="OpenAI":
            return np.concatenate([self.create_and_concatenate_parts_openAI(sequences, t) for t in tuples])
        elif self.scrambling_strategy=="inpaint":
            return np.concatenate([self.create_and_concatenate_parts_inpaint(sequences, t) for t in tuples])

    def split_sequences(self, sequences, t, masked_tuples):
        """
        This function splits the sequences into unmasked and masked parts based on the given tuples.
        Args:
            t (tuple): The start and end index of each sequence.
            masked_tuples (list): A list of tuples specifying the indices for masked regions.
        Returns:
            unmasked_parts (list): The unmasked parts of the sequences interleaved with -1.
            masked_parts (list): The masked parts of the sequences interleaved with -1.
        """
        start, end = t
        while False:
            unmasked_parts, masked_parts = [], []
            for i, region in enumerate(masked_tuples):
                mask_token = self.mask_tokens[f"<mask-{i+1}>"]
                unmasked_parts.extend(sequences[slice(start,region[0])])
                unmasked_parts.append(mask_token)
                masked_parts.append(mask_token)
                masked_parts.extend(sequences[slice(region[0],region[1])])
                start = region[1]
            unmasked_parts.extend(sequences[slice(start,end)])
            if len(masked_tuples) > 0:
                unmasked_parts.append(self.eos_token)
                # masked_parts.append(self.eom_token)
            return unmasked_parts, masked_parts
        while True:
            masked_parts = [elem for tupl in [([self.mask_tokens[f"<mask-{i+1}>"]], sequences[slice(region[0],region[1])]) for i, region in enumerate(masked_tuples)]
                            for subl in tupl for elem in subl]
            unmasked_parts = [elem for tupl in [(sequences[slice(start,masked_tuples[i][0])], [self.mask_tokens[f"<mask-{i+1}>"]]) if i==0
                                                else (sequences[slice(masked_tuples[i-1][1],masked_tuples[i][0])], [self.mask_tokens[f"<mask-{i+1}>"]]) for i in range(len(masked_tuples))]
                            for subl in tupl for elem in subl]
            unmasked_parts_end = [elem for sublst in ((sequences[slice(masked_tuples[-1][1],end)], [self.eos_token]) if len(masked_tuples) > 0 else (sequences[slice(start,end)],[]))
                                for elem in sublst]
            # unmasked_parts.extend(sequences[slice(masked_tuples[-1][1],end)] if len(masked_tuples) > 0 else sequences[slice(start,end)])
            unmasked_parts += unmasked_parts_end
            return unmasked_parts, masked_parts
        while False:
            masked_parts = np.concatenate([
                self.mask_tokens[f"<mask-{i+1}>"] * np.ones((1, region[1] - region[0]), dtype=int)
                for i, region in enumerate(masked_tuples)
            ], axis=1)

            unmasked_parts = np.concatenate([
                sequences[slice(start, masked_tuples[i][0])],
                self.mask_tokens[f"<mask-{i+1}>"] * np.ones((1, masked_tuples[i][0] - start), dtype=int)
            ] if i == 0 else [
                sequences[slice(masked_tuples[i-1][1], masked_tuples[i][0])],
                self.mask_tokens[f"<mask-{i+1}>"] * np.ones((1, masked_tuples[i][0] - masked_tuples[i-1][1]), dtype=int)
            ] for i in range(len(masked_tuples)))

            unmasked_parts_end = np.concatenate([
                sequences[slice(masked_tuples[-1][1], end)],
                np.array([self.eos_token], dtype=int) if len(masked_tuples) > 0 else np.empty((0,), dtype=int)
            ])

            unmasked_parts = np.concatenate([unmasked_parts, unmasked_parts_end])
            print(masked_parts, unmasked_parts)
            return unmasked_parts, masked_parts
        while False:    
            masked_parts = [
            [self.mask_tokens[f"<mask-{i+1}>"]] + sequences[slice(region[0], region[1])]
            for i, region in enumerate(masked_tuples)
            ]
            unmasked_parts = [
                sequences[slice(start, masked_tuples[0][0])] + [self.mask_tokens[f"<mask-{i+1}>"]]
                if i == 0
                else sequences[slice(masked_tuples[i-1][1], region[0])] + [self.mask_tokens[f"<mask-{i+1}>"]]
                for i, region in enumerate(masked_tuples)
            ]
            unmasked_parts_end = [sequences[slice(masked_tuples[-1][1], end)] + [self.eos_token] if masked_tuples else sequences[slice(start, end)]]
            unmasked_parts += unmasked_parts_end
            return [elem for sublist in unmasked_parts for elem in sublist], [elem for sublist in masked_parts for elem in sublist]
   
    def sample_lengths(self, start,end):
        """
        Sample a length uniformly from 1 to max_L*self.mask_fraction (must be bigger than 1).
        If the length is larger than max_L, return max_L.
        """
        max_L = end-start
        # length = np.random.randint(1, max(int(max_L*self.mask_fraction),2))
        length = 1+int(random.random() * (max(int(max_L*self.mask_fraction),2)-1))
        return min(length, max_L)

    def create_and_concatenate_parts_openAI(self, sequences, t):
        """
        This function creates and concatenates parts of the sequences based on the OpenAI scrambling strategy.
        It randomly selects two indices within the range of the given tuple,
        splits the sequence into three parts based on these indices, and then concatenates them with the 
        masked patch at the end
        """
        new_tuple = tuple(np.sort(np.random.choice(np.arange(t[0]+1, t[1]), 2, replace=False)))
        part1 = sequences[t[0]:new_tuple[0]]
        part2 = sequences[new_tuple[0]:new_tuple[1]]
        part3 = sequences[new_tuple[1]:t[1]]
        return np.concatenate([part1, [self.mask_tokens["<mask-1>"]], part3, [self.mask_tokens["<mask-1>"]], part2])

    def create_and_concatenate_parts_inpaint(self, sequences, t):
        """
        This function creates and concatenates parts of the sequences based on the inpaint scrambling strategy.
        It randomly selects `2*self.num_patches` indices within the range of the given tuple,
        splits the sequence into unmasked and masked parts based on these indices, and then concatenates them.
        The concatenation is done by joining all unmaksed parts (interleaved with mask tokens) and afterwards
        all masked parts (interleaved with mask tokens). At the en of the unmasked parts, a special token is added
        to indicate the end of the unmasked parts, and at the end of the masked parts, a special token is added
        to indicate the end of the masked parts.
        """
        ################ DEPRECATED
        # masked_ids = np.sort(np.random.choice(np.arange(t[0]+1, t[1]), 2*self.num_patches, replace=False))
        # masked_tuples = [(masked_ids[2*i], masked_ids[2*i+1]) for i in range(self.num_patches)]
        ################        
        # sample num_patches from a discrete poisson distribution with upper limit max_patches
        num_patches = 1000
        while num_patches > self.max_patches:
            num_patches = np.random.poisson(1)
        
        # sample num_patches starting points for the masked positions (+ final position)
        start_patches = sorted(random.sample(range(t[0]+1, t[1]), num_patches)) + [t[1]]
        # start_patches = list(np.sort(np.random.choice(np.arange(t[0]+1, t[1]),
        #                                               num_patches,
        #                                               replace=False))) + [t[1]]
        
        # sample num_patches lengths of the patches
        # len_patches = self.new_sample_lengths(np.array(start_patches)) 
        len_patches = [self.sample_lengths(start_patches[i],start_patches[i+1])
                       for i in range(len(start_patches)-1)]
        
        # create masked tuples with start and end indices of the patches
        masked_tuples = [(start_patches[i], start_patches[i]+len_patches[i]) for i in range(len(start_patches)-1)]
        # split the sequences into unmasked and masked parts
        unmasked_sequence, masked_sequence = self.split_sequences(sequences, t, masked_tuples)
        if self.troubleshoot:
            print(f"For sequence in {t}: sampled {num_patches=}, {start_patches=}, {len_patches=}, {masked_tuples=}")
        # concatenate the unmasked and masked parts
        return unmasked_sequence + masked_sequence

# Make dataset
class Uniclust30_Dataset_old(Dataset):
    """
        Dataset class used to import the Uniclust30 folders.
        If `filename` = "encoded_MSAs.pkl", it will load the full dataset.
        If `filename` = "encoded_MSAs_subset.pkl", it will load a small subset of the dataset.
        If `sample` = True, it will sample a random number of sequences from each cluster.
        If `sample` = False, it will load all the sequences from each cluster (and shuffle them).
        To limit the length of the MSAs, set `max_msa_len` to a positive integer.
        If `reverse` = True, it will reverse the sequences with probability 0.5 and move the last token to the front.
        If `scrambling_strategy` = "no-scramble", it will not scramble the sequences and simply concatenate them.
        If `scrambling_strategy` = "OpenAI", it will scramble the sequences using the OpenAI strategy.
        If `scrambling_strategy` = "inpaint", it will scramble the sequences using the inpaint strategy. In this case it will use
        `max_patches` patches and mask `mask_fraction` of the patches.
    """
    def __init__(self, filename="encoded_MSAs_train.pkl",
                 filepath="/nvme1/common/OpenProteinSet/",
                 sample=False,
                 max_msa_len=-1,
                 reverse=False,
                 seed=42,
                 troubleshoot=False,
                 fim_strategy="no-scramble",
                 max_patches=5,
                 mask_fraction=0.2):
        np.random.seed(seed)
        self.path = filepath
        # self.path_clusters = self.path + "OpenProteinSet_uniclust30-filtered/"
        self.dataset = pickle.load(open(self.path + filename, "rb"))
        self.cluster_names = list(self.dataset.keys())
        self.sample = sample
        self.max_msa_len = max_msa_len
        self.reverse = reverse
        self.Concatenate = ConcatenateSequences(max_patches=max_patches,
                                                mask_fraction=mask_fraction,
                                                scrambling_strategy=fim_strategy)
        self.troubleshoot = troubleshoot

    def __len__(self):
        return len(self.cluster_names)

    def __getitem__(self, idx):
        # get all the sequences in the cluster
        sequences = self.get_sequences(idx)
        # get total number of sequences in the cluster and choose how many to sample
        orig_num_sequences = self.get_number_of_sequences(sequences)
        num_sequences = np.random.randint(1, orig_num_sequences+1) if self.sample else orig_num_sequences
        # sample the sequences
        sequences = self.sample_sequences(sequences, num_sequences)
        # with probability 0.5, reverse the sequences and move the last token to the front
        sequences = self.reverse_sequences(sequences) if (self.reverse and np.random.rand() > 0.5) else sequences
        # limit the length of the MSA
        sequences = sequences[:self.max_msa_len] if self.max_msa_len > 0 else sequences
        # convert to tensor
        sequences = torch.asarray(sequences, dtype=torch.int64)
        if self.troubleshoot:
            print(f"Cluster {idx} has {orig_num_sequences} sequences, of which {num_sequences} sampled now. Total MSA length: {len(sequences)}")
        return dict(input_ids=sequences, labels=sequences)
    
    def get_sequences(self, idx):
        """Get the sequences in the cluster with index `idx`."""
        cluster_name = self.cluster_names[idx]
        sequences = self.dataset[cluster_name]
        return sequences
       
    def get_index_start_of_sequences(self, sequences):
        """Get the positions of the start of each sequence in the cluster."""
        return np.where(sequences == 0)[0]

    def get_number_of_sequences(self, sequences):
        """Get the number of sequences in the cluster."""
        return len(self.get_index_start_of_sequences(sequences))
    
    def reverse_sequences(self, sequence):
        """Reverse the sequences and move the last token to the front."""
        sequence = sequence[::-1]
        return np.concatenate([sequence[-1:], sequence[:-1]])
    
    # @profile
    def sample_sequences(self, sequences, num_sequences):
        """Sample `num_sequences` from the sequences in the cluster."""
        L = len(sequences)
        # get the indexes of the start of each sequence
        inds = self.get_index_start_of_sequences(sequences)
        # check that there are sequences in the cluster and that there are enough of them
        assert len(inds) > 0, "No sequences found in cluster."
        assert len(inds) >= num_sequences, "Not enough sequences in cluster."
        # sample n_sequences randomly from the sequences
        which_seqs = np.random.choice(np.arange(len(inds)), num_sequences, replace=False)
        # get the tuples of start and end indexes of the sequences
        tuples = [(inds[i],inds[i+1]) if i<len(inds)-1 else (inds[i], L) for i in which_seqs]
        if self.troubleshoot:
            print(f"Sampled sequences: {tuples}")
        # concatenate the sequences
        return self.Concatenate.concatenate(sequences, tuples)

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()