In [None]:
%git clone https://github.com/RPegoud/ember.git
%pip install -e ./ember

In [None]:
from hydra import compose, initialize
import torch
from ember import Logger, HFTokenizer, Transformer
from datasets import load_dataset
from torch.utils.data import DataLoader
import lightning as L
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [None]:
initialize(version_base=None, config_path="ember/configs/llm")
cfg = compose(config_name="train.yaml")

In [None]:
logger = Logger()
logger.log_config(cfg)


class Collator:
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer

    def __call__(self, batch: list[str]) -> torch.Tensor:
        texts = [x["text"] for x in batch]
        return self.tokenizer(texts, mode="train")

In [None]:
tokenizer = HFTokenizer(cfg.tokenizer.path)
collator = Collator(tokenizer)
model = Transformer(
    vocab_size=tokenizer.vocab_size,
    model_dim=cfg.model.model_dim,
    hidden_dim=cfg.model.hidden_dim,
    attention=cfg.model.attention,
    n_attn_blocks=cfg.model.n_attn_blocks,
    learning_rate=cfg.model.learning_rate,
    pad_token_id=tokenizer.pad_token_id,
)

ds = load_dataset(cfg.hparams.data.dataset, split=cfg.hparams.data.split)
train_loader = DataLoader(
    ds,
    batch_size=cfg.hparams.data.batch_size,
    persistent_workers=True,
    num_workers=cfg.hparams.data.num_workers,
    collate_fn=collator,
)

trainer = L.Trainer(
    max_epochs=cfg.hparams.trainer.max_epochs,
    precision=cfg.hparams.trainer.precision,
    gradient_clip_val=cfg.hparams.trainer.gradient_clip_val,
    accumulate_grad_batches=cfg.hparams.trainer.accumulate_grad_batches,
    log_every_n_steps=cfg.hparams.trainer.log_every_n_steps,
)
trainer.fit(model=model, train_dataloaders=train_loader)