In [20]:
import sys
import os

# Assuming the text_prediction module is in the 'src' directory
module_path = "/Users/aadil/Documents/Workspace/ml-projects/wiki-llm/src/"  # Replace with the correct path to your 'src' directory
if module_path not in sys.path:
    sys.path.append(module_path)

In [96]:
import torch
from torch.utils.data import DataLoader, DistributedSampler, Dataset as TDataset
from torch.nn.utils.rnn import pad_sequence
from datasets import load_dataset, load_from_disk, Dataset
import os
import pyarrow as pa
from torch.utils.data import Dataset as TorchDataset
import logging
import random

from text_prediction.utils import RankFilter

class TokenizedDataset(TorchDataset):
    def __init__(self, tokenized_dataset, block_size, device="cpu"):
        """
        tokenized_examples: List of tokenized sequences (each a list of token IDs).
        block_size: Length of each input sequence.
        """
        self.device = device
        self.block_size = block_size

        self._data = [torch.tensor(example['input_ids'], dtype=torch.long) for example in tokenized_dataset if len(example['input_ids']) > block_size]
        
    def __len__(self):
        return len(self._data) - self.block_size  # Max index to sample from

    def __getitem__(self, idx):
        """
        Return a single sample (input and target sequences)
        """
        example = self._data[idx]
        ix = torch.randint(0, len(example) - self.block_size, (1,)).item()
        input_seq = example[ix:ix + self.block_size]
        label_seq = example[ix + 1:ix + self.block_size + 1]
        return input_seq.to(self.device), label_seq.to(self.device)

class DataPipeline:

    def __init__(self, tokenizer, max_len, block_size, regenerate=False, num_samples=10000, verbose=False, augment_data=False, parent_path="."):
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.block_size = block_size
        self.regenerate = regenerate
        self.num_samples = num_samples
        self.train_dataloader = None
        self.val_dataloader = None
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.verbose = verbose
        self.augment_data = augment_data
        self.parent_path = parent_path

        logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')  # Update logging format
        self.logger = logging.getLogger(__name__)
        self.logger.addFilter(RankFilter(0))
        self.tokenizer.add_special_tokens({"additional_special_tokens": ["<ARTICLE_START>", "<ARTICLE_END>"]})

    def log(self, message, level=logging.INFO):
        if self.verbose:
            self.logger.log(level, message)

    def _tokenize_function(self, examples):
        """Tokenizes the examples, adds special tokens, and optionally augments data."""
        
        # Add special tokens to the beginning and end of each text
        texts_with_special_tokens = [
            "<ARTICLE_START>" + text + "<ARTICLE_END>" for text in examples["text"]
        ]

        tokenized = self.tokenizer(
            texts_with_special_tokens,
            truncation=True,
            max_length=self.max_len,
            add_special_tokens=False,
        )

        if self.augment_data:
            # Add data augmentation logic here
            tokenized["input_ids"] = self._augment(tokenized["input_ids"])

        return tokenized

    def _augment(self, input_ids):
        # Implement data augmentation logic
        # For example, randomly mask some tokens
        augmented = []
        for ids in input_ids:
            if random.random() < 0.1:  # 10% chance to mask a token
                ids[random.randint(0, len(ids) - 1)] = self.tokenizer.mask_token_id
            augmented.append(ids)
        return augmented

    def _get_tokenized_dataset(self, split="train"):
        tokenized_dataset_path = f"{self.parent_path}/data/{self.tokenizer.name_or_path}/wiki/{split}/tokenized_augmented" if self.augment_data else f"{self.parent_path}/data/{self.tokenizer.name_or_path}/wiki/{split}/tokenized"

        if not os.path.exists(tokenized_dataset_path) or self.regenerate:
            self.log("Local cache of dataset not found, downloading and tokenizing dataset...")
            # Load dataset (small subset of num_samples samples)
            ds = load_dataset("wikimedia/wikipedia", "20231101.en", split=split)
            ds = ds.select(range(self.num_samples))
            # Select only the 'text' column
            ds = ds.remove_columns([col for col in ds.column_names if col != "text"])
            # Tokenize the dataset
            ds = ds.map(self._tokenize_function, batched=True)
            ds.save_to_disk(tokenized_dataset_path)
        else:
            self.log("Local cache of dataset found, loading tokenized dataset...")
            ds = load_from_disk(tokenized_dataset_path)
        return ds

    def _is_dataset_valid(self, dataset_path):
        try:
            ds = load_from_disk(dataset_path)
            # Attempt to read a small portion of the dataset to ensure it's valid
            _ = ds[:1]
            return True
        except (pa.lib.ArrowInvalid, FileNotFoundError):
            return False

    @staticmethod
    def custom_collate(batch, tokenizer):
        """
        Custom collate function to stack sequences into a batch.
        """
        inputs, labels = zip(*batch)
        
        # Ensure tokenizer has a valid pad token
        pad_token_id = tokenizer.pad_token_id
        if pad_token_id is None:
            pad_token_id = tokenizer.eos_token_id  # Use EOS token as a fallback
        
        inputs = pad_sequence(inputs, batch_first=True, padding_value=pad_token_id)
        labels = pad_sequence(labels, batch_first=True, padding_value=pad_token_id)    
        return inputs, labels

    def get_dataloader(self, batch_size, shuffle=True):
        # DataLoader with random sampling
        tds = self._get_tokenized_dataset()
        dataset = TokenizedDataset(tds, self.block_size, self.device)
        return DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=shuffle,
            collate_fn=lambda batch: DataPipeline.custom_collate(batch, self.tokenizer),
        )

In [99]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("gpt2")
dp = DataPipeline(tokenizer, max_len=512, block_size=128, regenerate=False, num_samples=10000, verbose=True, augment_data=False, parent_path="..")

In [None]:
train_loader = dp.get_dataloader(32)

# Get a batch
i = 0
for x, y in train_loader:
    print(f"example {i} inputs")
    print(tokenizer.decode(x[0]))
    print(f"example {i} labels")
    print(tokenizer.decode(y[0]))
    i += 1
    if i > 5:
        break

2025-03-19 12:16:07,997 - Local cache of dataset found, loading tokenized dataset...


example 0 inputs
The common law offence of affray was abolished for England and Wales on 1 April 1987. Affray is now a statutory offence that is triable either way. It is created by section 3 of the Public Order Act 1986 which provides:

The term "violence" is defined by section 8.

Section 3(6) once provided that a constable could arrest without warrant anyone he reasonably suspected to be committing affray, but that subsection was repealed by paragraph 26(2) of Schedule 7 to, and Schedule 17 to, the Serious Organised Crime and Police Act 2005, which includes more general provisions for police to make arrests
example 0 labels
 common law offence of affray was abolished for England and Wales on 1 April 1987. Affray is now a statutory offence that is triable either way. It is created by section 3 of the Public Order Act 1986 which provides:

The term "violence" is defined by section 8.

Section 3(6) once provided that a constable could arrest without warrant anyone he reasonably suspect

In [7]:
import text_prediction

ModuleNotFoundError: No module named 'text_prediction'

In [8]:
import sys
!{sys.executable} -m pip install --no-cache-dir text_prediction



In [1]:
import sys
import os

# Assuming the text_prediction module is in the 'src' directory
module_path = "/Users/aadil/Documents/Workspace/ml-projects/wiki-llm/src/"  # Replace with the correct path to your 'src' directory
if module_path not in sys.path:
    sys.path.append(module_path)


In [2]:
from text_prediction.data_pipeline import DataPipeline

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("gpt2")

dp = DataPipeline(
    tokenizer=tokenizer,
    max_len=512,
    block_size=128,
    num_samples=10000,
    verbose=True,
    augment_data=False,
    parent_path=".."
)

In [4]:
dl = dp.get_dataloader(batch_size=32, shuffle=True)

2025-03-18 23:14:19,030 - Local cache of dataset found, loading tokenized dataset...


getting tokenized dataset


In [9]:
for i in range(5):
    x, y = next(iter(dl))
    print(f"Batch {i + 1}")
    print("Input shape:", x.shape)
    print("Label shape:", y.shape)
    print(f"Input sequence: {tokenizer.decode(x[0].tolist())}")
    print(f"Output sequence: {tokenizer.decode(y[0].tolist())}")

Batch 1
Input shape: torch.Size([32, 128])
Label shape: torch.Size([32, 128])
Input sequence: : Army, Navy and Aerospace Force (Title VII, chapter VII, Art. 217)

This is a subtle yet important distinction, both in terms of emphasizing the civil nature of the National Police, but also adapting the national police to function as a paramilitary force which can perform military duties as a result of the Colombian Conflict. This has led to some of the most important police units adopting military training and conducting special operations alongside the Colombian Army, Aerospace Force, and Navy. ThereforeThe history of Colombia includes the settlements and society by indigenous peoples, most notably, the Muisca Confederation, Quimbaya Civilization, and Tairona Chief
Output sequence:  Army, Navy and Aerospace Force (Title VII, chapter VII, Art. 217)

This is a subtle yet important distinction, both in terms of emphasizing the civil nature of the National Police, but also adapting the nationa