In [1]:
import tensorflow as tf
from einops.layers.tensorflow import Rearrange
from einops import rearrange, repeat

2023-01-06 03:50:37.547872: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-01-06 03:50:40.481650: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2023-01-06 03:50:40.481683: I tensorflow/compiler/xla/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.
2023-01-06 03:50:47.614196: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directo

In [2]:
def exists(val):
    return val is not None

def default(val, d):
    return val if exists(val) else d

def max_neg_value(t):
    return -tf.experimental.numpy.finfo(t.dtype).max

In [3]:
class InvariantPointAttention(tf.keras.layers.Layer):
    def __init__(
        self,
        dim,
        heads = 8,
        scalar_key_dim = 16,
        scalar_value_dim = 16,
        point_key_dim = 4,
        point_value_dim = 4,
        pairwise_repr_dim = None,
        require_pairwise_repr = True,
        eps = 1e-8,
        **kwargs
    ):
        super(InvariantPointAttention, self).__init__(**kwargs)

        self.eps = eps
        self.heads = heads
        self.require_pairwise_repr = require_pairwise_repr

        # num attention contributions

        num_attn_logits = 3 if require_pairwise_repr else 2

        # qkv projection for scalar attention (normal)

        self.scalar_attn_logits_scale = (num_attn_logits * scalar_key_dim) ** -0.5

        self.to_scalar_k = tf.keras.layers.Dense(scalar_key_dim * heads, use_bias = False)
        self.to_scalar_k = tf.keras.layers.Dense(scalar_key_dim * heads, use_bias = False)
        self.to_scalar_v = tf.keras.layers.Dense(scalar_value_dim * heads, use_bias = False)

        point_weight_init_value = tf.math.log(tf.math.exp(tf.ones((heads,))) - 1.)
        self.point_weights = tf.Variable(point_weight_init_value)

        self.point_attn_logits_scale = ((num_attn_logits * point_key_dim) * (9 / 2)) ** -0.5

        self.to_point_q = tf.keras.layers.Dense(point_key_dim * heads * 3, use_bias = False)
        self.to_point_k = tf.keras.layers.Dense(point_key_dim * heads * 3, use_bias = False)
        self.to_point_v = tf.keras.layers.Dense(point_value_dim * heads * 3, use_bias = False)

        pairwise_repr_dim = default(pairwise_repr_dim, dim) if require_pairwise_repr else 0

        if require_pairwise_repr:
            self.pairwise_attn_logits_scale = num_attn_logits ** -0.5

            self.to_pairwise_attn_bias = tf.keras.Sequential(
                tf.keras.layers.Dense(heads),
                Rearrange('b ... h -> (b h) ...')
            )
        
        self.to_out = tf.keras.layers.Dense(dim)
    
    def call(
        self,
        single_repr,
        pairwise_repr = None,
        *,
        rotations,
        translations,
        mask = None
    ):
        x, b, h, eps, require_pairwise_repr = single_repr, single_repr.shape[0], self.heads, self.eps, self.require_pairwise_repr
        assert not (require_pairwise_repr and not exists(pairwise_repr)), 'pairwise representation must be given as second argument'

        q_scalar, k_scalar, v_scalar = self.to_scalar_q(x), self.to_scalar_k(x), self.to_scalar_v(x)

        q_point, k_point, v_point = self.to_point_q(x), self.to_point_k(x), self.to_point_v(x)

        q_scalar, k_scalar, v_scalar = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (q_scalar, k_scalar, v_scalar))
        q_point, k_point, v_point = map(lambda t: rearrange(t, 'b n (h d c) -> (b h) n d c', h = h, c = 3), (q_point, k_point, v_point))

        rotations = repeat(rotations, 'b n r1 r2 -> (b h) n r1 r2', h = h)
        translations = repeat(translations, 'b n c -> (b h) n () c', h = h)

        q_point = tf.einsum('b n d c, b n c r -> b n d r', q_point, rotations) + translations
        k_point = tf.einsum('b n d c, b n c r -> b n d r', k_point, rotations) + translations
        v_point = tf.einsum('b n d c, b n c r -> b n d r', v_point, rotations) + translations

        attn_logits_scalar = tf.einsum('b i d, b j d -> b i j', q_scalar, k_scalar) * self.scalar_attn_logits_scale

        if require_pairwise_repr:
            attn_logits_pairwise = self.to_pairwise_attn_bias(pairwise_repr) * self.pairwise_attn_logits_scale
        
        point_qk_diff = rearrange(q_point, 'b i d c -> b i () d c') - rearrange(k_point, 'b j d c -> b () j d c')
        point_dist = (point_qk_diff ** 2).sum(dim = (-1, -2))

        point_weights = tf.math.softplus(self.point_weights)
        point_weights = repeat(point_weights, 'h -> (b h) () ()', b = b)

        attn_logits_points = -0.5 * (point_dist * point_weights * self.point_attn_logits_scale)

        attn_logits = attn_logits_scalar + attn_logits_points

        if require_pairwise_repr:
            attn_logits = attn_logits + attn_logits_pairwise
        
        if exists(mask):
            mask = rearrange(mask, 'b i -> b i ()') * rearrange(mask, 'b j -> b () j')
            mask = repeat(mask, 'b i j -> (b h) i j', h = h)
            mask_value = max_neg_value(attn_logits)
            attn_logits = tf.where(mask, mask_value, attn_logits)
        
        attn = tf.nn.softmax(attn_logits, axis = -1)

        results_scalar = tf.einsum('b i j, b j d -> b i d', attn, v_scalar)
        attn_with_heads = rearrange(attn, '(b h) i j -> b h i j', h = h)

        if require_pairwise_repr:
            results_pairwise = tf.einsum('b h i j, b i j d -> b h i d', attn_with_heads, pairwise_repr)

        results_points = tf.einsum('b i j, b j d c -> b i d c', attn, v_point)

        results_points = tf.einsum('b n d c, b n c r -> b n d r', results_points - translations, rotations.transpose(-1, -2))
        results_points_norm = tf.math.sqrt(tf.math.square(results_points).sum(axis = -1) + eps)

        results_scalar = rearrange(results_scalar, '(b h) n d -> b n (h d)', h = h)
        results_points = rearrange(results_points, '(b h) n d c -> b n (h d c)', h = h)
        results_points_norm = rearrange(results_points_norm, '(b h) n d -> b n (h d)', h = h)

        results = (results_scalar, results_points, results_points_norm)

        if require_pairwise_repr:
            results_pairwise = rearrange(results_pairwise, 'b h n d -> b n (h d)', h = h)
            results = (*results, results_pairwise)
        
        results = tf.concat(results, axis = -1)
        return self.to_out(results)

In [5]:
class FeedForward(tf.keras.layers.Layer):
    def __init__(self, dim, mult = 1., num_layers = 2, activation = tf.keras.layers.ReLU, **kwargs):
        super(FeedForward, self).__init__(**kwargs)
        self.mult = mult
        self.num_layers = num_layers
        self.activation = activation
        dim_hidden = dim * mult

        self.layers = []
        for ind in range(num_layers):
            is_first = ind == 0
            is_last  = ind == (num_layers - 1)
            dim_in   = dim if is_first else dim_hidden
            dim_out  = dim if is_last else dim_hidden

            self.layers.append(tf.keras.layers.Dense(dim_out))

            if is_last:
                continue

            self.layers.append(activation())
        self.layers = tf.keras.Sequential(self.layers)
    
    def call(self, inputs):
        return self.layers(inputs)

In [6]:
attn = InvariantPointAttention(
    dim = 64,                  # single (and pairwise) representation dimension
    heads = 8,                 # number of attention heads
    scalar_key_dim = 16,       # scalar query-key dimension
    scalar_value_dim = 16,     # scalar value dimension
    point_key_dim = 4,         # point query-key dimension
    point_value_dim = 4        # point value dimension
)

single_repr   = tf.random.normal((1, 256, 64))      # (batch x seq x dim)
pairwise_repr = tf.random.normal((1, 256, 256, 64)) # (batch x seq x seq x dim)
mask          = tf.ones((1, 256), dtype = tf.bool) # # (batch x seq)

rotations     = repeat(tf.eye(3), '... -> b n ...', b = 1, n = 256) # (batch x seq x rot1 x rot2) - example is identity
translations  = tf.zeros((1, 256, 3)) # translation, also identity for example

attn_out = attn(
    single_repr,
    pairwise_repr,
    rotations = rotations,
    translations = translations,
    mask = mask
)

2023-01-06 04:16:17.831806: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory
2023-01-06 04:16:17.839674: W tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:265] failed call to cuInit: UNKNOWN ERROR (303)
2023-01-06 04:16:17.850602: I tensorflow/compiler/xla/stream_executor/cuda/cuda_diagnostics.cc:156] kernel driver does not appear to be running on this host (codespaces-c2e541): /proc/driver/nvidia/version does not exist
2023-01-06 04:16:17.958452: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-01-06 04:16:19.168158: W tensorflow/tsl/framework/cpu

AttributeError: Exception encountered when calling layer 'invariant_point_attention' (type InvariantPointAttention).

'InvariantPointAttention' object has no attribute 'to_scalar_q'

Call arguments received by layer 'invariant_point_attention' (type InvariantPointAttention):
  • single_repr=tf.Tensor(shape=(1, 256, 64), dtype=float32)
  • pairwise_repr=tf.Tensor(shape=(1, 256, 256, 64), dtype=float32)
  • rotations=tf.Tensor(shape=(1, 256, 3, 3), dtype=float32)
  • translations=tf.Tensor(shape=(1, 256, 3), dtype=float32)
  • mask=tf.Tensor(shape=(1, 256), dtype=bool)

Bad pipe message: %s [b"\x05\xb1\x17H\xa5:TCe\xa6\xfc\x06\x86\xde\xce\xb4\xa3\x8f\x00\x00|\xc0,\xc00\x00\xa3\x00\x9f\xcc\xa9\xcc\xa8\xcc\xaa\xc0\xaf\xc0\xad\xc0\xa3\xc0\x9f\xc0]\xc0a\xc0W\xc0S\xc0+\xc0/\x00\xa2\x00\x9e\xc0\xae\xc0\xac\xc0\xa2\xc0\x9e\xc0\\\xc0`\xc0V\xc0R\xc0$\xc0(\x00k\x00j\xc0#\xc0'\x00g\x00@\xc0\n\xc0\x14\x009\x008\xc0\t\xc0\x13\x003\x002\x00\x9d\xc0\xa1\xc0\x9d\xc0Q\x00\x9c\xc0\xa0\xc0\x9c\xc0P\x00=\x00<\x005\x00/\x00\x9a\x00\x99\xc0\x07\xc0\x11\x00\x96\x00\x05\x00\xff\x01\x00\x00j\x00\x00\x00\x0e\x00\x0c\x00\x00\t127.0.0.1\x00\x0b\x00\x04\x03\x00\x01\x02\x00\n\x00\x0c\x00\n\x00\x1d\x00\x17\x00\x1e\x00\x19\x00\x18\x00#\x00\x00\x00\x16\x00\x00\x00\x17\x00\x00\x00\r\x000\x00.\x04\x03\x05\x03\x06\x03\x08\x07\x08\x08\x08\t\x08\n\x08\x0b\x08\x04\x08\x05\x08\x06\x04\x01\x05\x01\x06\x01\x03\x03\x02\x03\x03\x01\x02\x01\x03\x02\x02\x02\x04\x02\x05\x02"]
Bad pipe message: %s [b"\xc4\xe3\xad\xbe\xb5cA\xd7\xa1\x10\x7fW\xb2 \xf8\x17\x92e\x00\x00\xf4\xc00\xc0,\xc0(\xc0$\xc0\x14\x