Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DeepSpeedPlugin cpu_checkpointing flag not forwarded to deepspeed correctly #10874

Closed
jona-0 opened this issue Dec 1, 2021 · 3 comments · Fixed by #10899
Closed

DeepSpeedPlugin cpu_checkpointing flag not forwarded to deepspeed correctly #10874

jona-0 opened this issue Dec 1, 2021 · 3 comments · Fixed by #10899
Labels
bug Something isn't working strategy: deepspeed
Milestone

Comments

@jona-0
Copy link
Contributor

jona-0 commented Dec 1, 2021

🐛 Bug

We expect when the cpu_checkpointing flag is set, GPU memory usage to be constant during the forward pass (as it offloads each layers activations to the CPU) see https://www.deepspeed.ai/docs/config-json/#activation-checkpointing but it does not do this.

I suspect this is due to a typo in DeepSpeedConfig – we set cpu_checkpointing but try to read checkpoint_in_cpu.

To Reproduce

Run once with --cpu_checkpointing, once with --checkpoint_in_cpu. Observe that cpu_checkpointing does not change the GPU memory usage, but checkpoint_in_cpu does

import numpy as np
import os,psutil
import deepspeed
import torch
import torch.nn.functional as F
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.plugins import DeepSpeedPlugin
from pytorch_lightning.utilities.seed import seed_everything
from torch.utils.data import DataLoader, Dataset
from torch.utils.checkpoint import checkpoint as torch_checkpoint
from deepspeed import checkpointing as deepspeed_checkpointing
from deepspeed.ops.adam import DeepSpeedCPUAdam
import os
import argparse

deepspeed_checkpoint = deepspeed.checkpointing.checkpoint
seed_everything(42)


class MemoryMonitor:
    def __init__(self):
        self.reset()
        
    def reset(self):
        self.gpu_memory = 0
        self.max_gpu_memory = 0
        
    def update(self):
        self.gpu_memory = torch.cuda.memory_allocated()
        self.max_gpu_memory = max(self.gpu_memory, self.max_gpu_memory)
        
    def print_memory(self, msg):
        print(f"{msg}",
        f" GPU: {self.gpu_memory * 1e-9:0.1f}GB / {self.max_gpu_memory * 1e-9:0.1f}GB")


class RandomDataset(Dataset):
    def __init__(self, n_samples, dim_1, dim_2):
        self.n_samples = n_samples
        self.data = torch.randn(n_samples, dim_1, dim_2)
        
    def __getitem__(self, index):
        return self.data[index]
    
    def __len__(self):
        return self.n_samples
        
    
class LinearWithGPUStats(torch.nn.Linear):
    def __init__(self, *args, name=None, mem_mon=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.name = name
        self.mem_mon = mem_mon
        
    def forward(self, x):
        out = F.linear(x, self.weight, self.bias)
        self.mem_mon.update()
        self.mem_mon.print_memory(msg=self.name)
        return out


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.mem_mon = MemoryMonitor()
        self.batch_size = 256
        self.dim_1 = 256
        self.n_input_dim = 4
        self.n_hidden_dim=2000
        self.n_layers = 10
        self.output_dim = 2
        
        self.in_layer = LinearWithGPUStats(self.n_input_dim, self.n_hidden_dim, name = "linear_in", mem_mon =  self.mem_mon)
        self.hidden_layers = torch.nn.ModuleList([LinearWithGPUStats(self.n_hidden_dim, self.n_hidden_dim, name=f"linear_{i}", mem_mon=self.mem_mon) for i in range(self.n_layers)])
        
        self.out_layer = LinearWithGPUStats(self.n_hidden_dim, self.output_dim, name="linear_out", mem_mon=self.mem_mon)
        
    def train_dataloader(self):
        return DataLoader(RandomDataset(self.batch_size, self.dim_1, self.n_input_dim), batch_size=self.batch_size)
        
    def forward(self, x):
        x = self.in_layer(x)
        for layer in self.hidden_layers:
            x = deepspeed_checkpoint(layer, x)
        return self.out_layer(x)    
    
    def training_step(self, batch, batch_idx):
        print("\n\ntraining step", batch_idx)
        self.mem_mon.reset()
        loss = self(batch).sum()
        return {"loss": loss}
        
    def configure_optimizers(self):
        return DeepSpeedCPUAdam(self.parameters(), lr =0.001)
        
        
def run(args):
    model = BoringModel()
    
    dsp = DeepSpeedPlugin(stage=3,
                          offload_parameters = True,
                          offload_optimizer=True,
                          cpu_checkpointing=args.cpu_checkpointing,
                          partition_activations=True)
    if args.checkpoint_in_cpu:
        dsp.config["activation_checkpointing"]["checkpoint_in_cpu"] = args.checkpoint_in_cpu
    
    trainer = Trainer(
        max_epochs=1,
        gpus=1,
        precision=16,
        strategy= dsp
    )
    
    trainer.fit(model)
    model.mem_mon.update()
    model.mem_mon.print_memory("training_step")
    
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--checkpoint_in_cpu", action='store_true', default=False)
    parser.add_argument("--cpu_checkpointing", action='store_true', default=False)
    args = parser.parse_args()
    run(args)

Expected behavior

cpu_checkpointing should enable cpu offloading of the checkpoints and therefore constant GPU memory usage between layers, but we can see that GPU memory usage increases through the layers:

!python minimal.py --cpu_checkpointing
….
linear_in  GPU: 0.3GB / 0.3GB
linear_0  GPU: 0.8GB / 0.8GB
linear_1  GPU: 1.1GB / 1.1GB
linear_2  GPU: 1.3GB / 1.3GB
linear_3  GPU: 1.6GB / 1.6GB
linear_4  GPU: 1.9GB / 1.9GB
linear_5  GPU: 2.1GB / 2.1GB
linear_6  GPU: 2.4GB / 2.4GB
linear_7  GPU: 2.6GB / 2.6GB
linear_8  GPU: 2.9GB / 2.9GB
linear_9  GPU: 3.2GB / 3.2GB
linear_out  GPU: 2.9GB / 3.2GB

If we run with the suggested fix we can see GPU memory usage is constant between layers.

!python minimal.py --checkpoint_in_cpu
….
linear_in  GPU: 0.3GB / 0.3GB
linear_0  GPU: 0.5GB / 0.5GB
linear_1  GPU: 0.5GB / 0.5GB
linear_2  GPU: 0.5GB / 0.5GB
linear_3  GPU: 0.5GB / 0.5GB
linear_4  GPU: 0.5GB / 0.5GB
linear_5  GPU: 0.5GB / 0.5GB
linear_6  GPU: 0.5GB / 0.5GB
linear_7  GPU: 0.5GB / 0.5GB
linear_8  GPU: 0.5GB / 0.5GB
linear_9  GPU: 0.5GB / 0.5GB
linear_out  GPU: 0.3GB / 0.5GB

Environment

  • CUDA:
    • GPU:
      • A100-SXM-80GB
    • available: True
    • version: 11.1
  • Packages:
    • numpy: 1.20.3
    • pyTorch_debug: False
    • pyTorch_version: 1.10.0+cu111
    • pytorch-lightning: 1.5.4
    • tqdm: 4.62.3
  • System:
    • OS: Linux
    • architecture:
      • 64bit
      • ELF
    • processor: x86_64
    • python: 3.8.10
  • Any other relevant information:
  • Deepspeed version 0.5.7

Suggested fix

I think the minimal solution here is to change:
checkpoint_in_cpu=checkpoint_config.get("checkpoint_in_cpu"),
to
checkpoint_in_cpu=checkpoint_config.get("cpu_checkpointing"),
of line 530 in pytorch_lightning/plugins/training_type/deepspeed.py

As far as I can tell, nothing sets checkpoint_in_cpu in the config in the lightning codebase, so this looks like a typo. Deepspeed is confusing here because the flag you set in the deepspeed config is cpu_checkpoint, but the argument to deepspeed.checkpointing.configure is called checkpoint_in_cpu.

I am happy and keen to raise a pr with this fix in, and add some tests* but wanted to run this approach past someone before opening the pr as I am not familiar with this codebase.

* I think the test would look something like: Train two BoringModel with multiple layers, checkpointing and an on_before_backward hook to store GPU memory usage. The model trained with cpu_checkpointing should have significantly lower peak GPU memory usage.

cc @SeanNaren @awaelchli

@jona-0 jona-0 added the bug Something isn't working label Dec 1, 2021
@awaelchli
Copy link
Contributor

awaelchli commented Dec 2, 2021

Just waiting to see what @SeanNaren says but to me it looks like your observations are right!

I am happy and keen to raise a pr with this fix in, and add some tests* but wanted to run this approach past someone before opening the pr as I am not familiar with this codebase.

That would be awesome! Feel free to bring this in, high-five!

As for the tests, imo a simple test that the deepspeed.initialize call receives the config with the correct setting from us should be sufficient, as the performance and memory usage will be a direct consequence of the third-party functioning correctly (which is tested in their library).

@jona-0
Copy link
Contributor Author

jona-0 commented Dec 6, 2021

Just wondering if I should wait for another comment on this issue or to open the PR now (currently in draft state).
Will assume I should open it tomorrow, but wanted to check in case there was something else I should be doing first.

@SeanNaren
Copy link
Contributor

Thanks a lot @jona-0 this looks great, apologies on the delay! Please open the PR for reviews :)

A strange case for sure (it does seem in DeepSpeed the variables are different https://www.deepspeed.ai/docs/config-json/#activation-checkpointing). The PR you've opened will fix the issue:)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working strategy: deepspeed
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants