<a href="https://colab.research.google.com/github/Debangshu93/LLama-Ensemble/blob/main/Loading_Model_Test.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
from torch import nn
from torch.nn import functional as F
import numpy as np
import math

Helper Functions

In [None]:
config = {
    "vocab_size": 32000,
    'batch_size': 2,
    'context_window': 1024,
    'd_model': 4096,
    'n_heads' : 32,
    'ensembles' : 3,
    'multiple_of' : 256,
    'n_layers' : 32,
    'device' : 'cuda'
}

def precompute_theta_pos_frequencies(head_dim: int, seq_len: int, device: str, theta: float = 10000.0):
    assert head_dim % 2 == 0, "Dimension must be divisible by 2"
    theta_numerator = torch.arange(0, head_dim, 2).float()
    theta = 1.0 / (theta ** (theta_numerator / head_dim)).to(device) # (Dim / 2)
    m = torch.arange(seq_len, device=device)
    freqs = torch.outer(m, theta).float()
    freqs_complex = torch.polar(torch.ones_like(freqs), freqs)
    return freqs_complex

def apply_rotary_embeddings(x: torch.Tensor, freqs_complex: torch.Tensor, device: str):
    x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
    freqs_complex = freqs_complex.unsqueeze(0).unsqueeze(1).unsqueeze(3)
    x_rotated = x_complex * freqs_complex
    x_out = torch.view_as_real(x_rotated)
    x_out = x_out.reshape(*x.shape)
    return x_out.type_as(x).to(device)

def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
    batch_size, ensembles, seq_len, n_kv_heads, head_dim = x.shape
    if n_rep == 1:
        return x
    return (
        x[:, :, :, :, None, :].expand(batch_size, ensembles, seq_len, n_kv_heads, n_rep, head_dim).reshape(batch_size, ensembles, seq_len, n_kv_heads * n_rep, head_dim)
    )

LoRA Llama-2

In [None]:
class LoraLayer(nn.Module):
  def __init__(self, features_in, features_out, name = "None", rank = 1, alpha = 1):
    super().__init__()

    self.lora_A = nn.Parameter(torch.zeros((rank, features_out)))
    self.lora_B = nn.Parameter(torch.zeros((features_in, rank)))
    nn.init.normal_(self.lora_A, mean = 0, std = 1)
    self.name = name

    self.scale = alpha/rank
    self.enabled = True

  def forward(self, original_weights):
    if self.enabled:
      return (original_weights + torch.matmul(self.lora_B, self.lora_A).view(original_weights.shape)*self.scale)
    else:
      return original_weights

class Ensemble(nn.Module):
  def __init__(self, features_in, features_out, num_ensebles=2):
    super().__init__()

    self.adapter_1 = LoraLayer(features_in, features_out, "adapter_1")
    self.adapter_2 = LoraLayer(features_in, features_out, "adapter_2")
    self.adapter_3 = LoraLayer(features_in, features_out, "adapter_3")

  def forward(self, original_weights):
    adapter_1_out = self.adapter_1(original_weights).unsqueeze(2)
    adapter_2_out = self.adapter_2(original_weights).unsqueeze(2)
    adapter_3_out = self.adapter_3(original_weights).unsqueeze(2)
    return torch.concatenate([adapter_1_out, adapter_2_out, adapter_3_out], dim = 2) #a three dimensional tensor


class Linear(nn.Module):
  def __init__(self, features_in, features_out, bias = True):
    super().__init__()

    self.weight = nn.Parameter(torch.zeros((features_out, features_in) , device = 'meta'))
    self.is_bias = bias
    if self.is_bias :
      self.bias = nn.Parameter(torch.zeros(features_out, device = 'meta'))
      nn.init.normal_(self.bias, mean = 0, std = 1)
    nn.init.normal_(self.weight, mean = 0, std = 1)
    self.adapters =  Ensemble(features_in, features_out)

  def forward(self, x): #input is a 4d tensor [batch, 1, feat_in, feat_out]
    self.parallel_weights = self.adapters(self.weight)
    if self.is_bias :
      return torch.matmul(x, self.parallel_weights.transpose(0,2)) + self.bias
    else:
      return torch.matmul(x, self.parallel_weights.transpose(0,2))  #output is a 4d tensor [batch, d, feat_in, out_dim]


class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim, device = 'meta'))

    def _norm(self, x: torch.Tensor):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x: torch.Tensor):
        return self.weight * self._norm(x.float()).type_as(x)


class SelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()


        self.n_kv_heads = config['n_heads']
        self.n_heads_q = config['n_heads']
        self.n_rep = self.n_heads_q // self.n_kv_heads
        self.head_dim = config['d_model'] // config['n_heads']
        self.ensembles = config['ensembles']
        self.device = config['device']

        self.wq = Linear(config['d_model'], config['n_heads'] * self.head_dim, bias=False)
        self.wk = Linear(config['d_model'], self.n_kv_heads * self.head_dim, bias=False)
        self.wv = Linear(config['d_model'], self.n_kv_heads * self.head_dim, bias=False)
        self.wo = Linear(config['n_heads'] * self.head_dim, config['d_model'], bias=False)

        self.cache_k = torch.zeros((config['batch_size'], self.ensembles, config['context_window'], self.n_kv_heads, self.head_dim)).to(self.device)
        self.cache_v = torch.zeros((config['batch_size'], self.ensembles, config['context_window'], self.n_kv_heads, self.head_dim)).to(self.device)

    def forward(self, x: torch.Tensor, start_pos: int, freqs_complex: torch.Tensor):

        batch_size, ensembles, seq_len, _ = x.shape

        xq = self.wq(x)
        xk = self.wk(x)
        xv = self.wv(x)

        xq = xq.view(batch_size, self.ensembles, seq_len, self.n_heads_q, self.head_dim)
        xk = xk.view(batch_size, self.ensembles, seq_len, self.n_kv_heads, self.head_dim)
        xv = xv.view(batch_size, self.ensembles, seq_len, self.n_kv_heads, self.head_dim)
        xq = apply_rotary_embeddings(xq, freqs_complex, device = x.device)
        xk = apply_rotary_embeddings(xk, freqs_complex, device = x.device)
        self.cache_k[:batch_size, : self.ensembles, start_pos : start_pos + seq_len] = xk
        self.cache_v[:batch_size, : self.ensembles, start_pos : start_pos + seq_len] = xv

        keys = self.cache_k[:batch_size, : self.ensembles, : start_pos + seq_len]
        values = self.cache_v[:batch_size, : self.ensembles, : start_pos + seq_len]

        keys = repeat_kv(keys, self.n_rep)
        values = repeat_kv(values, self.n_rep)

        xq = xq.transpose(2, 3)
        keys = keys.transpose(2, 3)
        values = values.transpose(2, 3)
        scores = torch.matmul(xq, keys.transpose(3, 4)) / np.sqrt(self.head_dim)
        scores = F.softmax(scores.float(), dim=-1).type_as(xq)

        output = torch.matmul(scores, values)
        output = (output.transpose(2, 3).contiguous().view(batch_size, self.ensembles, seq_len, -1))
        return self.wo(output)


class FeedForward(nn.Module):
    def __init__(self, config):
        super().__init__()

        hidden_dim = 4 * config["d_model"]
        hidden_dim = int(2 * hidden_dim / 3)
        #if args.ffn_dim_multiplier is not None:
            #hidden_dim = int(args.ffn_dim_multiplier * hidden_dim)
        # Round the hidden_dim to the nearest multiple of the multiple_of parameter
        hidden_dim = config["multiple_of"] * ((hidden_dim + config["multiple_of"] - 1) // config["multiple_of"])

        self.w1 = Linear(config["d_model"], hidden_dim, bias=False)
        self.w2 = Linear(hidden_dim, config["d_model"], bias=False)
        self.w3 = Linear(config["d_model"], hidden_dim, bias=False)

    def forward(self, x: torch.Tensor):
        swish = F.silu(self.w1(x))
        x_V = self.w3(x)
        x = swish * x_V
        x = self.w2(x)
        return x

class EncoderBlock(nn.Module):

    def __init__(self, config):
        super().__init__()

        self.n_heads = config["n_heads"]
        self.dim = config["d_model"]
        self.head_dim = config["d_model"] // config["n_heads"]

        self.attention = SelfAttention(config)
        self.feed_forward = FeedForward(config)

        self.attention_norm = RMSNorm(config["d_model"])
        self.ffn_norm = RMSNorm(config["d_model"])

    def forward(self, x: torch.Tensor, start_pos: int, freqs_complex: torch.Tensor):
        h = x + self.attention.forward(
            self.attention_norm(x), start_pos, freqs_complex
        )
        out = h + self.feed_forward.forward(self.ffn_norm(h))
        return out


class Transformer(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.vocab_size = config['vocab_size']
        self.n_layers = config['n_layers']
        self.tok_embeddings = nn.Embedding(config['vocab_size'], config['d_model'], device = 'meta')

        self.layers = nn.ModuleList()
        for layer_id in range(config['n_layers']):
            self.layers.append(EncoderBlock(config))

        self.norm = RMSNorm(config['d_model'])
        self.output = Linear(config['d_model'], config['vocab_size'], bias=False)

        self.freqs_complex = precompute_theta_pos_frequencies(config['d_model'] // config['n_heads'], config['context_window'] * 2, device = config['device'])

    def forward(self, tokens: torch.Tensor, start_pos: int):
        batch_size, _, seq_len = tokens.shape
        h = self.tok_embeddings(tokens)
        freqs_complex = self.freqs_complex[start_pos:start_pos + seq_len]

        for layer in self.layers:
            h = layer(h, start_pos, freqs_complex)
        h = self.norm(h)
        output = self.output(h)
        return output


Load the weights in CPU (13 GB)

In [None]:
from pathlib import Path
checkpoints_dir= 'llama-2-7b'
checkpoints = sorted(Path(checkpoints_dir).glob("*.pth"))
ckpt_path = checkpoints[0]
checkpoint = torch.load(ckpt_path, map_location='cpu')

Initialize the Model (3 GB)

In [None]:
torch.set_default_dtype(torch.bfloat16) #32 bits floating
model = Transformer(config)

Load the Weights into the Model into the GPU (16GB)

In [None]:
model.load_state_dict(checkpoint, strict = False, assign = True)
model = model.to('cuda')

Run an Inference Model (64 GB)

In [None]:
batch = torch.randint(0,1000, (config['batch_size'], 1, 1)).to('cuda')
with torch.no_grad():
    model(batch, 0).shape