# Test tf behaviour

This notebook is to test the function of the code snippets from the research paper and make sure that I am recoding to pytorch properly

In [18]:
import tensorflow as tf

##### Resnet stem

###### Torch
self.conv1 = nn.Conv2d(3, 16, kernel_size=7, stride=2, padding=3, bias=False)

In [19]:
image_stem = tf.keras.layers.Conv2D(filters=16, 
                                    kernel_size=7,
                                    strides=2,
                                    padding='same')

#### shape_list

In [20]:
from tf_code.self_atten_utils_tf import shape_list

In [21]:
# B, H, W, Channels last
exp_tensor = tf.random.normal(
    shape = [10, 60, 60, 3]
)

In [22]:
exp_tensor.shape

TensorShape([10, 60, 60, 3])

In [23]:
type(exp_tensor)

tensorflow.python.framework.ops.EagerTensor

In [24]:
result = shape_list(exp_tensor)
result

[10, 60, 60, 3]

In [25]:
type(result)

list

#### split_heads_2d

In [26]:
from tf_code.self_atten_utils_tf import split_heads_2d

We need to preproc the data to the format expected by the split_heads_2d command first

###### explore workflow of split_heads

image -> conv -> conv -> split -> split_heads

In [27]:
preproc_layer = tf.keras.layers.Conv2D(filters=16, 
                              kernel_size=7, padding='same')

In [28]:
con_result = image_stem(exp_tensor)

In [29]:
con_result.shape

TensorShape([10, 30, 30, 16])

In [30]:
split_frame = split_heads_2d(con_result, 4)

In [31]:
type(split_frame)

tensorflow.python.framework.ops.EagerTensor

In [32]:
split_frame.shape

TensorShape([10, 4, 30, 30, 4])

In [33]:
split_frame = split_heads_2d(con_result, 2)
split_frame.shape

TensorShape([10, 2, 30, 30, 8])

so we end up with batch / channels / h / w / channels/heads and we crash if the channels/heads is not doable

#### combine_heads_2d

In [34]:
from tf_code.self_atten_utils_tf import combine_heads_2d

In [35]:
rejoin_frame = combine_heads_2d(split_frame)

In [36]:
tf.math.equal(rejoin_frame, con_result)

<tf.Tensor: id=95, shape=(10, 30, 30, 16), dtype=bool, numpy=
array([[[[ True,  True,  True, ...,  True,  True,  True],
         [ True,  True,  True, ...,  True,  True,  True],
         [ True,  True,  True, ...,  True,  True,  True],
         ...,
         [ True,  True,  True, ...,  True,  True,  True],
         [ True,  True,  True, ...,  True,  True,  True],
         [ True,  True,  True, ...,  True,  True,  True]],

        [[ True,  True,  True, ...,  True,  True,  True],
         [ True,  True,  True, ...,  True,  True,  True],
         [ True,  True,  True, ...,  True,  True,  True],
         ...,
         [ True,  True,  True, ...,  True,  True,  True],
         [ True,  True,  True, ...,  True,  True,  True],
         [ True,  True,  True, ...,  True,  True,  True]],

        [[ True,  True,  True, ...,  True,  True,  True],
         [ True,  True,  True, ...,  True,  True,  True],
         [ True,  True,  True, ...,  True,  True,  True],
         ...,
         [ True,  True

In [37]:
rejoin_frame.shape

TensorShape([10, 30, 30, 16])

#### rel_to_abs

In [38]:
from tf_code.self_atten_utils_tf import rel_to_abs

This function runs within `relative_logits_1d` that runs within the `relative_logits` function which takes `q` and the fixed H, W, Nh and dkh variables

Note this requires the Full Workflow section due to dkh and other static bits plus the split_q frame

In [43]:
rel_embeddings_w = tf.compat.v1.get_variable(
        'r_width', shape=(2*W - 1, dkh),
        initializer = tf.random_normal_initializer(dkh**-0.5))

In [44]:
type(rel_embeddings_w)

tensorflow.python.ops.resource_variable_ops.ResourceVariable

In [45]:
rel_embeddings_w.shape

TensorShape([59, 4])

In [66]:
split_q.shape

TensorShape([10, 4, 30, 30, 4])

In [67]:
rel_logits = tf.einsum('bhxyd,md->bhxym', split_q, rel_embeddings_w)

In [68]:
rel_logits.shape

TensorShape([10, 4, 30, 30, 59])

In [69]:
rel_logits_reshape = tf.reshape(
        rel_logits, [-1, Nh * H, W, 2 * W-1]
    )

In [70]:
rel_logits_reshape.shape

TensorShape([10, 120, 30, 59])

In [None]:
# debug from here

In [71]:
rel_logits_rel_to_abs = rel_to_abs(rel_logits)

ValueError: too many values to unpack (expected 4)

##### Full Workflow

we take the input image [Batch, height, width, filters] after the stem then we make sure that dk and dv are divisible by the number of heads. 

We conv the inout batch into [batch, height, width, 2xdk+dv, 1]

In [47]:
# take conv result from above
con_result.shape

TensorShape([10, 30, 30, 16])

In [48]:
_, H, W, _ = shape_list(con_result)

In [49]:
dv = 16
dk = 16
Nh = 4

dkh = dk // Nh
dvh = dv // Nh

In [50]:
# conv to split frame size
conv_split = tf.keras.layers.Conv2D(filters=2*dv+dk, 
                              kernel_size=1, padding='same')

In [51]:
con_split_res = conv_split(con_result)


In [52]:
con_split_res.shape

TensorShape([10, 30, 30, 48])

In [53]:
k, q, v = tf.split(con_split_res, [dk, dk, dv], axis=3)

In [54]:
k.shape

TensorShape([10, 30, 30, 16])

In [55]:
q.shape

TensorShape([10, 30, 30, 16])

In [56]:
v.shape

TensorShape([10, 30, 30, 16])

##### Run Split Heads

In [57]:
split_q = split_heads_2d(q, Nh)

In [58]:
split_q.shape

TensorShape([10, 4, 30, 30, 4])

In [59]:
split_k = split_heads_2d(k, Nh)

In [60]:
flatten_hw = lambda x, d: tf.reshape(x, [-1, Nh, H*W, d])

In [61]:
flattened_q = flatten_hw(split_q, dk//Nh)

In [62]:
flattened_q.shape

TensorShape([10, 4, 900, 4])

In [63]:
flattened_k = flatten_hw(split_k, dk//Nh)

In [64]:
logits = tf.matmul(flattened_q, flattened_k,transpose_b=True)

In [65]:
logits.shape

TensorShape([10, 4, 900, 900])