You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I suspect this is due to a typo in DeepSpeedConfig – we set cpu_checkpointing but try to read checkpoint_in_cpu.
To Reproduce
Run once with --cpu_checkpointing, once with --checkpoint_in_cpu. Observe that cpu_checkpointing does not change the GPU memory usage, but checkpoint_in_cpu does
import numpy as np
import os,psutil
import deepspeed
import torch
import torch.nn.functional as F
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.plugins import DeepSpeedPlugin
from pytorch_lightning.utilities.seed import seed_everything
from torch.utils.data import DataLoader, Dataset
from torch.utils.checkpoint import checkpoint as torch_checkpoint
from deepspeed import checkpointing as deepspeed_checkpointing
from deepspeed.ops.adam import DeepSpeedCPUAdam
import os
import argparse
deepspeed_checkpoint = deepspeed.checkpointing.checkpoint
seed_everything(42)
class MemoryMonitor:
def __init__(self):
self.reset()
def reset(self):
self.gpu_memory = 0
self.max_gpu_memory = 0
def update(self):
self.gpu_memory = torch.cuda.memory_allocated()
self.max_gpu_memory = max(self.gpu_memory, self.max_gpu_memory)
def print_memory(self, msg):
print(f"{msg}",
f" GPU: {self.gpu_memory * 1e-9:0.1f}GB / {self.max_gpu_memory * 1e-9:0.1f}GB")
class RandomDataset(Dataset):
def __init__(self, n_samples, dim_1, dim_2):
self.n_samples = n_samples
self.data = torch.randn(n_samples, dim_1, dim_2)
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.n_samples
class LinearWithGPUStats(torch.nn.Linear):
def __init__(self, *args, name=None, mem_mon=None, **kwargs):
super().__init__(*args, **kwargs)
self.name = name
self.mem_mon = mem_mon
def forward(self, x):
out = F.linear(x, self.weight, self.bias)
self.mem_mon.update()
self.mem_mon.print_memory(msg=self.name)
return out
class BoringModel(LightningModule):
def __init__(self):
super().__init__()
self.mem_mon = MemoryMonitor()
self.batch_size = 256
self.dim_1 = 256
self.n_input_dim = 4
self.n_hidden_dim=2000
self.n_layers = 10
self.output_dim = 2
self.in_layer = LinearWithGPUStats(self.n_input_dim, self.n_hidden_dim, name = "linear_in", mem_mon = self.mem_mon)
self.hidden_layers = torch.nn.ModuleList([LinearWithGPUStats(self.n_hidden_dim, self.n_hidden_dim, name=f"linear_{i}", mem_mon=self.mem_mon) for i in range(self.n_layers)])
self.out_layer = LinearWithGPUStats(self.n_hidden_dim, self.output_dim, name="linear_out", mem_mon=self.mem_mon)
def train_dataloader(self):
return DataLoader(RandomDataset(self.batch_size, self.dim_1, self.n_input_dim), batch_size=self.batch_size)
def forward(self, x):
x = self.in_layer(x)
for layer in self.hidden_layers:
x = deepspeed_checkpoint(layer, x)
return self.out_layer(x)
def training_step(self, batch, batch_idx):
print("\n\ntraining step", batch_idx)
self.mem_mon.reset()
loss = self(batch).sum()
return {"loss": loss}
def configure_optimizers(self):
return DeepSpeedCPUAdam(self.parameters(), lr =0.001)
def run(args):
model = BoringModel()
dsp = DeepSpeedPlugin(stage=3,
offload_parameters = True,
offload_optimizer=True,
cpu_checkpointing=args.cpu_checkpointing,
partition_activations=True)
if args.checkpoint_in_cpu:
dsp.config["activation_checkpointing"]["checkpoint_in_cpu"] = args.checkpoint_in_cpu
trainer = Trainer(
max_epochs=1,
gpus=1,
precision=16,
strategy= dsp
)
trainer.fit(model)
model.mem_mon.update()
model.mem_mon.print_memory("training_step")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint_in_cpu", action='store_true', default=False)
parser.add_argument("--cpu_checkpointing", action='store_true', default=False)
args = parser.parse_args()
run(args)
Expected behavior
cpu_checkpointing should enable cpu offloading of the checkpoints and therefore constant GPU memory usage between layers, but we can see that GPU memory usage increases through the layers:
I think the minimal solution here is to change:
checkpoint_in_cpu=checkpoint_config.get("checkpoint_in_cpu"),
to
checkpoint_in_cpu=checkpoint_config.get("cpu_checkpointing"),
of line 530 in pytorch_lightning/plugins/training_type/deepspeed.py
As far as I can tell, nothing sets checkpoint_in_cpu in the config in the lightning codebase, so this looks like a typo. Deepspeed is confusing here because the flag you set in the deepspeed config is cpu_checkpoint, but the argument to deepspeed.checkpointing.configure is called checkpoint_in_cpu.
I am happy and keen to raise a pr with this fix in, and add some tests* but wanted to run this approach past someone before opening the pr as I am not familiar with this codebase.
* I think the test would look something like: Train two BoringModel with multiple layers, checkpointing and an on_before_backward hook to store GPU memory usage. The model trained with cpu_checkpointing should have significantly lower peak GPU memory usage.
Just waiting to see what @SeanNaren says but to me it looks like your observations are right!
I am happy and keen to raise a pr with this fix in, and add some tests* but wanted to run this approach past someone before opening the pr as I am not familiar with this codebase.
That would be awesome! Feel free to bring this in, high-five!
As for the tests, imo a simple test that the deepspeed.initialize call receives the config with the correct setting from us should be sufficient, as the performance and memory usage will be a direct consequence of the third-party functioning correctly (which is tested in their library).
Just wondering if I should wait for another comment on this issue or to open the PR now (currently in draft state).
Will assume I should open it tomorrow, but wanted to check in case there was something else I should be doing first.
🐛 Bug
We expect when the cpu_checkpointing flag is set, GPU memory usage to be constant during the forward pass (as it offloads each layers activations to the CPU) see https://www.deepspeed.ai/docs/config-json/#activation-checkpointing but it does not do this.
I suspect this is due to a typo in DeepSpeedConfig – we set cpu_checkpointing but try to read checkpoint_in_cpu.
To Reproduce
Run once with --cpu_checkpointing, once with --checkpoint_in_cpu. Observe that cpu_checkpointing does not change the GPU memory usage, but checkpoint_in_cpu does
Expected behavior
cpu_checkpointing should enable cpu offloading of the checkpoints and therefore constant GPU memory usage between layers, but we can see that GPU memory usage increases through the layers:
If we run with the suggested fix we can see GPU memory usage is constant between layers.
Environment
Suggested fix
I think the minimal solution here is to change:
checkpoint_in_cpu=checkpoint_config.get("checkpoint_in_cpu"),
to
checkpoint_in_cpu=checkpoint_config.get("cpu_checkpointing"),
of line 530 in pytorch_lightning/plugins/training_type/deepspeed.py
As far as I can tell, nothing sets checkpoint_in_cpu in the config in the lightning codebase, so this looks like a typo. Deepspeed is confusing here because the flag you set in the deepspeed config is cpu_checkpoint, but the argument to deepspeed.checkpointing.configure is called checkpoint_in_cpu.
I am happy and keen to raise a pr with this fix in, and add some tests* but wanted to run this approach past someone before opening the pr as I am not familiar with this codebase.
* I think the test would look something like: Train two BoringModel with multiple layers, checkpointing and an on_before_backward hook to store GPU memory usage. The model trained with cpu_checkpointing should have significantly lower peak GPU memory usage.
cc @SeanNaren @awaelchli
The text was updated successfully, but these errors were encountered: