In [5]:
# !pip install transformers, AutoTokenizer, torch

In [6]:
from datasets import load_dataset
from transformers import AutoTokenizer
from torch.utils.data import DataLoader, IterableDataset
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [7]:
stream_dataset = load_dataset(
    "roneneldan/TinyStories", 
    split="train", 
    streaming=True
)

print("Dataset loaded in streaming mode")
# Preview first example
for example in stream_dataset.take(1):
    print(f"Sample text: {example['text'][:200]}...")

Dataset loaded in streaming mode
Sample text: One day, a little girl named Lily found a needle in her room. She knew it was difficult to play with it because it was sharp. Lily wanted to share the needle with her mom, so she could sew a button on...


In [8]:
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

print(f"Tokenizer initialized. Vocab size: {len(tokenizer)}")

Tokenizer initialized. Vocab size: 50257


In [9]:
def tokenize_function(examples):
    return tokenizer(examples["text"])

# Apply tokenization lazily to streaming dataset
tokenized_stream = stream_dataset.map(tokenize_function, batched=True)

print("Tokenization mapping applied")

Tokenization mapping applied


In [10]:
block_size = 128

def group_texts_streaming(dataset_iter, block_size):
    buffer = []
    for example in dataset_iter:
        buffer.extend(example["input_ids"])
        while len(buffer) >= block_size:
            chunk = buffer[:block_size]
            buffer = buffer[block_size:]
            yield {
                "input_ids": chunk,
                "attention_mask": [1] * block_size
            }

print(f"Block size set to: {block_size}")

Block size set to: 128


In [11]:
class StreamingLMIterableDataset(IterableDataset):
    def __init__(self, hf_iterable_dataset, block_size):
        self.dataset = hf_iterable_dataset
        self.block_size = block_size

    def __iter__(self):
        return group_texts_streaming(self.dataset, self.block_size)

grouped_iterable_dataset = StreamingLMIterableDataset(tokenized_stream, block_size)

print("Streaming dataset wrapper created")

Streaming dataset wrapper created


In [12]:
def collate_fn(batch):
    input_ids = torch.tensor([ex["input_ids"] for ex in batch], dtype=torch.long)
    attention_mask = torch.tensor([ex["attention_mask"] for ex in batch], dtype=torch.long)
    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": input_ids.clone()
    }

print("Collate function defined")

Collate function defined


In [13]:
train_loader = DataLoader(
    grouped_iterable_dataset, 
    batch_size=8, 
    collate_fn=collate_fn
)

print("DataLoader created with batch_size=8")

DataLoader created with batch_size=8


In [14]:
print("Sample streaming batches:")
for i, batch in enumerate(train_loader):
    print(f"Batch {i} -> input_ids shape: {batch['input_ids'].shape}, "
          f"labels shape: {batch['labels'].shape}")
    print(f"Sample tokens: {batch['input_ids'][0][:10]}")
    if i == 2:
        break

Sample streaming batches:
Batch 0 -> input_ids shape: torch.Size([8, 128]), labels shape: torch.Size([8, 128])
Sample tokens: tensor([ 3198,  1110,    11,   257,  1310,  2576,  3706, 20037,  1043,   257])
Batch 1 -> input_ids shape: torch.Size([8, 128]), labels shape: torch.Size([8, 128])
Sample tokens: tensor([  340,   257,  1263, 16225,    13,   383,  7586, 34681,   461,   373])
Batch 2 -> input_ids shape: torch.Size([8, 128]), labels shape: torch.Size([8, 128])
Sample tokens: tensor([   13,   679,  1234,   262, 21613,   287,   465,  6877,   290,  1718])
