# Example use

## Setup & Imports

In [None]:
# Enable python import reloading
%load_ext autoreload
%autoreload 2

In [None]:
from torchtext.datasets import WikiText2
from torchtext.data.utils import get_tokenizer
from torch.utils.data import DataLoader
from torchtext.vocab import build_vocab_from_iterator
import math
import urllib
from typing import Tuple
from pathlib import Path
from transformers import GPTNeoXTokenizerFast
from datasets import load_dataset


import torch
from torch import nn, Tensor
import torch.nn.functional as F
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from torch.utils.data import dataset

from alan_transformer.transformer import Transformer
from alan_transformer.train import train_loop

## Task: Complete Works of Shakespeare

The steps are:

- Download text file from https://www.gutenberg.org/files/100/100-0.txt
- Tokenize it
- Break up into prompts, where each prompt is size 1024 + answer
  prompt is the same size (but with a token offset of 1).
- One-hot-encode these
- Group prompts into batches
- Initialise model
- Run train loop - noting that we need to optimise based on all answers.

In [None]:
# Download text file
data_dir = Path(".data")
data_dir.mkdir(parents=True, exist_ok=True)
data_path = data_dir / "shakespeare.txt"
data_url = "https://www.gutenberg.org/files/100/100-0.txt"
urllib.request.urlretrieve(data_url, data_path)

# Load as a dataset
dataset = load_dataset("text", data_files=str(data_path))

In [None]:
# Tokenize it
tokenizer = GPTNeoXTokenizerFast.from_pretrained("EleutherAI/gpt-neox-20b", pad_token = "<|endoftext|>")
tokenized_dataset = dataset.map(
    lambda examples: tokenizer(
        examples["text"], 
        padding="max_length", # Pad to the max length
        truncation=True, # Truncate to the max length
        max_length=1024 # 1024 is the default max length for our transformer
    ),
    batched=True)

In [None]:
dataloader = DataLoader(tokenized_dataset["train"], batch_size=8, shuffle=True)
f"Number of batches: {len(dataloader)}"

In [None]:
model = Transformer()

In [None]:
torch.cuda.is_available()

In [None]:
train_loop(
    model,
    dataloader
)
