From f7528b98cb62520c14456152ef93498a6b02caf5 Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Thu, 23 Mar 2023 22:41:22 +0000 Subject: [PATCH 1/4] Add flash attention Signed-off-by: Walter Hugo Lopez Pinaya --- generative/networks/blocks/selfattention.py | 37 +++++++++++++++---- .../networks/blocks/transformerblock.py | 4 ++ generative/networks/nets/transformer.py | 3 ++ 3 files changed, 37 insertions(+), 7 deletions(-) diff --git a/generative/networks/blocks/selfattention.py b/generative/networks/blocks/selfattention.py index 7c0370c6..e318e55b 100644 --- a/generative/networks/blocks/selfattention.py +++ b/generative/networks/blocks/selfattention.py @@ -11,12 +11,22 @@ from __future__ import annotations +import importlib.util import math import torch import torch.nn as nn from torch.nn import functional as F +if importlib.util.find_spec("xformers") is not None: + import xformers + import xformers.ops as xops + + has_xformers = True +else: + xformers = None + has_xformers = False + class SABlock(nn.Module): """ @@ -31,6 +41,7 @@ class SABlock(nn.Module): causal: whether to use causal attention. sequence_length: if causal is True, it is necessary to specify the sequence length. with_cross_attention: Whether to use cross attention for conditioning. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. """ def __init__( @@ -42,6 +53,7 @@ def __init__( causal: bool = False, sequence_length: int | None = None, with_cross_attention: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() self.hidden_size = hidden_size @@ -51,6 +63,7 @@ def __init__( self.causal = causal self.sequence_length = sequence_length self.with_cross_attention = with_cross_attention + self.use_flash_attention = use_flash_attention if not (0 <= dropout_rate <= 1): raise ValueError("dropout_rate should be between 0 and 1.") @@ -95,15 +108,25 @@ def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch key = key.view(b, kv_t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, kv_t, hs) value = value.view(b, kv_t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, kv_t, hs) - # manual implementation of attention - attention_scores = (query @ key.transpose(-2, -1)) * self.scale + if self.use_flash_attention: + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + y = xops.memory_efficient_attention( + query, key, value, attn_bias=xops.LowerTriangularMask() if self.causal else None + ) + + else: + # manual implementation of attention + attention_scores = (query @ key.transpose(-2, -1)) * self.scale + + if self.causal: + attention_scores = attention_scores.masked_fill(self.causal_mask[:, :, :t, :kv_t] == 0, float("-inf")) - if self.causal: - attention_scores = attention_scores.masked_fill(self.causal_mask[:, :, :t, :kv_t] == 0, float("-inf")) + attention_probs = F.softmax(attention_scores, dim=-1) + attention_probs = self.drop_weights(attention_probs) + y = attention_probs @ value # (b, nh, t, kv_t) x (b, nh, kv_t, hs) -> (b, nh, t, hs) - attention_probs = F.softmax(attention_scores, dim=-1) - attention_probs = self.drop_weights(attention_probs) - y = attention_probs @ value # (b, nh, t, kv_t) x (b, nh, kv_t, hs) -> (b, nh, t, hs) y = y.transpose(1, 2).contiguous().view(b, t, c) # re-assemble all head outputs side by side y = self.out_proj(y) diff --git a/generative/networks/blocks/transformerblock.py b/generative/networks/blocks/transformerblock.py index fd034c7e..ae8cb962 100644 --- a/generative/networks/blocks/transformerblock.py +++ b/generative/networks/blocks/transformerblock.py @@ -32,6 +32,7 @@ class TransformerBlock(nn.Module): causal: whether to use causal attention. sequence_length: if causal is True, it is necessary to specify the sequence length. with_cross_attention: Whether to use cross attention for conditioning. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. """ def __init__( @@ -44,6 +45,7 @@ def __init__( causal: bool = False, sequence_length: int | None = None, with_cross_attention: bool = False, + use_flash_attention: bool = False, ) -> None: self.with_cross_attention = with_cross_attention super().__init__() @@ -62,6 +64,7 @@ def __init__( qkv_bias=qkv_bias, causal=causal, sequence_length=sequence_length, + use_flash_attention=use_flash_attention, ) self.norm2 = None @@ -75,6 +78,7 @@ def __init__( qkv_bias=qkv_bias, with_cross_attention=with_cross_attention, causal=False, + use_flash_attention=use_flash_attention, ) self.norm3 = nn.LayerNorm(hidden_size) diff --git a/generative/networks/nets/transformer.py b/generative/networks/nets/transformer.py index 0ea838af..dc961cc6 100644 --- a/generative/networks/nets/transformer.py +++ b/generative/networks/nets/transformer.py @@ -50,6 +50,7 @@ class DecoderOnlyTransformer(nn.Module): attn_layers_heads: Number of attention heads. with_cross_attention: Whether to use cross attention for conditioning. embedding_dropout_rate: Dropout rate for the embedding. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. """ def __init__( @@ -61,6 +62,7 @@ def __init__( attn_layers_heads: int, with_cross_attention: bool = False, embedding_dropout_rate: float = 0.0, + use_flash_attention: bool = False, ) -> None: super().__init__() self.num_tokens = num_tokens @@ -85,6 +87,7 @@ def __init__( causal=True, sequence_length=max_seq_len, with_cross_attention=with_cross_attention, + use_flash_attention=use_flash_attention, ) for _ in range(attn_layers_depth) ] From 32fe7067ea8efd9fd2f1aeab170af52c189e71fa Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Fri, 24 Mar 2023 07:26:45 +0000 Subject: [PATCH 2/4] Remove x-transformer from requirements-dev.txt Signed-off-by: Walter Hugo Lopez Pinaya --- requirements-dev.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/requirements-dev.txt b/requirements-dev.txt index 71ccc476..5b19e5d4 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -55,4 +55,3 @@ optuna git+https://github.com/Project-MONAI/MetricsReloaded@monai-support#egg=MetricsReloaded lpips==0.1.4 xformers==0.0.16 -x-transformers==1.8.1 From 66d67e6d601472c049b3b810f55aa6d57788221c Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Sat, 25 Mar 2023 12:15:22 +0000 Subject: [PATCH 3/4] Fix difference in values Signed-off-by: Walter Hugo Lopez Pinaya --- generative/networks/blocks/selfattention.py | 30 +++++++++++++++------ 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/generative/networks/blocks/selfattention.py b/generative/networks/blocks/selfattention.py index e318e55b..91d60891 100644 --- a/generative/networks/blocks/selfattention.py +++ b/generative/networks/blocks/selfattention.py @@ -19,12 +19,10 @@ from torch.nn import functional as F if importlib.util.find_spec("xformers") is not None: - import xformers import xformers.ops as xops has_xformers = True else: - xformers = None has_xformers = False @@ -67,6 +65,7 @@ def __init__( if not (0 <= dropout_rate <= 1): raise ValueError("dropout_rate should be between 0 and 1.") + self.dropout_rate = dropout_rate if hidden_size % num_heads != 0: raise ValueError("hidden size should be divisible by num_heads.") @@ -74,6 +73,9 @@ def __init__( if causal and sequence_length is None: raise ValueError("sequence_length is necessary for causal attention.") + if causal and sequence_length is None: + raise ValueError("sequence_length is necessary for causal attention.") + # key, query, value projections self.to_q = nn.Linear(hidden_size, hidden_size, bias=qkv_bias) self.to_k = nn.Linear(hidden_size, hidden_size, bias=qkv_bias) @@ -104,21 +106,31 @@ def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch key = self.to_k(kv) value = self.to_v(kv) - query = query.view(b, t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, t, hs) - key = key.view(b, kv_t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, kv_t, hs) - value = value.view(b, kv_t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, kv_t, hs) + query = query.view(b, t, self.num_heads, c // self.num_heads) # (b, t, nh, hs) + key = key.view(b, kv_t, self.num_heads, c // self.num_heads) # (b, kv_t, nh, hs) + value = value.view(b, kv_t, self.num_heads, c // self.num_heads) # (b, kv_t, nh, hs) if self.use_flash_attention: query = query.contiguous() key = key.contiguous() value = value.contiguous() y = xops.memory_efficient_attention( - query, key, value, attn_bias=xops.LowerTriangularMask() if self.causal else None + query=query, + key=key, + value=value, + scale=self.scale, + p=self.dropout_rate, + attn_bias=xops.LowerTriangularMask() if self.causal else None, ) else: + query = query.transpose(1, 2) # (b, nh, t, hs) + key = key.transpose(1, 2) # (b, nh, kv_t, hs) + value = value.transpose(1, 2) # (b, nh, kv_t, hs) + # manual implementation of attention - attention_scores = (query @ key.transpose(-2, -1)) * self.scale + query = query * self.scale + attention_scores = query @ key.transpose(-2, -1) if self.causal: attention_scores = attention_scores.masked_fill(self.causal_mask[:, :, :t, :kv_t] == 0, float("-inf")) @@ -127,7 +139,9 @@ def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch attention_probs = self.drop_weights(attention_probs) y = attention_probs @ value # (b, nh, t, kv_t) x (b, nh, kv_t, hs) -> (b, nh, t, hs) - y = y.transpose(1, 2).contiguous().view(b, t, c) # re-assemble all head outputs side by side + y = y.transpose(1, 2) # (b, nh, t, hs) -> (b, t, nh, hs) + + y = y.contiguous().view(b, t, c) # re-assemble all head outputs side by side y = self.out_proj(y) y = self.drop_output(y) From 8abd047e83b2bfc3f13dabf385b03460d3b6d6ab Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Sat, 25 Mar 2023 12:16:50 +0000 Subject: [PATCH 4/4] Add error Signed-off-by: Walter Hugo Lopez Pinaya --- generative/networks/blocks/selfattention.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/generative/networks/blocks/selfattention.py b/generative/networks/blocks/selfattention.py index 91d60891..b59b78bb 100644 --- a/generative/networks/blocks/selfattention.py +++ b/generative/networks/blocks/selfattention.py @@ -73,8 +73,8 @@ def __init__( if causal and sequence_length is None: raise ValueError("sequence_length is necessary for causal attention.") - if causal and sequence_length is None: - raise ValueError("sequence_length is necessary for causal attention.") + if use_flash_attention and not has_xformers: + raise ValueError("use_flash_attention is True but xformers is not installed.") # key, query, value projections self.to_q = nn.Linear(hidden_size, hidden_size, bias=qkv_bias)