# Expanding Transformer Architecture

In [22]:
import torch as th
import torch.nn.functional as F
import torch.nn as nn
from torch import Tensor
from numpy import cos, sin
import numpy as np
from grok.transformer import MultiHeadAttention, LayerNorm, FFN, Transformer
from typing import *
from copy import deepcopy

## Expanding Embedding Layer

In [23]:
vocab_len = 100
x = th.tensor([5,11,6, 66])
embedding_weight = th.rand((vocab_len,3))
F.embedding(x, embedding_weight)

tensor([[0.6228, 0.1856, 0.3963],
        [0.3907, 0.9098, 0.9340],
        [0.2448, 0.8799, 0.7597],
        [0.2656, 0.7177, 0.3176]])

In [24]:
embedding_weight_exp = th.zeros((vocab_len,4))
size = embedding_weight.shape
size = [slice(x) for x in size]
print(size)
embedding_weight_exp[size] = embedding_weight
F.embedding(x, embedding_weight_exp)

[slice(None, 100, None), slice(None, 3, None)]


tensor([[0.6228, 0.1856, 0.3963, 0.0000],
        [0.3907, 0.9098, 0.9340, 0.0000],
        [0.2448, 0.8799, 0.7597, 0.0000],
        [0.2656, 0.7177, 0.3176, 0.0000]])

## Expanding Embedding Layer with positional encoding 

In [25]:
def gen_position_encoding(context_len: int, d_model: int) -> th.Tensor:
        rows = [
            th.tensor(
                [
                    sin(pos / (10000 ** (i / d_model)))
                    if i % 2 == 0
                    else cos(pos / (10000 ** ((i - 1) / d_model)))
                    for i in range(d_model)
                ]
            )
            for pos in range(context_len)
        ]
        stack = th.stack(rows, dim=1)

        return stack.T  # type: ignore
gen_position_encoding(5,4)

tensor([[ 0.0000,  1.0000,  0.0000,  1.0000],
        [ 0.8415,  0.5403,  0.0100,  1.0000],
        [ 0.9093, -0.4161,  0.0200,  0.9998],
        [ 0.1411, -0.9900,  0.0300,  0.9996],
        [-0.7568, -0.6536,  0.0400,  0.9992]], dtype=torch.float64)

In [26]:
def embed(indices: th.Tensor, embedding_weight:th.Tensor, position_encoding:th.Tensor) -> th.Tensor:
        context_len = indices.shape[-1]
        pe = position_encoding[:context_len, :]  # type: ignore
        embedded = F.embedding(indices,embedding_weight)
        return pe + embedded

In [27]:
embed(x, embedding_weight, gen_position_encoding(10,3))

tensor([[ 0.6228,  1.1856,  0.3963],
        [ 1.2321,  1.4501,  0.9361],
        [ 1.1541,  0.4637,  0.7640],
        [ 0.4067, -0.2723,  0.3240]], dtype=torch.float64)

In [28]:
embed(x, embedding_weight_exp, gen_position_encoding(10,4))

tensor([[ 0.6228,  1.1856,  0.3963,  1.0000],
        [ 1.2321,  1.4501,  0.9440,  1.0000],
        [ 1.1541,  0.4637,  0.7797,  0.9998],
        [ 0.4067, -0.2723,  0.3476,  0.9996]], dtype=torch.float64)

## Expanding Head

## Adding Decoder Block

## Head Expansion logic

In [29]:
def new_emb(self, i):
    return self.embedding(i)

Transformer.embed = new_emb

In [30]:
net1 = Transformer(n_layers=1, n_heads=3, d_model=12)
net1.d_model//net1.n_heads

4

In [31]:
th.save(net1, "./checkpoints/net1.th")

In [32]:
def knowledge_transfer(net2:th.nn.Module, old_state_path:str):
    net1 = th.load(old_state_path)
    old_state = net1.state_dict()
    n_layers_old = net1.n_layers
    n_head_old = net1.n_heads

    dk_old = net1.d_model//net1.n_heads
    dk_new = net2.d_model//net2.n_heads

    new_state = net2.state_dict() 
    updated_state = deepcopy(new_state)
    for k in new_state:
        # print(k)
        if k == "position_encoding" or k == "self_attn_mask":
            continue
        elif "self_attn_norm" in k.split(".") or "ffn_norm" in k.split("."):
            continue
        elif "attn_heads" in k.split("."):
            updated_state[k] = th.zeros_like(new_state[k])
            weight_name = k.split(".")
            layer_idx = int(weight_name[2])
            if layer_idx < n_layers_old:
                head_idx = int(weight_name[5])   
                lst = [(i//dk_old, i%dk_old) for i in (head_idx*dk_new, head_idx*dk_new +dk_new)]
                w = []
                if lst[0][0] == lst[1][0]:
                    w.append(old_state[k][ lst[0][1]: lst[1][1], :])
                else:
                    for prev_head_idx in range(lst[0][0], lst[1][0]+1):
                        if not prev_head_idx < n_head_old:
                            continue
                        weight_name_old = weight_name.copy()
                        weight_name_old[5] = str(prev_head_idx)
                        k_old = ".".join(weight_name_old)

                        if prev_head_idx == lst[0][0]:
                            w_dash = old_state[k_old][lst[0][1]: , :]
                            # print(rng,w_dash.shape)
                            w.append(w_dash)

                        elif prev_head_idx == lst[1][0]:
                            w_dash = old_state[k_old][ :lst[1][1], :]
                            # print(rng, w_dash.shape)
                            w.append(w_dash)
                        else:
                            w.append(old_state[k_old])
                    if w:
                        final_old_w = th.cat(w)
                        dice = [slice(dim) for dim in final_old_w.shape]
                        updated_state[k][dice] = final_old_w
        else:
            updated_state[k] = th.zeros_like(new_state[k])
            if k in old_state:
                dice = [slice(dim) for dim in old_state[k].shape]
                updated_state[k][dice] = old_state[k]
        
    net2.load_state_dict(updated_state)

net2 = Transformer(n_layers=2, n_heads=4, d_model=16)
knowledge_transfer(net2, "./checkpoints/net1.th")

In [33]:
# for k in updated_state:
#     print(k, updated_state[k].shape)

In [34]:
x = th.tensor([11])
em1 = net1.embed(x)
em2 = net2.embed(x)
em1, em2

(tensor([[-0.9194, -0.1203, -0.2237, -0.1964, -0.1887, -0.7580, -1.2569,  0.3097,
           0.6622, -0.3703,  1.7225,  2.8508]], grad_fn=<EmbeddingBackward0>),
 tensor([[-0.9194, -0.1203, -0.2237, -0.1964, -0.1887, -0.7580, -1.2569,  0.3097,
           0.6622, -0.3703,  1.7225,  2.8508,  0.0000,  0.0000,  0.0000,  0.0000]],
        grad_fn=<EmbeddingBackward0>))

In [35]:
net1.decoder.blocks[0].self_attn(em1,em1,em1), net2.decoder.blocks[0].self_attn(em2,em2,em2) 

((tensor([[ 0.1791,  0.4886, -0.2609,  0.1333, -0.2004,  0.1003, -0.2532, -0.0619,
            0.2898, -0.3489, -0.2533, -0.1746]], grad_fn=<MmBackward0>),
  [],
  []),
 (tensor([[ 0.1791,  0.4886, -0.2609,  0.1333, -0.2004,  0.1003, -0.2532, -0.0619,
            0.2898, -0.3489, -0.2533, -0.1746,  0.0000,  0.0000,  0.0000,  0.0000]],
         grad_fn=<MmBackward0>),
  [],
  []))

In [36]:
net1.decoder.blocks[0](em1), net2.decoder.blocks[0](em2)

((tensor([[-0.8877,  0.0699, -0.2731, -0.2229, -0.6534, -0.8711, -1.1307,  0.1385,
            0.6710, -0.6157,  1.3199,  2.4552]],
         grad_fn=<NativeLayerNormBackward0>),
  [],
  []),
 (tensor([[-0.9711,  0.1167, -0.2735, -0.2102, -0.7058, -0.9796, -1.2583,  0.1975,
            0.8252, -0.6716,  1.5654,  2.8635, -0.1246, -0.1246, -0.1246, -0.1246]],
         grad_fn=<NativeLayerNormBackward0>),
  [],
  []))

In [37]:
em2, net2.decoder.blocks[1](em2)

(tensor([[-0.9194, -0.1203, -0.2237, -0.1964, -0.1887, -0.7580, -1.2569,  0.3097,
           0.6622, -0.3703,  1.7225,  2.8508,  0.0000,  0.0000,  0.0000,  0.0000]],
        grad_fn=<EmbeddingBackward0>),
 (tensor([[-1.0555, -0.2236, -0.3312, -0.3028, -0.2948, -0.8875, -1.4070,  0.2241,
            0.5910, -0.4839,  1.6949,  2.8697, -0.0984, -0.0984, -0.0984, -0.0984]],
         grad_fn=<NativeLayerNormBackward0>),
  [],
  []))

In [38]:
net1(x), net2(x)

((tensor([[ 0.9414, -0.4361,  0.2419,  ...,  0.1841, -0.8693,  0.3868]],
         grad_fn=<MmBackward0>),
  [],
  []),
 (tensor([[ 1.1203, -0.5008,  0.3135,  ...,  0.1870, -1.0280,  0.5455]],
         grad_fn=<MmBackward0>),
  [],
  []))

## Layer Normalization

In [39]:
norm = th.nn.LayerNorm(16)

In [40]:
em2, norm(em2)

(tensor([[-0.9194, -0.1203, -0.2237, -0.1964, -0.1887, -0.7580, -1.2569,  0.3097,
           0.6622, -0.3703,  1.7225,  2.8508,  0.0000,  0.0000,  0.0000,  0.0000]],
        grad_fn=<EmbeddingBackward0>),
 tensor([[-1.0555, -0.2236, -0.3312, -0.3028, -0.2948, -0.8875, -1.4070,  0.2241,
           0.5910, -0.4839,  1.6949,  2.8697, -0.0984, -0.0984, -0.0984, -0.0984]],
        grad_fn=<NativeLayerNormBackward0>))

In [41]:
net2.decoder.blocks[0].self_attn_norm.weight

Parameter containing:
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
       requires_grad=True)

In [42]:
em2, F.layer_norm(em2, (16,), norm_w)

NameError: name 'norm_w' is not defined

## Extra

In [43]:
chkpnt = th.load("../log/checkpoints/128_4_2.pt")

In [44]:
chkpnt

{'epoch': 16,
 'global_step': 32,
 'pytorch-lightning_version': '1.4.2',
 'state_dict': OrderedDict([('transformer.position_encoding',
               tensor([[ 0.0000e+00,  1.0000e+00,  0.0000e+00,  ...,  1.0000e+00,
                         0.0000e+00,  1.0000e+00],
                       [ 8.4147e-01,  5.4030e-01,  7.6172e-01,  ...,  1.0000e+00,
                         1.1548e-04,  1.0000e+00],
                       [ 9.0930e-01, -4.1615e-01,  9.8705e-01,  ...,  1.0000e+00,
                         2.3096e-04,  1.0000e+00],
                       ...,
                       [ 1.2357e-01, -9.9234e-01,  1.3992e-01,  ...,  9.9998e-01,
                         5.4274e-03,  9.9999e-01],
                       [-7.6825e-01, -6.4014e-01, -6.6357e-01,  ...,  9.9998e-01,
                         5.5429e-03,  9.9998e-01],
                       [-9.5375e-01,  3.0059e-01, -9.9978e-01,  ...,  9.9998e-01,
                         5.6584e-03,  9.9998e-01]])),
              ('transformer.self_att

In [45]:
from argparse import Namespace

In [48]:
d = {"a":11}
Namespace(**d)

Namespace(a=11)