# Goal
- Train GPT2 on wiki text

## Steps
- Read, download data
- Train tokenizer
- Prepare sliding window data loader
- Use GPT2 model
- Use train/test loop

### Read, download data

In [1]:
from datasets import load_dataset

train_dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
val_dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="validation")
test_dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")


train_dataset

KeyboardInterrupt: 

In [None]:
type(train_dataset['text'][1])

### Train tokenizer

In [None]:

# Read GPT2 tokenizer with padding side left, truncation and max length 1024. Pad token is end of text token.
from transformers import GPT2Tokenizer
wrapped_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
wrapped_tokenizer.pad_token = wrapped_tokenizer.eos_token
wrapped_tokenizer.padding_side = "left"
wrapped_tokenizer.truncation = True
wrapped_tokenizer.max_length = 1024

wrapped_tokenizer.encode("Hello, my dog is cute")



In [None]:
wrapped_tokenizer("<|endoftext|>")['input_ids']

In [None]:
wrapped_tokenizer(["Hello my name is Ajay"])['input_ids']

### Prepare sliding window data loader

In [None]:
import torch 

def slide_window(text_batch, max_length=1024):
    """
    More efficient version of slide_window that leverages batch processing.

    Args:
        text_batch (dict): A dictionary likely containing a 'text' key with a list of strings.
        max_length (int): The maximum sequence length for padding and truncation.

    Returns:
        dict: A dictionary containing 'input_ids', 'attention_mask', and 'output_ids' tensors.
    """
    
    # --- Step 1 & 2: Tokenize, Add EOS, Create Shifted Inputs/Outputs (Raw) ---
    
    all_tokens = []
    input_ids_raw = []
    output_ids_raw = []
    
    eos_token_id = wrapped_tokenizer.convert_tokens_to_ids(wrapped_tokenizer.eos_token)

    # Tokenize texts and prepare raw ID lists
    # Using a loop here is often necessary for the custom EOS + shift logic, 
    # but the expensive padding step will be batched later.
    for text in text_batch['text']:
        tokens = wrapped_tokenizer.tokenize(text)
        token_ids = wrapped_tokenizer.convert_tokens_to_ids(tokens)
        token_ids.append(eos_token_id) # Add EOS ID

        # Create input/output pairs (before padding/truncation)
        current_input_ids = token_ids[:-1]
        current_output_ids = token_ids[1:]
        
        input_ids_raw.append(current_input_ids)
        output_ids_raw.append(current_output_ids)
        all_tokens.append(tokens)

    # --- Step 3: Batch Pad Inputs and Generate Attention Mask ---
    
    # Let the tokenizer handle padding, truncation (via max_length), 
    # attention mask creation, and tensor conversion for the whole batch.
    padded_inputs = wrapped_tokenizer.pad(
        {"input_ids": input_ids_raw},
        padding="max_length",
        max_length=max_length,
        return_tensors="pt",
        return_attention_mask=True, # Ask the tokenizer to create the mask
    )

    # --- Step 4: Batch Pad Outputs ---
    
    # Pad the output sequences. Usually, no attention mask is needed for labels.
    padded_outputs = wrapped_tokenizer.pad(
        {"input_ids": output_ids_raw},
        padding="max_length",
        max_length=max_length,
        return_tensors="pt"
    )

    # --- Step 5: Return the final batch ---
    
    return {
        "input_ids": padded_inputs["input_ids"],
        "attention_mask": padded_inputs["attention_mask"],
        "output_ids": padded_outputs["input_ids"],
        # Optional: Keep raw words/ids if needed for debugging, but remove for efficiency
        'input_words': [tokens[:-1] for tokens in all_tokens], # Reconstruct if needed
        'output_words': [tokens[1:] for tokens in all_tokens], # Reconstruct if needed
        'input_ids_raw': input_ids_raw, # Keep if needed
        'output_ids_raw': output_ids_raw, # Keep if needed
    }

tokenized_train_dataset = train_dataset.map(slide_window, batched=True)
tokenized_val_dataset = val_dataset.map(slide_window, batched=True)
tokenized_test_dataset = test_dataset.map(slide_window, batched=True)


# filter to remove samples where length of input_ids is 0
tokenized_train_dataset = tokenized_train_dataset.filter(lambda x: len(x['input_ids']) > 0)
tokenized_val_dataset = tokenized_val_dataset.filter(lambda x: len(x['input_ids']) > 0)
tokenized_test_dataset = tokenized_test_dataset.filter(lambda x: len(x['input_ids']) > 0)

tokenized_train_dataset


In [None]:
tokenized_train_dataset[1]

In [None]:
import torch 
from datasets import Dataset as HFDataset
from torch.utils.data import Dataset

class HuggingFaceDataset(Dataset):
    """
    Wraps a Hugging Face Dataset to be used with a PyTorch DataLoader.

    Assumes the Hugging Face dataset has 'input' and 'target' columns.
    """

    def __init__(self, hf_dataset: HFDataset):
        self.hf_dataset = hf_dataset

    def __len__(self):
        return len(self.hf_dataset)

    def __getitem__(self, idx):
        item = self.hf_dataset[idx]
        return item['input_ids'], item['output_ids'], item['attention_mask']

def collate_fn(batch):
    input_ids = [item[0] for item in batch]
    output_ids = [item[1] for item in batch]
    attention_mask = [item[2] for item in batch]

    

    # set dtype to long

    input_ids_list = torch.tensor(input_ids, dtype=torch.long)
    output_ids_list = torch.tensor(output_ids, dtype=torch.long)
    attention_mask_list = torch.tensor(attention_mask, dtype=torch.long)
    return input_ids_list, output_ids_list, attention_mask_list

batch_size = 20
train_torch_dataset = HuggingFaceDataset(tokenized_train_dataset)
val_torch_dataset = HuggingFaceDataset(tokenized_val_dataset)
test_torch_dataset = HuggingFaceDataset(tokenized_test_dataset)

train_torch_dataloader = torch.utils.data.DataLoader(
    train_torch_dataset,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=collate_fn
)
val_torch_dataloader = torch.utils.data.DataLoader(
    val_torch_dataset,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=collate_fn
)
test_torch_dataloader = torch.utils.data.DataLoader(
    test_torch_dataset,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=collate_fn
)

train_torch_dataloader

In [None]:
batch = next(iter(train_torch_dataloader)) # (input_ids, output_ids)
input_ids, output_ids, attention_mask = batch
input_ids.shape, output_ids.shape, attention_mask.shape

### Use GPT2 model

In [None]:
from models import GPT2

num_train_batches = tokenized_train_dataset.num_rows // batch_size
num_val_batches = tokenized_val_dataset.num_rows // batch_size

config = {
        "emb_dim": 768,
        "heads": 12,
        "layers": 12,
        "vocab_size": 50257,
        "context_length": 1024,
        "device": torch.device("cuda"),
        "drop_out": 0.1,
        "train_test_split": 0.8,
        "num_epochs": 25,
        "model_path": "../model_files/gpt2.pth",
        "num_train_batches" : num_train_batches,
        "num_train_batches" : num_train_batches,
        "learning_rate" : 1e-4,
        "num_test_batches" : num_val_batches,
    }

gpt2 = GPT2(config)
gpt2.to(config['device'])
gpt2

### Use train/test loop

In [None]:
from utils import train

train(gpt2, train_torch_dataloader, val_torch_dataloader, config, use_fp_16=True)

### Generate text

In [None]:
dir(wrapped_tokenizer)

In [None]:
tokenized = wrapped_tokenizer("Hello my name is", truncation=True, max_length=100, padding="max_length", return_tensors="pt")

attention_mask = tokenized['attention_mask'].to(config["device"])
input_ids = tokenized['input_ids'].to(config["device"])

print(attention_mask)
print(input_ids)

prediction = gpt2(input_ids)
next_token = prediction.argmax(dim=-1)

print(prediction.shape)

In [None]:

def generate_text(starting_text, model, tokenizer, config, num_output_tokens=100):
    device = config["device"]
    output_tokens = []

    input_encoding = tokenizer(starting_text, truncation=True, max_length=1024, padding="max_length", return_tensors="pt")
    input_ids = input_encoding['input_ids'].to(device)
    
    output_text = f"{starting_text} -> "
    for _ in range(num_output_tokens):
        
        next_token_logits = model(input_ids)[:,-1,:]
        next_token = next_token_logits.argmax(dim=-1)

        output_tokens.append(next_token.item())

        next_token = next_token.to(device)
        next_token_decoded = tokenizer.decode(next_token.item())
        output_text += next_token_decoded
        

        # Append the predicted token to the input for the next iteration
        input_ids = torch.cat((input_ids, next_token.unsqueeze(0)), dim=1)
        input_ids = input_ids[:, -1024:]

        if next_token.item() == tokenizer.eos_token_id:
            break
        
        
        

        



        #output_text += next_text
    print(output_text)

generate_text("The capital is", gpt2, wrapped_tokenizer, config)


In [None]:
wrapped_tokenizer.encode("Who is the president of the United States?")