In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os

import torch
import lightning

from retnet import GPTR, GPTRConfig, GPTRClassifier
from lra import ListOps

In [3]:
dataset = ListOps("listops-1000")
dataset.setup()

In [4]:
train_dataloader = dataset.train_dataloader(batch_size=4, num_workers=23)
valid_dataloader = dataset.val_dataloader(batch_size=4, num_workers=23)

In [5]:
len(train_dataloader)//8

3000

In [6]:
config = GPTRConfig(vocab_size=dataset.vocab_size,
                    context_window=None,
                    nclasses=10,
                    embedding_dim=512,
                    nheads=8,
                    nlayers=6,
                    nhidden=2048
                    )
model = GPTRClassifier(config)

In [7]:
class LLMClassifier(lightning.LightningModule):
    def __init__(self, model):
        super().__init__()
        self.model = model
        self.warmup_steps = 1000*8

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        # it is independent of forward
        x, y, args = batch
        lengths = args['lengths']
        logits = self.model(x, lengths)
        loss = torch.nn.CrossEntropyLoss()(logits.logits, batch[1])
        acc = (torch.argmax(logits.logits, axis=-1) == batch[1]).float().mean()
        self.log("train_loss", loss, prog_bar=True)
        self.log("train_acc", acc, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y, args = batch
        logits = self.model.forward(x, args['lengths'])
        loss = torch.nn.CrossEntropyLoss()(logits.logits, batch[1])
        acc = (torch.argmax(logits.logits, axis=-1) == batch[1]).float().mean()
        self.log("valid_loss", loss, prog_bar=True)
        self.log("valid_acc", acc, prog_bar=True)
        return loss

    def create_optimizer(self):
        return torch.optim.AdamW(self.parameters(), lr=0.05, weight_decay=0.1)
            
    def lr_warmup_config(self):
        def warmup(step):
            """
            This method will be called for ceil(warmup_batches/accum_grad_batches) times,
            warmup_steps has been adjusted accordingly
            """
            if self.warmup_steps <= 0:
                factor = 1
            else:
                factor = min(step / self.warmup_steps, 1)
            return factor

        opt1 = self.create_optimizer()
        return {
            'frequency': 1,
            'optimizer': opt1,
            'lr_scheduler': {
                'scheduler': torch.optim.lr_scheduler.LambdaLR(opt1, warmup),
                'interval': 'step',
                'frequency': 1,
                'name': 'lr/warmup'
            },
        }

    def configure_optimizers(self):
        return (
            self.lr_warmup_config(),
        )


In [8]:
module = LLMClassifier(model)

In [9]:
os.cpu_count()

24

In [10]:
trainer = lightning.Trainer(max_epochs=2, accumulate_grad_batches=8)
trainer.fit(model=module, train_dataloaders=train_dataloader, val_dataloaders=valid_dataloader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


/home/lcadame/miniconda3/envs/ddpm_env/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:67: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/lcadame/miniconda3/envs/ddpm_env/lib/python3.11/site-packages/lightning/pytorch/core/optimizer.py:375: Found unsupported keys in the optimizer configuration: {'frequency'}

  | Name  | Type           | Params
-----------------------------------------
0 | model | GPTRClassifier | 20.5 M
-----------------------------------------
20.5 M    Trainable params
0         Non-trainable params
20.5 M    Tot

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

/home/lcadame/miniconda3/envs/ddpm_env/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...
