# Demonstrate Pytorch Lightning on Jupyter

**IMPORTANT**: must set `strategy=None|dp`.
- Contrary to the lightning's error message, we found that `strategy=None` does not work.
- Related issue: https://github.com/Lightning-AI/lightning/issues/7550

This means that on Jupyter, multi-GPU works but not multi-Gaudi because Gaudi's pytorch does not
support DP.

In [None]:
%load_ext autoreload
%autoreload 2

import argparse
import logging

from bench import Text8Dataset

## Training configuration

In [None]:
args = argparse.Namespace()
args.num_epochs = 5
args.batch_size = 64
args.block_size = 128
args.num_workers = 0
args.pin_memory = 0
args.precision = 16
args.default_root_dir = "."
print(vars(args))

logger = logging.getLogger(__name__)
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO,
)

## Prepare data loaders

In [None]:
logging.info("preparing the data loaders")
# NOTE: REDUCED DATA SIZE FOR DEBUGGING, TODO CLEAN BEFORE MERGE IF EVER
train_dataset = Text8Dataset("../data/text8", args.block_size, crop=(0, int(90e4)))
val_dataset = Text8Dataset(
    "../data/text8",
    args.block_size,
    crop=(int(90e4), int(5e4)),
    override_vocab=train_dataset.vocab,
    unknown_ch2i=0,
)
test_dataset = Text8Dataset(
    "../data/text8",
    args.block_size,
    crop=(int(95e4), int(5e4)),
    override_vocab=train_dataset.vocab,
    unknown_ch2i=0,
)
common = {
    "batch_size": args.batch_size,
    "pin_memory": bool(args.pin_memory),
    "num_workers": args.num_workers,
}
train_dataloader = DataLoader(train_dataset, shuffle=True, **common)
val_dataloader = DataLoader(val_dataset, shuffle=False, **common)

logging.info("creating the model")
model = GPT(train_dataset.vocab_size, args.block_size, n_layer=8, n_head=8, n_embd=256)

logging.info("preparing the learning rate schedule")
iter_tokens = args.batch_size * args.block_size  # number of tokens backpropped in one iteration
epoch_tokens = math.ceil(len(train_dataset) / args.batch_size) * iter_tokens
lr_decay = WarmupCosineLearningRateDecay(
    learning_rate=6e-4,
    warmup_tokens=epoch_tokens // 2,
    final_tokens=args.num_epochs * epoch_tokens,
)

## Start training

In [None]:
t0 = time.time()
logging.info("training...")
trainer = pl.Trainer(
    accelerator="auto",
    benchmark=True,
    max_epochs=args.num_epochs,
    gradient_clip_val=1.0,
    callbacks=[lr_decay, pl.callbacks.ModelSummary(max_depth=2)],
    precision=args.precision,
    default_root_dir=args.default_root_dir,
    strategy="dp",
)
trainer.fit(model, train_dataloader, val_dataloader)
t1 = time.time()
logging.info(
    "%d epochs took %fs, or %fs/epoch", args.num_epochs, t1 - t0, (t1 - t0) / args.num_epochs
)

## Test split

In [None]:
logging.info("testing...")
test_dataloader = DataLoader(test_dataset, shuffle=False, **common)
trainer.test(dataloaders=test_dataloader)

logging.info("sampling:")
context = "anarchism originated as a term of"
x = torch.tensor([train_dataset.stoi[s] for s in context], dtype=torch.long)[None, ...]
if next(model.parameters()).is_cuda:
    x = x.cuda()
y = sample(model, x, 200, temperature=1.0, sample=True, top_k=None)[0]
completion = "".join([train_dataset.itos[int(i)] for i in y])
print(completion)