In [74]:
import torch
import torch.nn as nn
from torch.nn import functional as F
import torch.utils.data as data
import math
from typing import Sequence
import numpy as np

# Basic Training

In [80]:
class XORDataset(data.Dataset):
    def __init__(self, size, std = 0.1):
        self.size = size
        self.std = std
        self.data = None
        self.label = None
        self.generate_data_points()

    def generate_data_points(self):
        xs = torch.randint(0, 2, (self.size, 2), dtype = torch.float32)
        ys = (xs.sum(dim = 1) == 1).to(torch.float32)
        xs += (self.std * torch.randn_like(xs))
        self.data = xs
        self.label = ys

    def __getitem__(self, idx):
        return self.data[idx], self.label[idx]

    def __len__(self):
        return self.size

In [98]:
class CustomDataloader:
    def __init__(self, dataset, batch_size, shuffle = True, collate_fn = None):
        self.dataset = dataset
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.collate_fn = collate_fn if collate_fn else default_collate
        self.dataset_len = len(dataset)
        self.num_batches = math.ceil(self.dataset_len / batch_size)
        self.batches = None

    def __iter__(self):
        # Resample
        if self.shuffle:
            indices = torch.randperm(self.dataset_len)
        else:
            indices = torch.arange(self.dataset_len)
        batches = []
        for i in range(0, self.dataset_len, self.batch_size):
            batches.append(indices[i : i + self.batch_size])
        self.batches = batches
        return _DataIterator(self)

    def __len__(self):
        return self.num_batches

class _DataIterator:
    def __init__(self, loader):
        self.loader = loader
        self.current_batch = 0

    def __iter__(self):
        return self

    def __next__(self):
        if self.current_batch == len(self.loader.batches):
            raise StopIteration
        batch_indices = self.loader.batches[self.current_batch]
        batch_data = [self.loader.dataset[i] for i in batch_indices]
        batch = self.loader.collate_fn(batch_data)
        self.current_batch += 1
        return batch

def default_collate(batch_data): 
    elem = batch_data[0]
    if isinstance(elem, torch.Tensor):
        return torch.stack(batch_data, 0)
    elif isinstance(elem, Sequence):
        transposed = zip(*batch_data)
        return [default_collate(x) for x in transposed]

In [100]:
class LinearModel(nn.Module):
    def __init__(self, num_inputs, num_hidden, num_classes, act_fn = nn.Tanh()):
        super().__init__()
        self.linear1 = nn.Linear(num_inputs, num_hidden)
        self.linear2 = nn.Linear(num_hidden, num_classes)
        self.act_fn = act_fn

    def forward(self, x):
        x = self.linear1(x)
        x = self.act_fn(x)
        x = self.linear2(x)
        return x

In [102]:
def set_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.mps.manual_seed(seed)

In [104]:
epochs = 100
set_seed(42)
train_dataset = XORDataset(2500)
train_data_loader = CustomDataloader(train_dataset, batch_size = 128, shuffle = True)
model = LinearModel(2,4,1)
device = torch.device('mps') if torch.mps.is_available else torch.device('cpu')
print(device)
loss_fn = nn.BCEWithLogitsLoss() # one class
model.to(device)
optimizer = torch.optim.SGD(model.parameters(), lr = 0.1)
model.train()
for epoch in range(epochs):
    sum_correct = 0
    sum_data_points = 0
    for data_input, data_label in train_data_loader:
        data_input = data_input.to(device)
        data_label = data_label.to(device)
        pred = model(data_input).flatten()
        loss = loss_fn(pred, data_label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        sum_data_points += len(data_input)
        sum_correct += ((pred > 0) == data_label).sum()
    print(np.round((100*sum_correct/sum_data_points).cpu().item(), 2))

mps
50.32
48.44
44.04
48.4
49.52
47.92
47.8
49.04
50.8
50.32
50.32
50.96
53.76
51.08
51.6
50.84
52.04
55.28
50.76
52.08
50.76
50.76
50.8
50.88
51.04
51.0
53.68
56.08
64.08
75.52
89.12
97.04
99.36
99.84
100.0
100.0
100.0
100.0
100.0
100.0
100.0
100.0
100.0
100.0
100.0
100.0
100.0
100.0
100.0
100.0
100.0
100.0
100.0
100.0
100.0
100.0
100.0
100.0
100.0
100.0
100.0
100.0
100.0
100.0
100.0
100.0
100.0
100.0
100.0
100.0
100.0
100.0
100.0
100.0
100.0
100.0
100.0
100.0
100.0
100.0
100.0
100.0
100.0
100.0
100.0
100.0
100.0
100.0
100.0
100.0
100.0
100.0
100.0
100.0
100.0
100.0
100.0
100.0
100.0
100.0


# Activations

In [112]:
class Activation(nn.Module):
    def __init__(self):
        super().__init__()
        self.name = self.__class__.__name__

class Identity(Activation):
    def forward(self, x):
        return x

class ReLU(Activation):
    def forward(self, x):
        return x * (x > 0)

class Tanh(Activation):
    def forward(self, x):
        exp_x = torch.exp(x)
        exp_neg_x = torch.exp(-x)
        return (exp_x - exp_neg_x)/(exp_x + exp_neg_x)

class Sigmoid(Activation):
    def forward(self, x):
        return 1/(1 + torch.exp(-x))

class SiLU(Activation):
    def forward(self, x):
        return x/(1 + torch.exp(-x))

class SoftPlus(Activation):
    def forward(self, x):
        return torch.log(1 + nptorchexp(x))

class ELU(Activation):
    def __init__(self, alpha = 1):
        super().__init__()
        self.alpha = alpha
    def forward(self, x):
        return torch.where(x < 0, alpha*(torch.exp(x)-1), x)

# Initializations

In [None]:
class LinearModel(nn.Module):
    def __init__(self, num_inputs, num_hidden, num_classes, act_fn = nn.Tanh()):
        super().__init__()
        self.linear1 = nn.Linear(num_inputs, num_hidden)
        self.linear2 = nn.Linear(num_hidden, num_classes)
        self.act_fn = act_fn
        self._init_weights()

    def _init_weights(self):
        for name, param in self.named_parameters():
            if name.endswith('bias'):
                nn.init.zeros_(param)
            else:
                nn.init.normal_(param, std = np.sqrt(2 / param.shape[1])) # kaiming
                nn.init.normal_(param, std = np.sqrt(2 / (param.shape[0] + param.shape[1]))) # kaiming

    def forward(self, x):
        x = self.linear1(x)
        x = self.act_fn(x)
        x = self.linear2(x)
        return x

# Optimizations

In [120]:
class OptimizerTemplate:
    def __init__(self, params, lr):
        self.params = params
        self.lr = lr

    def zero_grad(self):
        for p in self.params:
            if p.grad is not None:
                p.grad.detach_()
                p.grad.zero_()

    @torch.no_grad()
    def step(self):
        for p in self.params:
            if p.grad is not None:
                self.update_param(p)

    def update_param(self, p):
        raise NotImplementedError

In [124]:
class SGD(OptimizerTemplate):
    def update_param(self, p):
        p_update = -self.lr*p.grad
        p._add(p_update)

In [126]:
class SGDM(OptimizerTemplate):
    def __init__(self, params, lr, momentum = 0.0):
        super().__init__()
        self.beta1 = momentum
        self.p_to_momentum = {p : torch.zeros_like(p.data) for p in params}
    
    def update_param(self, p):
        self.p_to_momentum[p] = p.grad + self.beta1*self.p_to_momentum[p]
        p_update = -self.lr*self.p_to_momentum[p]
        p._add(p_update)

In [None]:
class AdaGrad(OptimizerTemplate):
    def __init__(self, params, lr, epsilon = 1e-8):
        super().__init__()
        self.epsilon = epsilon
        self.p_to_sq_grad = {p : torch.zeros_like(p.data) for p in params}

    def update_params(self, p):
        self.p_to_sq_grad[p].add_(p.grad**2)
        p_update = -self.lr*p.grad/torch.sqrt(self.p_to_sq_grad[p] + self.epsilon)
        p._add(p_update)

In [None]:
class RMSProp(OptimizerTemplate):
    def __init__(self, params, lr, epsilon = 1e-8, beta2 = 0.999):
        super().__init__()
        self.epsilon = epsilon
        self.beta2 = beta2
        self.p_to_delta = {p : torch.zeros_like(p.data) for p in params}
        self.p_to_sq_grad = {p : torch.zeros_like(p.data) for p in params}

    def update_params(self, p):
        self.p_to_sq_grad[p] = self.beta2*self.p_to_sq_grad[p] + (1-self.beta2)*(p.grad**2)
        ada_lr = torch.sqrt(self.p_to_delta[p] + self.epsilon)
        p_update = -ada_lr*p.grad/torch.sqrt(self.p_to_sq_grad[p] + self.epsilon)
        p._add(p_update)
        self.p_to_delta[p] = self.beta2*self.p_to_delta[p] + (1-self.beta2)*(p_update**2)

In [None]:
class AdaDelta(OptimizerTemplate):
    def __init__(self, params, lr, epsilon = 1e-8, beta2 = 0.999):
        super().__init__()
        self.epsilon = epsilon
        self.beta2 = beta2
        self.p_to_sq_grad = {p : torch.zeros_like(p.data) for p in params}

    def update_params(self, p):
        self.p_to_sq_grad[p] = self.beta2*self.p_to_sq_grad[p] + (1-self.beta2)*(p.grad**2)
        p_update = -self.lr*p.grad/torch.sqrt(self.p_to_sq_grad[p] + self.epsilon)
        p._add(p_update)

In [None]:
class Adam(OptimizerTemplate):
    def __init__(self, params, lr, epsilon = 1e-8, beta1 = 0.9, beta2 = 0.999):
        super().__init__()
        self.epsilon = epsilon
        self.beta1 = beta1
        self.beta2 = beta2
        self.p_to_num_updates = {p : 0 for p in params}
        self.p_to_mom = {p : torch.zeros_like(p.data) for p in params}
        self.p_to_sq_grad = {p : torch.zeros_like(p.data) for p in params}

    def update_params(self, p):
        self.p_to_mom[p] = self.beta1*self.p_to_mom[p] + (1-self.beta1)*(p.grad)
        self.p_to_sq_grad[p] = self.beta2*self.p_to_sq_grad[p] + (1-self.beta2)*(p.grad**2)
        self.p_to_num_updates[p] += 1
        mom = self.p_to_mom[p] / (1 - self.beta1**self.p_to_num_updates[p])
        sq_grad = self.p_to_sq_grad[p] / (1 - self.beta2**self.p_to_num_updates[p])
        p_update = -self.lr*mom/(torch.sqrt(sq_grad) + self.epsilon)
        p._add(p_update)

# U-Net

In [None]:
class UNet(nn.Module):
    def __init__(
        self,
        input_channels = 3,
        output_channels = 3,
        base_channels = 64,
        channel_multipliers = [1,2,4]
    ):
        super().__init__()
        levels = len(channel_multipliers)
        channels_list = [m*base_channels for m in channel_multipliers]
        input_layers = []
        output_layers = []
        for i in range(len(channels_list)):
            input_channel_num = input_channels if i == 0 else channels_list[i-1]
            input_layers.append(nn.Conv2d(input_channel_num, channels_list[i], kernel = 3, padding = 1))
            # Downsample
            input_layers.append(nn.Conv2d(channels_list[i], channels_list[i], kernel = 3, padding = 1, stride = 2))
        for i in reverse(range(len(channels_list))):
            output_channel_num = output_channels if i == 0 else channels_list[i-1]
            output_layers.append(nn.Conv2d(channels_list[i]*2, output_channel_num, kernel = 3, padding = 1))
        self.input_blocks = nn.Sequential(*input_layers)
        self.middle_blocks = nn.Conv2d(channels_list[-1], channels_list[-1], kernel = 1, padding = 1)
        self.output_blocks = nn.Sequential(*output_layers)

    def forward(self, x):
        input_blocks = []
        for block in self.input_blocks:
            x = block(x)
            input_blocks.append(x)
        x = self.middle_blocks(x)
        for block in self.output_blocks:
            x = F.interpolate(x, scale_factor = 2)
            x = torch.cat([x, input_blocks.pop()], dim = 1)
            x = block(x)
        return x

# MOE

In [152]:
class MOE(nn.Module):
    def __init__(self, embed_dim, num_experts, topk):
        self.num_experts = num_experts
        self.topk = topk
        self.router = nn.Linear(embed_dim, num_experts)
        self.noise_router = nn.Linear(embed_dim, num_experts)
        self.experts = [nn.Linear(embed_dim, embed_dim) for _ in range(num_experts)]

    def forward(self, x):
        b, seq_len, embed_dim = x.shape
        num_tokens = b*seq_len
        logits = self.router(x)
        noise_logits = self.noise_router(x)
        noise = torch.randn_like(x) * torch.softmax(noise_logits, dim = -1)
        logits += noise
        values, indices = torch.topk(logits, self.topk, dim = -1)
        masked_logits = torch.full_like(logits, -torch.inf)
        masked_logits = masked_logits.scatter(-1, indices, values)
        probs = torch.softmax(masked_logits, dim = -1).reshape(num_tokens, embed_dim)
        x = x.reshape(num_tokens, embed_dim)
        output = torch.zeros_like(x)
        for i in range(self.num_experts):
            expert_mask = (indices == i).any(dim = 1).flatten()
            expert = self.experts[i]
            if expert_mask.any():
                filter_x = x[expert_mask]
                filter_x = expert(x)
                filter_probs = probs[expert_mask, i][:, None]
                filter_x *= filter_probs
                output[expert_mask] += filter_x
        return output.reshape(b, seq_len, embed_dim)

# Dot Product Attention / Multi-Head Self Attention / Cross-Attention / Grouped Attention

In [156]:
def scaled_dot_product(q, k, v, mask = None):
    d_k = k.shape[-1]
    attention = torch.einsum('bhij,bhkj->bhik', q, k)
    attention /= (d_k**0.5)
    if mask is not None:
        attention = attention.masked_fill(mask == 0, -torch.inf)
    attention = torch.softmax(attention, dim = -1)
    values_added = torch.einsum('bhij,bhjk->bhik', attention, v)
    return values_added

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, max_context_length, dropout=0.0, qkv_bias=False):
        super().__init__()
        assert embed_dim % num_heads == 0, "embed_dim is indivisible by num_heads"
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.qkv = nn.Linear(embed_dim, 3*embed_dim, bias = qkv_bias)
        self.o_proj = nn.Linear(embed_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            'causal_mask', torch.triu(torch.ones(max_context_length, max_context_length), diagonal = 1)
        )

    def forward(self, x):
        batch_size, seq_len, embed_dim = x.shape
        qkv = self.qkv(x)
        # q is of shape b, h, n, d
        q, k, v = qkv.reshape(batch_size, seq_len, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        attention = torch.einsum('bhij,bhkj->bhik', q, k)
        attention /= (self.head_dim ** 0.5)
        attention = attention.masked_fill(self.causal_mask[:seq_len,:seq_len] == 1, -torch.inf)
        attention = torch.softmax(attention, dim = -1)
        attention = self.dropout(attention)
        values_added = torch.einsum('bhij,bhjk->bhik', attention, v)
        values_added = values_added.transpose(1,2).reshape(batch_size, seq_len, embed_dim)
        values_added = self.o_proj(values_added)
        return values_added

In [None]:
class CrossAttention(nn.Module):
    def __init__(self, embed_dim, cross_dim, num_heads, dropout=0.0, qkv_bias=False):
        super().__init__()
        assert embed_dim % num_heads == 0, "embed_dim is indivisible by num_heads"
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.dropout = nn.Dropout(dropout)
        self.q_proj = nn.Linear(embed_dim, embed_dim, bias = qkv_bias)
        self.k_proj = nn.Linear(cross_dim, embed_dim, bias = qkv_bias)
        self.v_proj = nn.Linear(cross_dim, embed_dim, bias = qkv_bias)
        self.o_proj = nn.Linear(embed_dim, embed_dim)

    def forward(self, x, y):
        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)
        attention = torch.einsum('bhij,bhkj->bhik', q, k)
        attention /= (self.head_dim ** 0.5)
        attention = torch.softmax(attention, dim = -1)
        attention = self.dropout(attention)
        values_added = torch.einsum('bhij,bhjk->bhik', attention, v)
        values_added = values_added.transpose(1,2).reshape(batch_size, seq_len, embed_dim)
        values_added = self.o_proj(values_added)
        return values_added

In [164]:
class GroupedQueryAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, num_kv_heads, max_context_length, dropout=0.0):
        super().__init__()
        assert embed_dim % num_heads == 0, "embed_dim is indivisible by num_heads"
        assert num_heads % num_kv_heads == 0, "num_heads is indivisible by num_kv_heads"
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.q_to_k_ratio = num_heads // num_kv_heads
        self.dropout = nn.Dropout(dropout)
        self.qkv = nn.Linear(embed_dim, embed_dim + self.head_dim*num_kv_heads, bias = qkv_bias)
        self.o_proj = nn.Linear(embed_dim, embed_dim)
        self.register_mask(
            "causal_mask", torch.triu(torch.ones(max_context_length, max_context_length), diagonal = 1)
        )

    def forward(self, x):
        batch_size, seq_len, embed_dim = x.shape
        qkv = self.qkv(x)
        q = qkv[:,:,:embed_dim].reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1,2)
        k, v = qkv[:,:,embed_dim:].reshape(batch_size, seq_len, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        k = k.repeat_interleave(self.q_to_k_ratio, dim = 1)
        v = v.repeat_interleave(self.q_to_k_ratio, dim = 1)
        attention = torch.einsum('bhij,bhkj->bhik', q, k)
        attention /= (self.head_dim ** 0.5)
        attention = attention.masked_fill(self.causal_mask[:seq_len,:seq_len] == 1, -torch.inf)
        attention = torch.softmax(attention, dim = -1)
        attention = self.dropout(attention)
        delta_x = torch.einsum('bhij,bhjk->bhik', attention, v)
        delta_x = delta_x.transpose(1,2).reshape(batch_size, seq_len, embed_dim)
        delta_x = self.o_proj(delta_x)
        return values_added

# Transformer

In [166]:
GPT_CONFIG_124M = {
    "vocab_size": 50257,    # Vocabulary size
    "context_length": 1024, # Context length
    "emb_dim": 768,         # Embedding dimension
    "n_heads": 12,          # Number of attention heads
    "n_layers": 12,         # Number of layers
    "drop_rate": 0.1,       # Dropout rate
    "qkv_bias": False,      # Query-Key-Value bias
    "std": 0.02
}

In [None]:
class GPT(nn.Module):
    def __init__(self, cfg):
        super()._init__()
        self.tok_embedding = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"])
        self.pos_embedding = nn.Embedding(cfg["context_length"], cfg["emb_dim"])
        self.trf_blocks = nn.Sequential(
            *[TransformerBlock(cfg) for _ in range(cfg["n_layers"])]
        )
        self.dropout = nn.Dropout(cfg["drop_rate"])
        self.final_norm = LayerNorm(cfg["emb_dim"])
        self.final_layer = nn.Linear(cfg["emb_dim"], cfg["vocab_size"])
        self.tok_embedding.weight = self.final_layer.weight

        # init
        for pn, p in self.named_parameters():
            if pn.endswith('proj.weight'):
                nn.init.normal_(p, std = cfg["std"]/((2*cfg["n_layers"])**0.5))
            elif pn.endswith('weight'):
                nn.init.normal_(p, std = cfg["std"])
            elif pn.endswith('bias'):
                nn.init.zeros_(p)

    def forward(self, x): 
        batch_size, seq_len = x.shape
        tok_embed = self.tok_embedding(x)
        pos_embed = self.pos_embedding(x)
        x = tok_embed + pos_embed
        x = self.dropout(x)
        x = self.trf_blocks(x)
        x = self.final_norm(x)
        logits = self.final_linear(x)
        return logits

In [170]:
class LayerNorm(nn.Module):
    def __init__(self, embed_dim, eps = 1e-5):
        super()._init__()
        self.shift = nn.Parameter(torch.zeros(embed_dim))
        self.scale = nn.Parameter(torch.ones(embed_dim))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(dim = -1, keepdim = True)
        var = x.var(dim = -1, keepdim = True, unbiased = False)
        norm_x = (x-mean) / torch.sqrt(var + self.eps)
        return norm_x*self.scale + self.shift

In [172]:
class TransformerBlock(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.attn = MultiHeadAttention(cfg)
        self.ffn = FeedForward(cfg)
        self.norm1 = LayerNorm(cfg["embed_dim"])
        self.norm2 = LayerNorm(cfg["embed_dim"])
        self.dropout = nn.Dropout(cfg["drop_rate"])

    def forward(self, cfg):
        orig = x
        x = self.norm1(x)
        x = self.attn(x)
        x = self.dropout(x)
        x += orig

        orig = x
        x = self.norm2(x)
        x = self.ffn(x)
        x = self.dropout(x)
        x += orig
        
        return x

In [None]:
GPT_CONFIG_124M = {
    "vocab_size": 50257,    # Vocabulary size
    "context_length": 1024, # Context length
    "emb_dim": 768,         # Embedding dimension
    "n_heads": 12,          # Number of attention heads
    "n_layers": 12,         # Number of layers
    "drop_rate": 0.1,       # Dropout rate
    "qkv_bias": False,      # Query-Key-Value bias
    "std": 0.02
}

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        assert cfg["emb_dim"] % cfg["n_heads"] == 0, "emb_dim is indivisible by n_heads"
        self.n_heads = cfg["n_heads"]
        self.head_dim = cfg["emb_dim"] // cfg["n_heads"]
        self.dropout = nn.Dropout(cfg["drop_rate"])
        self.qkv = nn.Linear(cfg["emb_dim"], 3*cfg["emb_dim"], bias = cfg["qkv_bias"])
        self.o_proj = nn.Linear(cfg["emb_dim"], cfg["emb_dim"])
        self.register_buffer(
            "causal_mask", torch.triu(torch.ones(cfg["context_length"], cfg["context_length"]), diagonal = 1)
        )
        
    def forward(self, x):
        batch_size, seq_len, embed_dim = x.shape
        qkv = self.qkv(x)
        q, k, v = qkv.reshape(batch_size, seq_len, 3, self.n_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        attn = torch.einsum('bhij,bhkj->bhik', q, k)
        attn /= (k.shape[-1]**0.5)
        attn = attn.masked_fill(self.causal_mask == 1, -torch.inf)
        attn = torch.softmax(attn, dim = -1)
        attn = self.dropout(attn)
        delta_x = torch.einsum('bhij,bhjk->bhik', attn, v)
        delta_x = delta_x.transpose(1,2).reshape(batch_size, seq_len, embed_dim)
        delta_x = self.o_proj(delta_x)
        return delta_x

In [None]:
class FeedForward(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.linear1 = nn.Linear(cfg["emb_dim"], 4*cfg["emb_dim"])
        self.gelu = nn.GELU()
        self.proj = nn.Linear(4*cfg["emb_dim"], cfg["emb_dim"])

    def forward(self, x):
        x = self.linear1(x)
        x = self.gelu(x)
        x = self.proj(x)
        return x

## GELU/SwiGLU

In [175]:
class SiLU(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(x)

class FeedForward(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.linear1 = nn.Linear(cfg["emb_dim"], 4*cfg["emb_dim"])
        self.linear2 = nn.Linear(cfg["emb_dim"], 4*cfg["emb_dim"])
        self.silu = SiLU()
        self.proj = nn.Linear(4*cfg["emb_dim"], cfg["emb_dim"])

    def forward(self, x):
        x_fc1 = self.linear1(x)
        x_fc2 = self.linear1(x)
        x = self.silu(x_fc1) * x_fc2
        x = self.proj(x)
        return x

## RMSNorm

In [None]:
class RMSNorm(nn.Module):
    def __init__(self, embed_dim, eps = 1e-05):
        super().__init__()
        self.eps = eps
        self.scale = nn.Parameter(torch.ones(embed_dim))

    def forward(self, x):
        rms = torch.sqrt((x**2).mean(dim = -1, keepdim = True) + self.eps)
        return (x/rms)*self.scale

## Pre/Post Layer Norm

In [179]:
class TransformerBlock(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.att = MultiHeadAttn(cfg)
        self.ffn = FeedForward(cfg)
        self.norm1 = LayerNorm(cfg["emb_dim"])
        self.norm2 = LayerNorm(cfg["emb_dim"])
        self.dropout = nn.Dropout(cfg["drop_rate"])

    def forward_pre(self, x): 
        orig_x = x
        x = self.norm1(x)
        x = self.att(x)
        x = self.dropout(x)
        x += orig_x

        orig_x = x
        x = self.norm2(x)
        x = self.ffn(x)
        x = self.dropout(x)
        x += orig_x

        return x

    def forward_post(self, x):
        orig_x = x
        x = self.att(x)
        x = self.dropout(x)
        x += orig_x
        x = self.norm1(x)

        orig_x = x
        x = self.ffn(x)
        x = self.dropout(x)
        x += orig_x
        x = self.norm2(x)

        return x

## PE

In [None]:
def generate_sinusoidal_pos_emb(ctx_len, emb_dim):
    pe = torch.zeros(ctx_len, emb_dim)
    position = torch.arange(ctx_len, dtype = torch.float)[:, None]
    div_term = torch.tensor(10000.0).pow(torch.arange(0, emd_dim, 2).float()/embed_dim)[None, :]
    pe[:, 0::2] = torch.sin(position / div_term)
    pe[:, 0::2] = torch.cos(position / div_term)

In [196]:
torch.tensor(10000.0).pow(torch.arange(0, 5, 2).float()/5).shape

torch.Size([3])