Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Running ray with pytorch lightning in slurm job causes falure with error "ValueError: signal only works in main thread" #3651

Closed
rashindrie opened this issue Sep 25, 2020 · 16 comments
Labels
3rd party Related to a 3rd-party bug Something isn't working help wanted Open to be worked on waiting on author Waiting on user action, correction, or update

Comments

@rashindrie
Copy link

🐛 Bug

I followed the instructions at https://docs.ray.io/en/master/tune/tutorials/tune-pytorch-lightning.html to integrate ray with pytorch lightning. However, when I submitted a slurm job to run the tuning I get the following error:
ValueError: signal only works in main thread

I submitted the same to ray project at ray/issues/10995 and I was suggested a hack to fix the issue.

Could we look for a way to disable the SLURM detection in pytorch lightning itself so that external parties do not have to hack its way around it?

To Reproduce

Steps to reproduce the behavior:

  1. Set up ray by following instructions at here.
  2. Submit a slurm job to run the tuning
  3. See error
ray.tune.error.TuneError: Trial raised an exception. Traceback:
ray::ImplicitFunc.train() (pid=4432, ip=172.26.92.190)
  File "/home/user/.local/lib/python3.7/site-packages/ray/tune/function_runner.py", line 227, in run
    self._entrypoint()
  File "/home/user/.local/lib/python3.7/site-packages/ray/tune/function_runner.py", line 290, in entrypoint
    self._status_reporter.get_checkpoint())
  File "/home/user/.local/lib/python3.7/site-packages/ray/tune/function_runner.py", line 497, in _trainable_func
    output = train_func(config)
  File "tune.py", line 261, in train_run
    trainer.fit(model)
  File "/home/user/.local/lib/python3.7/site-packages/pytorch_lightning/trainer/states.py", line 48, in wrapped_fn
    result = fn(self, *args, **kwargs)
  File "/home/user/.local/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1073, in fit
    results = self.accelerator_backend.train(model)
  File "/home/user/.local/lib/python3.7/site-packages/pytorch_lightning/accelerators/gpu_backend.py", line 51, in train
    results = self.trainer.run_pretrain_routine(model)
  File "/home/user/.local/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1184, in run_pretrain_routine
    self.register_slurm_signal_handlers()
  File "/home/user/.local/lib/python3.7/site-packages/pytorch_lightning/trainer/training_io.py", line 240, in register_slurm_signal_handlers
    signal.signal(signal.SIGUSR1, self.sig_handler)
  File "/usr/local/easybuild-2019/easybuild/software/mpi/gcc/8.3.0/openmpi/3.1.4/python/3.7.4/lib/python3.7/signal.py", line 47, in signal
    handler = _signal.signal(_enum_to_int(signalnum), _enum_to_int(handler))
ValueError: signal only works in main thread

Code sample

slurm script

#!/bin/bash

#SBATCH --gres=gpu:4
#SBATCH --nodes=1
#SBATCH --ntasks=1


# load necessary modules #
module purge
module load scikit-learn/0.21.3-python-3.7.4
module load python/3.7.4

python -u tune.py &> "tune_output.txt"

tune.py

class CustomDataSet(Dataset):
    def __init__(self, csv_file, img_dir, transform):
        self.data = pd.read_csv(csv_file)
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_loc = join(self.img_dir, self.data.name[idx])
        image = Image.open(img_loc).convert("RGB")
        tensor_image = self.transform(image)
        label = self.data.label[idx]

        return tensor_image, label


class ExperimentAE(pl.LightningModule):
    def __init__(self,
                 params: dict,
                 **kwargs) -> None:
        super(ExperimentAE, self).__init__()

        self.params = params
        self.model = SegNet()

    def forward(self, z):
        return self.decoder(z)

    def _run_step(self, x):
        x_hat, z = self.model(x)
        return x_hat, z

    def generate(self, x):
        return self._run_step(x)[0]

    def step(self, batch, batch_idx):
        x, y = batch
        self.curr_device = x.device
        x_hat, z = self._run_step(x)
        loss = F.mse_loss(x_hat, x, reduction='mean')
        return {"loss": loss}

    def training_step(self, batch, batch_idx):
        train_loss = self.step(batch, batch_idx)
        logs = {"ptl/train_loss": train_loss}
        return {"loss": train_loss, "log": logs}

    def validation_step(self, batch, batch_idx):
        val_loss = self.step(batch, batch_idx)
        return {"val_loss": val_loss["loss"]}

    def validation_end(self, outputs):
        avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
        tensorboard_logs = {'val_loss': avg_loss}
        return {'val_loss': avg_loss, 'log': tensorboard_logs}
        
    def configure_optimizers(self):
        optims = []

        optimizer = optim.Adam(self.model.parameters(),
                               lr=self.params['lr'],
                               # weight_decay=self.params['weight_decay']
                               )
       optims.append(optimizer)
       return optims

    def train_dataloader(self):
        transform = self.data_transforms(train=True)
        img_dir = "/data/brca/"
        train_csv = "/original/train.csv"

        dataset = CustomDataSet(train_csv, img_dir, transform=transform)
        loader = DataLoader(dataset, shuffle=True, batch_size=self.params['batch_size'], num_workers=4)

        self.num_train_imgs = dataset.__len__()

        return loader

    def val_dataloader(self):
        transform = self.data_transforms(train=False)

        img_dir = "/data/brca/"
        val_csv = "/original/validation.csv"

        val_dataset = CustomDataSet(val_csv, img_dir, transform=transform)
        self.valid_dataloader = DataLoader(val_dataset, shuffle=False, batch_size=self.params['batch_size'],
                                           num_workers=4)

        self.num_val_imgs = self.valid_dataloader.__len__()

        return self.valid_dataloader

    def test_dataloader(self):
        pass

    def data_transforms(self, train=True):
        if train:
            transform_train = transforms.Compose([
                transforms.ToTensor(),
            ])
            return transform_train

        else:
            transform_val = transforms.Compose([
                transforms.ToTensor(),
            ])
            return transform_val


def train_run(config_params, num_epochs=10, num_gpus=1):
    model = ExperimentAE(params=config_params)

    trainer = Trainer(
        max_epochs=num_epochs,
        gpus=num_gpus,
        logger=TensorBoardLogger(
            save_dir=tune.get_trial_dir(), name="", version="."),
        progress_bar_refresh_rate=0,
        callbacks=[
            TuneReportCallback(
                {
                    "loss": "val_loss",
                },
                on="validation_end")
        ])

    trainer.fit(model)


def tune_run(num_samples=20, num_epochs=10, gpus_per_trial=1):
    tune_config = {
        "lr": tune.loguniform(1e-4, 1e-5, 1e-3),
        "batch_size": tune.choice([2, 4, 8]),
        # 'weight_decay': 0.0,
        'scheduler_gamma': tune.choice([1, 0.95, 0.9, 0.85, 0.6]),
    }

    scheduler = ASHAScheduler(
        metric="loss",
        mode="min",
        max_t=10,
        grace_period=1,
        reduction_factor=2)

    reporter = CLIReporter(
        parameter_columns=["lr", "batch_size"],
        metric_columns=["loss", "training_iteration"]
    )
    tune.run(
        partial(
            train_run,
            num_epochs=num_epochs,
            num_gpus=gpus_per_trial
        ),
        resources_per_trial={
            "cpu": 1,
            "gpu": gpus_per_trial
        },
        config=tune_config,
        num_samples=num_samples,
        scheduler=scheduler,
        progress_reporter=reporter,
        name="tune_segnet_v1"
    )


if __name__ == "__main__":
    tune_run(num_samples=20, num_epochs=10, gpus_per_trial=1)

Expected behavior

The ray tune program to run properly in a slurm environment.

Environment

  • CUDA:
    • GPU:
      • Tesla V100-SXM2-16GB
      • Tesla V100-SXM2-16GB
    • available: True
    • version: 10.2
  • Packages:
    • numpy: 1.17.3
    • pyTorch_debug: False
    • pyTorch_version: 1.6.0
    • pytorch-lightning: 0.9.0
    • tqdm: 4.46.0
      • ray: 0.8.7
      • tensorflow: 2.1.0
  • System:
    • OS: Linux
    • architecture:
      • 64bit
      • ELF
    • processor: x86_64
    • python: 3.7.4
    • version: Proposal for help #1 SMP Tue May 26 15:05:43 EDT 2020

Additional context

@rashindrie rashindrie added bug Something isn't working help wanted Open to be worked on labels Sep 25, 2020
@github-actions
Copy link
Contributor

Hi! thanks for your contribution!, great first issue!

@edenlightning edenlightning added the 3rd party Related to a 3rd-party label Oct 2, 2020
@edenlightning edenlightning modified the milestones: 0.9.x, 1.0 Oct 2, 2020
@williamFalcon williamFalcon modified the milestones: 1.0, 1.1 Oct 5, 2020
@edenlightning
Copy link
Contributor

hey @rashindrie! Would you mind upgrading to 1.0.2 to see if the issue persists?

@edenlightning edenlightning added the waiting on author Waiting on user action, correction, or update label Oct 19, 2020
@edenlightning edenlightning modified the milestones: 1.1, 1.0.3 Oct 19, 2020
@rashindrie
Copy link
Author

Hi @edenlightning

Sure, will try that.
Give me some time to get back to you. Currently overloaded by some other tasks.

@edenlightning edenlightning modified the milestones: 1.0.x, 1.0.7 Nov 10, 2020
@Borda Borda modified the milestones: 1.0.7, 1.0.x Nov 11, 2020
@edenlightning edenlightning removed this from the 1.0.x milestone Nov 13, 2020
@edenlightning
Copy link
Contributor

Closing this for now, feel free to reopen!

@yorickvanzweeden
Copy link

@edenlightning I have the same behaviour. The 'hack' fixes it.

I am running a Python script on a SLURM cluster.

Environment

CUDA:
    GPU:
        GeForce RTX 2080 TI
        GeForce RTX 2080 TI
    available: True
    version: 10.2
Packages:
    numpy: 1.19.2
    pyTorch_debug: False
    pyTorch_version: 1.7.1
    pytorch-lightning: 1.1.3
    tqdm: 4.55.1
        ray: 1.1.0
System:
    OS: Linux
    architecture:
        64bit
        ELF
    processor: x86_64
    python: 3.7.9

@import-antigravity
Copy link

Hi, I'm having this issue as well. I don't like that you have to basically circumvent the normal functionality of the code in order to get it to work...

@import-antigravity
Copy link

This is still an issue in 1.2.1

@import-antigravity
Copy link

Update: calling ray.init() causes this error even with the "hack", effectively making distributed computing impossible. You can only use ray tune on a single node

@jacobdanovitch
Copy link

Can confirm setting os.environ['SLURM_JOB_NAME'] = 'bash' works well as suggested in the issue on the Ray repo (and can be seen here). Could be worth noting in the Slurm Cluster part of the documentation.

@import-antigravity
Copy link

Can confirm setting os.environ['SLURM_JOB_NAME'] = 'bash' works well as suggested in the issue on the Ray repo (and can be seen here). Could be worth noting in the Slurm Cluster part of the documentation.

Does it work correctly when you’re running a distributed job across multiple nodes?

@jacobdanovitch
Copy link

Does it work correctly when you’re running a distributed job across multiple nodes?

I've never been able to get Ray properly working on multiple nodes on my SLURM cluster, nothing to do with lightning. The init script they provide fails 9/10 times when trying to start workers unfortunately, not sure if it's to do with Ray or the cluster itself.

@import-antigravity
Copy link

I've never been able to get Ray properly working on multiple nodes on my SLURM cluster, nothing to do with lightning. The init script they provide fails 9/10 times when trying to start workers unfortunately, not sure if it's to do with Ray or the cluster itself.

I mean maybe it's a problem with your code 🤣

The hack works for me as well running on a single node but not on multiple nodes. Also I'll say again that I don't think that the officially supported solution to this problem should be to change the job name to circumvent PTL's slurm detection

@jacobdanovitch
Copy link

I mean maybe it's a problem with your code 🤣

Not to get off topic but it's not actually my code, it's just the sbatch script they provide. A raylet exits unexpectedly, but that's before anything Lightning-related is invoked so probably unrelated.

I agree it would make sense to have a way to interface with the connectors a little more directly.

@import-antigravity
Copy link

import-antigravity commented Mar 11, 2021 via email

@bw4sz
Copy link

bw4sz commented Apr 6, 2021

I ran into this problem on dask + SLURM. The hack described above works if it is run on each worker process. I also needed to set workers to 0 for data loaders. I hope this helps the next person.

@henrique
Copy link

Is still happens on master on SLURM with ray, and probably any processing spawning library.
perhaps it would be nice to add an additional check like: threading.current_thread() is threading.main_thread() before setting the signal?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
3rd party Related to a 3rd-party bug Something isn't working help wanted Open to be worked on waiting on author Waiting on user action, correction, or update
Projects
None yet
Development

No branches or pull requests

9 participants