In [None]:
import torch
import torch.nn as nn
from dataclasses import dataclass
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
import numpy as np
from scipy.sparse import coo_matrix

# Model Configurations

In [None]:
class ModelArgs:
    dim: int = 768
    n_layers: int = 12
    n_heads: int = 12
    n_segment:int = 3
    n_kv_heads: Optional[int] = None
    multiple_of: int = 64
    ffn_dim_multiplier: Optional[float] = 2.0
    norm_eps: float = 1e-12
    learning_rate: float = 5e-4
    vocab_size: int = 30522
    eval_iters: int = 100
    qkv_bias: bool = False
    block_size: int = 16
    max_iters: int = 1000
    eval_interval: int = 200
    drop_out: float = 0.1
    max_batch_size: int = 8
    max_seq_len: int = 768
    context_length: int = 12
    num_experts: int = 4
    top_k: int = 2
    device = 'cuda' if torch.cuda.is_available() else 'cpu'



# Bert Embeddings

In [None]:
class BERTEmbedding(nn.Module):
    def __init__(self,args:ModelArgs):
        super().__init__()
        self.tok_embed = nn.Embedding(args.vocab_size, args.dim)
        self.seg_embed = nn.Embedding(args.n_segment, args.dim)
        self.pos_embed = nn.Embedding(args.max_seq_len, args.dim)

        self.drop = nn.Dropout(args.drop_out)
        self.pos_inp = torch.tensor([i for i in range(args.max_seq_len)],)

    def forward(self, seq, seg):
        seg = seg.unsqueeze(1).expand(-1, seq.shape[0])

        embed_val = self.tok_embed(seq) + self.seg_embed(seg) + self.pos_embed(self.pos_inp)
        embed_val = self.drop(embed_val)
        return embed_val

In [None]:
def precompute_theta_pos_frequencies(head_dim: int, seq_len: int, device: str, theta: float = 10000.0):
    assert head_dim % 2 == 0, "Head dimension must be divisible by 2."
    theta_numerator = torch.arange(0, head_dim, 2, device=device).float()
    theta = 1.0 / (theta ** (theta_numerator / head_dim))
    m = torch.arange(seq_len, device=device)
    freqs = torch.outer(m, theta).float()
    freqs_complex = torch.polar(torch.ones_like(freqs, device=device), freqs)

    return freqs_complex


In [None]:
def apply_rotary_embeddings(x: torch.Tensor, freqs_complex: torch.Tensor, device: str):
    x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
    freqs_complex = freqs_complex.unsqueeze(0).unsqueeze(2)
    x_rotated = x_complex * freqs_complex
    x_out = torch.view_as_real(x_rotated)
    x_out = x_out.reshape(*x.shape).to(device)
    return x_out.type_as(x)


In [None]:
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
    batch_size, seq_len, n_kv_heads, head_dim = x.shape
    if n_rep == 1:
        return x
    return (
        x[:, :, :, None, :]
        .expand(batch_size, seq_len, n_kv_heads, n_rep, head_dim)
        .reshape(batch_size, seq_len, n_kv_heads * n_rep, head_dim)
    )


# Multi-Attention

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.head_q = args.n_heads
        self.head_kv = args.n_kv_heads if args.n_kv_heads is not None else args.n_heads
        self.rep = self.head_q // self.head_kv
        self.head_dim = args.dim // args.n_heads

        self.W_q = nn.Linear(args.dim, self.head_dim * self.head_q, bias=args.qkv_bias)
        self.W_k = nn.Linear(args.dim, self.head_dim * self.head_kv, bias=args.qkv_bias)
        self.W_v = nn.Linear(args.dim, self.head_dim * self.head_kv, bias=args.qkv_bias)
        self.out_proj = nn.Linear(args.dim, args.dim)

        self.k_cache = torch.zeros(
        args.max_batch_size, args.max_seq_len, self.head_kv, self.head_dim, device=args.device
        )
        self.v_cache = torch.zeros(
        args.max_batch_size, args.max_seq_len, self.head_kv, self.head_dim, device=args.device
        )

        self.register_buffer(
        'mask',
        torch.triu(torch.ones(args.max_seq_len, args.max_seq_len, device=args.device), diagonal=1)
        )

        self.drop_out = nn.Dropout(args.drop_out)

    def forward(self, x: torch.Tensor, start_pos: int):
        batch, seq_len, _ = x.shape

        freqs_complex = precompute_theta_pos_frequencies(self.head_dim, seq_len, device=x.device)

        query = self.W_q(x)
        k = self.W_k(x)
        v = self.W_v(x)

        query = query.view(batch, seq_len, self.head_q, self.head_dim)
        k = k.view(batch, seq_len, self.head_kv, self.head_dim)
        v = v.view(batch, seq_len, self.head_kv, self.head_dim)

        query = apply_rotary_embeddings(query, freqs_complex, device=x.device)
        k = apply_rotary_embeddings(k, freqs_complex, device=x.device)

        self.k_cache[:batch, start_pos: start_pos + seq_len] = k
        self.v_cache[:batch, start_pos: start_pos + seq_len] = v

        key = self.k_cache[:batch, :start_pos + seq_len]
        value = self.v_cache[:batch, :start_pos + seq_len]

        key = repeat_kv(key, self.rep)
        value = repeat_kv(value, self.rep)

        query = query.transpose(1, 2)
        key = key.transpose(1, 2)
        value = value.transpose(1, 2)

        attention_scores = query @ key.transpose(2, 3)

        mask_bool = self.mask[:seq_len, :seq_len].bool()
        attention_scores = attention_scores.masked_fill(mask_bool, float('-inf'))

        attention_weights = F.softmax(attention_scores/key.shape[-1]**0.5,dim = -1)
        attention_weights = self.drop_out(attention_weights)

        context_vec = attention_weights @ value
        context_vec = context_vec.transpose(1, 2).contiguous()
        context_vec = context_vec.view(batch, seq_len, -1)
        context_vec = self.out_proj(context_vec)

        return context_vec


# Encoder Block

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class BertEncoderBlock(nn.Module):
    def __init__(self,args:ModelArgs):
        super(BertEncoderBlock, self).__init__()
        self.hidden_size = args.dim
        self.num_heads = args.n_heads
        self.intermediate_size = 4*args.dim

        self.self_attention = MultiHeadAttention(args)

        self.attention_norm = nn.LayerNorm(args.dim)
        self.attention_dropout = nn.Dropout(args.drop_out)

        self.intermediate_dense = nn.Linear(args.dim, 4*args.dim)
        self.output_dense = nn.Linear(4*args.dim, args.dim)

        self.activation = F.gelu

        self.ffn_norm = nn.LayerNorm(args.dim)
        self.ffn_dropout = nn.Dropout(args.drop_out)

    def forward(self, x, attention_mask=None):
        attention_output, _ = self.self_attention(x, x, x, attn_mask=attention_mask)
        attention_output = self.attention_dropout(attention_output)
        attention_output = self.attention_norm(x + attention_output)

        intermediate_output = self.activation(self.intermediate_dense(attention_output))
        ffn_output = self.output_dense(intermediate_output)
        ffn_output = self.ffn_dropout(ffn_output)
        output = self.ffn_norm(attention_output + ffn_output)

        return output


# Encoder Layer

In [None]:
import torch
import torch.nn as nn

class BertEncoderLayer(nn.Module):
    def __init__(self, args:ModelArgs):
        super(BertEncoderLayer, self).__init__()
        self.num_layers = args.n_layers
        self.layers = nn.ModuleList([
            BertEncoderBlock(args)
            for _ in range(args.n_layers)
        ])
        self.layer_norm = nn.LayerNorm(args.dim)

    def forward(self, x, attention_mask=None):
        for layer in self.layers:
            x = layer(x, attention_mask)
        return self.layer_norm(x)


# Model

In [None]:
class BERT(nn.Module):
    def __init__(self,
                 args:ModelArgs):
        super().__init__()
        self.embedding = BERTEmbedding(args)
        self.encoder_layer = BertEncoderLayer(args)
        self.encoder_block = BertEncoderBlock(args)

    def forward(self, seq, seg):
        out = self.embedding(seq, seg)
        out = self.encoder_block(out)
        return out

In [None]:
import torch

def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    args = ModelArgs()
    args.device = device
    torch.autograd.set_detect_anomaly(True)
    model = BERT(ModelArgs())
    model = model.to(device)
    model = model.to(device)

    print(sum(p.numel() for p in model.parameters()) / 1e6, 'M parameters')
    optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate)


if __name__ == "__main__":
    main()


116.146944 M parameters
