# Goal
- Train GPT2 on wiki text

## Steps
- Read, download data
- Train tokenizer
- Prepare sliding window data loader
- Use GPT2 model
- Use train/test loop

### Read, download data

In [None]:
from datasets import load_dataset

train_dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
val_dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="validation")
test_dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")

train_dataset

In [None]:
type(train_dataset['text'][1])

### Train tokenizer

In [None]:
import tokenizers
import transformers

tokenizer = tokenizers.Tokenizer(tokenizers.models.BPE())
tokenizer.pre_tokenizer = tokenizers.pre_tokenizers.ByteLevel(add_prefix_space=False)

trainer = tokenizers.trainers.BpeTrainer(vocab_size=25000, special_tokens=["<|endoftext|>"])
tokenizer.train_from_iterator(train_dataset["text"], trainer=trainer)
tokenizer.post_processor = tokenizers.processors.ByteLevel(trim_offsets=False)

tokenizer.save("../data/tokenizer.json")
tokenizer.decoder = tokenizers.decoders.ByteLevel()


wrapped_tokenizer = transformers.PreTrainedTokenizerFast(
    tokenizer_object=tokenizer,
    bos_token="<|endoftext|>",
    eos_token="<|endoftext|>",
    padding_side="left",
    pad_token="<pad>"
)

tokenizer.encode("Hello my name is Ajay").tokens


In [None]:
wrapped_tokenizer("Hello my name is Ajay")['input_ids']

### Prepare sliding window data loader

In [None]:
def tokenize(examples):
    inpt_text = examples['text']
    inpt_text = [text[:-1] for text in inpt_text]
    output_text = [text[1:] for text in inpt_text]
    examples['input_ids'] = wrapped_tokenizer(inpt_text, truncation=True, max_length=20, padding="max_length", return_tensors="pt")['input_ids']
    examples['output_ids'] = wrapped_tokenizer(output_text, truncation=True, max_length=20, padding="max_length", return_tensors="pt")['input_ids']
    return examples

tokenized_train_dataset = train_dataset.map(tokenize, batched=True)
tokenized_val_dataset = val_dataset.map(tokenize, batched=True)
tokenized_test_dataset = test_dataset.map(tokenize, batched=True)

tokenized_train_dataset

In [None]:
import torch 
from datasets import Dataset as HFDataset
from torch.utils.data import Dataset

class HuggingFaceDataset(Dataset):
    """
    Wraps a Hugging Face Dataset to be used with a PyTorch DataLoader.

    Assumes the Hugging Face dataset has 'input' and 'target' columns.
    """

    def __init__(self, hf_dataset: HFDataset):
        self.hf_dataset = hf_dataset

    def __len__(self):
        return len(self.hf_dataset)

    def __getitem__(self, idx):
        item = self.hf_dataset[idx]
        return item['input_ids'], item['output_ids']

def collate_fn(batch):
    input_ids = [item[0] for item in batch]
    output_ids = [item[1] for item in batch]
    input_ids_list = torch.tensor(input_ids)
    output_ids_list = torch.tensor(output_ids)
    return input_ids_list, output_ids_list

batch_size = 128
train_torch_dataset = HuggingFaceDataset(tokenized_train_dataset)
val_torch_dataset = HuggingFaceDataset(tokenized_val_dataset)
test_torch_dataset = HuggingFaceDataset(tokenized_test_dataset)

train_torch_dataloader = torch.utils.data.DataLoader(
    train_torch_dataset,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=collate_fn
)
val_torch_dataloader = torch.utils.data.DataLoader(
    val_torch_dataset,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=collate_fn
)
test_torch_dataloader = torch.utils.data.DataLoader(
    test_torch_dataset,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=collate_fn
)

train_torch_dataloader

In [None]:
batch = next(iter(train_torch_dataloader)) # (input_ids, output_ids)
input_ids, output_ids = batch
input_ids.shape, output_ids.shape

### Use GPT2 model

In [None]:
from models import GPT2

config = {
        "emb_dim": 768,
        "heads": 12,
        "layers": 12,
        "vocab_size": 50257,
        "context_length": 128,
        "device": torch.device("cuda"),
        "drop_out": 0.1,
        "train_test_split": 0.8,
        "num_epochs": 5,
        "model_path": "./models/gpt2.pth",
    }

gpt2 = GPT2(config)
gpt2.to(config['device'])
gpt2

### Use train/test loop

In [None]:
from utils import train

train(gpt2, train_torch_dataloader, val_torch_dataloader, config)
torch.save(gpt2.state_dict(), config["model_path"])

KeyboardInterrupt: 

### Generate text

In [None]:

def generate_text_greedy(model, starting_context, max_new_tokens, context_size, config):
    for _ in range(max_new_tokens):
        idx = wrapped_tokenizer.encode(starting_context, return_tensors="pt").to(config["device"])
        model.to(config["device"])

        logits = model(
            idx[:, -context_size:]
        )  # consider only last set of context size tokens

        next_token_logit = logits[:, -1, :]

        probs = torch.softmax(next_token_logit, dim=-1)

        idx_next = torch.argmax(probs, dim=-1, keepdim=True)

        idx = torch.cat((idx, idx_next), dim=1)

    text = wrapped_tokenizer.decode(idx[0].tolist())

    return text

starting_context = "The cat"

generate_text_greedy(
    gpt2,
    starting_context,
    max_new_tokens=100,
    context_size=100,
    config=config,
)

