In [None]:
!pip install wandb

In [None]:
import string
import math
import wandb
import json
import torch
import torch.functional as F
import torch.distributions as dist
from torch.optim.lr_scheduler import _LRScheduler
from torch import optim, nn, arange
from torch.utils.data import Dataset, DataLoader

In [None]:
wandb.login()

In [None]:
# --------------------------------------------------
# GENERIC TRANSFORMER PARTS
# --------------------------------------------------
#
# Exactly the same in a MoE and a Dense transformer
# - Causal Mask
# - SelfAttention

In [None]:
# Vanilla causal attention mask + hard-alibi
# ------------------------------------------
# Hard Alibi (https://arxiv.org/pdf/2402.01032.pdf) is a variant of
# the Alibi position encoding (https://arxiv.org/pdf/2108.12409.pdf)
# where Alibi's slow linear decay over the attention is replaced by
# a discrete decay such that
# - the first head sees one token back
# - the second, two
# - the third, four, and so on
# ------
# Init:
# - config["max_length"] = maximum length sequence fed into it
# Forward:
# - takes [b, t, t] or [b, nh, t, t]
#
class HardAlibiCausalMask(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.max_seq_len = msl = config["max_length"]
        self.n_heads = n_heads = config["num_heads"]
        rel = self.get_relative_positions(msl)
        stack = []
        for i in range(n_heads):
            bools = (rel <= 0) & (rel > -(2**i))
            ready = torch.where(bools, torch.zeros(rel.shape), float("-inf"))
            stack.append(ready)
        self.register_buffer("values", torch.stack(stack, dim=0))

    def get_relative_positions(self, seq_len):
        x = torch.arange(seq_len)[None, :]
        y = torch.arange(seq_len)[:, None]
        return x - y

    def forward(self, att):
        sizes = att.size()
        t1, t2 = sizes[-2:]
        assert t1 == t2, "attention must be square"
        assert t1 <= self.max_seq_len, "attention must be smaller than max_seq_length"
        reshaped_att = att.view(-1, self.n_heads, t1, t1)
        return reshaped_att + self.values[:,:t1,:t1]

def test():
    m = torch.rand(2, 4, 5, 5)
    m = HardAlibiCausalMask({ "max_length": 6, "num_heads": 4, })(m)
    # To get a feel for how hard-alibi works uncomment
    #print(m)
    assert m[0][0][0][0] > 0
    assert m[0][0][1][1] > 0
    assert m[0][0][0][1] == float("-inf")

test()

In [None]:
# Vanilla multi-headed self-attention implementation
# ------------------------------------------------
# Init:
# - config["hidden_dim"] = int, dimension of input
# - config["num_heads"] = int, number of heads, head_size = hidden_dim // num_heads
# - config["max_length"] = int, sets masks size for the CausalMask
# Forward:
# - takes [b, t, c]
# - outputs [b, t, c]
class SelfAttention(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.hidden_dim = hd = config["hidden_dim"]
        self.num_heads = nh = config["num_heads"]
        self.k = nn.Linear(hd, hd)
        self.q = nn.Linear(hd, hd)
        self.v = nn.Linear(hd, hd)
        self.p = nn.Linear(hd, hd) # Projection from heads to output
        self.mask = HardAlibiCausalMask(config)

    def forward(self, x, layer_past=None):
        b, t, c = x.size()

        num_heads = self.num_heads
        heads_dim = c // num_heads
        k = self.k(x).view(b, t, num_heads, heads_dim).transpose(1, 2)
        q = self.q(x).view(b, t, num_heads, heads_dim).transpose(1, 2)
        v = self.v(x).view(b, t, num_heads, heads_dim).transpose(1, 2)

        # Mult-headed attention: [b, nh, t, hd] @ [b, nh, hd, ] -> [b, nh, t, t]
        att = q @ k.transpose(-2, -1)
        att = att / math.sqrt(k.size(-1))
        att = self.mask(att)
        att = torch.softmax(att, dim=-1)

        # [b, nh, t, t] x [b, nh, t, hs] -> [b, nh, t, hs]
        y = att @ v
        
        # [b, nh, t, hs] -> [b, t, nh, hs] -> [b, t, c]
        y = y.transpose(1, 2).contiguous().view(b, t, c)
        
        return self.p(y)

In [None]:
# --------------------------------------------------
# MLP vs MOE
# --------------------------------------------------
#
# - MLP
# - UnitCenteredNoise (used in SwitchMoE)
# - SwitchMoE

In [None]:
# MLP module
# -------------------------------
# Note: Signature of init here matches init for MoE
#
# Init:
# - config["hidden_dim"] = dimension of c in [b, t, c] input
# - index = int index
# - scaling = amount by which to scale weights
# Forward:
# - tuple (tensor, aux_loss)
# -- tensor is [b, t, c]
# -- aux_loss is a scalar added to cross-entropy loss
#
class MLP(nn.Module):

    def __init__(self, config, index, scaling=1.0):
        super().__init__()
        self.hidden_dim = hd = config["hidden_dim"]
        self.seq = nn.Sequential(*[
            nn.Linear(hd, hd * 4),
            nn.ReLU(),
            nn.Linear(hd * 4, hd),
        ])
        
        with torch.no_grad():
            self.seq[0].weight.data *= scaling
            self.seq[2].weight.data *= scaling
            self.seq[0].bias.data *= scaling
            self.seq[2].bias.data *= scaling   

    def forward(self, inp_tuple):
        x, aux_loss = inp_tuple
        return self.seq(x), aux_loss

In [None]:
# Elementwise multiplies: x * (1 +- eps)
# -------------------------------
# Init:
# - scaling = amount by which 1 varies
# Forward:
# - any tensor
#
class UnitCenteredNoise(nn.Module):
    def __init__(self, scaling=0.02):
        super(UnitCenteredNoise, self).__init__()
        self.scaling = scaling
        self.base = 1 - (scaling * 0.5)

    def forward(self, x):
        if self.training:
            # uniform 1-centered noise
            noise = torch.rand(x.size()).to(x.device)
            noise_centered = (noise * self.scaling) + self.base
            return x * noise_centered
        else:
            return x

In [None]:
# Actual Switch Mixture-of-Experts Layer
# -------------------------------
# Note: Signature of init here matches init for MLP
#
# Init:
# - config["hidden_dim"] = Dimension of c in [b, t, c] input
# - config["num_experts"] = A layer indexed array of how many
#                           experts are per layer
# - config["init_moe_scaling"] = Amount by which to scale
#                                weights in MoE experts. If you don't
#                                make them smaller, then it doesn't 
#                                learn as well for mysterious reasons.
#                                
# Forward:
# - tuple (tensor, aux_loss)
# -- tensor is [b, t, c]
# -- aux_loss is a scalar added to cross-entropy loss
class SwitchMoE(nn.Module):

    def __init__(self, config, index):
        super().__init__()

        self.hidden_dim = hd = config["hidden_dim"]
        self.num_experts = num_experts = config["num_experts"][index]
        self.moe_scaling = moe_scaling = config["init_moe_scaling"]
        
        self.experts = nn.ModuleList([
            MLP(config, index=index, scaling=moe_scaling)
            for index
            in range(num_experts)
        ])

        self.gate = nn.Sequential(
            nn.Linear(hd, num_experts),
            UnitCenteredNoise(scaling=0.02),
            nn.Softmax(dim=-1)
        )
        

    def forward(self, xx):
        inp, aux_loss = xx
        b, t, c = inp.shape

        # Reshape to [b * t, c], makes it easier to think about
        inp = inp.reshape(b * t, c)
        
        gate_val_continuous = self.gate(inp) # [b * t, c] -> [b * t, num_gates]
        _, gate_val_indices = torch.topk(gate_val_continuous, 1, dim=-1) # [b * t, num_gates] -> [b * t, 1]
        
        # Map [b * t, 1] a one-hot [b * t, num_experts] where the last dim is one-hot encoded
        one_hot = torch.nn.functional.one_hot(gate_val_indices, num_classes=self.num_experts).sum(1)

        # Calculate auxillary loss to balance the experts
        f = one_hot.sum(dim=0) # [b * t, num_experts] -> [num_experts]
        f = f / f.sum()
        P = gate_val_continuous.sum(dim=0) # [b * t, num_experts] -> [num_experts]
        P = P / P.sum()
        extra_aux_loss = (P * f).sum() * self.num_experts

        output = torch.zeros_like(inp)
        for i in range(self.num_experts):
            
            mask = one_hot[:,i] == 1 # mask shape: [b * t]
            mask_expand = mask.unsqueeze(-1).expand_as(output) # to [b * t, c]
      
            inp_for_expert = inp[mask_expand].reshape(-1, c)
            out_from_exp, _ = self.experts[i]((inp_for_expert, torch.zeros([1])))

            output[mask_expand] =+ out_from_exp.reshape(-1)
          
        return output.reshape(b, t, c), extra_aux_loss + aux_loss
        
        
def test():
    a = torch.ones(7) / torch.ones(7).sum()
    b = torch.ones(7) / torch.ones(7).sum()
    print((a * b).sum())
    
    m = torch.randn(16, 32, 48).to("cuda")
    nm = SwitchMoE({"hidden_dim": 48, "num_experts": [8,8,8], "init_moe_scaling": 1.0}, 1).to("cuda")
    out = nm((m, torch.zeros([1]).to("cuda")))

test()

In [None]:
# --------------------------------------------------
# STANDARD TRANSFORMER: WHOLE THING
# --------------------------------------------------

In [None]:
from torch import nn

class TransformerBlock(nn.Module):

    def __init__(self, config, index):
        super().__init__()
        self.hidden_dim = hd = config["hidden_dim"]
        self.attention = SelfAttention(config)

        # Switch between a MoE and a MLP layer according to config
        if config["layer_types"][index] == "switch":
            MLPClass = SwitchMoE
        elif config["layer_types"][index] == "mlp":
            MLPClass = MLP
        else:
            raise Exception("Invalid layer type")
        
        self.ff = MLPClass(config, index)
        self.norm1 = nn.LayerNorm([hd])
        self.norm2 = nn.LayerNorm([hd])

    def forward(self, input_tuple):
        x, extra_loss = input_tuple

        x_att = self.attention(x)
        x = x + self.norm1(x_att)

        x_ff, extra_loss = self.ff((x, extra_loss))
        x = x + self.norm2(x_ff)
        
        return x, extra_loss

In [None]:
class Transformer(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.input_dim = self.output_dim = config["input_dim"]
        self.hidden_dim = hd = config["hidden_dim"]
        self.max_length = config["max_length"]
        self.layer_types = config["layer_types"]
        self.layer_num = layer_num = len(config["layer_types"])
        self.num_experts = config["num_experts"] # note this is an array
        
        self.char_embed = nn.Embedding(self.input_dim, hd)

        self.stack = nn.Sequential(*[
            TransformerBlock(config, x) for x in range(layer_num)
        ])
        self.layer_norm = nn.LayerNorm(hd)
        self.to_prob = nn.Linear(hd, self.output_dim)
        

    def forward(self, x):
        b, t = x.size()
        # Initial input proc: char_embed
        inp = self.char_embed(x)

        # Initialize aux_loss with 0
        aux_loss = torch.tensor(0)
        x, aux_loss = self.stack((inp, aux_loss))

        x = self.layer_norm(x.view(b * t, self.hidden_dim))
        
        return self.to_prob(x).view(b, t, self.output_dim), aux_loss

    def get_param_groups(self, base_learning_rate):
        # Scale the learning rate for each expert by 1 / sqrt(num_experts)
        blr = base_learning_rate
        param_groups = []
        for name, param in self.named_parameters():
            
            if "stack" in name and "ff.experts" in name:
                for i in range(self.layer_num):
                    to_find = f"stack.{i}.ff.experts"
                    if to_find in name:
                        per_expert_lr = blr / math.sqrt(self.num_experts[i])
                        param_groups.append({ "params": param, "lr": per_expert_lr })
                        break
            else:
                param_groups.append({"params": param, "lr": blr })

        return param_groups
        

In [None]:
# Warms up to learning rate over 'steps_up'
# Cools down to LR * gamma over 'steps_down'
class LinearLRDecayWithWarmup(_LRScheduler):
    def __init__(self, optimizer, steps_up, steps_down, gamma, last_epoch=-1):
        self.steps_up = steps_up
        self.steps_down = steps_down
        self.gamma = gamma
        super().__init__(optimizer, last_epoch)
        
    def get_lr(self):

        if self.last_epoch > self.steps_down + self.steps_up:
            return [base_lr * self.gamma for base_lr in self.base_lrs]

        if self.last_epoch > self.steps_up:
            steps_after_up = self.last_epoch - self.steps_up
            percentage_there = steps_after_up / self.steps_down
            mult = (1 - percentage_there) * 1 + (percentage_there * self.gamma)
            return [base_lr * mult for base_lr in self.base_lrs]
            
        return [base_lr * (self.last_epoch / self.steps_up) for base_lr in self.base_lrs]


In [None]:
# Loads a text file, tokenizes it extreeemely simply,
# and just shoves it into memory as a PyTorch Tensor
#
# I've been using a ~200 mb scrape of Gutenberg
# books to test, but anything of similar size should
# work fine
class TextDataset(Dataset):
    def __init__(self, file_name, ctx_len):
        self.ctx_len = ctx_len
        
        # Load and preprocess the text
        with open(file_name, 'r', encoding='utf-8') as file:
            text = file.read()
        print("Read file ", file_name)

        
        # Filter out non-string characters
        text = ''.join(filter(lambda x: x in string.printable, text))
        print("Filtered file ")
        length = len(text)
        
        # Tokenize by character
        tokens = [string.printable.index(c) for c in text]
        print("Tokenized file")

        # Create chunks of ctx_len
        self.data = []
        i = 0
        while((i + 1) * ctx_len < length):
            peeled = tokens[i * ctx_len:(i + 1) * ctx_len]
            self.data.append(peeled)
            if (len(self.data)) % 100000 == 0:
                print(len(self.data), len(peeled), len(tokens))
            i = i + 1
        print("Chunked", len(self.data), len(self.data[0]))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return torch.tensor(self.data[idx], dtype=torch.long)

In [None]:
CONFIG_KEYS = [
    # MODEL DETAILS
    #
    # List of either 'mlp' or 'switch_moe'
    "layer_types",
    # List of ints, for how many experts go in IF it is a moe -- does nothing otherwise
    "num_experts",
    # Maximum length of input
    "max_length",
    "hidden_dim",
    # How many input and output classes we have
    "input_dim",
    "num_heads",
    # Scaling factor for MoE expert weights
    "init_moe_scaling",

    # TRAINING DETAILS
    #
    # Learning rate
    "lr",
    # How long to warm up from 0 lr
    "lr_steps_up",
    # How long to cool down from lr to lr * gamma
    "lr_steps_down",
    # Final learning rate = lr * gamma
    "lr_gamma", 
    "aux_loss_weight",
    # Text file to go through epochs times
   
    "file",
    "epochs",
    "batch_size",
    "log_interval",
    "device",
]

def config_to_run_name(config):
    return "__".join([
        "hd_" + str(config["hidden_dim"]),
        "layers_" + str(len(config["layer_types"])),
        "lr_" + str(config["lr"]),
        "file_" + str(config["file"]),
    ])

def verify_config(config):
    keys_mandatory = set(CONFIG_KEYS)
    keys_used = set(config.keys())
    keys_dif = keys_mandatory ^ keys_used
    assert len(keys_dif) == 0, f"Unrecognized or required keys: {keys_dif}"
    assert len(config["layer_types"]) == len(config["num_experts"])

In [None]:
def run_with_mlp_layer(config):

    verify_config(config)
    model = Transformer(config).to(config["device"])
    
    config["total_params"] = tp = sum(p.numel() for p in model.parameters())
    print(f"Total Model Parameters: {(tp / 1000000.0):.2f}m")
    print("Config: ", json.dumps(config, indent=4))
    
    # Load data
    dataset = TextDataset(config["file"], ctx_len=config["max_length"] + 1)
    dataloader = DataLoader(dataset, batch_size=config["batch_size"], shuffle=True)

    # Define the loss function, optimizer, param groups
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.get_param_groups(config["lr"]))
    scheduler = LinearLRDecayWithWarmup(
        optimizer,
        steps_up=config["lr_steps_up"],
        steps_down=config["lr_steps_down"],
        gamma=config["lr_gamma"]
    )
    
    wandb.init(
        project="switch-transformer-hard-alibi",
        config=config,
        name=config_to_run_name(config)
    )
    
    # Training loop
    model.train()
    loss_sum = 0
    for epoch in range(config["epochs"]):
        model.train()
        
        for i, batch in enumerate(dataloader):

            inputs = batch[:, :-1].to(config["device"])
            targets = batch[:, 1:].to(config["device"])
    
            outputs, loss_kl = model(inputs)
            loss_ce = criterion(outputs.view(-1, outputs.size(-1)), targets.view(-1))
            loss = loss_ce + (loss_kl * config["aux_loss_weight"])    
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step()

            wandb.log({
                "loss": loss,
                "loss_normal": loss_ce,
                "loss_kl": loss_kl,
                "lr": scheduler.get_last_lr()[0]
            })

            # We only care about KL loss instrumentally
            loss_sum += loss_ce.item()
            if i % config["log_interval"] == 0:
                print(i, ", loss: ", loss_sum / config["log_interval"])
                loss_sum = 0
    
        print(f"Epoch {epoch+1} Done")
        
    wandb.finish()


In [None]:
# Example of a simple dense Transformer with ~20m parameters
run_with_mlp_layer({
    "layer_types": ["mlp","mlp","mlp","mlp","mlp","mlp","mlp"],
    "num_experts": [None, None, None, None, None, None, None],
    "max_length": 512,
    "hidden_dim": 512,
    "num_heads": 8,
    "input_dim": len(string.printable),
    "lr_steps_down": 4500,
    "lr_steps_up": 500,
    "lr_gamma": 0.1,
    "lr": 0.0004,
    "aux_loss_weight": 0.08,
    "init_moe_scaling": 0.0625,
    "file": "two.txt",
    "epochs": 1,
    "batch_size": 64,
    "log_interval": 500,
    "device": "cuda",
})

In [None]:
# Example of a MoE Transformer with ~75m parameters
run_with_mlp_layer({
    "layer_types": ["mlp","mlp","mlp","switch","mlp","switch","mlp"],
    "num_experts": [None, None, None, 12, None, 12, None],
    "max_length": 512,
    "hidden_dim": 512,
    "num_heads": 8,
    "input_dim": len(string.printable),
    "lr_steps_down": 4500,
    "lr_steps_up": 500,
    "lr_gamma": 0.1,
    "lr": 0.0004,
    "aux_loss_weight": 0.06,
    "init_moe_scaling": 0.0625,
    "file": "two.txt",
    "epochs": 1,
    "batch_size": 64,
    "log_interval": 500,
    "device": "cuda",
})