DATALOADER

En este script, se detalla el dataloader para preparar el training set. EL training set de ejemplo esta en el archivo train_old.csv
De dicho archivo, se utiliza la columna 'mhc' y 'peptide' concatenadas como input y el target esta compuesto por las columnas 'label' y 'masslabel'. Tambien se utiliza un la función collate_fn de pytorch para asegurar el mismo tamaño de los inputs.



In [20]:

from typing import Union, List, Tuple, Sequence, Dict, Any, Optional, Collection, Mapping
from pathlib import Path
from tape.tokenizers import TAPETokenizer
from tape.datasets import pad_sequences as tape_pad
from torch.utils.data import Dataset
import numpy as np
import pandas as pd
import torch

from torch.utils.data import DataLoader

TRAINSET

In [8]:
class CSVDataset(Dataset):
    def __init__(self,
                 data_file: Union[str, Path, pd.DataFrame],
                 max_pep_len=30,
                 train: bool = True):
        if isinstance(data_file, pd.DataFrame):
            data = data_file
        else:
            data = pd.read_csv(data_file)
        mhc = data['mhc']
        self.mhc = mhc.values
        peptide = data['peptide']
        peptide = peptide.apply(lambda x: x[:max_pep_len])
        self.peptide = peptide.values
        if not train:
            data['label'] = np.nan
            data['masslabel'] = np.nan
        if 'masslabel' not in data and 'label' not in data:
            raise ValueError("missing label.")
        if 'masslabel' not in data:
            data['masslabel'] = np.nan
        if 'label' not in data:
            data['label'] = np.nan

        ###########################################################################################################
        ##### el target esta compuesto por el label(float) y masslabel(int) #######################################
        self.targets = np.stack([data['label'], data['masslabel']], axis=1)
        self.data = data
        if 'instance_weights' in data:
            self.instance_weights = data['instance_weights'].values
        else:
            self.instance_weights = np.ones(data.shape[0],)

    def __len__(self) -> int:
        return len(self.mhc)

    def __getitem__(self, index: int):
        ###########################################################################################################
        ##### aqui concatena el MHC con el peptido para que todo eso sea el input #################################
        seq = self.mhc[index] + self.peptide[index]
        return {
            "id": str(index),
            "primary": seq,
            "protein_length": len(seq),
            "targets": self.targets[index],
            "instance_weights": self.instance_weights[index]}

In [14]:
class BertDataset(Dataset):
    ''' Load data for pretrained Bert model, implemented in TAPE
    '''

    def __init__(self,
                 input_file,
                 tokenizer: Union[str, TAPETokenizer] = 'iupac',
                 max_pep_len=30,
                 in_memory: bool = False,
                 instance_weight: bool = False,
                 train: bool = True):
        if isinstance(tokenizer, str):
            tokenizer = TAPETokenizer(vocab=tokenizer)
        self.tokenizer = tokenizer
        self.data = CSVDataset(input_file,
                               max_pep_len=max_pep_len,
                               train=train)
        self.instance_weight = instance_weight

    def __len__(self) -> int:
        return len(self.data)

    def __getitem__(self, index: int):
        item = self.data[index]
        #print(item['primary']) # input
        #print(item['targets']) # target
        token_ids = self.tokenizer.encode(item['primary'])
        input_mask = np.ones_like(token_ids)
        ret = {'input_ids': token_ids,
               'input_mask': input_mask,
               'targets': item['targets']}
        if self.instance_weight:
            ret['instance_weights'] = item['instance_weights']
        return ret

    def collate_fn(self, batch) -> Dict[str, torch.Tensor]:
        elem = batch[0]
        batch = {key: [d[key] for d in batch] for key in elem}
        input_ids = torch.from_numpy(tape_pad(batch['input_ids'], 0))
        input_mask = torch.from_numpy(tape_pad(batch['input_mask'], 0))
        tmp = np.array(batch['targets'])
        #targets = torch.tensor(batch['targets'], dtype=torch.float32)
        targets = torch.tensor(tmp, dtype=torch.float32)
        ret = {'input_ids': input_ids,
               'input_mask': input_mask,
               'targets': targets}
        if self.instance_weight:
            instance_weights = torch.tensor(batch['instance_weights'],
                                            dtype=torch.float32)
            ret['instance_weights'] = instance_weights
        return ret


In [19]:
trainset = BertDataset('../tests/data/train_old.csv', max_pep_len=24, instance_weight=False)
first_sample = trainset[0] 
print(first_sample['input_ids']) # indices del one-hot encoding
print(first_sample['input_mask'])
print(first_sample['targets']) 

[ 2  9 22 11 22 12  7  9 22 28 13 23 21 23 12 16 13 23 28  8 12 25 11 22
 19 12 13 23 27  8 27 26 17 12 22 13 12 15 25 13 12 21 13 21 23 15 13 11
 20  9  3]
[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
[0.75609 1.     ]


TRAINDATA

In [26]:
train_data = DataLoader(        trainset,
                                batch_size=4,
                                #shuffle=True,
                                num_workers=16,
                                pin_memory=True,
                                collate_fn=trainset.collate_fn)


print(next(iter(train_data))) # obtenemos el primer batch

{'input_ids': tensor([[ 2,  9, 22, 11, 22, 12,  7,  9, 22, 28, 13, 23, 21, 23, 12, 16, 13, 23,
         28,  8, 12, 25, 11, 22, 19, 12, 13, 23, 27,  8, 27, 26, 17, 12, 22, 13,
         12, 15, 25, 13, 12, 21, 13, 21, 23, 15, 13, 11, 20,  9,  3,  0,  0,  0,
          0,  0],
        [ 2,  9, 22, 11, 22, 12,  7,  9, 22, 28, 13, 23, 21, 23, 12, 16, 13, 23,
         28,  8, 12, 25, 11, 22, 19, 12, 13, 23, 27,  8, 27, 26, 17, 12, 22,  5,
          5, 15, 15, 25, 25,  5, 25, 11, 15, 21, 25, 25,  7,  5, 14, 28,  5, 15,
          5,  3],
        [ 2,  9, 22, 11, 22, 12,  7,  9, 22, 28, 13, 23, 21, 23, 12, 16, 13, 23,
         28,  8, 12, 25, 11, 22, 19, 12, 13, 23, 27,  8, 27, 26, 17, 12, 22, 15,
         22, 25, 23,  9, 20, 22,  9, 10, 28, 10, 19, 21,  5, 19,  3,  0,  0,  0,
          0,  0],
        [ 2,  9, 22, 11, 22, 12,  7,  9, 22, 28, 13, 23, 21, 23, 12, 16, 13, 23,
         28,  8, 12, 25, 11, 22, 19, 12, 13, 23, 27,  8, 27, 26, 17, 12, 22, 10,
         10, 23,  9, 15,  8, 11, 25, 21, 