# Dev Playground

In [1]:
from accelerate import Accelerator, DistributedDataParallelKwargs
import gymnasium as gym
import minari
import torch
from gato.training.arguments import TrainingArgs
from gato.tasks.text_task import TextTask

pygame 2.5.2 (SDL 2.28.2, Python 3.10.11)
Hello from the pygame community. https://www.pygame.org/contribute.html


# TextTask

In [2]:
args = TrainingArgs(
    text_datasets=['wikitext-2-v1'], 
    text_datasets_paths=['wikitext'],
    warmup_steps=1,
    training_steps=12,
    eval_episodes=1,
    log_eval_freq=4,
    batch_size=4,
    text_prop=1.0,
    device='cuda',
)

In [3]:
text_task = TextTask(args.text_datasets, args.text_datasets_paths, args.sequence_length, args.tokenizer_model_name)

In [4]:
from gato.policy.gato_policy import GatoPolicy

In [5]:
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
accelerator = Accelerator(cpu=args.cpu, mixed_precision=args.mixed_precision, split_batches=True, gradient_accumulation_steps=args.gradient_accumulation_steps, kwargs_handlers=[ddp_kwargs])

In [6]:
model = GatoPolicy(
    device=args.device,
    embed_dim=args.embed_dim,
    layers=args.layers,
    heads=args.heads,
    dropout=args.dropout,
    mu=args.mu,
    M=args.M,
    patch_size=args.patch_size,
    resid_mid_channels=args.resid_mid_channels,
    continuous_tokens=args.continuous_tokens,
    discrete_tokens=args.discrete_tokens,
    context_len=args.sequence_length,
    use_patch_pos_encoding=not args.disable_patch_pos_encoding,
    use_pos_encoding=not args.disable_inner_pos_encoding,
    activation_fn=args.activation_fn,
    pretrained_lm=args.pretrained_lm,
    flash=args.flash,
    tokenizer_model_name=args.tokenizer_model_name,
    pad_seq=args.pad_seq,
)
model.to(args.device);

In [7]:
from gato.training.schedulers import get_linear_warmup_cosine_decay_scheduler

In [8]:
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=args.learning_rate,
    betas=(args.beta_1, args.beta_2),
    eps=args.adam_eps,
    weight_decay=args.weight_decay,
)

# Setup scheduler
scheduler = get_linear_warmup_cosine_decay_scheduler(
    optimizer, 
    args.warmup_steps, 
    args.training_steps, 
    base_lr=args.learning_rate, 
    init_lr=args.init_lr, 
    min_lr=args.learning_rate / args.min_factor, 
    cosine_decay=not args.disable_cosine_decay,
)

In [9]:
from gato.training.trainer import Trainer
from datetime import datetime

In [10]:
tasks = [text_task]
exp_name = f'{datetime.isoformat(datetime.now())}'
trainer = Trainer(
    model = model,
    optimizer = optimizer,
    scheduler = scheduler,
    accelerator = accelerator,
    tasks = tasks,
    exp_name = exp_name,
    args=args
)
trainer.train()

Num of examples to test : 100 | Actual batch size of test data : 68
Iteration 0
training/learning_rate: 9.285640897740315e-05
time/sample_batch: 0.0016694068908691406
time/training: 0.9930353164672852
evaluation/text/loss: 9.903299871612997
evaluation/text/perplexity: 19996.255859375
time/total: 20.59604048728943
time/evaluation: 19.6029953956604
training/train_loss_mean: 10.898804664611816
training/train_loss_std: 0.22483853898092993
Num of examples to test : 100 | Actual batch size of test data : 69
Iteration 1
training/learning_rate: 4.859583227770217e-05
time/sample_batch: 0.001455068588256836
time/training: 0.3894932270050049
evaluation/text/loss: 9.390184202056
evaluation/text/perplexity: 11970.306640625
time/total: 44.298654556274414
time/evaluation: 23.31282091140747
training/train_loss_mean: 9.925028562545776
training/train_loss_std: 0.20472925814875015
Num of examples to test : 100 | Actual batch size of test data : 55
Iteration 2
training/learning_rate: 1.1822816187347625e-0

# Pulling apart the training loop.

# Train

# Train iteration

Set the PyTorch model to 'train' mode: https://stackoverflow.com/a/51433411/3937773

`model.train()` tells your model that you are training the model.
This helps inform layers such as Dropout and BatchNorm, which are designed to behave differently during training and evaluation.
For instance, in training mode, BatchNorm updates a moving average on each new batch; whereas, for evaluation mode, these updates are frozen.

In [11]:
model.train();

In [12]:
inputs = text_task.sample_batch(args.batch_size)

In [13]:
for i in range(5):
    logits, loss = model.forward(inputs=inputs, compute_loss=True)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    scheduler.step()
    print(f'{i} loss: {loss}')

0 loss: 9.800954818725586
1 loss: 9.787251472473145
2 loss: 9.740415573120117
3 loss: 9.643735885620117
4 loss: 9.542501449584961
