In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

import plotly.express as px
import pandas as pd

In [2]:
# torch.set_default_device('cuda')
if torch.cuda.is_available():
    device = torch.device('cuda')
elif hasattr(torch, 'mps') and torch.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cpu')

DEBUG = False

# Mixture of depths dynamic computation optimisation

In [3]:
def print_tensor(name: str, value: torch.tensor):
    """This thing is extremely useful for making sure we are doing the correct
    operations (on the correct values!!) as my tensor calculus is very error prone."""
    print(f'{name}\n{"-"*20}\n{value.shape}\n{value}\n')

## 1. Figuring out the calculations

In [7]:
torch.random.manual_seed(42)

## Network hyperparameters
B = 256 # Batch size
S = 4 # Sequence length
h = 3 # Hidden / embedding dimension
#   ^ only using small, unique, and values > 1 for easy debugging
## MoD hyperparameters
# Control parameter for how many tokens must go through the layer / self attention,
# C tokens are passed through the decoder layer and S - C tokens are residual connection only 
C = 2 # Must be 1 < C <= S
# Compute efficiency is obtained from lesser C values while higher values preserve evaluation performance
# However there exists a pareto optimal combination of both as proven in the paper through ablation studies

## Batch inputs
X = torch.randn(B, S, h, device=device)

## Network subsection
# Full network is simply [EncoderLayer(), PositionalEncoding(), [DecoderLayer(), ...], LMHead()]
# Input to the network is an integer tensor of shape [Batch length, Sequence length, 1]
# Output is a float tensor of shape [Batch length, Sequence length, Vocab size]*
## *Though we generally only care about the output at the last sequence index so can be optimised to 
## [Batch length, 1, Vocab size]

## Standard transformer decoder block layer pseudocode (see figure 1 in paper)
# def layer(x: torch.tensor) -> torch.tensor:
#     x_in = x.copy()
#     x = LayerNorm(x)
#     x = SelfAttention(x) 
#     x_in = torch.concat((x_in, x), dim=1)
#     x = LayerNorm(x_in)
#     x = Linear(x)
#     x = GLU(x) # ReLU originally but modern SoTA use GLU derivatves, i.e. swiGLU for llama2
#     # See https://arxiv.org/abs/2002.05202 (check the conclusion :p).
#     x = Linear(x)
#     return x
layer  = lambda x_: x_ # No-op, just return input for visual debugging


# Tokens can either be omitted from the layer computations and only concatenated to the output, 
# or go through the computation + have a residial connection by multiplication with the last ffnn
router_l = nn.Linear(h, 1, device=device) # Routing head of current layer
X_l = X.clone() # Inputs to current decoder layer


## Mixture of depths layer computations, # def mod_layer(X_l: torch.tensor) -> torch.tensor: 
R_l = router_l(X_l) # Router activations
r_l, r_l_i = R_l.topk(C, dim=1, sorted=True) # Find top C tokens to pass through network
r_l_i.squeeze_() # Make shape [Batch, C]

# Create a multihot mask from the list of routing indices (r_l_i)
X_tilde_mask = torch.zeros(r_l_i.size(0), S, device=X.device).scatter_(1, r_l_i, 1.).type(torch.BoolTensor).unsqueeze(-1).to(device)
# Select the token embeddings from the layer input (X_l) based on the ranking indices (r_l_i) via the mask
X_tilde = torch.masked_select(X_l, X_tilde_mask).view(B, C, h)
# took me ages to find ^ this function, masked select works exactly the same as boolean indexing in pandas
# but in higher dimensional space

# Notice how the r_l term is in the X_routed operation, this allows backprop to train the router layer weights
X_routed = r_l * layer(X_tilde) + X_tilde
X_unrouted = torch.masked_select(X_l, ~X_tilde_mask).view(B, S-C, h) # The remainder of tokens for skip  connection

# # The order / sorting of the final concat does not matter at all thanks to our friend positional encoding.
X_l1 = torch.concat((X_routed, X_unrouted), dim=1) # Careful to concat along the sequence dimension


if DEBUG:
    print(f"Batch size:       {B}\nSequence length:  {S}\nHidden dimension: {h}", end="\n\n")
    print_tensor("Layer inputs (X_l)", X_l)
    print_tensor("R_l", R_l)
    print_tensor("r_l", r_l)
    print_tensor("r_l_i", r_l_i)
    print_tensor("X_tilde_mask", X_tilde_mask)
    print_tensor("X_tilde", X_tilde)
    print_tensor("X_routed", X_routed)
    print_tensor("Layer outputs (X_l1)", X_l1)

## 2. Wrapping in a basic GPT like transformer architecture

In [41]:
class SwiGLU(nn.Module):
    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.linear = nn.Linear(dim, hidden_dim)
        self.gate = nn.Linear(dim, hidden_dim)
        self.transform = nn.Linear(hidden_dim, dim)

    def forward(self, x):
        return self.transform(self.linear(x) * torch.sigmoid(self.gate(x)))


B = 2
S = 10
n_embed = 768
n_head = 12
X = torch.randn(B, S, n_embed, device=device)




swiglu = SwiGLU(dim=n_embed, hidden_dim=4*n_embed).to(device)
output = swiglu(X)


print("Output Shape:", output.shape)  # Should match [batch_size, out_features]
del X, layer, output
torch.cuda.empty_cache()

Output Shape: torch.Size([2, 10, 768])


In [65]:
torch.random.manual_seed(42)


class NewGELU(nn.Module):
    """Careful there are a few versions of GeLU, this one is the exact one used by OpenAI"""
    def forward(self, input):
        return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))

class SwiGLU(nn.Module):
    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.linear = nn.Linear(dim, hidden_dim)
        self.gate = nn.Linear(dim, hidden_dim)
        self.transform = nn.Linear(hidden_dim, dim)

    def forward(self, x):
        return self.transform(self.linear(x) * torch.sigmoid(self.gate(x)))



class MLP(nn.Module):

    def __init__(self, n_embed: int, swiglu: bool=True):
        super().__init__()
        if swiglu:
            self.feed_forward = SwiGLU(n_embed, 4 * n_embed)
        else:
            self.feed_forward = nn.Sequential(
                [
                    nn.Linear(n_embed, 4 * n_embed),
                    NewGELU(),
                    nn.Linear(4 * n_embed, n_embed),
                ]
            )
    def forward(self, x):
        return self.feed_forward(x)

class DecoderBlock(nn.Module):

    def __init__(self, n_embed: int, h_head: int):
        super().__init__()
        self.ln_1 = nn.LayerNorm(n_embed)
        self.attn = nn.MultiheadAttention(embed_dim=n_embed, num_heads=n_head)
        self.ln_2 = nn.LayerNorm(n_embed)
        self.mlp = MLP(n_embed)

    def forward(self, x):
        x_t = self.ln_1(x) # keep layer norm full precision

        with torch.autocast('cuda', torch.bfloat16):
            x = x + self.attn(x_t, x_t, x_t, need_weights=False)[0]
            x = x + self.mlp(self.ln_2(x))
        return x

B = 2
S = 10
n_embed = 768
n_head = 12
X = torch.randn(B, S, n_embed, device=device)

layer = DecoderBlock(n_embed, n_head).to(device)
# layer = nn.TransformerDecoderLayer(d_model=h, nhead=nhead, activation=F.glu, batch_first=True)

out = layer(X)
print_tensor("input", X)
print_tensor("output", out)

del X, layer, out
torch.cuda.empty_cache()

input
--------------------
torch.Size([2, 10, 768])
tensor([[[ 0.1940,  2.1614, -0.1721,  ..., -0.6821,  0.7974, -0.8484],
         [ 0.8574,  0.4992,  0.1359,  ..., -1.0946, -1.1731, -0.4472],
         [-1.2405, -0.1784,  0.4220,  ...,  0.9749,  0.5732,  1.3790],
         ...,
         [ 0.5391,  0.4001,  1.0236,  ..., -1.0001, -1.1552, -0.6955],
         [ 1.7624, -1.2477,  1.2913,  ...,  0.0937, -0.7382,  0.2442],
         [-0.1452, -2.5195,  1.4329,  ...,  0.5431,  0.6378,  0.3826]],

        [[-0.4443,  1.8408,  1.2662,  ...,  0.6949,  0.3456, -0.6991],
         [ 0.0722,  0.7010,  2.5725,  ...,  0.8067,  0.1701,  0.7526],
         [-0.1311, -1.1856,  0.5139,  ...,  0.2042, -1.1267, -0.0607],
         ...,
         [-0.5234, -1.6324, -0.5524,  ..., -0.5802,  0.4249, -1.7961],
         [-2.5483,  0.1865,  1.3971,  ..., -0.1381, -0.5499, -1.7784],
         [ 0.9432,  0.6944, -0.2105,  ..., -1.5116, -0.4728, -1.5385]]],
       device='cuda:0')

output
--------------------
torch.Size(

In [22]:
class MoDDecoderLayer(nn.Module):
    def __init__(self, n_embed: int, n_head: int, C: int):
        super().__init__()
        self.h = n_embed
        self.C = C
        self.router = nn.Linear(n_embed, 1)
        # self.layer = nn.TransformerDecoderLayer(d_model=h, nhead=nhead, activation=F.glu)#, batch_first=True)
        self.layer = DecoderBlock(n_embed, n_head)

    def forward(self, X) -> torch.Tensor:
        B, S, h = X.shape
        ## Mixture of depths layer computations, # def mod_layer(X_l: torch.tensor) -> torch.tensor: 
        R = self.router(X) # Router activations
        r, r_i = R.topk(self.C, dim=1, sorted=True) # Find top C tokens to pass through network
        r_i = r_i.squeeze() # Make shape [Batch, C]
        
        # Create a multihot mask from the list of routing indices (r_l_i)
        X_tilde_mask = torch.zeros(r_i.size(0), S, device=X.device).scatter(1, r_i, 1.).type(torch.bool).unsqueeze(-1)

        ## ^^ issue here with that thing, might need to register buffer
        # Select the token embeddings from the layer input (X_l) based on the ranking indices (r_l_i) via the mask
        X_tilde = torch.masked_select(X, X_tilde_mask).view(B, self.C, h)
        # took me ages to find ^ this function, masked select works exactly the same as boolean indexing in pandas
        # but in higher dimensional space
        
        # Notice how the r_l term is in the X_routed operation, this allows backprop to train the router layer weights

        #### TODO: fix the stuf with the layer forward pass, not sure what to do here
        X_routed = r * self.layer(X_tilde) + X_tilde


        
        X_unrouted = torch.masked_select(X, ~X_tilde_mask).view(B, S-C, h) # The remainder of tokens for skip  connection
        
        # The order / sorting of the final concat does not matter at all thanks to our friend positional encoding.
        X_l1 = torch.concat((X_routed, X_unrouted), dim=1) # Careful to concat along the sequence dimension

        return X_l1


B = 2
S = 10
C = 5
n_embed = 768
n_head = 12
X = torch.randn(B, S, n_embed, device=device)

layer = MoDDecoderLayer(n_embed, n_head, C).to(device)
# layer = nn.TransformerDecoderLayer(d_model=h, nhead=nhead, activation=F.glu, batch_first=True)

out = layer(X)
print_tensor("input", X)
print_tensor("output", out)
del X, layer
torch.cuda.empty_cache()

input
--------------------
torch.Size([2, 10, 768])
tensor([[[-1.6291,  1.3206,  1.5349,  ...,  0.6512,  0.2013, -0.8091],
         [-0.6031,  0.7363,  1.3001,  ..., -0.9398,  0.1413, -0.2003],
         [ 1.2939,  1.2201,  0.1389,  ...,  0.2965, -0.7913, -0.6119],
         ...,
         [-0.4256,  2.3303, -1.0826,  ..., -0.6241,  3.0210, -2.6685],
         [ 2.5794,  1.4523,  0.7480,  ...,  0.0672,  0.0753, -1.0983],
         [ 0.7906, -0.3601,  3.4060,  ..., -0.0181,  0.6491, -1.7173]],

        [[ 1.5160, -1.0064, -0.0648,  ...,  0.0188,  0.1854, -1.1085],
         [ 0.6835,  1.1110, -0.4112,  ...,  0.3753, -1.6994,  1.1767],
         [-0.7526,  0.2568,  0.6397,  ..., -3.7622,  0.1098,  0.4100],
         ...,
         [-1.9752, -0.6943,  0.3540,  ..., -0.1650,  1.4039,  0.8064],
         [-0.2743, -0.2561,  0.4554,  ...,  0.2385, -0.4198, -1.0945],
         [ 0.4451,  0.5967,  1.3188,  ...,  0.5855, -0.0664,  0.9142]]],
       device='cuda:0')

output
--------------------
torch.Size(

In [None]:
# TODO: find pytorch glu implementation https://pytorch.org/cppdocs/api/function_namespaceat_1aa10cf0aaff07f0a75dbfa51f168f563d.html#_CPPv4N2at3gluERKN2at6TensorE7int64_t
# -- looks like I need to search the codebase as pytorch uses a compiled function https://github.com/pytorch/pytorch/blob/7c23fed12c24ad7d635b0aa7af08449c9510375c/torch/nn/functional.py#L1514

In [None]:
help(nn.Module)

In [None]:
import torch
import torch.nn as nn

class CustomModule(nn.Module):
    def __init__(self, L, D):
        super(CustomModule, self).__init__()
        # L and D can be set as per model requirements or input feature sizes
        self.L = L
        self.D = D
        # Assume some initial data tensor, possibly as a parameter or buffer if fixed
        self.data = nn.Parameter(torch.randn(1, L, D))

    def forward(self, indices):
        # Obtain batch size from indices
        N, M = indices.shape

        # Assuming data is (1, L, D) and needs to be expanded to (N, L, D)
        data_expanded = self.data.expand(N, -1, -1)

        # Create a range for the N dimension to use in advanced indexing
        batch_indices = torch.arange(N, device=indices.device).view(-1, 1).expand(-1, M)

        # Use advanced indexing to get the NxMxD tensor
        result = data_expanded[batch_indices, indices]

        return result

# Example usage
N, M, L, D = 3, 4, 5, 2
indices = torch.randint(low=0, high=L, size=(N, M))

model = CustomModule(L, D)
result = model(indices)

print("Result Shape:", result.shape)  # Should be (N, M, D)


Result Shape: torch.Size([3, 4, 2])


## 3. Training on fake data

## 4. Training on sample real data
TinyShakespeare + tiktoken tokenisation à la Kaparthy

## 5. Evaluation / ablations

1 idea to choose c values is to train with no c constraint and look at activation distributions

In [None]:
# Need to combine the one-hot vectors in the last dimension to get 2x2x1 instead of 2x2x4
torch.zeros(r_l_i.squeeze().size(0), S).scatter_(1, r_l_i.squeeze(), 1.)

In [None]:
R_l = router(X_l)
r_l, r_l_i = R_l.topk(C, dim=1, sorted=True)
r_l_i.squeeze_() # Make shape [Batch, C]

X_tilde_mask = torch.zeros(r_l_i.size(0), S).scatter_(1, r_l_i, 1.).type(torch.BoolTensor).unsqueeze(-1)
X_tilde = torch.masked_select(X_l, X_tilde_mask).view(B, C, h)

X_routed = r_l * layer(X_tilde) + X_tilde
X_unrouted = torch.masked_select(X_l, ~X_tilde_mask).view(B, S-C, h) # The remainder of tokens for skip  connection

X_l1 = torch.concat((X_routed, X_unrouted), dim=1)


print_tensor("X_l", X_l)
print_tensor("R_l", R_l)
print_tensor("r_l", r_l)
print_tensor("r_l_i", r_l_i)
print_tensor("X_tilde_mask", X_tilde_mask)
print_tensor("X_tilde", X_tilde)
print_tensor("X_routed", X_routed)

In [None]:
torch.concat((X_routed, X_unrouted), dim=1)

In [None]:
labels = torch.tensor([1, 4, 1, 0, 5, 2])
labels = labels.unsqueeze(0)
target = torch.zeros(labels.size(0), 15).scatter_(1, labels, 1.)
labels.shape, target.shape

In [None]:
X_l.shape

In [None]:
scores = torch.randn(10, 6000)
idx = torch.topk(scores, 6000, dim=1, sorted=True)
out = torch.gather(scores, dim=1, index=idx.indices)
idx.values.shape

In [None]:
r_l, r_l_i, X_l[0, :], X_tilde[0, :]

In [None]:
X_l.r_l_i[0]

In [None]:
torch.masked_select(X_l, ~X_tilde_mask.unsqueeze(-1)).view(B, (S + 1) - C, h)

In [None]:
torch.randn(384).view(2, 2, 32)

In [None]:
X_l.shape

In [None]:
r_l.shape, X_tilde.shape

In [None]:
X_tilde * r_l.unsqueeze(-1)

In [None]:
r_l *

In [None]:
X_l.shape, i.shape, i

In [None]:
x_tilde.shape, x_tilde

In [None]:
torch.masked_select(X_l, x_tilde.unsqueeze(-1)).view(B, C, h)

In [None]:
x_tilde.type(torch.BoolTensor)

In [None]:
F.one_hot(i, S).shape, x_routed.shape

In [None]:
x.scatter(-1, i, x)
X.clone

In [None]:
X.scatter(-1, x_tilde)

In [None]:
R.shape, x.shape, minc.shape

In [None]:
torch.gt(R, minc.unsqueeze(-1))

In [None]:
i.shape

In [None]:
pbr.shape

In [None]:
X.shape

In [None]:
R.shape

In [None]:
R[-1] > pbr

In [None]:
1  - (C/S) # routing the top 75th percentile

In [None]:
R.topk(C, dim=1)

In [None]:
top_c.indices.squeeze()

In [None]:
X.topk(1, dim=2).values.min()

In [None]:
torch.concat((X,X), dim=1).shape