<a href="https://colab.research.google.com/github/Debangshu93/LLama-Ensemble/blob/main/Big_Llama.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
from torch import nn
from torch.nn import functional as F
import numpy as np
import math


In [None]:
class LoraLayer(nn.Module):
  def __init__(self, features_in, features_out, name = "None", rank = 1, alpha = 1):
    super().__init__()

    self.lora_A = nn.Parameter(torch.zeros((rank, features_out)))
    self.lora_B = nn.Parameter(torch.zeros((features_in, rank)))
    nn.init.normal_(self.lora_A, mean = 0, std = 1)
    self.name = name

    self.scale = alpha/rank
    self.enabled = True

  def forward(self, original_weights):
    if self.enabled:
      return original_weights + torch.matmul(self.lora_B, self.lora_A).view(original_weights.shape)*self.scale
    else:
      return original_weights

In [None]:
class Ensemble(nn.Module):
  def __init__(self, features_in, features_out, num_ensebles=2):
    super().__init__()

    self.adapter_1 = LoraLayer(features_in, features_out, "adapter_1")
    self.adapter_2 = LoraLayer(features_in, features_out, "adapter_2")
    self.adapter_3 = LoraLayer(features_in, features_out, "adapter_3")

  def forward(self, original_weights):
    adapter_1_out = self.adapter_1(original_weights).unsqueeze(2)
    adapter_2_out = self.adapter_2(original_weights).unsqueeze(2)
    adapter_3_out = self.adapter_3(original_weights).unsqueeze(2)
    return torch.concatenate([adapter_1_out, adapter_2_out, adapter_3_out], dim = 2) #a three dimensional tensor

In [None]:
class Linear(nn.Module):
  def __init__(self, features_in, features_out, bias = True):
    super().__init__()

    self.weight = nn.Parameter((torch.zeros(features_out, features_in), device = 'meta'))
    self.is_bias = bias
    if self.is_bias :
      self.bias = nn.Parameter(torch.zeros(features_out))
      nn.init.normal_(self.bias, mean = 0, std = 1)
    nn.init.normal_(self.weight, mean = 0, std = 1)
    self.adapters =  Ensemble(features_in, features_out)

    #self.adapter = LoraLayer(features_in, features_out)

  def forward(self, x): #input is a 4d tensor [batch, 1, feat_in, feat_out]
    self.parallel_weights = self.adapters(self.weight)
    if self.is_bias :
      return torch.matmul(x, self.parallel_weights.T) + self.bias
    else:
      return torch.matmul(x, self.parallel_weights.T)  #output is a 4d tensor [batch, d, feat_in, out_dim]

In [None]:
config = {
    "vocab_size": 65,
    'batch_size': 4,
    'context_window': 8,
    'd_model': 512,
    'n_heads' : 16,
    'ensembles' : 3,
    'multiple_of' : 48,
    'n_layers' : 16
}

Check RMS-Norm Functionality

In [None]:
class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x: torch.Tensor):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x: torch.Tensor):
        return self.weight * self._norm(x.float()).type_as(x)


layer = RMSNorm(config["d_model"])


In [None]:
batch = torch.rand(config["batch_size"], config["context_window"], config["d_model"])

In [None]:
expected_out = layer(batch)
expected_out = expected_out.unsqueeze(1)
expected_out = torch.concatenate([expected_out, expected_out, expected_out], dim = 1)

In [None]:
adapter_batch = batch.unsqueeze(1)
adapter_batch = torch.concatenate([adapter_batch, adapter_batch, adapter_batch], dim = 1)

In [None]:
output = layer(adapter_batch)

In [None]:
print(torch.all(torch.round(expected_out, decimals = 8) == torch.round(output, decimals = 8)))


tensor(True)


Rotatory Positional Embedding Attention Head

Requires no change

In [None]:
def precompute_theta_pos_frequencies(head_dim: int, seq_len: int, theta: float = 10000.0):
    assert head_dim % 2 == 0, "Dimension must be divisible by 2"
    theta_numerator = torch.arange(0, head_dim, 2).float()
    theta = 1.0 / (theta ** (theta_numerator / head_dim))
    m = torch.arange(seq_len)
    freqs = torch.outer(m, theta).float()
    freqs_complex = torch.polar(torch.ones_like(freqs), freqs)
    return freqs_complex

Compute the freqs_complex to be used for testing rotatory embedding

In [None]:
freqs_complex = precompute_theta_pos_frequencies(config["d_model"] // config["n_heads"], config["context_window"] * 2)
freqs_complex = freqs_complex[0:config['context_window']]

Test the functinality of applying rotatory embeddings

In [None]:
def apply_rotary_embeddings(x: torch.Tensor, freqs_complex: torch.Tensor):
    x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
    freqs_complex = freqs_complex.unsqueeze(0).unsqueeze(1).unsqueeze(3)
    x_rotated = x_complex * freqs_complex
    x_out = torch.view_as_real(x_rotated)
    x_out = x_out.reshape(*x.shape)
    return x_out.type_as(x)

In [None]:
def apply_vanilla_rotary_embeddings(x: torch.Tensor, freqs_complex: torch.Tensor):
    x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
    freqs_complex = freqs_complex.unsqueeze(0).unsqueeze(2) # requires change
    x_rotated = x_complex * freqs_complex
    x_out = torch.view_as_real(x_rotated)
    x_out = x_out.reshape(*x.shape)
    return x_out.type_as(x)

Fix the Input

In [None]:
xq = torch.rand(config['batch_size'], config['context_window'], config['n_heads'], config['d_model']// config['n_heads'])
xq_lora = xq.unsqueeze(1)
xq_lora = torch.concatenate([xq_lora, xq_lora, xq_lora], dim = 1)

In [None]:
expected_out = apply_vanilla_rotary_embeddings(xq, freqs_complex)
expected_out = expected_out.unsqueeze(1)
expected_out = torch.concatenate([expected_out, expected_out, expected_out], dim = 1)
output =  apply_rotary_embeddings(xq_lora, freqs_complex)

In [None]:
print(torch.all(torch.round(expected_out, decimals = 8) == torch.round(output, decimals = 8)))

tensor(True)


Cache Mechanism

In [None]:
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
    batch_size, ensembles, seq_len, n_kv_heads, head_dim = x.shape
    if n_rep == 1:
        return x
    return (
        x[:, :, :, :, None, :].expand(batch_size, ensembles, seq_len, n_kv_heads, n_rep, head_dim).reshape(batch_size, ensembles, seq_len, n_kv_heads * n_rep, head_dim)
    )

In [None]:
def repeat_vanilla_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
    batch_size, seq_len, n_kv_heads, head_dim = x.shape # requires change
    if n_rep == 1:
        return x
    return (
        x[:, :, :, :, None, :].expand(batch_size, seq_len, n_kv_heads, n_rep, head_dim).reshape(batch_size, seq_len, n_kv_heads * n_rep, head_dim)
    ) # require change

Self Attention Head

In [None]:
class SelfAttention_NoLoRA(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.n_kv_heads = config['n_heads']
        self.n_heads_q = config['n_heads']
        self.n_rep = self.n_heads_q // self.n_kv_heads
        self.head_dim = config['d_model'] // config['n_heads']

        self.wq = nn.Linear(config['d_model'], config['n_heads'] * self.head_dim, bias=False) #requires change
        self.wk = nn.Linear(config['d_model'], self.n_kv_heads * self.head_dim, bias=False) #requires change
        self.wv = nn.Linear(config['d_model'], self.n_kv_heads * self.head_dim, bias=False) #requires change
        self.wo = nn.Linear(config['n_heads'] * self.head_dim, config['d_model'], bias=False) #requires change

        self.cache_k = torch.zeros((config['batch_size'], config['context_window'], self.n_kv_heads, self.head_dim)) #requires change
        self.cache_v = torch.zeros((config['batch_size'], config['context_window'], self.n_kv_heads, self.head_dim)) #requires change

    def forward(
        self,
        x: torch.Tensor,
        start_pos: int,
        freqs_complex: torch.Tensor
    ):
        batch_size, seq_len, _ = x.shape  #requires change

        xq = self.wq(x)
        xk = self.wk(x)
        xv = self.wv(x)

        xq = xq.view(batch_size, seq_len, self.n_heads_q, self.head_dim) #requires change
        xk = xk.view(batch_size, seq_len, self.n_kv_heads, self.head_dim) #requires change
        xv = xv.view(batch_size, seq_len, self.n_kv_heads, self.head_dim) #requires change
        xq = apply_vanilla_rotary_embeddings(xq, freqs_complex)
        xk = apply_vanilla_rotary_embeddings(xk, freqs_complex)
        self.cache_k[:batch_size, start_pos : start_pos + seq_len] = xk #requires change
        self.cache_v[:batch_size, start_pos : start_pos + seq_len] = xv #requires change

        keys = self.cache_k[:batch_size, : start_pos + seq_len] #requires change
        values = self.cache_v[:batch_size, : start_pos + seq_len] #requires change

        keys = repeat_vanilla_kv(keys, self.n_rep)
        values = repeat_vanilla_kv(values, self.n_rep)

        xq = xq.transpose(1, 2) #requires change
        keys = keys.transpose(1, 2) #requires change
        values = values.transpose(1, 2) #requires change
        scores = torch.matmul(xq, keys.transpose(2, 3)) / np.sqrt(self.head_dim) #requires change
        scores = F.softmax(scores.float(), dim=-1).type_as(xq)

        output = torch.matmul(scores, values)
        output = (output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)) #requires change
        return self.wo(output)


layer = SelfAttention_NoLoRA(config)

In [None]:
class SelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.n_kv_heads = config['n_heads']
        self.n_heads_q = config['n_heads']
        self.n_rep = self.n_heads_q // self.n_kv_heads
        self.head_dim = config['d_model'] // config['n_heads']
        self.ensembles = config['ensembles']

        self.wq = Linear(config['d_model'], config['n_heads'] * self.head_dim, bias=False)
        self.wk = Linear(config['d_model'], self.n_kv_heads * self.head_dim, bias=False)
        self.wv = Linear(config['d_model'], self.n_kv_heads * self.head_dim, bias=False)
        self.wo = Linear(config['n_heads'] * self.head_dim, config['d_model'], bias=False)

        self.cache_k = torch.zeros((config['batch_size'], self.ensembles, config['context_window'], self.n_kv_heads, self.head_dim))
        self.cache_v = torch.zeros((config['batch_size'], self.ensembles, config['context_window'], self.n_kv_heads, self.head_dim))

    def forward(
        self,
        x: torch.Tensor,
        start_pos: int,
        freqs_complex: torch.Tensor
    ):
        batch_size, ensembles, seq_len, _ = x.shape

        xq = self.wq(x)
        xk = self.wk(x)
        xv = self.wv(x)

        xq = xq.view(batch_size, self.ensembles, seq_len, self.n_heads_q, self.head_dim)
        xk = xk.view(batch_size, self.ensembles, seq_len, self.n_kv_heads, self.head_dim)
        xv = xv.view(batch_size, self.ensembles, seq_len, self.n_kv_heads, self.head_dim)
        xq = apply_rotary_embeddings(xq, freqs_complex)
        xk = apply_rotary_embeddings(xk, freqs_complex)
        self.cache_k[:batch_size, : self.ensembles, start_pos : start_pos + seq_len] = xk
        self.cache_v[:batch_size, : self.ensembles, start_pos : start_pos + seq_len] = xv

        keys = self.cache_k[:batch_size, : self.ensembles, : start_pos + seq_len]
        values = self.cache_v[:batch_size, : self.ensembles, : start_pos + seq_len]

        keys = repeat_kv(keys, self.n_rep)
        values = repeat_kv(values, self.n_rep)

        xq = xq.transpose(2, 3)
        keys = keys.transpose(2, 3)
        values = values.transpose(2, 3)
        scores = torch.matmul(xq, keys.transpose(3, 4)) / np.sqrt(self.head_dim)
        scores = F.softmax(scores.float(), dim=-1).type_as(xq)

        output = torch.matmul(scores, values)
        output = (output.transpose(2, 3).contiguous().view(batch_size, self.ensembles, seq_len, -1))
        return self.wo(output)


layer_lora = SelfAttention(config)

In [None]:
for key in layer_lora.state_dict():
  if "lora" not in key:
    layer_lora.state_dict()[key].copy_(layer.state_dict()[key])

In [None]:
batch = torch.rand(config['batch_size'], config['context_window'], config['d_model'])

In [None]:
expected_output = layer(batch, 0, freqs_complex)
expected_out = expected_output.unsqueeze(1)
expected_out = torch.concatenate([expected_out, expected_out, expected_out], dim = 1)

output_lora = layer_lora(batch.unsqueeze(1), 0, freqs_complex)
print(torch.all(torch.round(expected_out, decimals = 2) == torch.round(output_lora, decimals = 2)))

tensor(True)


Feed Forward Block

In [None]:
class FeedForward(nn.Module):
    def __init__(self, config):
        super().__init__()

        hidden_dim = 4 * config["d_model"]
        hidden_dim = int(2 * hidden_dim / 3)
        #if args.ffn_dim_multiplier is not None:
            #hidden_dim = int(args.ffn_dim_multiplier * hidden_dim)
        # Round the hidden_dim to the nearest multiple of the multiple_of parameter
        hidden_dim = config["multiple_of"] * ((hidden_dim + config["multiple_of"] - 1) // config["multiple_of"])

        self.w1 = Linear(config["d_model"], hidden_dim, bias=False)
        self.w2 = Linear(hidden_dim, config["d_model"], bias=False)
        self.w3 = Linear(config["d_model"], hidden_dim, bias=False)

    def forward(self, x: torch.Tensor):
        swish = F.silu(self.w1(x))
        x_V = self.w3(x)
        x = swish * x_V
        x = self.w2(x)
        return x

layer_lora = FeedForward(config)

In [None]:
class FeedForward_NoLoRA(nn.Module):
    def __init__(self, config):
        super().__init__()

        hidden_dim = 4 * config["d_model"]
        hidden_dim = int(2 * hidden_dim / 3)
        #if args.ffn_dim_multiplier is not None:
            #hidden_dim = int(args.ffn_dim_multiplier * hidden_dim)
        # Round the hidden_dim to the nearest multiple of the multiple_of parameter
        hidden_dim = config["multiple_of"] * ((hidden_dim + config["multiple_of"] - 1) // config["multiple_of"])

        self.w1 = nn.Linear(config["d_model"], hidden_dim, bias=False)
        self.w2 = nn.Linear(hidden_dim, config["d_model"], bias=False)
        self.w3 = nn.Linear(config["d_model"], hidden_dim, bias=False)

    def forward(self, x: torch.Tensor):
        swish = F.silu(self.w1(x))
        x_V = self.w3(x)
        x = swish * x_V
        x = self.w2(x)
        return x

layer = FeedForward_NoLoRA(config)

In [None]:
for key in layer_lora.state_dict():
  if "lora" not in key:
    layer_lora.state_dict()[key].copy_(layer.state_dict()[key])

In [None]:
batch = torch.rand(config['batch_size'], config['context_window'], config['d_model'])

In [None]:
expected_output = layer(batch)
expected_out = expected_output.unsqueeze(1)
expected_out = torch.concatenate([expected_out, expected_out, expected_out], dim = 1)

output_lora = layer_lora(batch.unsqueeze(1))
print(torch.all(torch.round(expected_out, decimals = 6) == torch.round(output_lora, decimals = 6)))

tensor(False)


Encoder Block

In [None]:
class EncoderBlock(nn.Module):

    def __init__(self, config):
        super().__init__()

        self.n_heads = config["n_heads"]
        self.dim = config["d_model"]
        self.head_dim = config["d_model"] // config["n_heads"]

        self.attention = SelfAttention(config)
        self.feed_forward = FeedForward(config)

        self.attention_norm = RMSNorm(config["d_model"])
        self.ffn_norm = RMSNorm(config["d_model"])

    def forward(self, x: torch.Tensor, start_pos: int, freqs_complex: torch.Tensor):
        h = x + self.attention.forward(
            self.attention_norm(x), start_pos, freqs_complex
        )
        out = h + self.feed_forward.forward(self.ffn_norm(h))
        return out

block_lora = EncoderBlock(config)

In [None]:
class EncoderBlock_NoLoRA(nn.Module):

    def __init__(self, config):
        super().__init__()

        self.n_heads = config["n_heads"]
        self.dim = config["d_model"]
        self.head_dim = config["d_model"] // config["n_heads"]

        self.attention = SelfAttention_NoLoRA(config)
        self.feed_forward = FeedForward_NoLoRA(config)

        self.attention_norm = RMSNorm(config["d_model"])
        self.ffn_norm = RMSNorm(config["d_model"])

    def forward(self, x: torch.Tensor, start_pos: int, freqs_complex: torch.Tensor):
        h = x + self.attention.forward(
            self.attention_norm(x), start_pos, freqs_complex
        )
        out = h + self.feed_forward.forward(self.ffn_norm(h))
        return out

block = EncoderBlock_NoLoRA(config)

In [None]:
for key in block_lora.state_dict():
  if "lora" not in key:
    block_lora.state_dict()[key].copy_(block.state_dict()[key])

In [None]:
batch = torch.rand(config['batch_size'], config['context_window'], config['d_model'])

In [None]:
expected_output = block(batch, 0, freqs_complex)
expected_out = expected_output.unsqueeze(1)
expected_out = torch.concatenate([expected_out, expected_out, expected_out], dim = 1)

output_lora = block_lora(batch.unsqueeze(1), 0, freqs_complex)
print(torch.all(torch.round(expected_out, decimals = 5) == torch.round(output_lora, decimals = 5)))

tensor(True)


Transformer Model

In [None]:
class Transformer(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.vocab_size = config['vocab_size']
        self.n_layers = config['n_layers']
        self.tok_embeddings = nn.Embedding(config['vocab_size'], config['d_model'])

        self.layers = nn.ModuleList()
        for layer_id in range(config['n_layers']):
            self.layers.append(EncoderBlock(config))

        self.norm = RMSNorm(config['d_model'])
        self.output = Linear(config['d_model'], config['vocab_size'], bias=False)

        self.freqs_complex = precompute_theta_pos_frequencies(config['d_model'] // config['n_heads'], config['context_window'] * 2)

    def forward(self, tokens: torch.Tensor, start_pos: int):
        batch_size, _, seq_len = tokens.shape
        h = self.tok_embeddings(tokens)
        freqs_complex = self.freqs_complex[start_pos:start_pos + seq_len]

        for layer in self.layers:
            h = layer(h, start_pos, freqs_complex)
        h = self.norm(h)
        output = self.output(h)
        return output

model_lora = Transformer(config)

In [None]:
class Transformer_NoLoRA(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.vocab_size = config['vocab_size']
        self.n_layers = config['n_layers']
        self.tok_embeddings = nn.Embedding(config['vocab_size'], config['d_model'])

        self.layers = nn.ModuleList()
        for layer_id in range(config['n_layers']):
            self.layers.append(EncoderBlock_NoLoRA(config))

        self.norm = RMSNorm(config['d_model'])
        self.output = nn.Linear(config['d_model'], config['vocab_size'], bias=False)

        self.freqs_complex = precompute_theta_pos_frequencies(config['d_model'] // config['n_heads'], config['context_window'] * 2)

    def forward(self, tokens: torch.Tensor, start_pos: int):
        batch_size, seq_len = tokens.shape
        h = self.tok_embeddings(tokens)
        freqs_complex = self.freqs_complex[start_pos:start_pos + seq_len]

        for layer in self.layers:
            h = layer(h, start_pos, freqs_complex)
        h = self.norm(h)
        output = self.output(h)
        return output

model = Transformer_NoLoRA(config)

In [None]:
for key in model_lora.state_dict():
  if "lora" not in key:
    model_lora.state_dict()[key].copy_(model.state_dict()[key])

In [None]:
batch = torch.randint(0,65, (config['batch_size'], config['context_window']))

In [None]:
expected_output = model(batch, 0)
expected_out = expected_output.unsqueeze(1)
expected_out = torch.concatenate([expected_out, expected_out, expected_out], dim = 1)

output_lora = model_lora(batch.unsqueeze(1), 0)


In [None]:
output_lora.numel()

6240

In [None]:
count = 0
for i in range(output_lora.shape[0]):
  for j in range(output_lora.shape[1]):
    for k in range(output_lora.shape[2]):
      for l in range(output_lora.shape[3]):
        if torch.round(expected_out[i,j,k,l], decimals=4) != torch.round(output_lora[i,j,k,l], decimals = 4):
          count +=1

In [None]:
count

12