# Expanding Transformer Architecture

In [1]:
import torch as th
import torch.nn.functional as F
import torch.nn as nn
from torch import Tensor
from numpy import cos, sin
import numpy as np
from grok.transformer import MultiHeadAttention, LayerNorm, FFN, Transformer
from typing import *
from copy import deepcopy

## Expanding Embedding Layer

In [2]:
vocab_len = 100
x = th.tensor([5,11,6, 66])
embedding_weight = th.rand((vocab_len,3))
F.embedding(x, embedding_weight)

tensor([[0.1377, 0.7205, 0.4607],
        [0.1675, 0.6693, 0.7980],
        [0.7938, 0.8521, 0.8825],
        [0.0472, 0.5890, 0.8053]])

In [3]:
embedding_weight_exp = th.zeros((vocab_len,4))
size = embedding_weight.shape
size = [slice(x) for x in size]
print(size)
embedding_weight_exp[size] = embedding_weight
F.embedding(x, embedding_weight_exp)

[slice(None, 100, None), slice(None, 3, None)]


tensor([[0.1377, 0.7205, 0.4607, 0.0000],
        [0.1675, 0.6693, 0.7980, 0.0000],
        [0.7938, 0.8521, 0.8825, 0.0000],
        [0.0472, 0.5890, 0.8053, 0.0000]])

## Expanding Embedding Layer with positional encoding 

In [4]:
def gen_position_encoding(context_len: int, d_model: int) -> th.Tensor:
        rows = [
            th.tensor(
                [
                    sin(pos / (10000 ** (i / d_model)))
                    if i % 2 == 0
                    else cos(pos / (10000 ** ((i - 1) / d_model)))
                    for i in range(d_model)
                ]
            )
            for pos in range(context_len)
        ]
        stack = th.stack(rows, dim=1)

        return stack.T  # type: ignore
gen_position_encoding(5,4)

tensor([[ 0.0000,  1.0000,  0.0000,  1.0000],
        [ 0.8415,  0.5403,  0.0100,  1.0000],
        [ 0.9093, -0.4161,  0.0200,  0.9998],
        [ 0.1411, -0.9900,  0.0300,  0.9996],
        [-0.7568, -0.6536,  0.0400,  0.9992]], dtype=torch.float64)

In [5]:
def embed(indices: th.Tensor, embedding_weight:th.Tensor, position_encoding:th.Tensor) -> th.Tensor:
        context_len = indices.shape[-1]
        pe = position_encoding[:context_len, :]  # type: ignore
        embedded = F.embedding(indices,embedding_weight)
        return pe + embedded

In [6]:
embed(x, embedding_weight, gen_position_encoding(10,3))

tensor([[ 0.1377,  1.7205,  0.4607],
        [ 1.0090,  1.2096,  0.8002],
        [ 1.7031,  0.4359,  0.8868],
        [ 0.1883, -0.4010,  0.8117]], dtype=torch.float64)

In [7]:
embed(x, embedding_weight_exp, gen_position_encoding(10,4))

tensor([[ 0.1377,  1.7205,  0.4607,  1.0000],
        [ 1.0090,  1.2096,  0.8080,  1.0000],
        [ 1.7031,  0.4359,  0.9025,  0.9998],
        [ 0.1883, -0.4010,  0.8353,  0.9996]], dtype=torch.float64)

## Expanding Head

## Adding Decoder Block

## Head Expansion logic

In [8]:
def new_emb(self, i):
    return self.embedding(i)

Transformer.embed = new_emb

In [9]:
net1 = Transformer(n_layers=1, n_heads=3, d_model=12)
net1.d_model//net1.n_heads

4

In [10]:
th.save(net1, "./checkpoints/net1.th")

In [11]:
def knowledge_transfer(net2:th.nn.Module, old_state_path:str):
    net1 = th.load(old_state_path)
    old_state = net1.state_dict()
    n_layers_old = net1.n_layers
    n_head_old = net1.n_heads

    dk_old = net1.d_model//net1.n_heads
    dk_new = net2.d_model//net2.n_heads

    new_state = net2.state_dict() 
    updated_state = deepcopy(new_state)
    for k in new_state:
        # print(k)
        if k == "position_encoding" or k == "self_attn_mask":
            continue
        elif "self_attn_norm" in k.split(".") or "ffn_norm" in k.split("."):
            continue
        elif "attn_heads" in k.split("."):
            updated_state[k] = th.zeros_like(new_state[k])
            weight_name = k.split(".")
            layer_idx = int(weight_name[2])
            if layer_idx < n_layers_old:
                head_idx = int(weight_name[5])   
                lst = [(i//dk_old, i%dk_old) for i in (head_idx*dk_new, head_idx*dk_new +dk_new)]
                w = []
                if lst[0][0] == lst[1][0]:
                    w.append(old_state[k][ lst[0][1]: lst[1][1], :])
                else:
                    for prev_head_idx in range(lst[0][0], lst[1][0]+1):
                        if not prev_head_idx < n_head_old:
                            continue
                        weight_name_old = weight_name.copy()
                        weight_name_old[5] = str(prev_head_idx)
                        k_old = ".".join(weight_name_old)

                        if prev_head_idx == lst[0][0]:
                            w_dash = old_state[k_old][lst[0][1]: , :]
                            # print(rng,w_dash.shape)
                            w.append(w_dash)

                        elif prev_head_idx == lst[1][0]:
                            w_dash = old_state[k_old][ :lst[1][1], :]
                            # print(rng, w_dash.shape)
                            w.append(w_dash)
                        else:
                            w.append(old_state[k_old])
                    if w:
                        final_old_w = th.cat(w)
                        dice = [slice(dim) for dim in final_old_w.shape]
                        updated_state[k][dice] = final_old_w
        else:
            updated_state[k] = th.zeros_like(new_state[k])
            if k in old_state:
                dice = [slice(dim) for dim in old_state[k].shape]
                updated_state[k][dice] = old_state[k]
        
    net2.load_state_dict(updated_state)

net2 = Transformer(n_layers=2, n_heads=4, d_model=16)
knowledge_transfer(net2, "./checkpoints/net1.th")

In [12]:
# for k in updated_state:
#     print(k, updated_state[k].shape)

In [13]:
x = th.tensor([11])
em1 = net1.embed(x)
em2 = net2.embed(x)
em1, em2

(tensor([[-0.6082, -1.1562, -0.4722, -0.1791, -0.6773, -0.7933, -0.5872, -0.8640,
           0.3931, -0.2593, -0.6099, -1.1041]], grad_fn=<EmbeddingBackward0>),
 tensor([[-0.6082, -1.1562, -0.4722, -0.1791, -0.6773, -0.7933, -0.5872, -0.8640,
           0.3931, -0.2593, -0.6099, -1.1041,  0.0000,  0.0000,  0.0000,  0.0000]],
        grad_fn=<EmbeddingBackward0>))

In [14]:
net1.decoder.blocks[0].self_attn(em1,em1,em1), net2.decoder.blocks[0].self_attn(em2,em2,em2) 

((tensor([[-0.2708, -0.0279,  0.2348, -0.1655,  0.0838,  0.4526,  0.1671, -0.2010,
           -0.0152, -0.0036,  0.0667, -0.0103]], grad_fn=<MmBackward0>),
  [],
  []),
 (tensor([[-0.2708, -0.0279,  0.2348, -0.1655,  0.0838,  0.4526,  0.1671, -0.2010,
           -0.0152, -0.0036,  0.0667, -0.0103,  0.0000,  0.0000,  0.0000,  0.0000]],
         grad_fn=<MmBackward0>),
  [],
  []))

In [15]:
net1.decoder.blocks[0](em1), net2.decoder.blocks[0](em2)

((tensor([[-0.9175, -1.2313,  0.6247,  0.4132,  0.1563,  0.8593,  0.1814, -1.0971,
            2.1925,  0.4381, -0.1889, -1.4307]],
         grad_fn=<NativeLayerNormBackward0>),
  [],
  []),
 (tensor([[-1.1534, -1.4880,  0.3027,  0.1785, -0.1695,  0.5351,  0.0983, -1.5038,
            1.8673,  0.2039, -0.5342, -1.7418,  0.8512,  0.8512,  0.8512,  0.8512]],
         grad_fn=<NativeLayerNormBackward0>),
  [],
  []))

In [16]:
em2, net2.decoder.blocks[1](em2)

(tensor([[-0.6082, -1.1562, -0.4722, -0.1791, -0.6773, -0.7933, -0.5872, -0.8640,
           0.3931, -0.2593, -0.6099, -1.1041,  0.0000,  0.0000,  0.0000,  0.0000]],
        grad_fn=<EmbeddingBackward0>),
 (tensor([[-0.4085, -1.6820, -0.0927,  0.5884, -0.5691, -0.8388, -0.3598, -1.0030,
            1.9183,  0.4021, -0.4127, -1.5609,  1.0047,  1.0047,  1.0047,  1.0047]],
         grad_fn=<NativeLayerNormBackward0>),
  [],
  []))

In [17]:
net1(x), net2(x)

((tensor([[-0.1523, -0.0998,  0.0358,  ...,  0.0609, -0.1605,  0.0148]],
         grad_fn=<MmBackward0>),
  [],
  []),
 (tensor([[-0.3205, -0.1772, -0.0632,  ...,  0.1257, -0.1621, -0.3319]],
         grad_fn=<MmBackward0>),
  [],
  []))

## Layer Normalization

In [18]:
norm = th.nn.LayerNorm(16)

In [19]:
em2, norm(em2)

(tensor([[-0.6082, -1.1562, -0.4722, -0.1791, -0.6773, -0.7933, -0.5872, -0.8640,
           0.3931, -0.2593, -0.6099, -1.1041,  0.0000,  0.0000,  0.0000,  0.0000]],
        grad_fn=<EmbeddingBackward0>),
 tensor([[-0.4085, -1.6820, -0.0927,  0.5884, -0.5691, -0.8388, -0.3598, -1.0030,
           1.9182,  0.4021, -0.4127, -1.5609,  1.0047,  1.0047,  1.0047,  1.0047]],
        grad_fn=<NativeLayerNormBackward0>))

In [20]:
net2.decoder.blocks[0].self_attn_norm.weight

Parameter containing:
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
       requires_grad=True)

In [21]:
em2, F.layer_norm(em2, (16,), norm_w)

NameError: name 'norm_w' is not defined

## Extra