Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.
Merged
1 change: 1 addition & 0 deletions generative/networks/blocks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@
from __future__ import annotations

from .selfattention import SABlock
from .transformerblock import TransformerBlock
54 changes: 31 additions & 23 deletions generative/networks/blocks/selfattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class SABlock(nn.Module):
qkv_bias: bias term for the qkv linear layer.
causal: whether to use causal attention.
sequence_length: if causal is True, it is necessary to specify the sequence length.
with_cross_attention: Whether to use cross attention for conditioning.
"""

def __init__(
Expand All @@ -40,8 +41,16 @@ def __init__(
qkv_bias: bool = False,
causal: bool = False,
sequence_length: int | None = None,
with_cross_attention: bool = False,
) -> None:
super().__init__()
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
self.scale = 1.0 / math.sqrt(self.head_dim)
self.causal = causal
self.sequence_length = sequence_length
self.with_cross_attention = with_cross_attention

if not (0 <= dropout_rate <= 1):
raise ValueError("dropout_rate should be between 0 and 1.")
Expand All @@ -52,50 +61,49 @@ def __init__(
if causal and sequence_length is None:
raise ValueError("sequence_length is necessary for causal attention.")

# output projection
self.out_proj = nn.Linear(hidden_size, hidden_size)
# key, query, value projections for all heads, but in a batch
self.qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias)
# key, query, value projections
self.to_q = nn.Linear(hidden_size, hidden_size, bias=qkv_bias)
self.to_k = nn.Linear(hidden_size, hidden_size, bias=qkv_bias)
self.to_v = nn.Linear(hidden_size, hidden_size, bias=qkv_bias)

# regularization
self.drop_weights = nn.Dropout(dropout_rate)
self.drop_output = nn.Dropout(dropout_rate)

self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
self.scale = 1.0 / math.sqrt(self.head_dim)
self.causal = causal
self.sequence_length = sequence_length
# output projection
self.out_proj = nn.Linear(hidden_size, hidden_size)

if causal and sequence_length is not None:
# causal mask to ensure that attention is only applied to the left in the input sequence
self.mask = torch.tril(torch.ones(sequence_length, sequence_length)).view(
1, 1, sequence_length, sequence_length
self.register_buffer(
"causal_mask",
torch.tril(torch.ones(sequence_length, sequence_length)).view(1, 1, sequence_length, sequence_length),
)
else:
self.mask = None

def forward(self, x: torch.Tensor) -> torch.Tensor:
def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor:
b, t, c = x.size() # batch size, sequence length, embedding dimensionality (hidden_size)

if self.sequence_length is not None and t != self.sequence_length:
raise ValueError("sequence length should be equal to the one specified in the SABlock constructor.")

# calculate query, key, values for all heads in batch and move head forward to be the batch dim
query, key, value = self.qkv(x).split(self.hidden_size, dim=2)
key = key.view(b, t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, t, hs)
query = self.to_q(x)

kv = context if context is not None else x
_, kv_t, _ = kv.size()
key = self.to_k(kv)
value = self.to_v(kv)

query = query.view(b, t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, t, hs)
value = value.view(b, t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, t, hs)
key = key.view(b, kv_t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, kv_t, hs)
value = value.view(b, kv_t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, kv_t, hs)

# manual implementation of attention
attention_scores = (query @ key.transpose(-2, -1)) * self.scale

if self.causal:
attention_scores = attention_scores.masked_fill(self.mask[:, :, :t, :t] == 0, float("-inf"))
attention_scores = attention_scores.masked_fill(self.causal_mask[:, :, :t, :kv_t] == 0, float("-inf"))

attention_probs = F.softmax(attention_scores, dim=-1)
attention_probs = self.drop_weights(attention_probs)
y = attention_probs @ value # (b, nh, t, t) x (b, nh, t, hs) -> (b, nh, t, hs)
y = attention_probs @ value # (b, nh, t, kv_t) x (b, nh, kv_t, hs) -> (b, nh, t, hs)
y = y.transpose(1, 2).contiguous().view(b, t, c) # re-assemble all head outputs side by side

y = self.out_proj(y)
Expand Down
88 changes: 88 additions & 0 deletions generative/networks/blocks/transformerblock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import torch
import torch.nn as nn
from monai.networks.blocks.mlp import MLPBlock

from generative.networks.blocks.selfattention import SABlock


class TransformerBlock(nn.Module):
"""
A transformer block, based on: "Dosovitskiy et al.,
An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>"

Args:
hidden_size: dimension of hidden layer.
mlp_dim: dimension of feedforward layer.
num_heads: number of attention heads.
dropout_rate: faction of the input units to drop.
qkv_bias: apply bias term for the qkv linear layer
causal: whether to use causal attention.
sequence_length: if causal is True, it is necessary to specify the sequence length.
with_cross_attention: Whether to use cross attention for conditioning.
"""

def __init__(
self,
hidden_size: int,
mlp_dim: int,
num_heads: int,
dropout_rate: float = 0.0,
qkv_bias: bool = False,
causal: bool = False,
sequence_length: int | None = None,
with_cross_attention: bool = False,
) -> None:
self.with_cross_attention = with_cross_attention
super().__init__()

if not (0 <= dropout_rate <= 1):
raise ValueError("dropout_rate should be between 0 and 1.")

if hidden_size % num_heads != 0:
raise ValueError("hidden_size should be divisible by num_heads.")

self.norm1 = nn.LayerNorm(hidden_size)
self.attn = SABlock(
hidden_size=hidden_size,
num_heads=num_heads,
dropout_rate=dropout_rate,
qkv_bias=qkv_bias,
causal=causal,
sequence_length=sequence_length,
)

self.norm2 = None
self.cross_attn = None
if self.with_cross_attention:
self.norm2 = nn.LayerNorm(hidden_size)
self.cross_attn = SABlock(
hidden_size=hidden_size,
num_heads=num_heads,
dropout_rate=dropout_rate,
qkv_bias=qkv_bias,
with_cross_attention=with_cross_attention,
causal=False,
)

self.norm3 = nn.LayerNorm(hidden_size)
self.mlp = MLPBlock(hidden_size, mlp_dim, dropout_rate)

def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor:
x = x + self.attn(self.norm1(x))
if self.with_cross_attention:
x = x + self.cross_attn(self.norm2(x), context=context)
x = x + self.mlp(self.norm3(x))
return x
74 changes: 52 additions & 22 deletions generative/networks/nets/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,32 @@

from __future__ import annotations

import importlib.util

import torch
import torch.nn as nn

if importlib.util.find_spec("x_transformers") is not None:
from x_transformers import Decoder, TransformerWrapper
from generative.networks.blocks.transformerblock import TransformerBlock

has_x_transformers = True
else:
has_x_transformers = False
__all__ = ["DecoderOnlyTransformer"]


__all__ = ["DecoderOnlyTransformer"]
class AbsolutePositionalEmbedding(nn.Module):
"""Absolute positional embedding.

Args:
max_seq_len: Maximum sequence length.
embedding_dim: Dimensionality of the embedding.
"""

def __init__(self, max_seq_len: int, embedding_dim: int) -> None:
super().__init__()
self.max_seq_len = max_seq_len
self.embedding_dim = embedding_dim
self.embedding = nn.Embedding(max_seq_len, embedding_dim)

def forward(self, x: torch.Tensor) -> torch.Tensor:
batch_size, seq_len = x.size()
positions = torch.arange(seq_len, device=x.device).repeat(batch_size, 1)
return self.embedding(positions)


class DecoderOnlyTransformer(nn.Module):
Expand All @@ -37,6 +49,7 @@ class DecoderOnlyTransformer(nn.Module):
attn_layers_depth: Number of attention layers.
attn_layers_heads: Number of attention heads.
with_cross_attention: Whether to use cross attention for conditioning.
embedding_dropout_rate: Dropout rate for the embedding.
"""

def __init__(
Expand All @@ -47,27 +60,44 @@ def __init__(
attn_layers_depth: int,
attn_layers_heads: int,
with_cross_attention: bool = False,
embedding_dropout_rate: float = 0.0,
) -> None:
super().__init__()
self.num_tokens = num_tokens
self.max_seq_len = max_seq_len
self.attn_layers_dim = attn_layers_dim
self.attn_layers_depth = attn_layers_depth
self.attn_layers_heads = attn_layers_heads
self.with_cross_attention = with_cross_attention

self.token_embeddings = nn.Embedding(num_tokens, attn_layers_dim)
self.position_embeddings = AbsolutePositionalEmbedding(max_seq_len=max_seq_len, embedding_dim=attn_layers_dim)
self.embedding_dropout = nn.Dropout(embedding_dropout_rate)

if has_x_transformers:
self.model = TransformerWrapper(
num_tokens=self.num_tokens,
max_seq_len=self.max_seq_len,
attn_layers=Decoder(
dim=self.attn_layers_dim,
depth=self.attn_layers_depth,
heads=self.attn_layers_heads,
cross_attend=with_cross_attention,
),
)
else:
raise ImportError("x-transformers is not installed.")
self.blocks = nn.ModuleList(
[
TransformerBlock(
hidden_size=attn_layers_dim,
mlp_dim=attn_layers_dim * 4,
num_heads=attn_layers_heads,
dropout_rate=0.0,
qkv_bias=False,
causal=True,
sequence_length=max_seq_len,
with_cross_attention=with_cross_attention,
)
for _ in range(attn_layers_depth)
]
)

self.to_logits = nn.Linear(attn_layers_dim, num_tokens)

def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor:
return self.model(x, context=context)
tok_emb = self.token_embeddings(x)
pos_emb = self.position_embeddings(x)
x = self.embedding_dropout(tok_emb + pos_emb)

for block in self.blocks:
x = block(x, context=context)

return self.to_logits(x)
7 changes: 0 additions & 7 deletions tests/test_selfattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,6 @@ def test_ill_arg(self):
with self.assertRaises(ValueError):
SABlock(hidden_size=12, num_heads=4, dropout_rate=0.4, causal=True, sequence_length=None)

def test_wrong_sequence_length(self):
net = SABlock(hidden_size=16, num_heads=4, dropout_rate=0.0, causal=True, sequence_length=6)
with self.assertRaises(ValueError):
with eval_mode(net):
result = net(torch.randn((2, 4, 16)))
self.assertEqual(result.shape, (2, 4, 16))


if __name__ == "__main__":
unittest.main()
1 change: 1 addition & 0 deletions tests/test_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def test_conditioned_models(self):
attn_layers_depth=2,
attn_layers_heads=2,
with_cross_attention=True,
embedding_dropout_rate=0,
)
with eval_mode(net):
net.forward(torch.randint(0, 10, (1, 16)), context=torch.randn(1, 4, 8))
Expand Down
Loading