In [1]:
import pandas as pd
import torch
from torch.utils.data import Dataset

class AntigenTCRDataset(Dataset):
    def __init__(self, csv_file, total_max_length, CLS='[CLS]', SEP='[SEP]', PAD='[PAD]', UNK='[UNK]'):
        # Read the CSV file
        self.data = pd.read_csv(csv_file)
        
        # Special tokens
        self.CLS = CLS
        self.SEP = SEP
        self.PAD = PAD
        self.UNK = UNK

        # Encoding dictionary for amino acids and special tokens
        self.encodings = {
            f"{self.PAD}": 0, f"{self.CLS}": 1, f"{self.SEP}": 2, "[MASK]": 3,
            f"{self.UNK}": 4, "L": 5, "W": 6, "H": 7, "N": 8, "R": 9, "S": 10,
            "M": 11, "D": 12, "A": 13, "Q": 14, "C": 15, "F": 16, "V": 17,
            "K": 18, "G": 19, "I": 20, "E": 21, "Y": 22, "P": 23, "T": 24
        }

        # Process and pad the sequences
        self.data['combined_seqs'] = self.data.apply(lambda row: self.pad_combined_sequences(row['antigen'], row['TCR'], total_max_length), axis=1)

        # Convert to input tokens and attention masks
        self.input_tokens = [self.sequence_to_input_tokens(seq, self.encodings) for seq in self.data['combined_seqs']]
        self.attention_masks = [self.create_attention_mask(tokens, self.encodings) for tokens in self.input_tokens]

    def __len__(self):
        return len(self.data)

    def separate_aa(self, sequence):
        return ' '.join(sequence)

    def pad_combined_sequences(self, antigen_sequence, tcr_sequence, total_max_length):
        # Separate amino acids in each sequence
        separated_antigen = self.separate_aa(antigen_sequence)
        separated_tcr = self.separate_aa(tcr_sequence)

        # Combine sequences
        combined = f'{self.CLS} {separated_antigen} {self.SEP} {separated_tcr}'
        combined_length = len(combined.replace(' ', ''))  # Count characters excluding spaces

        # Calculate the needed padding and apply it
        padding_length = total_max_length - combined_length
        padding = ' '.join([self.PAD] * padding_length)
        if padding:  # Add a leading space if padding is not empty
            padding = ' ' + padding

        combined += padding
        return combined

    def sequence_to_input_tokens(self, sequence, encodings):
        return [encodings.get(elem, encodings[self.UNK]) for elem in sequence.split()]

    def create_attention_mask(self, input_tokens, encodings):
        return [1 if token != encodings[self.PAD] else 0 for token in input_tokens]

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        sample = {
            'antigen': self.data.iloc[idx]['antigen'],
            'tcr': self.data.iloc[idx]['TCR'],
            'combined_sequences': self.data.iloc[idx]['combined_seqs'],
            'input_ids': torch.tensor(self.input_tokens[idx]),
            'attention_mask': torch.tensor(self.attention_masks[idx])
        }

        return sample

# Example usage
# Note: You need to specify 'total_max_length' based on your data
dataset = AntigenTCRDataset('../data/data_balanced.csv', total_max_length=50)


In [2]:
# {f"{self.PAD}": 0, f"{self.CLS}": 1, f"{self.SEP}": 2, "[MASK]": 3, 
#                             f"{self.UNK}": 4, "L": 5, "W": 6, "H": 7, "N": 8, "R": 9, "S": 10, 
#                             "M": 11, "D": 12, "A": 13, "Q": 14, "C": 15, "F": 16, "V": 17, 
#                             "K": 18, "G": 19, "I": 20, "E": 21, "Y": 22, "P": 23, "T": 24}

In [3]:
dataset[0]

{'antigen': 'LLWNGPMAV',
 'tcr': 'CASSPIGGATDTQYF',
 'combined_sequences': '[CLS] L L W N G P M A V [SEP] C A S S P I G G A T D T Q Y F [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]',
 'input_ids': tensor([ 1,  5,  5,  6,  8, 19, 23, 11, 13, 17,  2, 15, 13, 10, 10, 23, 20, 19,
         19, 13, 24, 12, 24, 14, 22, 16,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0]),
 'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])}

In [4]:
for i in range(10):
    print(f"input_ids: {len(dataset[i]['input_ids'])}, attention_mask: {len(dataset[i]['attention_mask'])}")

input_ids: 42, attention_mask: 42
input_ids: 42, attention_mask: 42
input_ids: 42, attention_mask: 42
input_ids: 42, attention_mask: 42
input_ids: 42, attention_mask: 42
input_ids: 42, attention_mask: 42
input_ids: 42, attention_mask: 42
input_ids: 42, attention_mask: 42
input_ids: 42, attention_mask: 42
input_ids: 42, attention_mask: 42
