# Expanding Transformer Architecture

In [1]:
import torch as th
import torch.nn.functional as F
import torch.nn as nn
from torch import Tensor
from numpy import cos, sin
import numpy as np
from grok.transformer import MultiHeadAttention, LayerNorm, FFN, Transformer
from typing import *
from copy import deepcopy

## Expanding Embedding Layer

In [2]:
vocab_len = 100
x = th.tensor([5,11,6, 66])
embedding_weight = th.rand((vocab_len,3))
F.embedding(x, embedding_weight)

tensor([[0.7974, 0.7866, 0.2839],
        [0.6512, 0.3038, 0.4062],
        [0.5381, 0.4823, 0.1273],
        [0.4329, 0.4239, 0.8615]])

In [3]:
embedding_weight_exp = th.zeros((vocab_len,4))
size = embedding_weight.shape
size = [slice(x) for x in size]
print(size)
embedding_weight_exp[size] = embedding_weight
F.embedding(x, embedding_weight_exp)

[slice(None, 100, None), slice(None, 3, None)]


tensor([[0.7974, 0.7866, 0.2839, 0.0000],
        [0.6512, 0.3038, 0.4062, 0.0000],
        [0.5381, 0.4823, 0.1273, 0.0000],
        [0.4329, 0.4239, 0.8615, 0.0000]])

## Expanding Embedding Layer with positional encoding 

In [4]:
def gen_position_encoding(context_len: int, d_model: int) -> th.Tensor:
        rows = [
            th.tensor(
                [
                    sin(pos / (10000 ** (i / d_model)))
                    if i % 2 == 0
                    else cos(pos / (10000 ** ((i - 1) / d_model)))
                    for i in range(d_model)
                ]
            )
            for pos in range(context_len)
        ]
        stack = th.stack(rows, dim=1)

        return stack.T  # type: ignore
gen_position_encoding(5,4)

tensor([[ 0.0000,  1.0000,  0.0000,  1.0000],
        [ 0.8415,  0.5403,  0.0100,  1.0000],
        [ 0.9093, -0.4161,  0.0200,  0.9998],
        [ 0.1411, -0.9900,  0.0300,  0.9996],
        [-0.7568, -0.6536,  0.0400,  0.9992]], dtype=torch.float64)

In [5]:
def embed(indices: th.Tensor, embedding_weight:th.Tensor, position_encoding:th.Tensor) -> th.Tensor:
        context_len = indices.shape[-1]
        pe = position_encoding[:context_len, :]  # type: ignore
        embedded = F.embedding(indices,embedding_weight)
        return pe + embedded

In [6]:
embed(x, embedding_weight, gen_position_encoding(10,3))

tensor([[ 0.7974,  1.7866,  0.2839],
        [ 1.4927,  0.8441,  0.4083],
        [ 1.4474,  0.0662,  0.1317],
        [ 0.5740, -0.5661,  0.8680]], dtype=torch.float64)

In [7]:
embed(x, embedding_weight_exp, gen_position_encoding(10,4))

tensor([[ 0.7974,  1.7866,  0.2839,  1.0000],
        [ 1.4927,  0.8441,  0.4162,  1.0000],
        [ 1.4474,  0.0662,  0.1473,  0.9998],
        [ 0.5740, -0.5661,  0.8915,  0.9996]], dtype=torch.float64)

## Expanding Head

## Adding Decoder Block

## Head Expansion logic

In [8]:
def new_emb(self, i):
    return self.embedding(i)

Transformer.embed = new_emb

In [9]:
net1 = Transformer(n_layers=1, n_heads=3, d_model=12)
net1.d_model//net1.n_heads

4

In [10]:
th.save(net1, "./checkpoints/net1.th")

In [11]:
def knowledge_transfer(net2:th.nn.Module, old_state_path:str):
    net1 = th.load(old_state_path)
    old_state = net1.state_dict()
    n_layers_old = net1.n_layers
    n_head_old = net1.n_heads

    dk_old = net1.d_model//net1.n_heads
    dk_new = net2.d_model//net2.n_heads

    new_state = net2.state_dict() 
    updated_state = deepcopy(new_state)
    for k in new_state:
        # print(k)
        if k == "position_encoding" or k == "self_attn_mask":
            continue
        elif "self_attn_norm" in k.split(".") or "ffn_norm" in k.split("."):
            continue
        elif "attn_heads" in k.split("."):
            updated_state[k] = th.zeros_like(new_state[k])
            weight_name = k.split(".")
            layer_idx = int(weight_name[2])
            if layer_idx < n_layers_old:
                head_idx = int(weight_name[5])   
                lst = [(i//dk_old, i%dk_old) for i in (head_idx*dk_new, head_idx*dk_new +dk_new)]
                w = []
                if lst[0][0] == lst[1][0]:
                    w.append(old_state[k][ lst[0][1]: lst[1][1], :])
                else:
                    for prev_head_idx in range(lst[0][0], lst[1][0]+1):
                        if not prev_head_idx < n_head_old:
                            continue
                        weight_name_old = weight_name.copy()
                        weight_name_old[5] = str(prev_head_idx)
                        k_old = ".".join(weight_name_old)

                        if prev_head_idx == lst[0][0]:
                            w_dash = old_state[k_old][lst[0][1]: , :]
                            # print(rng,w_dash.shape)
                            w.append(w_dash)

                        elif prev_head_idx == lst[1][0]:
                            w_dash = old_state[k_old][ :lst[1][1], :]
                            # print(rng, w_dash.shape)
                            w.append(w_dash)
                        else:
                            w.append(old_state[k_old])
                    if w:
                        final_old_w = th.cat(w)
                        dice = [slice(dim) for dim in final_old_w.shape]
                        updated_state[k][dice] = final_old_w
        else:
            updated_state[k] = th.zeros_like(new_state[k])
            if k in old_state:
                dice = [slice(dim) for dim in old_state[k].shape]
                updated_state[k][dice] = old_state[k]
        
    net2.load_state_dict(updated_state)

net2 = Transformer(n_layers=2, n_heads=4, d_model=16)
knowledge_transfer(net2, "./checkpoints/net1.th")

In [12]:
# for k in updated_state:
#     print(k, updated_state[k].shape)

In [13]:
x = th.tensor([11])
em1 = net1.embed(x)
em2 = net2.embed(x)
em1, em2

(tensor([[-0.3194,  0.5154,  0.7834,  0.5756,  0.2514,  0.6421, -0.3060, -2.4690,
           1.0348, -0.9342,  1.0119, -0.0656]], grad_fn=<EmbeddingBackward0>),
 tensor([[-0.3194,  0.5154,  0.7834,  0.5756,  0.2514,  0.6421, -0.3060, -2.4690,
           1.0348, -0.9342,  1.0119, -0.0656,  0.0000,  0.0000,  0.0000,  0.0000]],
        grad_fn=<EmbeddingBackward0>))

In [14]:
net1.decoder.blocks[0].self_attn(em1,em1,em1), net2.decoder.blocks[0].self_attn(em2,em2,em2) 

((tensor([[-0.4170, -0.1111,  0.3454, -0.5814,  0.2336,  0.0348, -0.4388, -0.0494,
           -0.1477,  0.4032,  0.4214,  0.0969]], grad_fn=<MmBackward0>),
  [],
  []),
 (tensor([[-0.4170, -0.1111,  0.3454, -0.5814,  0.2336,  0.0348, -0.4388, -0.0494,
           -0.1477,  0.4032,  0.4214,  0.0969,  0.0000,  0.0000,  0.0000,  0.0000]],
         grad_fn=<MmBackward0>),
  [],
  []))

In [15]:
net1.decoder.blocks[0](em1), net2.decoder.blocks[0](em2)

((tensor([[-1.0726,  0.2880,  0.9357,  0.0589,  0.1053,  0.8309, -0.6380, -2.5448,
            0.6212, -0.1655,  1.3562,  0.2249]],
         grad_fn=<NativeLayerNormBackward0>),
  [],
  []),
 (tensor([[-1.2233,  0.3458,  1.0938,  0.0832,  0.1330,  0.9774, -0.7215, -2.9246,
            0.7293, -0.1784,  1.5777,  0.2721, -0.0411, -0.0411, -0.0411, -0.0411]],
         grad_fn=<NativeLayerNormBackward0>),
  [],
  []))

In [16]:
em2, net2.decoder.blocks[1](em2)

(tensor([[-0.3194,  0.5154,  0.7834,  0.5756,  0.2514,  0.6421, -0.3060, -2.4690,
           1.0348, -0.9342,  1.0119, -0.0656,  0.0000,  0.0000,  0.0000,  0.0000]],
        grad_fn=<EmbeddingBackward0>),
 (tensor([[-0.4412,  0.5694,  0.8939,  0.6423,  0.2498,  0.7228, -0.4249, -3.0435,
            1.1983, -1.1855,  1.1705, -0.1339, -0.0545, -0.0545, -0.0545, -0.0545]],
         grad_fn=<NativeLayerNormBackward0>),
  [],
  []))

In [17]:
net1(x), net2(x)

((tensor([[ 0.6222, -0.1207,  0.5258,  ...,  0.9854, -1.2260, -1.0245]],
         grad_fn=<MmBackward0>),
  [],
  []),
 (tensor([[ 0.7600, -0.0692,  0.5707,  ...,  1.2149, -1.4931, -1.2358]],
         grad_fn=<MmBackward0>),
  [],
  []))

## Layer Normalization

In [18]:
norm = th.nn.LayerNorm(16)

In [19]:
em2, norm(em2)

(tensor([[-0.3194,  0.5154,  0.7834,  0.5756,  0.2514,  0.6421, -0.3060, -2.4690,
           1.0348, -0.9342,  1.0119, -0.0656,  0.0000,  0.0000,  0.0000,  0.0000]],
        grad_fn=<EmbeddingBackward0>),
 tensor([[-0.4412,  0.5694,  0.8939,  0.6423,  0.2498,  0.7228, -0.4249, -3.0435,
           1.1983, -1.1855,  1.1705, -0.1339, -0.0545, -0.0545, -0.0545, -0.0545]],
        grad_fn=<NativeLayerNormBackward0>))

In [20]:
net2.decoder.blocks[0].self_attn_norm.weight

Parameter containing:
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
       requires_grad=True)

In [21]:
em2, F.layer_norm(em2, (16,), norm_w)

NameError: name 'norm_w' is not defined

## Extra