# Test tf behaviour

This notebook is to test the function of the code snippets from the research paper and make sure that I am recoding to pytorch properly

In [1]:
import tensorflow as tf

#### Resnet stem

###### Torch
self.conv1 = nn.Conv2d(3, 16, kernel_size=7, stride=2, padding=3, bias=False)

In [2]:
image_stem = tf.keras.layers.Conv2D(filters=16, 
                                    kernel_size=7,
                                    strides=2,
                                    padding='same')

#### shape_list

In [3]:
from tf_code.self_atten_utils_tf import shape_list

In [4]:
# B, H, W, Channels last
exp_tensor = tf.random.normal(
    shape = [10, 60, 60, 3]
)

In [5]:
exp_tensor.shape

TensorShape([10, 60, 60, 3])

In [6]:
type(exp_tensor)

tensorflow.python.framework.ops.EagerTensor

In [7]:
result = shape_list(exp_tensor)
result

[10, 60, 60, 3]

In [8]:
type(result)

list

#### split_heads_2d

In [9]:
from tf_code.self_atten_utils_tf import split_heads_2d

We need to preproc the data to the format expected by the split_heads_2d command first

###### explore workflow of split_heads

image -> conv -> conv -> split -> split_heads

In [10]:
preproc_layer = tf.keras.layers.Conv2D(filters=16, 
                              kernel_size=7, padding='same')

In [11]:
con_result = image_stem(exp_tensor)

In [12]:
con_result.shape

TensorShape([10, 30, 30, 16])

In [13]:
split_frame = split_heads_2d(con_result, 4)

In [14]:
type(split_frame)

tensorflow.python.framework.ops.EagerTensor

In [15]:
split_frame.shape

TensorShape([10, 4, 30, 30, 4])

In [16]:
split_frame = split_heads_2d(con_result, 2)
split_frame.shape

TensorShape([10, 2, 30, 30, 8])

so we end up with batch / channels / h / w / channels/heads and we crash if the channels/heads is not doable

#### combine_heads_2d

In [17]:
from tf_code.self_atten_utils_tf import combine_heads_2d

In [18]:
rejoin_frame = combine_heads_2d(split_frame)

In [19]:
tf.math.equal(rejoin_frame, con_result)

<tf.Tensor: id=51, shape=(10, 30, 30, 16), dtype=bool, numpy=
array([[[[ True,  True,  True, ...,  True,  True,  True],
         [ True,  True,  True, ...,  True,  True,  True],
         [ True,  True,  True, ...,  True,  True,  True],
         ...,
         [ True,  True,  True, ...,  True,  True,  True],
         [ True,  True,  True, ...,  True,  True,  True],
         [ True,  True,  True, ...,  True,  True,  True]],

        [[ True,  True,  True, ...,  True,  True,  True],
         [ True,  True,  True, ...,  True,  True,  True],
         [ True,  True,  True, ...,  True,  True,  True],
         ...,
         [ True,  True,  True, ...,  True,  True,  True],
         [ True,  True,  True, ...,  True,  True,  True],
         [ True,  True,  True, ...,  True,  True,  True]],

        [[ True,  True,  True, ...,  True,  True,  True],
         [ True,  True,  True, ...,  True,  True,  True],
         [ True,  True,  True, ...,  True,  True,  True],
         ...,
         [ True,  True

In [20]:
rejoin_frame.shape

TensorShape([10, 30, 30, 16])

### Pre-prep cells to setup next function

In [23]:
# take conv result from above
con_result.shape

TensorShape([10, 30, 30, 16])

In [24]:
_, H, W, _ = shape_list(con_result)

In [26]:
dv = 16
dk = 16
Nh = 4

dkh = dk // Nh
dvh = dv // Nh

#### rel_to_abs

In [21]:
from tf_code.self_atten_utils_tf import rel_to_abs

This function runs within `relative_logits_1d` that runs within the `relative_logits` function which takes `q` and the fixed H, W, Nh and dkh variables

Note this requires the Full Workflow section due to dkh and other static bits plus the split_q frame

In [27]:
rel_embeddings_w = tf.compat.v1.get_variable(
        'r_width', shape=(2*W - 1, dkh),
        initializer = tf.random_normal_initializer(dkh**-0.5))

In [28]:
type(rel_embeddings_w)

tensorflow.python.ops.resource_variable_ops.ResourceVariable

In [29]:
rel_embeddings_w.shape

TensorShape([59, 4])

In [30]:
split_q.shape

NameError: name 'split_q' is not defined

In [None]:
rel_logits = tf.einsum('bhxyd,md->bhxym', split_q, rel_embeddings_w)

In [None]:
rel_logits.shape

In [None]:
rel_logits_reshape = tf.reshape(
        rel_logits, [-1, Nh * H, W, 2 * W-1]
    )

In [None]:
rel_logits_reshape.shape

In [None]:
# debug from here

In [None]:
rel_logits_rel_to_abs = rel_to_abs(rel_logits)

##### Test 2 of rel to abs

In [None]:
# generate a tensor split by heads
conv_tensor = tf.random.normal(
    shape = [10, 30, 30, 16]
)

split_tensor = tf.random.normal(
    shape = [10, 4, 30, 30, 4]
)


In [None]:
rel_to_abs(conv_tensor)

##### Full Workflow

we take the input image [Batch, height, width, filters] after the stem then we make sure that dk and dv are divisible by the number of heads. 

We conv the inout batch into [batch, height, width, 2xdk+dv, 1]

##### Self Attention 2D layer

In [31]:
from tf_code.self_atten_layer_tf import relative_logits

In [32]:
# conv to split frame size
conv_split = tf.keras.layers.Conv2D(filters=2*dv+dk, 
                              kernel_size=1, padding='same')

In [33]:
con_split_res = conv_split(con_result)


In [34]:
con_split_res.shape

TensorShape([10, 30, 30, 48])

In [35]:
k, q, v = tf.split(con_split_res, [dk, dk, dv], axis=3)

In [36]:
# take the original input and split into q k v

In [37]:
k.shape

TensorShape([10, 30, 30, 16])

In [38]:
q.shape

TensorShape([10, 30, 30, 16])

In [39]:
v.shape

TensorShape([10, 30, 30, 16])

In [40]:
q *= dkh ** -0.5

##### Run Split Heads

In [41]:
split_q = split_heads_2d(q, Nh)

In [42]:
split_q.shape

TensorShape([10, 4, 30, 30, 4])

In [43]:
split_k = split_heads_2d(k, Nh)

In [44]:
flatten_hw = lambda x, d: tf.reshape(x, [-1, Nh, H*W, d])

In [45]:
flattened_q = flatten_hw(split_q, dk//Nh)

In [46]:
flattened_q.shape

TensorShape([10, 4, 900, 4])

In [47]:
flattened_k = flatten_hw(split_k, dk//Nh)

In [48]:
logits = tf.matmul(flattened_q, flattened_k,transpose_b=True)

In [49]:
logits.shape

TensorShape([10, 4, 900, 900])

##### relative_logits

In [50]:
### Test einsum
x = tf.random.normal(
    shape = [5,4], mean = dkh**-0.5
)
y = tf.random.normal(
    shape = [4,5], mean = dkh**-0.5
)

result = tf.einsum('xy,yz->xz', x,y)

In [51]:
result.shape

TensorShape([5, 5])

In [52]:
rel_embeddings_w = tf.random.normal(
    shape = [2*H-1, dkh], mean = dkh**-0.5
)

In [53]:
rel_embeddings_w.shape

TensorShape([59, 4])

In [54]:
from tf_code.self_atten_utils_tf import relative_logits_1d

In [55]:
# test einsum
print(split_q.shape)
print(rel_embeddings_w.shape)

rel_logits = tf.einsum('bhxyd,md->bhxym', split_q, rel_embeddings_w)
print(rel_logits.shape)

(10, 4, 30, 30, 4)
(59, 4)
(10, 4, 30, 30, 59)


In [56]:
rel_logits_w = relative_logits_1d(
        split_q, rel_embeddings_w, H, W, Nh, [0, 1, 2, 4, 3, 5]
    )

In [57]:
rel_logits_w.shape

TensorShape([10, 4, 900, 900])