Skip to content

Commit

Permalink
【Hackathon 5th No.83】PaddleMIX ppdiffusers models模块功能升级同步HF (#322)
Browse files Browse the repository at this point in the history
  • Loading branch information
co63oc committed Dec 14, 2023
1 parent 3ab12cf commit c4eb91e
Show file tree
Hide file tree
Showing 46 changed files with 8,799 additions and 2,365 deletions.
2 changes: 2 additions & 0 deletions ppdiffusers/ppdiffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
LVDMAutoencoderKL,
LVDMUNet3DModel,
ModelMixin,
MotionAdapter,
MultiAdapter,
PriorTransformer,
T2IAdapter,
Expand All @@ -72,6 +73,7 @@
UNet2DConditionModel,
UNet2DModel,
UNet3DConditionModel,
UNetMotionModel,
VQModel,
)
from .optimization import (
Expand Down
597 changes: 423 additions & 174 deletions ppdiffusers/ppdiffusers/loaders.py

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions ppdiffusers/ppdiffusers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from .unet_2d import UNet2DModel
from .unet_2d_condition import UNet2DConditionModel
from .unet_3d_condition import UNet3DConditionModel
from .unet_motion_model import MotionAdapter, UNetMotionModel
from .vq_model import VQModel

try:
Expand Down
94 changes: 87 additions & 7 deletions ppdiffusers/ppdiffusers/models/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,94 @@
# limitations under the License.

import paddle.nn as nn
import paddle.nn.functional as F

from ..utils import USE_PEFT_BACKEND
from .lora import LoRACompatibleLinear

def get_activation(act_fn):
if act_fn in ["swish", "silu"]:
return nn.Silu()
elif act_fn == "mish":
return nn.Mish()
elif act_fn == "gelu":
return nn.GELU()
ACTIVATION_FUNCTIONS = {
"swish": nn.Silu(),
"silu": nn.Silu(),
"mish": nn.Mish(),
"gelu": nn.GELU(),
"relu": nn.ReLU(),
}


def get_activation(act_fn: str) -> nn.Layer:
"""Helper function to get activation function from string.
Args:
act_fn (str): Name of activation function.
Returns:
nn.Layer: Activation function.
"""

act_fn = act_fn.lower()
if act_fn in ACTIVATION_FUNCTIONS:
return ACTIVATION_FUNCTIONS[act_fn]
else:
raise ValueError(f"Unsupported activation function: {act_fn}")


class GELU(nn.Layer):
r"""
GELU activation function with tanh approximation support with `approximate="tanh"`.
Parameters:
dim_in (`int`): The number of channels in the input.
dim_out (`int`): The number of channels in the output.
approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation.
"""

def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"):
super().__init__()
self.proj = LoRACompatibleLinear(dim_in, dim_out)
self.approximate = approximate
self.approximate_bool = approximate == "tanh"

def forward(self, hidden_states):
hidden_states = self.proj(hidden_states)
hidden_states = F.gelu(hidden_states, approximate=self.approximate_bool)
return hidden_states


class GEGLU(nn.Layer):
r"""
A [variant](https://arxiv.org/abs/2002.05202) of the gated linear unit activation function.
Parameters:
dim_in (`int`): The number of channels in the input.
dim_out (`int`): The number of channels in the output.
"""

def __init__(self, dim_in: int, dim_out: int):
super().__init__()
linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear

self.proj = linear_cls(dim_in, dim_out * 2)

def forward(self, hidden_states, scale: float = 1.0):
args = () if USE_PEFT_BACKEND else (scale,)
hidden_states, gate = self.proj(hidden_states, *args).chunk(2, axis=-1)
return hidden_states * F.gelu(gate)


class ApproximateGELU(nn.Layer):
r"""
The approximate form of the Gaussian Error Linear Unit (GELU). For more details, see section 2 of this
[paper](https://arxiv.org/abs/1606.08415).
Parameters:
dim_in (`int`): The number of channels in the input.
dim_out (`int`): The number of channels in the output.
"""

def __init__(self, dim_in: int, dim_out: int):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out)

def forward(self, x):
x = self.proj(x)
return x * F.sigmoid(1.702 * x)
Loading

0 comments on commit c4eb91e

Please sign in to comment.