Skip to content

Commit

Permalink
Additional model - CrossAttentionTransformerEncoder (#251)
Browse files Browse the repository at this point in the history
* docu changes

* initial commit for dti branch

* improved NDict's print_tree func

* &

* made the print keys optional

* fixed ndict's print_tree

* fixed mypy & added documentation

* changed naming from 'print_keys' to 'print_values' + added docu

* addressed inline comment on documentation

* fxied one line

* minor fix

* adding CAT first commit

* minor progress

* minor

* flake fix

* minor changes

* documenting

* unittest checks all contexts

* clean unnecessary comments

* docu and clean

* more detailing

* more detailing

* added support in cls token for CAT

* docu and style

* added blog post ref

* added random batch size to test

* canceled the suppport in cls_tokens

* renaming to CrossAttentionTransformerEncoder

* fix param in docu

* revert changes in head1d

* detailed documentation for the model

* added support for kwargs in the model

* fixed default values + docu + tiny change in unittest

* changed the default kwargs values

---------

Co-authored-by: Sagi Polaczek <sagi.polaczek@ibm.com>
  • Loading branch information
SagiPolaczek and Sagi Polaczek committed Feb 9, 2023
1 parent 7f510ea commit 18aa0d3
Show file tree
Hide file tree
Showing 5 changed files with 238 additions and 7 deletions.
159 changes: 155 additions & 4 deletions fuse/dl/models/backbones/backbone_transformer.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import torch
from torch import nn
from typing import Optional
from vit_pytorch.vit import Transformer as _Transformer
from vit_pytorch.vit import repeat
from x_transformers import Encoder, CrossAttender, TransformerWrapper


class Transformer(nn.Module):
"""
Transformer backbone.
Gets a [batch_size, num_tokens, token_dim] shaped tensor
Returns a [batch_size, num_tokens + 1, token_dim] shaped tensor, where the first token is the CLS token
Returns a [batch_size, num_tokens + num_cls_tokens, token_dim] shaped tensor, where the first tokens are the CLS tokens
"""

def __init__(
Expand All @@ -27,7 +29,7 @@ def __init__(
super().__init__()
self.num_cls_tokens = num_cls_tokens
self.pos_embedding = nn.Parameter(torch.randn(1, num_tokens + num_cls_tokens, token_dim))
self.cls_token = nn.Parameter(torch.randn(1, num_cls_tokens, token_dim))
self.cls_tokens = nn.Parameter(torch.randn(1, num_cls_tokens, token_dim))
self.dropout = nn.Dropout(emb_dropout)
self.transformer = _Transformer(
dim=token_dim, depth=depth, heads=heads, dim_head=dim_head, mlp_dim=mlp_dim, dropout=dropout
Expand All @@ -36,12 +38,161 @@ def __init__(
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
:param x: [batch_size, num_tokens, token_dim] shaped tensor
:return: [batch_size, num_tokens + 1, token_dim] shaped tensor, where the first token is the CLS token
:return: [batch_size, num_tokens + num_cls_tokens, token_dim] shaped tensor, where the first tokens are the CLS tokens
"""
b, n, _ = x.shape
cls_tokens = repeat(self.cls_token, "1 a d -> b a d", b=b)
cls_tokens = repeat(self.cls_tokens, "1 a d -> b a d", b=b)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding
x = self.dropout(x)
x = self.transformer(x)
return x


class CrossAttentionTransformerEncoder(nn.Module):
"""
CrossAttentionTransformerEncoder backbone model based on x-transformers library.
Input:
two sequences 'seq_a, seq_b' with shapes [batch_size, len(seq_a)], [batch_size, len(seq_b)] respectively.
Output:
features tensor with shape [batch_size, output_dim]
Architecture:
The model consist of the following blocks:
* Two encoding layers - one for each sequence
* One or two cross attention (1) layers - depends on the user's request
* One linear layer
Forward Pass:
-> Receive two sequences as inputs, assuming both of them are already tokenized.
-> Pass each of the sequences through it's own encoder (extracting features)
-> Performs cross attention with the two encoded features using one of the sequences as context
(also supports using both as sequence:
* performs two times cross attention, each time using a different sequence as the context
* concat both outputs into one vector)
-> Apply linear layer on the last encoded vector
Building Blocks:
In this architecture we use three components from "x-transformers" library.
Here is a short summary - for more information consider check source code.
* TransformerWrapper - wraps an attention layer (in our case the Encoder) and applies token & positional embedding.
* Encoder - self attention layers. In our case it gets the embedding from the wrapper.
* CrossAttender - cross attention layers. In our case it gets the embedding from the encoders.
(1) see the following blog post for more info regarding cross attention in transformers:
https://vaclavkosar.com/ml/cross-attention-in-transformer-architecture
"""

def __init__(
self,
emb_dim: int,
num_tokens_a: int,
num_tokens_b: int,
max_seq_len_a: int,
max_seq_len_b: int,
depth_a: int = 6,
depth_b: int = 6,
depth_cross_attn: int = 6,
heads_a: int = 9,
heads_b: int = 9,
output_dim: Optional[int] = None,
context: str = "seq_b",
kwargs_wrapper_a: Optional[dict] = None,
kwargs_wrapper_b: Optional[dict] = None,
kwargs_encoder_a: Optional[dict] = None,
kwargs_encoder_b: Optional[dict] = None,
kwargs_cross_attn: Optional[dict] = None,
):
"""
:param emb_dim: inner model dimension
:param num_tokens_a: number of tokens of the first sequence
:param num_tokens_b: number of tokens of the second sequence
:param max_seq_len_a: the maximum length of the first sequence
:param max_seq_len_b: the maximum length of the second sequence
:param depth_a: first sequence encoder's depth
:param depth_b: second sequence encoder's depth
:param depth_cross_attn: cross attender(s)' length
:param heads_a: number of attention heads for the first sequence's encoder
:param heads_b: number of attention heads for the second sequence's encoder
:param output_dim: (optional) model's output dimension. if not give the emb dim will be used as default.
:param context: which sequence will be used as context in the cross attention module:
"seq_a": the first sequence will be used as a context
"seq_b": the second sequence will be used as a context
"both": will use two cross attention modules to take each one of the sequences as a context to the other one.
:param kwargs_wrapper_a: optional - additional arguments for sequence a's TransformerWrapper object
:param kwargs_wrapper_b: optional - additional arguments for sequence b's TransformerWrapper object
:param kwargs_encoder_a: optional - additional arguments for sequence a's Encoder object
:param kwargs_encoder_b: optional - additional arguments for sequence b's Encoder object
:param kwargs_cross_attn: optional - additional arguments for the CrossAttender object(s)
"""
super().__init__()

if output_dim is None:
output_dim = emb_dim

assert context in ["seq_a", "seq_b", "both"]
self._context = context

# init sequences' encoders
self.enc_a = TransformerWrapper(
num_tokens=num_tokens_a,
max_seq_len=max_seq_len_a,
**kwargs_wrapper_a if kwargs_wrapper_a else dict(),
attn_layers=Encoder(
dim=emb_dim, depth=depth_a, heads=heads_a, **kwargs_encoder_a if kwargs_encoder_a else dict()
),
)
self.enc_b = TransformerWrapper(
num_tokens=num_tokens_b,
max_seq_len=max_seq_len_b,
**kwargs_wrapper_b if kwargs_wrapper_b else dict(),
attn_layers=Encoder(
dim=emb_dim, depth=depth_b, heads=heads_b, **kwargs_encoder_b if kwargs_encoder_b else dict()
),
)

# cross attention module(s)
if self._context in ["seq_a", "seq_b"]:
self.cross_attn = CrossAttender(
dim=emb_dim, depth=depth_cross_attn, **kwargs_cross_attn if kwargs_cross_attn else dict()
)

else: # both
self.cross_attn_a_as_context = CrossAttender(
dim=emb_dim, depth=depth_cross_attn, **kwargs_cross_attn if kwargs_cross_attn else dict()
)
self.cross_attn_b_as_context = CrossAttender(
dim=emb_dim, depth=depth_cross_attn, **kwargs_cross_attn if kwargs_cross_attn else dict()
)

self.last_linear = nn.Linear(emb_dim, output_dim)

def forward(self, xa: torch.Tensor, xb: torch.Tensor) -> torch.Tensor:
"""
assumes input sequences are already tokenized
:param xa: tensor with shape [batch_size, seq_len_a]
:param xb: tensor with shape [batch_size, seq_len_b]
:return: raw embeddings
"""
# encoding stage
enc_xa = self.enc_a(xa, return_embeddings=True)
enc_xb = self.enc_b(xb, return_embeddings=True)

# cross attention stage
if self._context == "seq_a":
x = self.cross_attn(enc_xb, context=enc_xa)

elif self._context == "seq_b":
x = self.cross_attn(enc_xa, context=enc_xb)

else:
x_acb = self.cross_attn_b_as_context(enc_xa, context=enc_xb)
x_bca = self.cross_attn_a_as_context(enc_xb, context=enc_xa)
x = torch.cat((x_acb, x_bca), dim=1)

# linear layer
x = self.last_linear(x)
return x
2 changes: 1 addition & 1 deletion fuse/dl/models/heads/heads_1D.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(
head_name: str = "head_0",
mode: str = None, # "classification" or "regression"
conv_inputs: Sequence[Tuple[str, int]] = None,
num_outputs: int = 2, # num classes in case of classification
num_outputs: int = 2,
append_features: Optional[Sequence[Tuple[str, int]]] = None,
layers_description: Sequence[int] = (256,),
append_layers_description: Sequence[int] = tuple(),
Expand Down
80 changes: 80 additions & 0 deletions fuse/dl/tests/test_cat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
"""
(C) Copyright 2023 IBM Corp.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Created on Jan 11, 2023
"""

import torch
import unittest
import random
from fuse.dl.models.backbones.backbone_transformer import CrossAttentionTransformerEncoder


class TestCrossAttentionTransformerEncoder(unittest.TestCase):
def test_all_contexts(self) -> None:
"""
test cross attention transformer for each of the three context options: "seq_a", "seq_b" and "both"
"""

# model parameters
model_params = {
"emb_dim": 128,
"num_tokens_a": 10000,
"num_tokens_b": 20000,
"max_seq_len_a": 512,
"max_seq_len_b": 1024,
"output_dim": 256,
"kwargs_wrapper_a": dict(emb_dropout=0.1),
"kwargs_wrapper_b": dict(emb_dropout=0.1),
"kwargs_encoder_a": dict(layer_dropout=0.1),
"kwargs_cross_attn": dict(cross_attn_tokens_dropout=0.1),
}

# test for each context case
for context in ["seq_a", "seq_b", "both"]:
model_params["context"] = context
self.validate_model_with_params(model_params)

def validate_model_with_params(self, model_params: dict) -> None:
"""
Basic validation for the CrossAttentionTransformerEncoder model
:param model_params: A dictionary of the model's parameters to validate
"""

# init model
model = CrossAttentionTransformerEncoder(**model_params)

# init random sequences that don't exceed max sequences length
seq_a_len = random.randint(0, model_params["max_seq_len_a"])
seq_b_len = random.randint(0, model_params["max_seq_len_b"])
batch_size = random.randint(1, 10)
s1 = torch.randint(0, model_params["num_tokens_a"], (batch_size, seq_a_len))
s2 = torch.randint(0, model_params["num_tokens_b"], (batch_size, seq_b_len))

# processing sample
output = model(s1, s2)

# validation
assert output.shape[0] == batch_size
if output[:, 0].shape[1] != model_params["output_dim"]:
raise Exception(
f"Expected output dimension to be {model_params['output_dim']}, but got: {output.shape[1]}. used model parameters: {model_params}."
)


if __name__ == "__main__":
unittest.main()
3 changes: 2 additions & 1 deletion fuse/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,5 @@ omegaconf
nibabel
vit-pytorch
lifelines
clearml
clearml
x-transformers
1 change: 0 additions & 1 deletion fuseimg/data/ops/tests/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ def test_basic_1(self):
pipeline = PipelineDefault(
"test_pipeline",
[
# (op_normalize_against_self, {} ),
(OpClip(), dict(key="data.input.img", clip=(-0.5, 3.0))),
(OpToRange(), dict(key="data.input.img", from_range=(-0.5, 3.0), to_range=(-3.5, 3.5))),
],
Expand Down

0 comments on commit 18aa0d3

Please sign in to comment.