# Test multinode multigpu training with deepspeed and ray
Tokenized datasets are cached under `data/main_cache`. 

In [None]:
import torch
import pytorch_lightning as pl
from pytorch_lightning.strategies import DeepSpeedStrategy
from transformers import AutoModelForCausalLM, AutoTokenizer
from src.data.data_manager import LightningDataModule
from ray.train.lightning import RayLightningTrainer
from ray.train import ScalingConfig
# LightningModule for training your Hugging Face LLM
class LLMModel(pl.LightningModule):
    def __init__(self, model_name, learning_rate=1e-5):
        super().__init__()
        self.save_hyperparameters()
        self.model = AutoModelForCausalLM.from_pretrained(model_name)
    def forward(self, input_ids, attention_mask):
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids)
        return outputs.loss
    def training_step(self, batch, batch_idx):
        loss = self(**batch)
        self.log("train_loss", loss, on_step=True, prog_bar=True)
        return loss
    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.hparams.learning_rate)


def train_func(config):
    model_name = "gpt2"  # replace with your desired LLM
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    dm = LightningDataModule(
        tokenizer=tokenizer,
        dataset_configs={"wikitext": {}},
        batch_size=2,
        max_length=512
    )

    model = LLMModel(model_name=model_name)

    # DeepSpeed ZeRO-3 strategy with CPU offload
    strategy = DeepSpeedStrategy(config="../config/ds_config.json")

    trainer = pl.Trainer(
        strategy=strategy,
        accelerator="gpu",
        devices=torch.cuda.device_count(),
        precision=16,
        max_epochs=3
    )

    trainer.fit(model, datamodule=dm)


# Main entry-point (Ray handles the distributed training orchestration)
if __name__ == "__main__":

    scaling_config = ScalingConfig(
        num_workers=8,  # total number of GPUs across your nodes
        use_gpu=True,
        resources_per_worker={"CPU": 8, "GPU": 1},
    )

    trainer = RayLightningTrainer(
        train_loop_per_worker=train_func,
        scaling_config=scaling_config,
        run_config=None,
        lightning_config={},
        trainer_init_config={},
    )

    trainer.fit()
