In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

class NestedFFN(nn.Module):
    def __init__(self, d_model, d_ff, num_granularities=4):
        super(NestedFFN, self).__init__()

        # Initialize FFN layers
        self.num_granularities = num_granularities
        self.d_model = d_model
        self.d_ff = d_ff

        # Create weight matrices for W1 and W2 with the largest size
        self.W1 = nn.Parameter(torch.randn(d_ff, d_model))
        self.W2 = nn.Parameter(torch.randn(d_ff, d_model))

        # Calculate the sizes of each granularity
        self.granularity_sizes = [d_ff // (2 ** i) for i in range(num_granularities)]
        print(self.granularity_sizes)

    def forward(self, x, granularity_level):
        assert 0 <= granularity_level < self.num_granularities, "Invalid granularity level"

        # m_i Number of neuron selected
        m_i = self.granularity_sizes[granularity_level]

        # Perform the FFN operation with the selected subset of weights
        hidden = F.gelu(x @ self.W1[:m_i, :].T)
        output = hidden @ self.W2[:m_i, :]

        return output

class TransformerLayer(nn.Module):
    def __init__(self, d_model, num_heads, nested_ffn, granularity_level, dropout=0.1):
        super(TransformerLayer, self).__init__()
        self.self_attn = nn.MultiheadAttention(d_model, num_heads, dropout=dropout)
        self.granularity_level = granularity_level
        self.nested_ffn = nested_ffn
        self.layernorm1 = nn.LayerNorm(d_model)
        self.layernorm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, src, src_mask=None, src_key_padding_mask=None):
        src2, _ = self.self_attn(src, src, src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)
        src = src + self.dropout(src2)
        src = self.layernorm1(src)

        src2 = self.nested_ffn(src, self.granularity_level)
        src = src + self.dropout(src2)
        src = self.layernorm2(src)

        return src

class Transformer(nn.Module):
    def __init__(self, d_model, num_layers, num_heads, nested_ffn, num_granularities=4, dropout=0.1):
        super(Transformer, self).__init__()
        self.models = [ ]

        # We Stack l Layers with the same granularity_level
        # Creating M1, M2, ... , Mg
        for id in range(num_granularities):

          self.models.append( nn.ModuleList([
            TransformerLayer(d_model, num_heads, nested_ffn, id, dropout)
            for _ in range(num_layers)
          ]))

        self.layernorm = nn.LayerNorm(d_model)

    def forward(self, src, src_mask=None, src_key_padding_mask=None, granularity_level = 0):
      # So granularity_level indicates the model M_i that we want to use
        for layer in self.models[granularity_level]:
            src = layer(src, src_mask, src_key_padding_mask)
        src = self.layernorm(src)
        return src

In [None]:
# Hyperparameters
d_model = 512
d_ff = 2048
num_granularities = 4
num_layers = 6
num_heads = 8
dropout = 0.1
epochs = 10
learning_rate = 0.001


batch_size = 32
sequence_length = 10
num_batches = 100
inputs = [torch.randn(sequence_length, batch_size, d_model) for _ in range(num_batches)]
targets = [torch.randn(sequence_length, batch_size, d_model) for _ in range(num_batches)]

# Initialize the model
nested_ffn = NestedFFN(d_model, d_ff, num_granularities)
model = Transformer(d_model, num_layers, num_heads, nested_ffn, num_granularities=num_granularities, dropout=dropout)

# Loss function and optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Training loop
for epoch in range(epochs):
    total_loss = 0.0
    for batch_idx in range(num_batches):
        input_batch = inputs[batch_idx]
        target_batch = targets[batch_idx]

        # Zero the gradients
        optimizer.zero_grad()

        # Compute the loss for each granularity level and combine them
        losses = []
        for granularity_level in range(num_granularities):
            output = model(input_batch, granularity_level=granularity_level)
            loss = criterion(output, target_batch)
            losses.append(loss)

        # Combine the losses
        combined_loss = sum(losses) / num_granularities

        # Backpropagation
        combined_loss.backward()

        # Update parameters
        optimizer.step()

        # Accumulate loss for reporting
        total_loss += combined_loss.item()

    print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss/num_batches}")


[2048, 1024, 512, 256]
Epoch 1/10, Loss: 1.8654756820201874
Epoch 2/10, Loss: 1.6407215428352355
Epoch 3/10, Loss: 1.4758982849121094
Epoch 4/10, Loss: 1.3556608510017396
Epoch 5/10, Loss: 1.268385728597641
Epoch 6/10, Loss: 1.2044523561000824
Epoch 7/10, Loss: 1.1572381377220153
Epoch 8/10, Loss: 1.1219669103622436
Epoch 9/10, Loss: 1.095331188440323
Epoch 10/10, Loss: 1.0746522784233092


In [None]:
# Output Nodes
d_model = 512
# Input Nodes
d_ff = 2048
num_granularities = 4
#Number of transformers for each granularity_level ( l of the paper )
num_layers = 2
num_heads = 8
granularity_level = 4

In [None]:
nested_ffn = NestedFFN(d_model, d_ff, num_granularities)

transformer = Transformer(d_model, num_layers, num_heads, nested_ffn, granularity_level)

x = torch.randn(10, 32, d_model)  # sequence length 10, batch size 32, d_model 512

output = transformer(x)

print(output.shape)

[2048, 1024, 512, 256]
2048
2048
torch.Size([10, 32, 512])
