Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add guards to SD imports #9158

Merged
merged 2 commits into from
May 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,14 @@
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from apex.contrib.group_norm import GroupNorm

try:
from apex.contrib.group_norm import GroupNorm

OPT_GROUP_NORM = True
except Exception:
print('Fused optimized group norm has not been installed.')

Check notice

Code scanning / CodeQL

Use of a print statement at module level Note

Print statement may execute during import.
OPT_GROUP_NORM = False


def conv_nd(dims, *args, **kwargs):
Expand Down
17 changes: 14 additions & 3 deletions nemo/collections/multimodal/modules/stable_diffusion/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,20 @@

import torch
import torch.nn.functional as F
from apex.contrib.group_norm import GroupNorm
from einops import rearrange, repeat
from torch import einsum, nn
from torch._dynamo import disable

if os.environ.get("USE_NATIVE_GROUP_NORM", "0") == "1":
from nemo.gn_native import GroupNormNormlization as GroupNorm
else:
from apex.contrib.group_norm import GroupNorm
try:
from apex.contrib.group_norm import GroupNorm

from transformer_engine.pytorch.module import LayerNormLinear, LayerNormMLP
OPT_GROUP_NORM = True
except Exception:
print('Fused optimized group norm has not been installed.')

Check notice

Code scanning / CodeQL

Use of a print statement at module level Note

Print statement may execute during import.
OPT_GROUP_NORM = False

from nemo.collections.multimodal.modules.stable_diffusion.diffusionmodules.util import checkpoint
from nemo.collections.nlp.modules.common.megatron.adapters.parallel_adapters import (
Expand All @@ -37,6 +40,14 @@
from nemo.core import adapter_mixins
from nemo.utils import logging

try:
from transformer_engine.pytorch.module import LayerNormLinear, LayerNormMLP

HAVE_TE = True

except (ImportError, ModuleNotFoundError):
HAVE_TE = False


def check_cuda():
if not torch.cuda.is_available():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,19 @@
import numpy as np
import torch
import torch.nn as nn
from apex.contrib.group_norm import GroupNorm
from einops import rearrange

from nemo.collections.multimodal.modules.stable_diffusion.attention import LinearAttention
from nemo.collections.multimodal.parts.stable_diffusion.utils import instantiate_from_config

try:
from apex.contrib.group_norm import GroupNorm

OPT_GROUP_NORM = True
except Exception:
print('Fused optimized group norm has not been installed.')

Check notice

Code scanning / CodeQL

Use of a print statement at module level Note

Print statement may execute during import.
OPT_GROUP_NORM = False


def get_timestep_embedding(timesteps, embedding_dim):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,6 @@
import torch.nn as nn
import torch.nn.functional as F

# FP8 related import
import transformer_engine
from apex.contrib.group_norm import GroupNorm

from nemo.collections.multimodal.modules.stable_diffusion.attention import SpatialTransformer
from nemo.collections.multimodal.modules.stable_diffusion.diffusionmodules.util import (
avg_pool_nd,
Expand All @@ -45,6 +41,23 @@
)
from nemo.utils import logging

try:
# FP8 related import
import transformer_engine

HAVE_TE = True

except (ImportError, ModuleNotFoundError):
HAVE_TE = False

try:
from apex.contrib.group_norm import GroupNorm

OPT_GROUP_NORM = True
except Exception:
print('Fused optimized group norm has not been installed.')

Check notice

Code scanning / CodeQL

Use of a print statement at module level Note

Print statement may execute during import.
OPT_GROUP_NORM = False


def convert_module_to_dtype(module, dtype, enable_norm_layers=False):
# Convert module parameters to dtype
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,18 @@
import numpy as np
import torch
import torch.nn as nn
from apex.contrib.group_norm import GroupNorm
from einops import repeat
from torch._dynamo import disable
from torch.cuda.amp import custom_bwd, custom_fwd

try:
from apex.contrib.group_norm import GroupNorm

OPT_GROUP_NORM = True
except Exception:
print('Fused optimized group norm has not been installed.')

Check notice

Code scanning / CodeQL

Use of a print statement at module level Note

Print statement may execute during import.
OPT_GROUP_NORM = False


def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
if schedule == "linear":
Expand Down
Loading