diff --git a/generative/networks/blocks/selfattention.py b/generative/networks/blocks/selfattention.py index 7c0370c6..b59b78bb 100644 --- a/generative/networks/blocks/selfattention.py +++ b/generative/networks/blocks/selfattention.py @@ -11,12 +11,20 @@ from __future__ import annotations +import importlib.util import math import torch import torch.nn as nn from torch.nn import functional as F +if importlib.util.find_spec("xformers") is not None: + import xformers.ops as xops + + has_xformers = True +else: + has_xformers = False + class SABlock(nn.Module): """ @@ -31,6 +39,7 @@ class SABlock(nn.Module): causal: whether to use causal attention. sequence_length: if causal is True, it is necessary to specify the sequence length. with_cross_attention: Whether to use cross attention for conditioning. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. """ def __init__( @@ -42,6 +51,7 @@ def __init__( causal: bool = False, sequence_length: int | None = None, with_cross_attention: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() self.hidden_size = hidden_size @@ -51,9 +61,11 @@ def __init__( self.causal = causal self.sequence_length = sequence_length self.with_cross_attention = with_cross_attention + self.use_flash_attention = use_flash_attention if not (0 <= dropout_rate <= 1): raise ValueError("dropout_rate should be between 0 and 1.") + self.dropout_rate = dropout_rate if hidden_size % num_heads != 0: raise ValueError("hidden size should be divisible by num_heads.") @@ -61,6 +73,9 @@ def __init__( if causal and sequence_length is None: raise ValueError("sequence_length is necessary for causal attention.") + if use_flash_attention and not has_xformers: + raise ValueError("use_flash_attention is True but xformers is not installed.") + # key, query, value projections self.to_q = nn.Linear(hidden_size, hidden_size, bias=qkv_bias) self.to_k = nn.Linear(hidden_size, hidden_size, bias=qkv_bias) @@ -91,20 +106,42 @@ def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch key = self.to_k(kv) value = self.to_v(kv) - query = query.view(b, t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, t, hs) - key = key.view(b, kv_t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, kv_t, hs) - value = value.view(b, kv_t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, kv_t, hs) + query = query.view(b, t, self.num_heads, c // self.num_heads) # (b, t, nh, hs) + key = key.view(b, kv_t, self.num_heads, c // self.num_heads) # (b, kv_t, nh, hs) + value = value.view(b, kv_t, self.num_heads, c // self.num_heads) # (b, kv_t, nh, hs) + + if self.use_flash_attention: + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + y = xops.memory_efficient_attention( + query=query, + key=key, + value=value, + scale=self.scale, + p=self.dropout_rate, + attn_bias=xops.LowerTriangularMask() if self.causal else None, + ) + + else: + query = query.transpose(1, 2) # (b, nh, t, hs) + key = key.transpose(1, 2) # (b, nh, kv_t, hs) + value = value.transpose(1, 2) # (b, nh, kv_t, hs) + + # manual implementation of attention + query = query * self.scale + attention_scores = query @ key.transpose(-2, -1) + + if self.causal: + attention_scores = attention_scores.masked_fill(self.causal_mask[:, :, :t, :kv_t] == 0, float("-inf")) - # manual implementation of attention - attention_scores = (query @ key.transpose(-2, -1)) * self.scale + attention_probs = F.softmax(attention_scores, dim=-1) + attention_probs = self.drop_weights(attention_probs) + y = attention_probs @ value # (b, nh, t, kv_t) x (b, nh, kv_t, hs) -> (b, nh, t, hs) - if self.causal: - attention_scores = attention_scores.masked_fill(self.causal_mask[:, :, :t, :kv_t] == 0, float("-inf")) + y = y.transpose(1, 2) # (b, nh, t, hs) -> (b, t, nh, hs) - attention_probs = F.softmax(attention_scores, dim=-1) - attention_probs = self.drop_weights(attention_probs) - y = attention_probs @ value # (b, nh, t, kv_t) x (b, nh, kv_t, hs) -> (b, nh, t, hs) - y = y.transpose(1, 2).contiguous().view(b, t, c) # re-assemble all head outputs side by side + y = y.contiguous().view(b, t, c) # re-assemble all head outputs side by side y = self.out_proj(y) y = self.drop_output(y) diff --git a/generative/networks/blocks/transformerblock.py b/generative/networks/blocks/transformerblock.py index fd034c7e..ae8cb962 100644 --- a/generative/networks/blocks/transformerblock.py +++ b/generative/networks/blocks/transformerblock.py @@ -32,6 +32,7 @@ class TransformerBlock(nn.Module): causal: whether to use causal attention. sequence_length: if causal is True, it is necessary to specify the sequence length. with_cross_attention: Whether to use cross attention for conditioning. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. """ def __init__( @@ -44,6 +45,7 @@ def __init__( causal: bool = False, sequence_length: int | None = None, with_cross_attention: bool = False, + use_flash_attention: bool = False, ) -> None: self.with_cross_attention = with_cross_attention super().__init__() @@ -62,6 +64,7 @@ def __init__( qkv_bias=qkv_bias, causal=causal, sequence_length=sequence_length, + use_flash_attention=use_flash_attention, ) self.norm2 = None @@ -75,6 +78,7 @@ def __init__( qkv_bias=qkv_bias, with_cross_attention=with_cross_attention, causal=False, + use_flash_attention=use_flash_attention, ) self.norm3 = nn.LayerNorm(hidden_size) diff --git a/generative/networks/nets/transformer.py b/generative/networks/nets/transformer.py index 0ea838af..dc961cc6 100644 --- a/generative/networks/nets/transformer.py +++ b/generative/networks/nets/transformer.py @@ -50,6 +50,7 @@ class DecoderOnlyTransformer(nn.Module): attn_layers_heads: Number of attention heads. with_cross_attention: Whether to use cross attention for conditioning. embedding_dropout_rate: Dropout rate for the embedding. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. """ def __init__( @@ -61,6 +62,7 @@ def __init__( attn_layers_heads: int, with_cross_attention: bool = False, embedding_dropout_rate: float = 0.0, + use_flash_attention: bool = False, ) -> None: super().__init__() self.num_tokens = num_tokens @@ -85,6 +87,7 @@ def __init__( causal=True, sequence_length=max_seq_len, with_cross_attention=with_cross_attention, + use_flash_attention=use_flash_attention, ) for _ in range(attn_layers_depth) ] diff --git a/requirements-dev.txt b/requirements-dev.txt index 71ccc476..5b19e5d4 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -55,4 +55,3 @@ optuna git+https://github.com/Project-MONAI/MetricsReloaded@monai-support#egg=MetricsReloaded lpips==0.1.4 xformers==0.0.16 -x-transformers==1.8.1