Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error while saving the model in multi gpu scenario using torch.save #4120

Closed
shrinath-suresh opened this issue Oct 13, 2020 · 4 comments 路 Fixed by #4127
Closed

Error while saving the model in multi gpu scenario using torch.save #4120

shrinath-suresh opened this issue Oct 13, 2020 · 4 comments 路 Fixed by #4127
Assignees
Labels
bug Something isn't working help wanted Open to be worked on

Comments

@shrinath-suresh
Copy link

馃悰 Bug

In pytorch lightning - 0.9, while saving the model using torch.save in multi gpu scenario (DDP) results in error.

import pytorch_lightning as pl
from sklearn.metrics import accuracy_score
from sklearn.datasets import load_iris
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split, TensorDataset


class IrisClassification(pl.LightningModule):
    def __init__(self):
        super(IrisClassification, self).__init__()
        self.fc1 = nn.Linear(4, 10)
        self.fc2 = nn.Linear(10, 10)
        self.fc3 = nn.Linear(10, 3)
        self.cross_entropy_loss = nn.CrossEntropyLoss()

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = F.softmax(x, dim=1)
        return x

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.02)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self.forward(x)
        loss = self.cross_entropy_loss(logits, y)

        logs = {"loss": loss}
        return {"loss": loss, "log": logs}

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self.forward(x)
        loss = F.cross_entropy(logits, y)
        return {"val_loss": loss}

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
        return {"val_loss": avg_loss}

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self.forward(x)
        loss = F.cross_entropy(logits, y)
        a, y_hat = torch.max(logits, dim=1)
        test_acc = accuracy_score(y_hat.cpu(), y.cpu())
        return {"test_loss": loss, "test_acc": torch.tensor(test_acc)}

    def test_epoch_end(self, outputs):
        avg_loss = torch.stack([x["test_loss"] for x in outputs]).mean()
        avg_test_acc = torch.stack([x["test_acc"] for x in outputs]).mean()
        logs = {"test_loss": avg_loss, "test_acc": avg_test_acc}
        return {
            "avg_test_loss": avg_loss,
            "avg_test_acc": avg_test_acc,
            "log": logs,
            "progress_bar": logs,
        }


if __name__ == "__main__":
    model = IrisClassification()

    trainer = pl.Trainer(max_epochs=5, gpus=2, distributed_backend="ddp")
    #trainer = pl.Trainer(max_epochs=5)

    iris = load_iris()
    df = iris.data
    target = iris["target"]

    data = torch.Tensor(df).float()
    labels = torch.Tensor(target).long()

    data_set = TensorDataset(data, labels)
    train_set, val_set = random_split(data_set, [130, 20])
    train_set, test_set = random_split(train_set, [110, 20])
    train_loader = DataLoader(dataset=train_set, batch_size=8, num_workers=32)
    test_loader = DataLoader(dataset=test_set, batch_size=8, num_workers=32)
    val_loader = DataLoader(dataset=val_set, batch_size=8, num_workers=32)

    trainer.fit(model, train_loader, val_loader)
    trainer.test(test_dataloaders=test_loader)
    torch.save(trainer.model, "model.pth")

To Reproduce

Run the script in multi gpu machine to reproduce the error.

    trainer = pl.Trainer(max_epochs=5, gpus=2, distributed_backend="ddp")

However, running the same script in cpu works fine (model is getting saved).

trainer = pl.Trainer(max_epochs=5)

Expected behavior

When running with multiple gpus, model file should be dumped as model.pth

Environment

  • CUDA:
    • GPU:
      • Tesla K80
      • Tesla K80
      • Tesla K80
      • Tesla K80
      • Tesla K80
      • Tesla K80
      • Tesla K80
      • Tesla K80
    • available: True
    • version: 10.2
  • Packages:
    • numpy: 1.19.1
    • pyTorch_debug: False
    • pyTorch_version: 1.6.0
    • pytorch-lightning: 0.9.0
    • tqdm: 4.50.2
  • System:
    • OS: Linux
    • architecture:
      • 64bit
      • ELF
    • processor: x86_64
    • python: 3.8.5
    • version: 25~18.04.1-Ubuntu SMP Fri Sep 11 12:03:04 UTC 2020

Additional context

Stack Trace:

Traceback (most recent call last):
  File "/home/ubuntu/ddp_test/iris.py", line 88, in <module>
    torch.save(trainer.model, "iris_gpu.pt")
  File "/home/ubuntu/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages/torch/serialization.py", line 364, in save
    _save(obj, opened_zipfile, pickle_module, pickle_protocol)
  File "/home/ubuntu/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages/torch/serialization.py", line 466, in _save
    pickler.dump(obj)
  File "/home/ubuntu/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 452, in __getstate__
    self._check_default_group()
  File "/home/ubuntu/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 475, in _check_default_group
    raise RuntimeError("DDP Pickling/Unpickling are only supported "
RuntimeError: DDP Pickling/Unpickling are only supported when using DDP with the default process group. That is, when you have called init_process_group and have not passed process_group argument to DDP constructor

The issue is very specific to ddp environment as the script works fine with cpu

Tried the same classification example with pytorch 1.6 multi gpu scenario, the model is getting saved successfully.

import os

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from torch.nn.parallel import DistributedDataParallel as DDP


def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'

    # initialize the process group
    dist.init_process_group("gloo", rank=rank, world_size=world_size)


def cleanup():
    dist.destroy_process_group()


class IrisClassification(nn.Module):
    def __init__(self, dev0, dev1):
        super(IrisClassification, self).__init__()
        self.fc1 = nn.Linear(4, 10).to(dev0)
        self.fc2 = nn.Linear(10, 3).to(dev1)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.softmax(x, dim=1)
        return x


def run_demo(demo_fn, world_size):
    mp.spawn(demo_fn,
             args=(world_size,),
             nprocs=world_size,
             join=True)


def demo_model_parallel(rank, world_size):
    # Dataset
    setup(rank, world_size)

    features, labels = load_iris(return_X_y=True)
    features_train, features_test, labels_train, labels_test = train_test_split(features, labels, random_state=42,
                                                                                shuffle=True)

    x_train, y_train = torch.from_numpy(features_train).float(), torch.from_numpy(labels_train).long()

    y_train = y_train.to(rank)
    x_train = x_train.to(rank)

    # create local model
    dev0 = rank * 2
    dev1 = rank * 2 + 1
    model = IrisClassification(dev0, dev1).to(rank)
    # construct DDP model
    ddp_model = DDP(model, device_ids=[rank])
    # define loss function and optimizer
    loss_fn = nn.CrossEntropyLoss()

    optimizer = optim.Adam(ddp_model.parameters(), lr=0.001)

    # forward pass
    outputs = ddp_model(x_train)
    # backward pass
    loss_fn(outputs, y_train).backward()
    # update parameters
    optimizer.step()

    if rank == 0:
        print("Saving Model")
        torch.save(model, "model.pth")

    cleanup()


if __name__ == "__main__":
    n_gpus = torch.cuda.device_count()
    if n_gpus < 8:
        print(f"Requires at least 8 GPUs to run, but got {n_gpus}.")
    else:
        run_demo(demo_model_parallel, 2)

@shrinath-suresh shrinath-suresh added bug Something isn't working help wanted Open to be worked on labels Oct 13, 2020
@SeanNaren SeanNaren self-assigned this Oct 13, 2020
@williamFalcon
Copy link
Contributor

this is on 0.9. mind upgrading to 1.0?

@shrinath-suresh
Copy link
Author

In 1.0, #4114 - this error is getting reproduced in both cpu and multi gpu version.

Traceback (most recent call last):                                                                                                                                                                          
  File "iris.py", line 88, in <module>
    torch.save(trainer.model, "iris_gpu.pt")
  File "/home/ubuntu/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages/torch/serialization.py", line 364, in save
    _save(obj, opened_zipfile, pickle_module, pickle_protocol)
  File "/home/ubuntu/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages/torch/serialization.py", line 466, in _save
    pickler.dump(obj)
TypeError: 'NoneType' object is not callable

@SeanNaren
Copy link
Contributor

SeanNaren commented Oct 13, 2020

Instead of torch.save(...) could you replace with trainer.save_checkpoint('model.pth')?

https://pytorch-lightning.readthedocs.io/en/latest/weights_loading.html#manual-saving

I'll continue to investigate why torch.save doesn't work as intended, but I'm getting more verbose output:

Traceback (most recent call last):
  File "iris.py", line 88, in <module>
    torch.save(trainer.model, 'test.pth')
  File "/opt/conda/lib/python3.7/site-packages/torch/serialization.py", line 364, in save
    _save(obj, opened_zipfile, pickle_module, pickle_protocol)
  File "/opt/conda/lib/python3.7/site-packages/torch/serialization.py", line 466, in _save
    pickler.dump(obj)
TypeError: can't pickle _thread.lock objects

This is also the same track as what @rohitgr7 is on in #4114

@shrinath-suresh
Copy link
Author

shrinath-suresh commented Oct 13, 2020

@SeanNaren Thank you very much for the alternative suggestion. We are using mlflow.pytorch library to log the model into mlflow - https://www.mlflow.org/docs/latest/python_api/mlflow.pytorch.html . Unfortunatly, the libary dumps the entire model into mlflow using torch.save (and there no option to save the state dict or checkpoint).

In ddp - multi gpu scenario, the same error - TypeError: can't pickle _thread.lock objects is getting reproduced for us.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Open to be worked on
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants