In [3]:
import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import tiktoken

In [8]:
class SpamDataset(Dataset):
    def __init__(self, csv_file, tokenizer, max_length=None, pad_token_id=50256):
        self.data = pd.read_csv(csv_file)
        self.tokenizer = tokenizer
        self.encoded_texts = [self.tokenizer.encode(text) for text in self.data["Text"]]

        if max_length is None:
            self.max_length = self._longest_encoded_length()
        else:
            self.max_length = max_length

        self.encoded_texts = [text[:self.max_length] for text in self.encoded_texts]
        self.encoded_texts = [text + [pad_token_id] * (self.max_length - len(text)) for text in self.encoded_texts]

    def __getitem__(self, idx):
        encoded = self.encoded_texts[idx]
        label = self.data.iloc[idx]["Label"]
        return (torch.tensor(encoded, dtype=torch.long), torch.tensor(label, dtype=torch.long))
    
    def __len__(self):
        return len(self.data)
    
    def _longest_encoded_length(self):
        max_length = 0
        for encoded_text in self.encoded_texts:
            if len(encoded_text) > max_length:
                max_length = len(encoded_text)
        return max_length
    
    def _pad_encoded_text(self, encoded_text, max_length, pad_token_id):
        return encoded_text + [pad_token_id] * (max_length - len(encoded_text))
    
    def _encode_text(self, text):
        return self.tokenizer.encode(text)
    

In [10]:
tokenizer = tiktoken.get_encoding("gpt2")
train_dataset = SpamDataset(csv_file="train.csv", tokenizer=tokenizer, max_length=None)
train_dataset.max_length

120