Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

SIGSEGV on TPUVM pod size >= v3-128 #8358

Closed
zcain117 opened this issue Jul 10, 2021 · 12 comments
Closed

SIGSEGV on TPUVM pod size >= v3-128 #8358

zcain117 opened this issue Jul 10, 2021 · 12 comments
Assignees
Labels
accelerator: tpu Tensor Processing Unit bug Something isn't working help wanted Open to be worked on

Comments

@zcain117
Copy link
Contributor

馃悰 Bug

Given a decently large input data size, PyTorch Lightning will crash on v3-128 (and presumably any larger TPU pods too). The same code works fine on a v3-32. If I make the input image size smaller, the code also works on v3-128

I am continuing to try to find more informative logs deep in the TPU stack about the segfault.

This crash does not happen with regular pytorch/xla. So I wanted to get a sense of:

  1. is pytorch-lightning expected to work on v3-128? Or maybe it's only intended to work on smaller TPU pod slices for now?
  2. is there any difference in the spawn behavior for v3-32 vs. v3-128 that would cause v3-128 to SIGSEGV but v3-32 works ok?
  3. can you think of any reason why v3-128 pod size + bigger data size = SIGSEGV but v3-128 pod size + smaller data size is able to train?

Repro:

below is a repro script. IMAGE_SIZE=256 results in SIGSEGV on v3-128 but IMAGE_SIZE=28 is able to train successfully on v3-128:

import os
import torch
from torch import nn
import torch.nn.functional as F
from torchvision import transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader, random_split
import pytorch_lightning as pl

import torch_xla.utils.utils as xu
import torch_xla.core.xla_model as xm


BATCH_SIZE = 6
IMAGE_SIZE = 256
NUM_CHANNELS = 3


class LitAutoEncoder(pl.LightningModule):

    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(NUM_CHANNELS*IMAGE_SIZE*IMAGE_SIZE, 64),
            nn.ReLU(),
            nn.Linear(64, 3)
        )
        self.decoder = nn.Sequential(
            nn.Linear(3, 64),
            nn.ReLU(),
            nn.Linear(64, NUM_CHANNELS*IMAGE_SIZE*IMAGE_SIZE)
        )

    def forward(self, x):
        # in lightning, forward defines the prediction/inference actions
        embedding = self.encoder(x)
        return embedding
    def training_step(self, batch, batch_idx):
        # training_step defined the train loop.
        # It is independent of forward
        print("in training step...")
        x, y = batch
        # Does not seem to affect the segfault.
        #myzeros = torch.zeros([256, 8192], device=self.device)
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)
        # Logging to TensorBoard by default
        self.log('train_loss', loss)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

train_loader = xu.SampleGenerator(
    data=(torch.zeros(BATCH_SIZE, NUM_CHANNELS, IMAGE_SIZE, IMAGE_SIZE),
    torch.zeros(BATCH_SIZE, dtype=torch.int64)),
    sample_count=1200000 // BATCH_SIZE // xm.xrt_world_size())
val_loader = xu.SampleGenerator(
    data=(torch.zeros(BATCH_SIZE, NUM_CHANNELS, IMAGE_SIZE, IMAGE_SIZE),
    torch.zeros(BATCH_SIZE, dtype=torch.int64)),
    sample_count=50000 // BATCH_SIZE // xm.xrt_world_size())

autoencoder = LitAutoEncoder()
trainer = pl.Trainer(tpu_cores=8)
trainer.fit(autoencoder, train_loader)

My repro steps:

  1. Create TPUVM pod slice:
gcloud alpha compute tpus tpu-vm create zcain-v3-128 --zone=europe-west4-a --project=tpu-pytorch --accelerator-type v3-128 --version=v2-alpha --metadata startup-script='#! /bin/bash
cd /usr/share/
git clone https://github.com/PyTorchLightning/pytorch-lightning.git
cd pytorch-lightning
pip3 install .
cd ..
git clone https://github.com/zcain117/taming-transformers-tpu.git -b debug-print-statements
cd taming-transformers-tpu/
pip3 install -r requirements.txt
EOF'
  1. SSH into TPUVM pod: gcloud alpha compute tpus tpu-vm ssh zcain-v3-128 --zone europe-west4-a --project=tpu-pytorch

  2. gcloud compute config-ssh

  3. python3 -m torch_xla.distributed.xla_dist --tpu=zcain-v3-128 -- python3 /usr/share/taming-transformers-tpu/repro.py

@zcain117 zcain117 added bug Something isn't working help wanted Open to be worked on labels Jul 10, 2021
@tgisaturday
Copy link

I've been testing around with
tpu_spawn.py here, especially with time.sleep() but wasn't the main cause of the problem.

https://github.com/tgisaturday/pytorch-lightning/blob/master/pytorch_lightning/plugins/training_type/tpu_spawn.py#L173

If lightning shows different behaviors according to input image size, maybe is this data_loader related issue?

@kaushikb11
Copy link
Contributor

is pytorch-lightning expected to work on v3-128? Or maybe it's only intended to work on smaller TPU pod slices for now?
is there any difference in the spawn behavior for v3-32 vs. v3-128 that would cause v3-128 to SIGSEGV but v3-32 works ok?

There is no behavioral difference on the Lightning side for training on v3-32 or v3-128.

can you think of any reason why v3-128 pod size + bigger data size = SIGSEGV but v3-128 pod size + smaller data size is able to train?

Need to dig more into it, to find out the cause. Assigning myself to the issue.

@kaushikb11 kaushikb11 self-assigned this Jul 12, 2021
@tgisaturday
Copy link

@zcain117 @kaushikb11
Update:
I鈥檝e refactored current codebase and removed all the unnecessary garbage codes.
https://github.com/tgisaturday/dalle-lightning
Now the model runs much faster and without long compilation time.
tgisaturday/dalle-lightning#1 (comment)

Here鈥檚 what I鈥檝e found:

  • Logging image results using lightning logger slows down the process and sometimes cause SEGFAULT
  • Using third-party libs like einops also brings unwanted slowdown.
  • Using the latest script with python3 -m torch_xla.distributed.xla_dist --tpu=tpu-vm-pod-256 --restart-tpuvm-pod-server -- python3 dalle-lightning/train_vae.py --use_tpus --fake_data --model vqvae still fails, but by adding --img_size 32 to use 32 x 32 for input, the code runs well on v3-256.
  • Seems like image size is the main cause of larger pod problem. I couldn鈥檛 find any difference between native torch-xla and lightning regarding data loading issue yet.

I鈥檒l attach latest error log on v3-256:
2021-07-14 09:36:10 10.164.0.60 [0] E0714 09:36:10.803151 779217 coredump_hook.cc:250] RAW: Remote crash gathering disabled for SIGTERM.
2021-07-14 09:36:10 10.164.0.60 [0] E0714 09:36:10.816310 779217 process_state.cc:771] RAW: Raising signal 15 with default behavior
2021-07-14 09:36:10 10.164.0.60 [0] Traceback (most recent call last):
2021-07-14 09:36:10 10.164.0.60 [0] File "dalle-lightning/train_vae.py", line 222, in
2021-07-14 09:36:10 10.164.0.60 [0] trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)
2021-07-14 09:36:10 10.164.0.60 [0] File "/home/taehoon.kim/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 522, in fit
2021-07-14 09:36:10 10.164.0.60 [0] self._run(model)
2021-07-14 09:36:10 10.164.0.60 [0] File "/home/taehoon.kim/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 883, in _run
2021-07-14 09:36:10 10.164.0.60 [0] self._dispatch()
2021-07-14 09:36:10 10.164.0.60 [0] File "/home/taehoon.kim/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 927, in _dispatch
2021-07-14 09:36:10 10.164.0.60 [0] self.accelerator.start_training(self)
2021-07-14 09:36:10 10.164.0.60 [0] File "/home/taehoon.kim/.local/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 97, in start_training
2021-07-14 09:36:10 10.164.0.60 [0] self.training_type_plugin.start_training(trainer)
2021-07-14 09:36:10 10.164.0.60 [0] File "/home/taehoon.kim/.local/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/tpu_spawn.py", line 267, in start_training
2021-07-14 09:36:10 10.164.0.60 [0] xmp.spawn(self.new_process, **self.xmp_spawn_kwargs)
2021-07-14 09:36:10 10.164.0.60 [0] File "/usr/local/lib/python3.8/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 388, in spawn
2021-07-14 09:36:10 10.164.0.60 [0] return torch.multiprocessing.start_processes(
2021-07-14 09:36:10 10.164.0.60 [0] File "/usr/local/lib/python3.8/dist-packages/torch/multiprocessing/spawn.py", line 188, in start_processes
2021-07-14 09:36:10 10.164.0.60 [0] while not context.join():
2021-07-14 09:36:10 10.164.0.60 [0] File "/usr/local/lib/python3.8/dist-packages/torch/multiprocessing/spawn.py", line 130, in join
2021-07-14 09:36:10 10.164.0.60 [0] raise ProcessExitedException(
2021-07-14 09:36:10 10.164.0.60 [0] torch.multiprocessing.spawn.ProcessExitedException: process 0 terminated with signal SIGSEGV

Any ideas or progress?

@kaushikb11
Copy link
Contributor

kaushikb11 commented Jul 14, 2021

The action item here would be to test it with an XLA script and test if it raises the same issue.

@tgisaturday
Copy link

@kaushikb11 I've tested this script on v3-256 and it works with native torch-xla.
https://github.com/pytorch/xla/blob/master/test/test_train_mp_imagenet.py

Here's steps that I took:

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \
--zone ${ZONE} --project ${PROJECT_ID} --worker=all \
  --command "git clone --recursive https://github.com/pytorch/xla.git"

python3 -m torch_xla.distributed.xla_dist --tpu=tpu-vm-pod-256 --restart-tpuvm-pod-server -- python3 xla/test/test_train_mp_imagenet.py --fake_data --model=resnet50 --num_epochs=1

Lightning seems to fail during new_process.

@tgisaturday
Copy link

tgisaturday commented Jul 16, 2021

@zcain117 @kaushikb11 Fixed. Replacing xla.sample generator with native pytorch dataset class and using LightningDataModule solves memory OOM happening with large pod + large image input size.

Here are examples for proper DataModules with fake_data generation option.
https://github.com/tgisaturday/dalle-lightning/blob/29d9bc153d81afae510ef5bf6c2035ddca14ae6c/pl_dalle/loader.py#L37

@awaelchli awaelchli added the accelerator: tpu Tensor Processing Unit label Jul 16, 2021
@zcain117
Copy link
Contributor Author

Great find! Glad it's working now

Do you think it was an OOM that was causing the SIGSEGV? I am curious why there is an OOM for larger TPU sizes.

The torch_xla SampleGenerator is here: https://github.com/pytorch/xla/blob/master/torch_xla/utils/utils.py#L44 You can see each SampleGenerator stores data, which in this case is a tuple of 2 large tensors. Note that torchvision.datasets.fakedata does not store any data. It generates a new sample for each call to getitem. Also note that torchvision.datasets.fakedata will actually give a new random image for each call whereas torch_xla SampleGenerator will just return the same static tensors over and over.

With native torch_xla, you can see here that we wait to create the SampleGenerator until we're inside the train_imagenet method and at that point in the code, each bundle of 8 processes would already be on its own machine. So in theory with native torch_xla, moving from v3-32 (which would have 4 machines, each. driving 8 cores) to a v3-128 (which would have 16 machines, each driving 8 cores) would not have any different memory behavior.

Is there some step in PyTorch Lightning setup where all the SampleGenerators would be on the same machine?

@tgisaturday
Copy link

Great find! Glad it's working now

Do you think it was an OOM that was causing the SIGSEGV? I am curious why there is an OOM for larger TPU sizes.

The torch_xla SampleGenerator is here: https://github.com/pytorch/xla/blob/master/torch_xla/utils/utils.py#L44 You can see each SampleGenerator stores data, which in this case is a tuple of 2 large tensors. Note that torchvision.datasets.fakedata does not store any data. It generates a new sample for each call to getitem. Also note that torchvision.datasets.fakedata will actually give a new random image for each call whereas torch_xla SampleGenerator will just return the same static tensors over and over.

With native torch_xla, you can see here that we wait to create the SampleGenerator until we're inside the train_imagenet method and at that point in the code, each bundle of 8 processes would already be on its own machine. So in theory with native torch_xla, moving from v3-32 (which would have 4 machines, each. driving 8 cores) to a v3-128 (which would have 16 machines, each driving 8 cores) would not have any different memory behavior.

Is there some step in PyTorch Lightning setup where all the SampleGenerators would be on the same machine?

Maybe related to ddp sampler? Since lightning is supposed to autometically add DDP sampler for multi-tpu training and somehow it duplicates the data.

@kaushikb11
Copy link
Contributor

@tgisaturday Could you add the recent updates to the Github issue?

@tgisaturday
Copy link

@kaushikb11 I'm testing out few more things. I'll leave an update when they are finished.

@kaushikb11
Copy link
Contributor

@tgisaturday We could close this issue, as we figured it was a DataLoader issue. We could create a separate issue for the Custom logging issue.

@tgisaturday
Copy link

@tgisaturday We could close this issue, as we figured it was a DataLoader issue. We could create a separate issue for the Custom logging issue.

I'm also testing text data loading. I'll open another issue if it seems to be lightning problem.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
accelerator: tpu Tensor Processing Unit bug Something isn't working help wanted Open to be worked on
Projects
None yet
Development

No branches or pull requests

4 participants