Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 48 additions & 11 deletions generative/networks/blocks/selfattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,20 @@

from __future__ import annotations

import importlib.util
import math

import torch
import torch.nn as nn
from torch.nn import functional as F

if importlib.util.find_spec("xformers") is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently if the code does not find xformers butuse_flash_attention is set True the code errors out. I think we need to self use_flash_attention=False in the init if has_xformers=False, and ideally raise a warning, too.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added error message in case user want to use flash attention but xformers is not installed

import xformers.ops as xops

has_xformers = True
else:
has_xformers = False


class SABlock(nn.Module):
"""
Expand All @@ -31,6 +39,7 @@ class SABlock(nn.Module):
causal: whether to use causal attention.
sequence_length: if causal is True, it is necessary to specify the sequence length.
with_cross_attention: Whether to use cross attention for conditioning.
use_flash_attention: if True, use flash attention for a memory efficient attention mechanism.
"""

def __init__(
Expand All @@ -42,6 +51,7 @@ def __init__(
causal: bool = False,
sequence_length: int | None = None,
with_cross_attention: bool = False,
use_flash_attention: bool = False,
) -> None:
super().__init__()
self.hidden_size = hidden_size
Expand All @@ -51,16 +61,21 @@ def __init__(
self.causal = causal
self.sequence_length = sequence_length
self.with_cross_attention = with_cross_attention
self.use_flash_attention = use_flash_attention

if not (0 <= dropout_rate <= 1):
raise ValueError("dropout_rate should be between 0 and 1.")
self.dropout_rate = dropout_rate

if hidden_size % num_heads != 0:
raise ValueError("hidden size should be divisible by num_heads.")

if causal and sequence_length is None:
raise ValueError("sequence_length is necessary for causal attention.")

if use_flash_attention and not has_xformers:
raise ValueError("use_flash_attention is True but xformers is not installed.")

# key, query, value projections
self.to_q = nn.Linear(hidden_size, hidden_size, bias=qkv_bias)
self.to_k = nn.Linear(hidden_size, hidden_size, bias=qkv_bias)
Expand Down Expand Up @@ -91,20 +106,42 @@ def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch
key = self.to_k(kv)
value = self.to_v(kv)

query = query.view(b, t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, t, hs)
key = key.view(b, kv_t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, kv_t, hs)
value = value.view(b, kv_t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, kv_t, hs)
query = query.view(b, t, self.num_heads, c // self.num_heads) # (b, t, nh, hs)
key = key.view(b, kv_t, self.num_heads, c // self.num_heads) # (b, kv_t, nh, hs)
value = value.view(b, kv_t, self.num_heads, c // self.num_heads) # (b, kv_t, nh, hs)

if self.use_flash_attention:
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()
y = xops.memory_efficient_attention(
query=query,
key=key,
value=value,
scale=self.scale,
p=self.dropout_rate,
attn_bias=xops.LowerTriangularMask() if self.causal else None,
)

else:
query = query.transpose(1, 2) # (b, nh, t, hs)
key = key.transpose(1, 2) # (b, nh, kv_t, hs)
value = value.transpose(1, 2) # (b, nh, kv_t, hs)

# manual implementation of attention
query = query * self.scale
attention_scores = query @ key.transpose(-2, -1)

if self.causal:
attention_scores = attention_scores.masked_fill(self.causal_mask[:, :, :t, :kv_t] == 0, float("-inf"))

# manual implementation of attention
attention_scores = (query @ key.transpose(-2, -1)) * self.scale
attention_probs = F.softmax(attention_scores, dim=-1)
attention_probs = self.drop_weights(attention_probs)
y = attention_probs @ value # (b, nh, t, kv_t) x (b, nh, kv_t, hs) -> (b, nh, t, hs)

if self.causal:
attention_scores = attention_scores.masked_fill(self.causal_mask[:, :, :t, :kv_t] == 0, float("-inf"))
y = y.transpose(1, 2) # (b, nh, t, hs) -> (b, t, nh, hs)

attention_probs = F.softmax(attention_scores, dim=-1)
attention_probs = self.drop_weights(attention_probs)
y = attention_probs @ value # (b, nh, t, kv_t) x (b, nh, kv_t, hs) -> (b, nh, t, hs)
y = y.transpose(1, 2).contiguous().view(b, t, c) # re-assemble all head outputs side by side
y = y.contiguous().view(b, t, c) # re-assemble all head outputs side by side

y = self.out_proj(y)
y = self.drop_output(y)
Expand Down
4 changes: 4 additions & 0 deletions generative/networks/blocks/transformerblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class TransformerBlock(nn.Module):
causal: whether to use causal attention.
sequence_length: if causal is True, it is necessary to specify the sequence length.
with_cross_attention: Whether to use cross attention for conditioning.
use_flash_attention: if True, use flash attention for a memory efficient attention mechanism.
"""

def __init__(
Expand All @@ -44,6 +45,7 @@ def __init__(
causal: bool = False,
sequence_length: int | None = None,
with_cross_attention: bool = False,
use_flash_attention: bool = False,
) -> None:
self.with_cross_attention = with_cross_attention
super().__init__()
Expand All @@ -62,6 +64,7 @@ def __init__(
qkv_bias=qkv_bias,
causal=causal,
sequence_length=sequence_length,
use_flash_attention=use_flash_attention,
)

self.norm2 = None
Expand All @@ -75,6 +78,7 @@ def __init__(
qkv_bias=qkv_bias,
with_cross_attention=with_cross_attention,
causal=False,
use_flash_attention=use_flash_attention,
)

self.norm3 = nn.LayerNorm(hidden_size)
Expand Down
3 changes: 3 additions & 0 deletions generative/networks/nets/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class DecoderOnlyTransformer(nn.Module):
attn_layers_heads: Number of attention heads.
with_cross_attention: Whether to use cross attention for conditioning.
embedding_dropout_rate: Dropout rate for the embedding.
use_flash_attention: if True, use flash attention for a memory efficient attention mechanism.
"""

def __init__(
Expand All @@ -61,6 +62,7 @@ def __init__(
attn_layers_heads: int,
with_cross_attention: bool = False,
embedding_dropout_rate: float = 0.0,
use_flash_attention: bool = False,
) -> None:
super().__init__()
self.num_tokens = num_tokens
Expand All @@ -85,6 +87,7 @@ def __init__(
causal=True,
sequence_length=max_seq_len,
with_cross_attention=with_cross_attention,
use_flash_attention=use_flash_attention,
)
for _ in range(attn_layers_depth)
]
Expand Down
1 change: 0 additions & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,3 @@ optuna
git+https://github.com/Project-MONAI/MetricsReloaded@monai-support#egg=MetricsReloaded
lpips==0.1.4
xformers==0.0.16
x-transformers==1.8.1