In [None]:
from datasets import load_from_disk
import random
from tqdm import tqdm
from transformers import T5Tokenizer

# Load the tokenizer
tokenizer = T5Tokenizer.from_pretrained('t5-small', legacy=False)

# Load the tokenized datasets
tokenized_train_dataset = load_from_disk('tokenized_datasets/train')
tokenized_val_dataset = load_from_disk('tokenized_datasets/val')
tokenized_test_dataset = load_from_disk('tokenized_datasets/test')

# Function to shuffle and select a subset of the dataset with progress bar
def shuffle_and_select(dataset, num_samples, seed=42):
    indices = list(range(len(dataset)))
    random.seed(seed)
    random.shuffle(indices)
    selected_indices = indices[:num_samples]
    subset = dataset.select(selected_indices)
    return subset

# Select a smaller subset of the dataset with progress bar
print("Shuffling and selecting smaller train dataset...")
small_train_dataset = shuffle_and_select(tokenized_train_dataset, 5000)
print("Shuffling and selecting smaller val dataset...")
small_val_dataset = shuffle_and_select(tokenized_val_dataset, 1000)

# Adjust padding for smaller datasets
def adjust_padding(examples, max_length=512):
    # Adjust inputs
    inputs = tokenizer.pad(
        {"input_ids": examples["input_ids"], "attention_mask": examples["attention_mask"]},
        padding="max_length",
        max_length=max_length,
        return_tensors="pt"
    )
    # Adjust labels
    labels = tokenizer.pad(
        {"input_ids": examples["labels"]},
        padding="max_length",
        max_length=max_length,
        return_tensors="pt"
    )
    # Convert tensors to lists
    examples["input_ids"] = inputs["input_ids"].tolist()
    examples["attention_mask"] = inputs["attention_mask"].tolist()
    examples["labels"] = labels["input_ids"].tolist()
    return examples

print("Adjusting padding for smaller train dataset...")
small_train_dataset = small_train_dataset.map(lambda examples: adjust_padding(examples, max_length=512), batched=True)
print("Adjusting padding for smaller val dataset...")
small_val_dataset = small_val_dataset.map(lambda examples: adjust_padding(examples, max_length=512), batched=True)