# Imports

In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
import random
import time
from functools import partial
import json

# Sampler

## Basic sampler

In [2]:
def basic_rand_sampler(seq, sample_len):
    """
    Basic random text sampler.
    If sample_len is greater than the length of the seq, the seq is returned.
    """
    seq_len   = len(seq)
    if seq_len > sample_len:
        start_idx = random.randint(0, min(seq_len,seq_len - sample_len))
        end_idx   = start_idx+sample_len
        return seq[start_idx:end_idx]
    else:
        return seq

In [3]:
text = "ABC DEF GHI JKL!"

In [4]:
[basic_rand_sampler(text, 8) for i in range(6)]

['F GHI JK', 'C DEF GH', 'ABC DEF ', 'C DEF GH', 'GHI JKL!', 'F GHI JK']

In [5]:
[basic_rand_sampler(text, 200) for i in range(3)]

['ABC DEF GHI JKL!', 'ABC DEF GHI JKL!', 'ABC DEF GHI JKL!']

## Identity sampler

In [6]:
identity_sampler = lambda x: x

In [7]:
assert text == identity_sampler(text)

# Tokenizer

## Basic aminoacid tokenizer

In [8]:
def basic_aa_tokenizer(seq, context_length, return_mask=True):
    """
    Maps a number between 0 and 21 to each 21 proteogenic aminoacids.
    Unknown char input gets mapped to 22.
    """
    aa = "ACDEFGHIKLMNOPQRSTUVWY"
    d = {a: i for i, a in enumerate(aa)}
    seq_len = len(seq)
    seq_empty = torch.zeros(context_length - len(seq))
    seq_tok   = torch.tensor([d[a] if a in aa else 22 for a in seq])
    seq = torch.cat([seq_tok, seq_empty], dim=0)#.unsqueeze(0)
    if return_mask:
        mask = torch.zeros_like(seq).bool()
        mask[0:seq_len+1] = True
        return seq, mask
    else:
        return seq

In [9]:
aa_seq = "ACDEFGHIKLMNOPQRSTUVWYZZZ"

In [10]:
basic_aa_tokenizer(aa_seq, 30)

(tensor([ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12., 13.,
         14., 15., 16., 17., 18., 19., 20., 21., 22., 22., 22.,  0.,  0.,  0.,
          0.,  0.]),
 tensor([ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True, False, False, False, False]))

## Text tokenizer

In [11]:
from simple_tokenizer import tokenize

In [12]:
tokenize(text, context_length=30, return_mask=True)

(tensor([[ 5334, 11649, 22279,    73, 14134,   256,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0]]),
 tensor([[ True,  True,  True,  True,  True,  True, False, False, False, False,
          False, False, False, False, False, False, False, False, False, False,
          False, False, False, False, False, False, False, False, False, False]]))

In [13]:
t, m = tokenize(text, context_length=30, return_mask=True)

# Dataset

In [33]:
class CLASPDataset(Dataset):
    """
    Basic CLASP dataset that loads the preprocessed csv file into RAM.
        path: path to the csv file
    """
    def __init__(self, path, text_sampler, bioseq_sampler, text_tok, bioseq_tok):
        super().__init__()
        
        self.path = path

        tp = time.time()
        with open(path, "r") as reader:
            self.data = reader.readlines()
        print(f"Load data time: {time.time() - tp:.3f} s")

        self.cols = self.data.pop(0).split(",")
        self.len  = len(self.data)

        self.text_sampler   = text_sampler
        self.bioseq_sampler = bioseq_sampler

        self.text_tok   = text_tok
        self.bioseq_tok = bioseq_tok

    def __len__(self):
        return self.len

    def __getitem__(self, idx):
        sample = self.data[idx][:-2] # without "\n"
        sample = sample.split(",")
        sample = [x for x in sample if len(x) > 0]

        text   = " ".join(sample[:-2])
        bioseq = sample[-1]

        text   = self.text_sampler(text)
        bioseq = self.bioseq_sampler(bioseq)
        
#        print(text)
#        print(bioseq)

        text, text_mask = self.text_tok(text)
        bioseq, bioseq_mask = self.bioseq_tok(bioseq)

        return text, text_mask, bioseq, bioseq_mask

In [34]:
str_sampler = partial(basic_rand_sampler, sample_len=100)
text_tok    = partial(tokenize, context_length=120, return_mask=True)
bioseq_tok  = partial(basic_aa_tokenizer, context_length=120, return_mask=True)

In [35]:
ds = CLASPDataset(path="uniprot_100_reduced.csv",
                  text_sampler=str_sampler,
                  bioseq_sampler=str_sampler,
                  text_tok=text_tok,
                  bioseq_tok=bioseq_tok)

Load data time: 0.002 s


In [36]:
ds[0]

(tensor([[  331, 17650,  8729,  1095,  7549, 17015,  3683,   537, 14909,  6321,
          18001,   539,  8368,  1160,   675,  9460,  8741,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0]]),
 tensor([[ True,  True,  True,  True,

# Dataloader

In [21]:
dl = DataLoader(ds, 32)

In [22]:
batch = next(iter(dl))

In [23]:
[b.shape for b in batch]

[torch.Size([32, 1, 120]),
 torch.Size([32, 1, 120]),
 torch.Size([32, 120]),
 torch.Size([32, 120])]

# CLASPRankSplitDataset

In [14]:
path_offset_dict = '/home/mmp/hdd1/ProTexCLIP/uniprot_sprot_offset_dict.json'

In [15]:
with open(path_offset_dict, "r", encoding='utf-8') as data_file:    
    offset_dict = json.load(data_file)

In [16]:
len(offset_dict.keys())

564278

In [17]:
file_path = "/home/mmp/hdd1/ProTexCLIP/uniprot_sprot.csv"

In [18]:
class RankSplitDataset(Dataset):
    def __init__(self, file_path, offset_dict, rank, world_size):
        self.file_path        = file_path
        self.offset_dict      = offset_dict
        self.total_len        = len(offset_dict.keys())
        self.rank_len         = self.total_len // world_size
        self.rank_line_offset = self.rank_len * rank
        self.rank_byte_offset = self.offset_dict[str(self.rank_line_offset)] # because json keys are strings after it is saved
        
        print(f"rank: {rank:<5}")
        print(f"total len: {self.total_len}")
        print(f"rank len: {self.rank_len}")
        print(f"rank line offset: {self.rank_line_offset}")
        print(f"rank byte offset: {self.rank_byte_offset}")
        
        tp = time.time()
        with open(self.file_path, 'r', encoding='utf-8') as f:
            f.seek(self.rank_byte_offset) # move to the line for the specific rank
            lines = []
            for i in range(self.rank_len): # load all the lines for the rank
                lines.append(f.readline())
        print(f"dataset load data time: {time.time() - tp:.3f} s")
        
        self.data = lines
        print(f"dataset len: {len(self.data)}")

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx]

In [19]:
class CLASPRankSplitDataset(RankSplitDataset):
    """
    CLASP rank split dataset that loads equally sized pieces for each rank
    of the preprocessed csv file into RAM.
        path: path to the csv file
    """
    def __init__(self, file_path, offset_dict, rank, world_size,
                 text_sampler, bioseq_sampler, text_tok, bioseq_tok):
        super().__init__(file_path, offset_dict, rank, world_size)
        
        self.text_sampler   = text_sampler
        self.bioseq_sampler = bioseq_sampler

        self.text_tok   = text_tok
        self.bioseq_tok = bioseq_tok

    def __getitem__(self, idx):
        sample = self.data[idx][:-2] # without "\n"
        sample = sample.split(",")
        sample = [x for x in sample if len(x) > 0]

        text   = " ".join(sample[:-2])
        bioseq = sample[-1]

        text   = self.text_sampler(text)
        bioseq = self.bioseq_sampler(bioseq)

        text, text_mask = self.text_tok(text)
        bioseq, bioseq_mask = self.bioseq_tok(bioseq)

        return text, text_mask, bioseq, bioseq_mask

In [20]:
str_sampler = partial(basic_rand_sampler, sample_len=100)
text_tok    = partial(tokenize, context_length=120, return_mask=True)
bioseq_tok  = partial(basic_aa_tokenizer, context_length=120, return_mask=True)

In [21]:
!free -h

              total        used        free      shared  buff/cache   available
Mem:           31Gi       9.4Gi        17Gi       0.0Ki       4.8Gi        21Gi
Swap:         979Mi       973Mi       6.0Mi


In [22]:
ds1 = CLASPRankSplitDataset(file_path=file_path,
                           offset_dict=offset_dict,
                           rank=0,
                           world_size=2,
                           text_sampler=str_sampler,
                           bioseq_sampler=str_sampler,
                           text_tok=text_tok,
                           bioseq_tok=bioseq_tok)

total len: 564278
rank len: 282139
rank line offset: 0
rank byte offset: 118
dataset load data time: 1.423 s
dataset len: 282139


In [23]:
!free -h

              total        used        free      shared  buff/cache   available
Mem:           31Gi        10Gi        15Gi       0.0Ki       4.8Gi        19Gi
Swap:         979Mi       973Mi       6.0Mi


In [24]:
ds2 = CLASPRankSplitDataset(file_path=file_path,
                           offset_dict=offset_dict,
                           rank=1,
                           world_size=2,
                           text_sampler=str_sampler,
                           bioseq_sampler=str_sampler,
                           text_tok=text_tok,
                           bioseq_tok=bioseq_tok)

total len: 564278
rank len: 282139
rank line offset: 282139
rank byte offset: 1520906478
dataset load data time: 1.221 s
dataset len: 282139


In [25]:
!free -h

              total        used        free      shared  buff/cache   available
Mem:           31Gi        12Gi        14Gi       0.0Ki       4.8Gi        18Gi
Swap:         979Mi       973Mi       6.0Mi


In [26]:
ds1[0] == ds2[0]

RuntimeError: Boolean value of Tensor with more than one value is ambiguous

In [27]:
ds1[0]

(tensor([[  278,   271,   268,   279,   275,   263,   273,   271,   271,   275,
           1818,    12, 43499,  8093,   281, 48765, 26769,   269,   346, 10077,
            281,   271,   271,   271,   271,   274,   271,   276,    92,   269,
           3692,   331,   282,   551,   276,   275,   279,   275,   279,   275,
            282, 35004,   271,   280,   277,   277,   271,   269,   272,   282,
             12,   282,   326,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0]]),
 tensor([[ True,  True,  True,  True,

In [28]:
ds2[0]

(tensor([[ 1735,   282,  1783,  4853,    78,  9429,  1160,   320,   281,   328,
            269, 18833, 17640,   537,   841,  2343,   324,  8902, 18650,   282,
          25693,  1029, 27751,   934,   269,   274,   277,   281,   273,   274,
            277,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0]]),
 tensor([[ True,  True,  True,  True,

# End