In [1]:
import numpy as np
import tf_utils as tfu
import tensorflow as tf
import tensorflow.keras as keras
import torch
import torch.nn as nn

In [2]:
strategy = tfu.strategy.cpu()

2022-05-01 18:42:46.036965: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-05-01 18:42:46.037337: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-05-01 18:42:46.041597: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-05-01 18:42:46.041954: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-05-01 18:42:46.042706: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from S

In [164]:
@CustomObject
class CosFormerMultiHeadAttention(keras.layers.Layer):
    def __init__(self,
                 embed_dim,
                 num_heads,
                 causal=False,
                 act_fun="relu",
                 **kwargs):
        super(CosFormerMultiHeadAttention, self).__init__(**kwargs)
        
        assert embed_dim % num_heads == 0
        
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.causal = causal
        self.act_fun = act_fun
        
        self.q_proj = keras.layers.Dense(embed_dim, activation=act_fun)
        self.k_proj = keras.layers.Dense(embed_dim, activation=act_fun)
        self.v_proj = keras.layers.Dense(embed_dim)
        self.out_proj = keras.layers.Dense(embed_dim)
        
    def get_index(self, seq_len):
        return np.pi / 2 * tf.reshape(tf.range(1, seq_len + 1, dtype=tf.float32), (1, -1, 1))
    
    def split_heads(self, x, batch_size):
        x = tf.reshape(x, (batch_size, -1, self.num_heads, self.head_dim))
        return tf.reshape(tf.transpose(x, (0, 2, 1, 3)), (batch_size*self.num_heads, -1, self.head_dim))
    
    def call(self, query, value=None, key=None, eps=1e-6):
        if key is None:
            key = query
        if value is None:
            value = query
        
        batch_size = tf.shape(query)[0]
        tgt_len = tf.shape(query)[1]
        src_len = tf.shape(key)[1]

        q = self.q_proj(query)
        k = self.k_proj(key)
        v = self.v_proj(value)

        # multi-head reshape
        q = self.split_heads(q, batch_size)
        k = self.split_heads(k, batch_size)
        v = self.split_heads(v, batch_size)

        # cos transform
        m = tf.cast(tf.maximum(src_len, tgt_len), dtype=tf.float32)
        weight_index = self.get_index(m)
        q_ = tf.concat([q * tf.sin(weight_index[:, :tgt_len, :] / m), q * tf.cos(weight_index[:, :tgt_len, :] / m)], axis=2)
        k_ = tf.concat([k * tf.sin(weight_index[:, :src_len, :] / m), k * tf.cos(weight_index[:, :src_len, :] / m)], axis=2)

        if self.causal:
            kv_ = tf.einsum("nld,nlm->nldm", k_, v)
            kv_cum = tf.cumsum(kv_, axis=1)
            qkv = tf.einsum("nld,nldm->nlm", q_, kv_cum)
            k_cum = tf.cumsum(k_, axis=1)
            denom = tf.maximum(tf.einsum("nlm,nlm->nl", q_, k_cum), eps)
            attn_output = qkv / tf.expand_dims(denom, axis=2)
        else:
            kv_ = tf.einsum("nld,nlm->ndm", k_, v)
            z_ = 1 / tf.maximum(tf.einsum("nld,nd->nl", q_, tf.reduce_sum(k_, axis=1)), eps)
            attn_output = tf.einsum("nld,ndm,nl->nlm", q_, kv_, z_)
        
        attn_output = tf.reshape(attn_output, (batch_size, self.num_heads, -1, self.head_dim))
        attn_output = tf.reshape(tf.transpose(attn_output, (0, 2, 1, 3)), (batch_size, tgt_len, self.embed_dim))
        attn_output = self.out_proj(attn_output)
        return attn_output

In [230]:
def split_heads(x, batch_size, num_heads, depth):
    x = tf.reshape(x, (batch_size, -1, num_heads, depth))
    return tf.reshape(tf.transpose(x, (0, 2, 1, 3)), (batch_size*num_heads, -1, depth))

def get_index(seq_len):
    return np.pi / 2 * tf.reshape(tf.range(1, seq_len + 1, dtype=tf.float32), (1, -1, 1))

def call(query, num_heads, key=None, value=None, causal=False, eps=1e-6):
    if key is None:
        key = query
    if value is None:
        value = query

    batch_size, tgt_len, embed_dim = tf.shape(query)
    src_len = tf.shape(key)[1]
    head_dim = embed_dim // num_heads

    q = query
    k = key
    v = value
    
    # q = self.q_proj(query)
    # k = self.q_proj(key)
    # v = self.q_proj(value)

    # multi-head reshape
    q = split_heads(q, batch_size, num_heads, head_dim)
    k = split_heads(k, batch_size, num_heads, head_dim)
    v = split_heads(v, batch_size, num_heads, head_dim)

    # cos transform
    m = tf.cast(tf.maximum(src_len, tgt_len), dtype=tf.float32)
    weight_index = get_index(m)
    q_ = tf.concat([q * tf.sin(weight_index[:, :tgt_len, :] / m), q * tf.cos(weight_index[:, :tgt_len, :] / m)], axis=2)
    k_ = tf.concat([k * tf.sin(weight_index[:, :src_len, :] / m), k * tf.cos(weight_index[:, :src_len, :] / m)], axis=2)
    
    if causal:
        kv_ = tf.einsum("nld,nlm->nldm", k_, v)
        kv_cum = tf.cumsum(kv_, axis=1)
        qkv = tf.einsum("nld,nldm->nlm", q_, kv_cum)
        k_cum = tf.cumsum(k_, axis=1)
        denom = tf.maximum(tf.einsum("nlm,nlm->nl", q_, k_cum), eps)
        attn_output = qkv / tf.expand_dims(denom, axis=2)
    else:
        kv_ = tf.einsum("nld,nlm->ndm", k_, v)
        z_ = 1 / tf.maximum(tf.einsum("nld,nd->nl", q_, tf.reduce_sum(k_, axis=1)), eps)
        attn_output = tf.einsum("nld,ndm,nl->nlm", q_, kv_, z_)
        print(attn_output.shape)

    attn_output = tf.reshape(tf.transpose(tf.reshape(attn_output, (batch_size, num_heads, -1, head_dim)), (0, 2, 1, 3)), (batch_size, tgt_len, -1))
    return attn_output

def torch_get_index(seq_len):
    return np.pi / 2 * torch.arange(1, seq_len + 1).reshape(1, -1, 1)

def forward(
        query,
        num_heads,
        key=None,
        value=None,
        causal=False,
        eps = 1e-6,
    ):
    """Input shape: Sequence x Batch x Embedding
    Args:
        query (Tensor): `(L, N, E)` where L is the target sequence length, N is the batch size,
        E is the embedding dimension.
        key (Tensor): `(S, N, E)` where S is the source sequence length, N is the batch size,
        E is the embedding dimension.
        value (Tensor): `(S, N, E)` where S is the source sequence length, N is the batch size,
        E is the embedding dimension.
        attn_mask (Optional[Tensor], optional): typically used to implement causal attention, 
        where the mask prevents the attention from looking forward in time (default: None).
    """
    if key == None:
        key = query
    if value == None:
        value = query

    num_heads = num_heads
    tgt_len, bsz, embed_dim = query.size()
    src_len = key.size(0)
    head_dim = embed_dim // num_heads
    
    q = query
    k = key
    v = value

    # multihead reshape
    # (N * h, L, d)
    q = q.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
    # (N * h, S, d)
    k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
    # (N * h, S, d)
    v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
    
    # cos transform
    m = max(src_len, tgt_len)
    # get index and send to cuda
    weight_index = torch_get_index(m).to(q)
    # (N * h, L, 2 * d)
    q_ = torch.cat([q * torch.sin(weight_index[:, :tgt_len, :] / m), q * torch.cos(weight_index[:, :tgt_len, :] / m)], dim=-1)
    # (N * h, S, 2 * d)
    k_ = torch.cat([k * torch.sin(weight_index[:, :src_len, :] / m), k * torch.cos(weight_index[:, :src_len, :] / m)], dim=-1)

    if causal:
        ## Need to improve speed!
        # (N * h, L, 2 * d) (N * h, L, d) -> (N * h, L, h, 2 * d, d)
        kv_ = torch.einsum("nld,nlm->nldm", k_, v)
        # (N * h, L, 2 * d, d) -> (N * h, L, 2 * d, d)
        kv_cum = torch.cumsum(kv_, dim=1)
        # (N * h, L, 2 * d) (N * h, L, 2 * d, d) -> (N * h, L, d)
        qkv = torch.einsum("nld,nldm->nlm", q_, kv_cum)
        # (N * h, L, 2 * d) -> (N * h, L, 2 * d)
        k_cum = torch.cumsum(k_, dim=1)
        # (N * h, L, 2 * d) (N * h, L, 2 * d) -> (N * h, L)
        denom = torch.clamp_min(torch.einsum("nlm,nlm->nl", q_, k_cum), eps)
        # (N * h, L, d) (N * h, L, 1) -> (N * h, L, d)
        attn_output = qkv / denom.unsqueeze(-1)
        # (N * h, L, d) -> (L, N * h, d) -> (L, N, E)
        attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, -1)
    else:
        # (N * h, L, 2 * d) (N * h, L, d) -> (N * h, 2 * d, d)
        kv_ = torch.einsum('nld,nlm->ndm', k_, v)
        # (N * h, L, 2 * d) (N * h, 2 * d) -> (N * h, L)
        z_ = 1 / torch.clamp_min(torch.einsum('nld,nd->nl', q_, torch.sum(k_, axis=1)), eps)
        # (N * h, L, 2 * d) (N * h, d, 2 * d) (N * h, L) -> (N * h, L, d)
        attn_output = torch.einsum('nld,ndm,nl->nlm', q_, kv_, z_)
        # (N * h, L, d) -> (L, N * h, d) -> (L, N, E)
        print(attn_output.shape)
        attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, -1)

    return attn_output

In [247]:
batch_size = 512
num_heads = 4
embed_dim = 128
length = 148
head_dim = embed_dim // num_heads

In [249]:
a = np.arange(embed_dim*batch_size*length, dtype=np.float32).reshape(batch_size, length, embed_dim)
a.shape

(512, 148, 128)

In [245]:
atf = tf.constant(a)
at = torch.Tensor(a).transpose(0, 1)
atf.shape, at.shape

(TensorShape([512, 148, 128]), torch.Size([148, 512, 128]))

In [246]:
np.max(call(atf, num_heads) - forward(at, num_heads).transpose(0, 1))

(2048, 148, 32)
torch.Size([2048, 148, 32])


30.0

In [169]:
call(atf, num_heads, causal=True) - forward(at, num_heads, causal=True).transpose(0, 1)

<tf.Tensor: shape=(2, 5, 20), dtype=float32, numpy=
array([[[ 0.0000000e+00,  0.0000000e+00,  0.0000000e+00,  0.0000000e+00,
          0.0000000e+00,  0.0000000e+00,  0.0000000e+00,  0.0000000e+00,
          0.0000000e+00,  0.0000000e+00,  0.0000000e+00,  0.0000000e+00,
          0.0000000e+00,  0.0000000e+00,  0.0000000e+00,  0.0000000e+00,
          0.0000000e+00,  0.0000000e+00,  0.0000000e+00,  0.0000000e+00],
        [ 0.0000000e+00, -1.9073486e-06, -1.9073486e-06,  0.0000000e+00,
          1.9073486e-06,  3.8146973e-06,  3.8146973e-06,  5.7220459e-06,
          5.7220459e-06,  1.9073486e-06,  0.0000000e+00,  1.9073486e-06,
         -3.8146973e-06,  0.0000000e+00,  0.0000000e+00,  1.9073486e-06,
          1.9073486e-06,  3.8146973e-06,  3.8146973e-06,  3.8146973e-06],
        [ 0.0000000e+00,  0.0000000e+00,  0.0000000e+00,  3.8146973e-06,
          0.0000000e+00,  0.0000000e+00, -7.6293945e-06,  0.0000000e+00,
          0.0000000e+00,  0.0000000e+00,  3.8146973e-06,  3.8146973e-0

In [6]:
atf = tf.transpose(tf.reshape(a, (-1, batch_size*num_heads, head_dim)), (1, 0, 2))

In [7]:
weight_index = np.pi / 2 * tf.reshape(tf.range(1, length + 1, dtype=tf.float32), (1, -1, 1))
q_ = tf.concat([atf * tf.sin(weight_index[:, :length, :] / length), atf * tf.cos(weight_index[:, :length, :] / length)], axis=-1)

In [11]:
tf.maximum(tf.einsum("nld,nd->nl", q_, tf.reduce_sum(q_, axis=1)), 1e-6)
# tf.maximum(tf.einsum("nld,nd->nl", q_, tf.reduce_sum(q_, axis=1)), 1e-6)

<tf.Tensor: shape=(8, 5), dtype=float32, numpy=
array([[  2337.1118,  66363.09  , 152012.55  , 237462.38  , 299471.97  ],
       [  8728.424 ,  79309.57  , 171110.16  , 260841.7   , 323981.22  ],
       [ 16033.954 ,  93330.766 , 191337.83  , 285295.8   , 349404.72  ],
       [ 24253.697 , 108426.695 , 212695.5   , 310824.6   , 375742.5   ],
       [ 33387.67  , 124597.36  , 235183.25  , 337428.12  , 402994.44  ],
       [ 43435.855 , 141842.77  , 258801.03  , 365106.44  , 431160.62  ],
       [ 54398.266 , 160162.88  , 283548.8   , 393859.44  , 460240.97  ],
       [ 66274.89  , 179557.72  , 309426.66  , 423687.2   , 490235.6   ]],
      dtype=float32)>

In [12]:
at = torch.Tensor(a)
at = at.contiguous().view(-1, batch_size*num_heads, head_dim).transpose(0, 1)

In [13]:
weight_index = np.pi / 2 * torch.arange(1, length + 1).reshape(1, -1, 1)
q_ = torch.cat([at * torch.sin(weight_index[:, :length, :] / length), at * torch.cos(weight_index[:, :length, :] / length)], dim=-1)

In [14]:
torch.clamp_min(torch.einsum("nld,nd->nl", q_, torch.sum(q_, axis=1)), 1e-6)

tensor([[  2337.1118,  66363.0859, 152012.5469, 237462.3750, 299471.9375],
        [  8728.4238,  79309.5781, 171110.1562, 260841.7031, 323981.2188],
        [ 16033.9531,  93330.7656, 191337.8125, 285295.8125, 349404.6875],
        [ 24253.6992, 108426.7109, 212695.5156, 310824.5938, 375742.5000],
        [ 33387.6719, 124597.3672, 235183.2656, 337428.1250, 402994.4375],
        [ 43435.8555, 141842.7656, 258801.0312, 365106.4375, 431160.6250],
        [ 54398.2695, 160162.8906, 283548.8125, 393859.4375, 460240.9688],
        [ 66274.8906, 179557.7188, 309426.6562, 423687.1875, 490235.5938]])