<a href="https://colab.research.google.com/github/Rishit-dagli/Invariant-Attention/blob/main/example/invariant_attention_example_ipynb.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Invariant Attention Example

This notebook shows the the process of using the `invariant-attention` Python package. Invariant Point Attention which was used in the structure module of Alphafold2 from the paper Highly accurate protein structure prediction with AlphaFold for coordinate refinement. Invariant Point Attention is a form of attention that acts on a set of frames and is invariant under global Euclidean transformations on said frames.

If you find this useful please consider giving a ⭐ to the [repo](https://github.com/Rishit-dagli/Invariant-Attention/).

In [1]:
!pip install invariant-attention

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting invariant-attention
  Downloading Invariant_Attention-0.1.0-py3-none-any.whl (15 kB)
Collecting einops~=0.3.0
  Downloading einops-0.3.2-py3-none-any.whl (25 kB)
Installing collected packages: einops, invariant-attention
Successfully installed einops-0.3.2 invariant-attention-0.1.0


## Setup

In [2]:
import tensorflow as tf
from einops import repeat

## Standalone IPA

In [3]:
from invariant_attention import InvariantPointAttention

In [4]:
attn = InvariantPointAttention(
    dim=64,  # single (and pairwise) representation dimension
    heads=8,  # number of attention heads
    scalar_key_dim=16,  # scalar query-key dimension
    scalar_value_dim=16,  # scalar value dimension
    point_key_dim=4,  # point query-key dimension
    point_value_dim=4,  # point value dimension
)

single_repr = tf.random.normal((1, 256, 64))  # (batch x seq x dim)
pairwise_repr = tf.random.normal((1, 256, 256, 64))  # (batch x seq x seq x dim)
mask = tf.ones((1, 256), dtype=tf.bool)  # # (batch x seq)

rotations = repeat(
    tf.eye(3), "... -> b n ...", b=1, n=256
)
translations = tf.zeros((1, 256, 3))

attn_out = attn(
    single_repr,
    pairwise_repr,
    rotations=rotations,
    translations=translations,
    mask=mask,
) # (1, 256, 64)

## Running an IPA Block

In [5]:
from invariant_attention import IPABlock

In [6]:
block = IPABlock(
    dim=64,
    heads=8,
    scalar_key_dim=16,
    scalar_value_dim=16,
    point_key_dim=4,
    point_value_dim=4,
)

seq = tf.random.normal((1, 256, 64))
pairwise_repr = tf.random.normal((1, 256, 256, 64))
mask = tf.ones((1, 256), dtype=tf.bool)

rotations = repeat(tf.eye(3), "... -> b n ...", b=1, n=256)
translations = tf.zeros((1, 256, 3))

block_out = block(
    seq,
    pairwise_repr=pairwise_repr,
    rotations=rotations,
    translations=translations,
    mask=mask,
)

updates = tf.keras.layers.Dense(6)(block_out)
quaternion_update, translation_update = tf.split(
    updates, num_or_size_splits=2, axis=-1
)  # (1, 256, 3), (1, 256, 3)

## Running an IPATransformer

In [7]:
from invariant_attention import IPATransformer

In [8]:
seq = tf.random.normal((1, 256, 32))
pairwise_repr = tf.random.normal((1, 256, 256, 32))
mask = tf.ones((1, 256), dtype=tf.bool)
translations = tf.zeros((1, 256, 3))

model = IPATransformer(
    dim=32,
    depth=2,
    num_tokens=None,
    predict_points=False,
    detach_rotations=True,
)

outputs = model(
    single_repr=seq,
    translations=translations,
    quaternions=tf.random.normal((1, 256, 4)),
    pairwise_repr=pairwise_repr,
    mask=mask,
) # (1, 256, 32), (1, 256, 3), (1, 256, 4)