In [94]:
from transformers import GPT2Tokenizer
from torch.utils.data import DataLoader
import torch
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.add_special_tokens({
    "pad_token": "[PAD]",
    "unk_token": "[UNK]",
    "bos_token": "[BOS]",
    "eos_token": "[EOS]",
})

4

In [96]:
from datasets import load_dataset

dataset = load_dataset("Skylion007/openwebtext", streaming=True, split="train")

In [97]:
def process_batch(batch, max_length=512):
    processed = {"input_ids": [], "attention_mask": []}
    
    for text in batch["text"]:
        encoded = tokenizer.encode(
            tokenizer.bos_token + text + tokenizer.eos_token,
            max_length=max_length,
            truncation=True,
            padding="max_length",
            return_tensors="pt"
        )
        
        attention_mask = (encoded != tokenizer.pad_token_id).int()
        
        processed["input_ids"].append(encoded)
        processed["attention_mask"].append(attention_mask)
    
    return processed


In [98]:
processed_dataset = dataset.map(
    process_batch,
    batched=True,
    batch_size=100,
    remove_columns=["text"]
)

In [100]:
def collate_fn(batch):
    return {
        "input_ids": torch.stack([item["input_ids"] for item in batch]),
        "attention_mask": torch.stack([item["attention_mask"] for item in batch])
    }
shuffled_dataset = processed_dataset.shuffle(buffer_size=1000)

dataloader = DataLoader(
    shuffled_dataset,
    batch_size=32,
    collate_fn=collate_fn,
    num_workers=4,
    pin_memory=True
)