In [1]:
import os.path

from torch.utils.data import DataLoader
import torch
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchinfo import summary
from tokenizer import Tokenizer
from file_dataset import FileDataset
from transformer_model import build_transformer

In [2]:
from ray.train.torch import TorchTrainer
from ray.train import SyncConfig
import tempfile
import ray.train.torch

In [3]:
def train_fn(tmpdir):
    tokenizer = Tokenizer()
    VOCAB_SIZE = tokenizer.get_vocab_size()
    dataset = FileDataset("/Users/daniilogorodnikov/dataset/app", 'sha256', 128)
    dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
    trainloader = ray.train.torch.prepare_data_loader(dataloader)

    criterion = nn.CrossEntropyLoss(ignore_index=5)
    optimizer = torch.optim.Adam(transformer.parameters(), lr=1e-4)

    transformer = build_transformer(vocab_size=VOCAB_SIZE,
                                d_model=512,
                                max_seq_len=128,
                                d_ff=1024,
                                dropout=0.1,
                                n_layers=6,
                                n_heads=8,
                                factor=2)

    model = ray.train.torch.prepare_model(transformer)

    for epoch in range(10):
        if ray.train.get_context().get_world_size() > 1:
            trainloader.sampler.set_epoch(epoch)
        for batch in trainloader:
            input_ids = batch['input_ids']
            attention_mask = batch['attention_mask']

            # Forward pass
            outputs = model(input_ids, attention_mask)

            # Расчет потерь (сравниваем выход с входом)
            loss = criterion(outputs.view(-1, outputs.shape[-1]), input_ids.view(-1))

            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        metrics = {"loss": loss.item(), "epoch": epoch}
        with tempfile.TemporaryDirectory() as tmpdir:
            torch.save(model.state_dict(),
                       os.path.join(tmpdir, "model.pt")
                       )
            ray.train.report(metrics,
                             checkpoint=ray.train.Checkpoint.from_directory(tmpdir),
                             )
        if ray.train.get_context().get_world_rank() == 0:
            print(metrics)


In [4]:
scaling_config = ray.train.ScalingConfig(num_workers=2, use_gpu=False)

# [5] Launch distributed training job.
trainer = ray.train.torch.TorchTrainer(
    train_fn,
    scaling_config=scaling_config,
    # [5a] If running in a multi-node cluster, this is where you
    # should configure the run's persistent storage that is accessible
    # across all worker nodes.
    # run_config=ray.train.RunConfig(storage_path="s3://..."),
)

In [5]:
result = trainer.fit()

2025-05-25 14:59:08,975	INFO worker.py:1888 -- Started a local Ray instance.
2025-05-25 14:59:09,595	INFO tune.py:253 -- Initializing Ray automatically. For cluster usage or custom Ray initialization, call `ray.init(...)` before `<FrameworkTrainer>(...)`.
2025-05-25 14:59:09,597	INFO tune.py:616 -- [output] This uses the legacy output and progress reporter, as Jupyter notebooks are not supported by the new engine, yet. For more information, please see https://github.com/ray-project/ray/issues/36949


== Status ==
Current time: 2025-05-25 14:59:09 (running for 00:00:00.13)
Using FIFO scheduling algorithm.
Logical resource usage: 3.0/8 CPUs, 0/0 GPUs
Result logdir: /tmp/ray/session_2025-05-25_14-59-06_331250_54916/artifacts/2025-05-25_14-59-09/TorchTrainer_2025-05-25_14-59-06/driver_artifacts
Number of trials: 1/1 (1 PENDING)




[36m(TorchTrainer pid=54954)[0m Started distributed worker processes: 
[36m(TorchTrainer pid=54954)[0m - (node_id=fce7f84855974765396ce14e761d727c7ff489cd5a60b2a887f44415, ip=127.0.0.1, pid=54959) world_rank=0, local_rank=0, node_rank=0
[36m(TorchTrainer pid=54954)[0m - (node_id=fce7f84855974765396ce14e761d727c7ff489cd5a60b2a887f44415, ip=127.0.0.1, pid=54958) world_rank=1, local_rank=1, node_rank=0
[36m(RayTrainWorker pid=54959)[0m Setting up process group for: env:// [rank=0, world_size=2]
  0%|          | 0/780 [00:00<?, ?it/s]
  1%|          | 4/780 [00:00<00:45, 17.23it/s]


== Status ==
Current time: 2025-05-25 14:59:14 (running for 00:00:05.15)
Using FIFO scheduling algorithm.
Logical resource usage: 3.0/8 CPUs, 0/0 GPUs
Result logdir: /tmp/ray/session_2025-05-25_14-59-06_331250_54916/artifacts/2025-05-25_14-59-09/TorchTrainer_2025-05-25_14-59-06/driver_artifacts
Number of trials: 1/1 (1 RUNNING)




  0%|          | 0/780 [00:00<?, ?it/s]
 14%|█▍        | 110/780 [00:05<00:21, 31.40it/s][32m [repeated 48x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/user-guides/configure-logging.html#log-deduplication for more options.)[0m


== Status ==
Current time: 2025-05-25 14:59:19 (running for 00:00:10.22)
Using FIFO scheduling algorithm.
Logical resource usage: 3.0/8 CPUs, 0/0 GPUs
Result logdir: /tmp/ray/session_2025-05-25_14-59-06_331250_54916/artifacts/2025-05-25_14-59-09/TorchTrainer_2025-05-25_14-59-06/driver_artifacts
Number of trials: 1/1 (1 RUNNING)




 26%|██▌       | 204/780 [00:10<00:50, 11.32it/s][32m [repeated 42x across cluster][0m


== Status ==
Current time: 2025-05-25 14:59:24 (running for 00:00:15.28)
Using FIFO scheduling algorithm.
Logical resource usage: 3.0/8 CPUs, 0/0 GPUs
Result logdir: /tmp/ray/session_2025-05-25_14-59-06_331250_54916/artifacts/2025-05-25_14-59-09/TorchTrainer_2025-05-25_14-59-06/driver_artifacts
Number of trials: 1/1 (1 RUNNING)




 32%|███▏      | 249/780 [00:12<00:37, 14.16it/s]
2025-05-25 14:59:26,971	INFO tune.py:1009 -- Wrote the latest version of all result files and experiment state to '/Users/daniilogorodnikov/ray_results/TorchTrainer_2025-05-25_14-59-06' in 0.0028s.


== Status ==
Current time: 2025-05-25 14:59:26 (running for 00:00:17.34)
Using FIFO scheduling algorithm.
Logical resource usage: 3.0/8 CPUs, 0/0 GPUs
Result logdir: /tmp/ray/session_2025-05-25_14-59-06_331250_54916/artifacts/2025-05-25_14-59-09/TorchTrainer_2025-05-25_14-59-06/driver_artifacts
Number of trials: 1/1 (1 RUNNING)




2025-05-25 14:59:27,239	INFO tune.py:1041 -- Total run time: 17.64 seconds (17.33 seconds for the tuning loop).
Resume training with: <FrameworkTrainer>.restore(path="/Users/daniilogorodnikov/ray_results/TorchTrainer_2025-05-25_14-59-06", ...)
