In [4]:
import os.path

from torch.utils.data import DataLoader
import torch
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchinfo import summary
from tokenizer import Tokenizer
from file_dataset import FileDataset
from transformer_model import build_transformer

In [5]:
from ray.train.torch import TorchTrainer
from ray.train import SyncConfig
import tempfile
import ray.train.torch

In [6]:
def train_fn(dataset):
    tokenizer = Tokenizer()
    VOCAB_SIZE = tokenizer.get_vocab_size()
    dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
    trainloader = ray.train.torch.prepare_data_loader(dataloader)
    print("all OK")
    transformer = build_transformer(vocab_size=VOCAB_SIZE,
                                d_model=512,
                                max_seq_len=128,
                                d_ff=1024,
                                dropout=0.1,
                                n_layers=6,
                                n_heads=8,
                                factor=2)
    print("all OK")
    criterion = nn.CrossEntropyLoss(ignore_index=5)
    optimizer = torch.optim.Adam(transformer.parameters(), lr=1e-4)

    model = ray.train.torch.prepare_model(transformer)
    print("all OK")
    for epoch in range(10):
        if ray.train.get_context().get_world_size() > 1:
            trainloader.sampler.set_epoch(epoch)
        for batch in trainloader:
            input_ids = batch['input_ids']
            attention_mask = batch['attention_mask']

            # Forward pass
            outputs = model(input_ids, attention_mask)

            # Расчет потерь (сравниваем выход с входом)
            loss = criterion(outputs.view(-1, outputs.shape[-1]), input_ids.view(-1))

            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        metrics = {"loss": loss.item(), "epoch": epoch}
        with tempfile.TemporaryDirectory() as tmpdir:
            torch.save(model.state_dict(),
                       os.path.join(tmpdir, "model.pt")
                       )
            ray.train.report(metrics,
                             checkpoint=ray.train.Checkpoint.from_directory(tmpdir),
                             )
        if ray.train.get_context().get_world_rank() == 0:
            print(metrics)


In [7]:
scaling_config = ray.train.ScalingConfig(num_workers=1, use_gpu=False)
dataset = torch.load("dataset.pt", weights_only=False)
# [5] Launch distributed training job.
trainer = ray.train.torch.TorchTrainer(
    train_fn,
    scaling_config=scaling_config,
    train_loop_config=dataset
    # [5a] If running in a multi-node cluster, this is where you
    # should configure the run's persistent storage that is accessible
    # across all worker nodes.
    # run_config=ray.train.RunConfig(storage_path="s3://..."),
)

In [None]:
result = trainer.fit()

2025-05-26 00:20:35,032	INFO worker.py:1879 -- Started a local Ray instance. View the dashboard at [1m[32mhttp://127.0.0.1:8265 [39m[22m
2025-05-26 00:21:20,486	INFO tune.py:253 -- Initializing Ray automatically. For cluster usage or custom Ray initialization, call `ray.init(...)` before `<FrameworkTrainer>(...)`.
2025-05-26 00:21:20,491	INFO tune.py:616 -- [output] This uses the legacy output and progress reporter, as Jupyter notebooks are not supported by the new engine, yet. For more information, please see https://github.com/ray-project/ray/issues/36949


== Status ==
Current time: 2025-05-26 00:26:11 (running for 00:04:25.90)
Using FIFO scheduling algorithm.
Logical resource usage: 0/8 CPUs, 0/0 GPUs
Result logdir: /tmp/ray/session_2025-05-26_00-20-32_790466_2079/artifacts/2025-05-26_00-21-46/TorchTrainer_2025-05-26_00-18-41/driver_artifacts
Number of trials: 1/1 (1 PENDING)




In [13]:
dataset = FileDataset("/Users/daniilogorodnikov/dataset/app", 'sha256', 128)

100%|██████████| 671/671 [00:36<00:00, 18.62it/s]


In [14]:
torch.save(dataset, "dataset.pt")

In [None]:
torch.load("dataset.pt", weights_only=False)

Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x1050452b0>>
Traceback (most recent call last):
  File "/Users/daniilogorodnikov/PycharmProjects/Notus/.venv/lib/python3.13/site-packages/ipykernel/ipkernel.py", line 790, in _clean_thread_parent_frames
    active_threads = {thread.ident for thread in threading.enumerate()}
KeyboardInterrupt: 


In [13]:
os.path.join("tmp/", "dataset.pt")

'tmp/dataset.pt'