In [None]:
from llm_trainer import LLMTrainer, create_dataset  # pip install llm_trainer
# create_dataset(chunks_limit=1500)

In [None]:
import random
import torch
from omegaconf import OmegaConf
from dacite import from_dict
from dacite import Config as DaciteConfig
from xlstm import xLSTMLMModel, xLSTMLMModelConfig
import numpy as np
from torch.nn import functional as F


def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

set_seed(42)

# Create a model
cfg = OmegaConf.load("xlstm_config.yaml")
cfg = from_dict(data_class=xLSTMLMModelConfig, data=OmegaConf.to_container(cfg), config=DaciteConfig(strict=True))
xLSTM = xLSTMLMModel(cfg)

# Print the size of the model (a number of parameters it has) in millions
num_params = sum(p.numel() for p in xLSTM.parameters() if p.requires_grad)
print(f"Total Parameters: {num_params / 1e6:.2f}M")

In [None]:
# Train the model
trainer = LLMTrainer(model=xLSTM, model_returns_logits=True)

In [None]:
trainer.train(max_steps=3000, save_each_n_steps=1000, context_window=512, MINI_BATCH_SIZE=8)