In [1]:
from lightning import LightningModule, Trainer
import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
import ray
from ray import tune, air, init
from ray.tune.integration.pytorch_lightning import TuneReportCheckpointCallback
from ray.air.integrations.wandb import WandbLoggerCallback
import os
import dotenv

In [2]:
class SimpleNN(LightningModule):
    def __init__(self, input_size, hidden_size, learning_rate):
        super().__init__()
        self.save_hyperparameters()
        self.model = nn.Sequential(
            nn.Linear(input_size, hidden_size), nn.ReLU(), nn.Linear(hidden_size, 1)
        )
        self.learning_rate = learning_rate
        self.train_loss = 0
        self.val_loss = 0
        self.train_predictions = []
        self.train_targets = []
        self.val_predictions = []
        self.val_targets = []
        self.rmse_train = 0

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = nn.MSELoss()(y_hat, y)
        self.train_loss = loss

        self.train_predictions.append(y_hat.detach())
        self.train_targets.append(y.detach())

        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = nn.MSELoss()(y_hat, y)
        self.val_loss = loss

        self.val_predictions.append(y_hat.detach())
        self.val_targets.append(y.detach())

        self.log("val_loss", loss)

        return {"val_loss": loss}

    def on_train_epoch_end(self):
        all_preds_train = torch.cat(self.train_predictions)
        all_targets_train = torch.cat(self.train_targets)

        self.rmse_train = torch.sqrt(
            nn.functional.mse_loss(all_preds_train, all_targets_train)
        )
        # if self.i == 9:
        #     print(f"rmse_train in train {self.i}: {rmse_train}", flush=True)
        self.train_predictions.clear()
        self.train_targets.clear()

    def on_validation_epoch_end(self):
        all_preds_val = torch.cat(self.val_predictions)
        all_targets_val = torch.cat(self.val_targets)

        rmse_val = torch.sqrt(nn.functional.mse_loss(all_preds_val, all_targets_val))
        self.log("rmse_val", rmse_val)

        self.log("rmse_train", self.rmse_train)

        self.val_predictions.clear()
        self.val_targets.clear()

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.learning_rate)


def train_nn(config, train_data, val_data):
    model = SimpleNN(
        input_size=10, hidden_size=config["hidden_size"], learning_rate=config["lr"]
    )

    # Generate dummy data

    train_loader = DataLoader(train_data, batch_size=800)
    val_loader = DataLoader(val_data, batch_size=800)

    checkpoint_dir = ray.train.get_context().get_trial_dir()
    os.makedirs(checkpoint_dir, exist_ok=True)

    trainer = Trainer(
        max_epochs=10,
        num_sanity_val_steps=0,
        check_val_every_n_epoch=1,
        val_check_interval=1,
        callbacks=[  # ModelCheckpoint(
            #     monitor="rmse_val",
            #     mode="min",
            #     save_top_k=1,
            #     save_weights_only=False,
            #     dirpath=checkpoint_dir,
            #     filename="{epoch}-{val_rmse:.2f}",
            # ),
            TuneReportCheckpointCallback(
                metrics={
                    "loss": "val_loss",
                    "rmse_train": "rmse_train",
                    "rmse_val": "rmse_val",
                },
                filename="checkpoint",
                # train_end:
                # rmse_train & val_loss missing
                # val_rmse correct, train_rmse wrong
                # validation_end:
                # rmse_train & val_loss missing
                # val_rmse correct, train_rmse wrong
                #
                # validation_epoch_end:
                # val_loss, rmse_train & rmse_val missing
                # val_rmse & train_rmse wrong
                on="validation_end",
                # ! added
                save_checkpoints=True,
            )
        ],
        # deterministic=True,
    )

    trainer.fit(model, train_loader, val_loader)

In [3]:
config = {"lr": tune.loguniform(1e-4, 1e-1), "hidden_size": tune.choice([32, 64, 128])}

context = init(
    address="local",
    include_dashboard=False,
    ignore_reinit_error=True,
)

torch.manual_seed(42)
X = torch.randn(1000, 10)
y = torch.sum(X, dim=1, keepdim=True)

train_data = TensorDataset(X[:800], y[:800])
val_data = TensorDataset(X[800:], y[800:])

dotenv.load_dotenv()
api_key = os.getenv("WANDB_API_KEY")
entity = os.getenv("WANDB_ENTITY")
callback = [WandbLoggerCallback(api_key=api_key, entity=entity, project="ray_debug")]
tuner = tune.Tuner(
    tune.with_parameters(train_nn, train_data=train_data, val_data=val_data),
    tune_config=tune.TuneConfig(metric="rmse_val", mode="min", num_samples=2),
    param_space=config,
    run_config=air.RunConfig(
        callbacks=callback,
    ),
)

results = tuner.fit()

0,1
Current time:,2024-08-23 21:56:39
Running for:,00:00:13.90
Memory:,10.4/16.0 GiB

Trial name,status,loc,hidden_size,lr,iter,total time (s),loss,rmse_train,rmse_val
train_nn_c7318_00000,TERMINATED,127.0.0.1:14560,128,0.00115049,10,1.42068,8.79886,3.01525,2.96629
train_nn_c7318_00001,TERMINATED,127.0.0.1:14561,32,0.02213,10,1.39823,0.384783,1.02456,0.620309


[36m(train_nn pid=14561)[0m GPU available: False, used: False
[36m(train_nn pid=14561)[0m TPU available: False, using: 0 TPU cores
[36m(train_nn pid=14561)[0m HPU available: False, using: 0 HPUs
[36m(train_nn pid=14561)[0m `Trainer(val_check_interval=1)` was configured so validation will run after every batch.
[36m(train_nn pid=14561)[0m Missing logger folder: /private/tmp/ray/session_2024-08-23_21-56-19_562368_14499/artifacts/2024-08-23_21-56-25/train_nn_2024-08-23_21-56-25/working_dirs/train_nn_c7318_00001_1_hidden_size=32,lr=0.0221_2024-08-23_21-56-26/lightning_logs
[36m(train_nn pid=14561)[0m 
[36m(train_nn pid=14561)[0m   | Name  | Type       | Params | Mode 
[36m(train_nn pid=14561)[0m ---------------------------------------------
[36m(train_nn pid=14561)[0m 0 | model | Sequential | 385    | train
[36m(train_nn pid=14561)[0m ---------------------------------------------
[36m(train_nn pid=14561)[0m 385       Trainable params
[36m(train_nn pid=14561)[0m 0   

Epoch 0:   0%|          | 0/1 [00:00<?, ?it/s] 
Epoch 0: 100%|██████████| 1/1 [00:00<00:00,  1.93it/s, v_num=0]
Validation: |          | 0/? [00:00<?, ?it/s][A
Validation:   0%|          | 0/1 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/1 [00:00<?, ?it/s][A
Validation DataLoader 0: 100%|██████████| 1/1 [00:00<00:00, 56.56it/s][A


[36m(train_nn pid=14561)[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/Users/adamova/ray_results/train_nn_2024-08-23_21-56-25/train_nn_c7318_00001_1_hidden_size=32,lr=0.0221_2024-08-23_21-56-26/checkpoint_000000)


[36m(train_nn pid=14561)[0m 
Epoch 1: 100%|██████████| 1/1 [00:00<00:00, 75.60it/s, v_num=0]       [A
[36m(train_nn pid=14561)[0m 
Epoch 2: 100%|██████████| 1/1 [00:00<00:00, 110.83it/s, v_num=0]       [A
[36m(train_nn pid=14560)[0m 
[36m(train_nn pid=14561)[0m 
Epoch 3: 100%|██████████| 1/1 [00:00<00:00, 14.04it/s, v_num=0]        [A
[36m(train_nn pid=14560)[0m 
Epoch 2: 100%|██████████| 1/1 [00:00<00:00, 14.07it/s, v_num=0]        [A
Epoch 4:   0%|          | 0/1 [00:00<?, ?it/s, v_num=0]        
Epoch 3:   0%|          | 0/1 [00:00<?, ?it/s, v_num=0]        
[36m(train_nn pid=14561)[0m 
[36m(train_nn pid=14560)[0m 
[36m(train_nn pid=14561)[0m 
[36m(train_nn pid=14560)[0m 
Epoch 6: 100%|██████████| 1/1 [00:00<00:00, 11.53it/s, v_num=0]


[36m(train_nn pid=14561)[0m `Trainer.fit` stopped: `max_epochs=10` reached.
2024-08-23 21:56:39,933	INFO tune.py:1009 -- Wrote the latest version of all result files and experiment state to '/Users/adamova/ray_results/train_nn_2024-08-23_21-56-25' in 0.0137s.


[36m(train_nn pid=14561)[0m 
Epoch 9: 100%|██████████| 1/1 [00:00<00:00,  8.86it/s, v_num=0]        [A
[36m(train_nn pid=14560)[0m 
[36m(train_nn pid=14560)[0m 


[36m(_WandbLoggingActor pid=14564)[0m wandb: Currently logged in as: adamovanja (ritme). Use `wandb login --relogin` to force relogin
[36m(train_nn pid=14560)[0m GPU available: False, used: False
[36m(train_nn pid=14560)[0m TPU available: False, using: 0 TPU cores
[36m(train_nn pid=14560)[0m HPU available: False, using: 0 HPUs
[36m(train_nn pid=14560)[0m `Trainer(val_check_interval=1)` was configured so validation will run after every batch.
[36m(train_nn pid=14560)[0m Missing logger folder: /private/tmp/ray/session_2024-08-23_21-56-19_562368_14499/artifacts/2024-08-23_21-56-25/train_nn_2024-08-23_21-56-25/working_dirs/train_nn_c7318_00000_0_hidden_size=128,lr=0.0012_2024-08-23_21-56-26/lightning_logs
[36m(_WandbLoggingActor pid=14564)[0m wandb: wandb version 0.17.7 is available!  To upgrade, please run:
[36m(_WandbLoggingActor pid=14564)[0m wandb:  $ pip install wandb --upgrade
[36m(_WandbLoggingActor pid=14564)[0m wandb: Tracking run with wandb version 0.17.2
[36m(

In [4]:
best_result = results.get_best_result("rmse_val", "min", scope="all")
print(f"Best trial config: {best_result.config}")
best_rmse_train = best_result.metrics["rmse_train"]
best_rmse_val = best_result.metrics["rmse_val"]
print(f"Best trial final train rmse: {best_rmse_train}")
print(f"Best trial final validation rmse: {best_rmse_val}")

Best trial config: {'lr': 0.022130015984348436, 'hidden_size': 32}
Best trial final train rmse: 1.0245614051818848
Best trial final validation rmse: 0.6203088164329529


In [5]:
# checkpoint_dir = best_result.checkpoint.to_directory()
# best_result.checkpoint
checkpoint_dir = best_result.checkpoint.path
checkpoint_path = os.path.join(checkpoint_dir, "checkpoint")

checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
print(checkpoint["hyper_parameters"])

{'input_size': 10, 'hidden_size': 32, 'learning_rate': 0.022130015984348436}


In [6]:
model = SimpleNN.load_from_checkpoint(checkpoint_path)

# train
# model.eval()
# train_data = TensorDataset(X[:800], y[:800])
rmse_train_recalc = torch.sqrt(nn.functional.mse_loss(model(X[:800]), y[:800])).item()
# rmse_train_recalc = root_mean_squared_error(model(X[:800]).detach().numpy(), y[:800].detach().numpy())
print(rmse_train_recalc)

# val
# model.eval()
rmse_val_calc = torch.sqrt(nn.functional.mse_loss(model(X[800:]), y[800:])).item()
# rmse_val_calc = root_mean_squared_error(model(X[800:]).detach().numpy(), y[800:].detach().numpy())
print(rmse_val_calc)

0.6273286938667297
0.6203088164329529


In [7]:
# train_loader = DataLoader(train_data, batch_size=32)
# val_loader = DataLoader(val_data, batch_size=32)

# model.eval()

# train_preds = []
# train_targets = []
# for batch in train_loader:
#     x, y = batch
#     preds = model(x)
#     train_preds.append(preds)
#     train_targets.append(y)

# train_preds = torch.cat(train_preds)
# train_targets = torch.cat(train_targets)
# rmse_train_recalc = torch.sqrt(nn.functional.mse_loss(train_preds, train_targets)).item()
# print(rmse_train_recalc)
# val_preds = []
# val_targets = []
# for batch in val_loader:
#     x, y = batch
#     preds = model(x)
#     val_preds.append(preds)
#     val_targets.append(y)

# val_preds = torch.cat(val_preds)
# val_targets = torch.cat(val_targets)
# rmse_val_calc = torch.sqrt(nn.functional.mse_loss(val_preds, val_targets)).item()
# print(rmse_val_calc)