Skip to content
This repository has been archived by the owner on May 14, 2024. It is now read-only.

CUDA out of memory error on the kohya native trainer XL #333

Open
domochevisk opened this issue Jan 16, 2024 · 0 comments
Open

CUDA out of memory error on the kohya native trainer XL #333

domochevisk opened this issue Jan 16, 2024 · 0 comments

Comments

@domochevisk
Copy link

Been getting the out of memory error in the kogya native XL trainer, as if the GPU doesn't manage to allocate enough memory for the task... I lowered the batch size to 1 and other things and the problem persists.

[Dataset 0]
loading image sizes.
100% 4/4 [00:00<00:00, 77314.36it/s]
make buckets
number of images (including repeats) / 各bucketの画像枚数(繰り返し回数を含む)
bucket 0: resolution (768, 1024), count: 1
bucket 1: resolution (832, 1024), count: 3
mean ar error (without repeats): 0.0
prepare accelerator
loading model for process 0/1
load Diffusers pretrained models: Linaqruf/animagine-xl, variant=fp16
The config attributes {'force_upcast': True} were passed to AutoencoderKL, but are not expected and will be ignored. Please verify your config.json configuration file.
The config attributes {'attention_type': 'default'} were passed to UNet2DConditionModel, but are not expected and will be ignored. Please verify your config.json configuration file.
U-Net converted to original U-Net
Disable Diffusers' xformers
Enable xformers for U-Net
number of models: 1
number of trainable parameters: 2567463684
prepare optimizer, data loader etc.
use Adafactor optimizer | {'scale_parameter': False, 'relative_step': False, 'warmup_init': False}
running training / 学習開始
num examples / サンプル数: 4
num batches per epoch / 1epochのバッチ数: 4
num epochs / epoch数: 250
batch size per device / バッチサイズ: 1
gradient accumulation steps / 勾配を合計するステップ数 = 1
total optimization steps / 学習ステップ数: 1000
steps: 0% 0/1000 [00:00<?, ?it/s]
epoch 1/250
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /content/kohya-trainer/sdxl_train.py:649 in │
│ │
│ 646 │ args = parser.parse_args() │
│ 647 │ args = train_util.read_config_from_file(args, parser) │
│ 648 │ │
│ ❱ 649 │ train(args) │
│ 650 │
│ │
│ /content/kohya-trainer/sdxl_train.py:455 in train │
│ │
│ 452 │ │ │ │ │
│ 453 │ │ │ │ # Predict the noise residual │
│ 454 │ │ │ │ with accelerator.autocast(): │
│ ❱ 455 │ │ │ │ │ noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_e │
│ 456 │ │ │ │ │
│ 457 │ │ │ │ target = noise │
│ 458 │
│ │
│ /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518 in _wrapped_call_impl │
│ │
│ 1515 │ │ if self._compiled_call_impl is not None: │
│ 1516 │ │ │ return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] │
│ 1517 │ │ else: │
│ ❱ 1518 │ │ │ return self._call_impl(*args, **kwargs) │
│ 1519 │ │
│ 1520 │ def _call_impl(self, *args, **kwargs): │
│ 1521 │ │ forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.fo │
│ │
│ /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527 in _call_impl │
│ │
│ 1524 │ │ if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks │
│ 1525 │ │ │ │ or _global_backward_pre_hooks or _global_backward_hooks │
│ 1526 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks): │
│ ❱ 1527 │ │ │ return forward_call(*args, **kwargs) │
│ 1528 │ │ │
│ 1529 │ │ try: │
│ 1530 │ │ │ result = None │
│ │
│ /usr/local/lib/python3.10/dist-packages/accelerate/utils/operations.py:521 in forward │
│ │
│ 518 │ model_forward = ConvertOutputsToFp32(model_forward) │
│ 519 │ │
│ 520 │ def forward(*args, **kwargs): │
│ ❱ 521 │ │ return model_forward(*args, **kwargs) │
│ 522 │ │
│ 523 │ # To act like a decorator so that it can be popped when doing extract_model_from_pa │ │ 524 │ forward.__wrapped__ = model_forward │ │ │ │ /usr/local/lib/python3.10/dist-packages/accelerate/utils/operations.py:509 in __call__ │ │ │ │ 506 │ │ update_wrapper(self, model_forward) │ │ 507 │ │ │ 508 │ def __call__(self, *args, **kwargs): │ │ ❱ 509 │ │ return convert_to_fp32(self.model_forward(*args, **kwargs)) │ │ 510 │ │ │ 511 │ def __getstate__(self): │ │ 512 │ │ raise pickle.PicklingError( │ │ │ │ /usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py:16 in decorate_autocast │ │ │ │ 13 │ @functools.wraps(func) │ │ 14 │ def decorate_autocast(*args, **kwargs): │ │ 15 │ │ with autocast_instance: │ │ ❱ 16 │ │ │ return func(*args, **kwargs) │ │ 17 │ │ │ 18 │ decorate_autocast.__script_unsupported = "@autocast() decorator is not supported in │ │ 19 │ return decorate_autocast │ │ │ │ /content/kohya-trainer/library/sdxl_original_unet.py:1081 in forward │ │ │ │ 1078 │ │ │ │ 1079 │ │ for module in self.output_blocks: │ │ 1080 │ │ │ h = torch.cat([h, hs.pop()], dim=1) │ │ ❱ 1081 │ │ │ h = call_module(module, h, emb, context) │ │ 1082 │ │ │ │ 1083 │ │ h = h.type(x.dtype) │ │ 1084 │ │ h = call_module(self.out, h, emb, context) │ │ │ │ /content/kohya-trainer/library/sdxl_original_unet.py:1066 in call_module │ │ │ │ 1063 │ │ │ │ if isinstance(layer, ResnetBlock2D): │ │ 1064 │ │ │ │ │ x = layer(x, emb) │ │ 1065 │ │ │ │ elif isinstance(layer, Transformer2DModel): │ │ ❱ 1066 │ │ │ │ │ x = layer(x, context) │ │ 1067 │ │ │ │ else: │ │ 1068 │ │ │ │ │ x = layer(x) │ │ 1069 │ │ │ return x │ │ │ │ /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518 in _wrapped_call_impl │ │ │ │ 1515 │ │ if self._compiled_call_impl is not None: │ │ 1516 │ │ │ return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] │ │ 1517 │ │ else: │ │ ❱ 1518 │ │ │ return self._call_impl(*args, **kwargs) │ │ 1519 │ │ │ 1520 │ def _call_impl(self, *args, **kwargs): │ │ 1521 │ │ forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.fo │ │ │ │ /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527 in _call_impl │ │ │ │ 1524 │ │ if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks │ │ 1525 │ │ │ │ or _global_backward_pre_hooks or _global_backward_hooks │ │ 1526 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks): │ │ ❱ 1527 │ │ │ return forward_call(*args, **kwargs) │ │ 1528 │ │ │ │ 1529 │ │ try: │ │ 1530 │ │ │ result = None │ │ │ │ /content/kohya-trainer/library/sdxl_original_unet.py:723 in forward │ │ │ │ 720 │ │ │ │ 721 │ │ # 2. Blocks │ │ 722 │ │ for block in self.transformer_blocks: │ │ ❱ 723 │ │ │ hidden_states = block(hidden_states, context=encoder_hidden_states, timestep │ │ 724 │ │ │ │ 725 │ │ # 3. Output │ │ 726 │ │ if not self.use_linear_projection: │ │ │ │ /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518 in _wrapped_call_impl │ │ │ │ 1515 │ │ if self._compiled_call_impl is not None: │ │ 1516 │ │ │ return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] │ │ 1517 │ │ else: │ │ ❱ 1518 │ │ │ return self._call_impl(*args, **kwargs) │ │ 1519 │ │ │ 1520 │ def _call_impl(self, *args, **kwargs): │ │ 1521 │ │ forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.fo │ │ │ │ /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527 in _call_impl │ │ │ │ 1524 │ │ if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks │ │ 1525 │ │ │ │ or _global_backward_pre_hooks or _global_backward_hooks │ │ 1526 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks): │ │ ❱ 1527 │ │ │ return forward_call(*args, **kwargs) │ │ 1528 │ │ │ │ 1529 │ │ try: │ │ 1530 │ │ │ result = None │ │ │ │ /content/kohya-trainer/library/sdxl_original_unet.py:644 in forward │ │ │ │ 641 │ │ │ │ │ │ 642 │ │ │ │ return custom_forward │ │ 643 │ │ │ │ │ ❱ 644 │ │ │ output = torch.utils.checkpoint.checkpoint(create_custom_forward(self.forwar │ │ 645 │ │ else: │ │ 646 │ │ │ output = self.forward_body(hidden_states, context, timestep) │ │ 647 │ │ │ │ /usr/local/lib/python3.10/dist-packages/torch/_compile.py:24 in inner │ │ │ │ 21 │ │ def inner(*args, **kwargs): │ │ 22 │ │ │ import torch._dynamo │ │ 23 │ │ │ │ │ ❱ 24 │ │ │ return torch._dynamo.disable(fn, recursive)(*args, **kwargs) │ │ 25 │ │ │ │ 26 │ │ return inner │ │ 27 │ else: │ │ │ │ /usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py:328 in _fn │ │ │ │ 325 │ │ │ dynamic_ctx = enable_dynamic(self.dynamic, self.export) │ │ 326 │ │ │ dynamic_ctx.__enter__() │ │ 327 │ │ │ try: │ │ ❱ 328 │ │ │ │ return fn(*args, **kwargs) │ │ 329 │ │ │ finally: │ │ 330 │ │ │ │ set_eval_frame(prior) │ │ 331 │ │ │ │ dynamic_ctx.__exit__(None, None, None) │ │ │ │ /usr/local/lib/python3.10/dist-packages/torch/_dynamo/external_utils.py:17 in inner │ │ │ │ 14 │ │ │ 15 │ @functools.wraps(fn) │ │ 16 │ def inner(*args, **kwargs): │ │ ❱ 17 │ │ return fn(*args, **kwargs) │ │ 18 │ │ │ 19 │ return inner │ │ 20 │ │ │ │ /usr/local/lib/python3.10/dist-packages/torch/utils/checkpoint.py:451 in checkpoint │ │ │ │ 448 │ │ │ │ "Passing context_fnordebug` is only supported when " │
│ 449 │ │ │ │ "use_reentrant=False." │
│ 450 │ │ │ ) │
│ ❱ 451 │ │ return CheckpointFunction.apply(function, preserve, *args) │
│ 452 │ else: │
│ 453 │ │ gen = _checkpoint_without_reentrant_generator( │
│ 454 │ │ │ function, preserve, context_fn, determinism_check, debug, *args, **kwargs │
│ │
│ /usr/local/lib/python3.10/dist-packages/torch/autograd/function.py:539 in apply │
│ │
│ 536 │ │ if not torch._C._are_functorch_transforms_active(): │
│ 537 │ │ │ # See NOTE: [functorch vjp and autograd interaction] │
│ 538 │ │ │ args = _functorch.utils.unwrap_dead_wrappers(args) │
│ ❱ 539 │ │ │ return super().apply(*args, **kwargs) # type: ignore[misc] │
│ 540 │ │ │
│ 541 │ │ if cls.setup_context == _SingleLevelFunction.setup_context: │
│ 542 │ │ │ raise RuntimeError( │
│ │
│ /usr/local/lib/python3.10/dist-packages/torch/utils/checkpoint.py:230 in forward │
│ │
│ 227 │ │ ctx.save_for_backward(*tensor_inputs) │
│ 228 │ │ │
│ 229 │ │ with torch.no_grad(): │
│ ❱ 230 │ │ │ outputs = run_function(*args) │
│ 231 │ │ return outputs │
│ 232 │ │
│ 233 │ @staticmethod
│ │
│ /content/kohya-trainer/library/sdxl_original_unet.py:640 in custom_forward │
│ │
│ 637 │ │ │ │
│ 638 │ │ │ def create_custom_forward(func): │
│ 639 │ │ │ │ def custom_forward(*inputs): │
│ ❱ 640 │ │ │ │ │ return func(*inputs) │
│ 641 │ │ │ │ │
│ 642 │ │ │ │ return custom_forward │
│ 643 │
│ │
│ /content/kohya-trainer/library/sdxl_original_unet.py:630 in forward_body │
│ │
│ 627 │ │ hidden_states = self.attn2(norm_hidden_states, context=context) + hidden_states │
│ 628 │ │ │
│ 629 │ │ # 3. Feed-forward │
│ ❱ 630 │ │ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states │
│ 631 │ │ │
│ 632 │ │ return hidden_states │
│ 633 │
│ │
│ /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518 in _wrapped_call_impl │
│ │
│ 1515 │ │ if self._compiled_call_impl is not None: │
│ 1516 │ │ │ return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] │
│ 1517 │ │ else: │
│ ❱ 1518 │ │ │ return self._call_impl(*args, **kwargs) │
│ 1519 │ │
│ 1520 │ def _call_impl(self, *args, **kwargs): │
│ 1521 │ │ forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.fo │
│ │
│ /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527 in _call_impl │
│ │
│ 1524 │ │ if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks │
│ 1525 │ │ │ │ or _global_backward_pre_hooks or _global_backward_hooks │
│ 1526 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks): │
│ ❱ 1527 │ │ │ return forward_call(*args, **kwargs) │
│ 1528 │ │ │
│ 1529 │ │ try: │
│ 1530 │ │ │ result = None │
│ │
│ /content/kohya-trainer/library/sdxl_original_unet.py:574 in forward │
│ │
│ 571 │ │
│ 572 │ def forward(self, hidden_states): │
│ 573 │ │ for module in self.net: │
│ ❱ 574 │ │ │ hidden_states = module(hidden_states) │
│ 575 │ │ return hidden_states │
│ 576 │
│ 577 │
│ │
│ /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518 in _wrapped_call_impl │
│ │
│ 1515 │ │ if self._compiled_call_impl is not None: │
│ 1516 │ │ │ return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] │
│ 1517 │ │ else: │
│ ❱ 1518 │ │ │ return self._call_impl(*args, **kwargs) │
│ 1519 │ │
│ 1520 │ def _call_impl(self, *args, **kwargs): │
│ 1521 │ │ forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.fo │
│ │
│ /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527 in _call_impl │
│ │
│ 1524 │ │ if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks │
│ 1525 │ │ │ │ or _global_backward_pre_hooks or _global_backward_hooks │
│ 1526 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks): │
│ ❱ 1527 │ │ │ return forward_call(*args, **kwargs) │
│ 1528 │ │ │
│ 1529 │ │ try: │
│ 1530 │ │ │ result = None │
│ │
│ /content/kohya-trainer/library/sdxl_original_unet.py:552 in forward │
│ │
│ 549 │ │ return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) │
│ 550 │ │
│ 551 │ def forward(self, hidden_states): │
│ ❱ 552 │ │ hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1) │
│ 553 │ │ return hidden_states * self.gelu(gate) │
│ 554 │
│ 555 │
│ │
│ /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518 in _wrapped_call_impl │
│ │
│ 1515 │ │ if self._compiled_call_impl is not None: │
│ 1516 │ │ │ return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] │
│ 1517 │ │ else: │
│ ❱ 1518 │ │ │ return self._call_impl(*args, **kwargs) │
│ 1519 │ │
│ 1520 │ def _call_impl(self, *args, **kwargs): │
│ 1521 │ │ forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.fo │
│ │
│ /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527 in _call_impl │
│ │
│ 1524 │ │ if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks │
│ 1525 │ │ │ │ or _global_backward_pre_hooks or _global_backward_hooks │
│ 1526 │ │ │ │ or _global_forward_hooks or global_forward_pre_hooks): │
│ ❱ 1527 │ │ │ return forward_call(*args, **kwargs) │
│ 1528 │ │ │
│ 1529 │ │ try: │
│ 1530 │ │ │ result = None │
│ │
│ /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:114 in forward │
│ │
│ 111 │ │ │ init.uniform
(self.bias, -bound, bound) │
│ 112 │ │
│ 113 │ def forward(self, input: Tensor) -> Tensor: │
│ ❱ 114 │ │ return F.linear(input, self.weight, self.bias) │
│ 115 │ │
│ 116 │ def extra_repr(self) -> str: │
│ 117 │ │ return f'in_features={self.in_features}, out_features={self.out_features}, bias= │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
OutOfMemoryError: CUDA out of memory. Tried to allocate 18.00 MiB. GPU 0 has a total capacty of
14.75 GiB of which 7.06 MiB is free. Process 297668 has 14.74 GiB memory in use. Of the allocated
memory 14.28 GiB is allocated by PyTorch, and 332.24 MiB is reserved by PyTorch but unallocated. If
reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation. See
documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
steps: 0% 0/1000 [00:06<?, ?it/s]
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /usr/local/bin/accelerate:8 in │
│ │
│ 5 from accelerate.commands.accelerate_cli import main │
│ 6 if name == 'main': │
│ 7 │ sys.argv[0] = re.sub(r'(-script.pyw|.exe)?$', '', sys.argv[0]) │
│ ❱ 8 │ sys.exit(main()) │
│ 9 │
│ │
│ /usr/local/lib/python3.10/dist-packages/accelerate/commands/accelerate_cli.py:45 in main │
│ │
│ 42 │ │ exit(1) │
│ 43 │ │
│ 44 │ # Run │
│ ❱ 45 │ args.func(args) │
│ 46 │
│ 47 │
│ 48 if name == "main": │
│ │
│ /usr/local/lib/python3.10/dist-packages/accelerate/commands/launch.py:918 in launch_command │
│ │
│ 915 │ elif defaults is not None and defaults.compute_environment == ComputeEnvironment.AMA │
│ 916 │ │ sagemaker_launcher(defaults, args) │
│ 917 │ else: │
│ ❱ 918 │ │ simple_launcher(args) │
│ 919 │
│ 920 │
│ 921 def main(): │
│ │
│ /usr/local/lib/python3.10/dist-packages/accelerate/commands/launch.py:580 in simple_launcher │
│ │
│ 577 │ process.wait() │
│ 578 │ if process.returncode != 0: │
│ 579 │ │ if not args.quiet: │
│ ❱ 580 │ │ │ raise subprocess.CalledProcessError(returncode=process.returncode, cmd=cmd) │
│ 581 │ │ else: │
│ 582 │ │ │ sys.exit(1) │
│ 583 │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
CalledProcessError: Command '['/usr/bin/python3', 'sdxl_train.py',
'--sample_prompts=/content/fine_tune/config/sample_prompt.toml',
'--config_file=/content/fine_tune/config/config_file.toml']' returned non-zero exit status 1.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant