<a href="https://colab.research.google.com/github/Quinnybob/sparse-depth-transformer/blob/main/sparseDepthTransformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [13]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import time

# -------------------------------
# Semantic Scoring Module
# -------------------------------
class TokenSemanticScorer(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        self.fc = nn.Linear(embed_dim, 1)

    def forward(self, x):
        scores = self.fc(x).squeeze(-1)
        return torch.sigmoid(scores)  # Between 0 and 1

# -------------------------------
# Mini Transformer Block
# -------------------------------
class MiniTransformerBlock(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        self.ln1 = nn.LayerNorm(embed_dim)
        self.attn = nn.MultiheadAttention(embed_dim, num_heads=2, batch_first=True)
        self.ln2 = nn.LayerNorm(embed_dim)
        self.ff = nn.Sequential(
            nn.Linear(embed_dim, embed_dim * 4),
            nn.ReLU(),
            nn.Linear(embed_dim * 4, embed_dim)
        )

    def forward(self, x):
        h = x
        x = self.ln1(x)
        attn_output, _ = self.attn(x, x, x)
        x = h + attn_output
        h = x
        x = self.ln2(x)
        x = h + self.ff(x)
        return x

# -------------------------------
# Sparse Depth Transformer with Hard Skipping
# -------------------------------
class SparseDepthTransformer(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_layers):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embed_dim)
        self.pos_embed = nn.Parameter(torch.randn(1, 512, embed_dim))
        self.semantic_scorer = TokenSemanticScorer(embed_dim)
        self.layers = nn.ModuleList([MiniTransformerBlock(embed_dim) for _ in range(num_layers)])
        self.ln_final = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, vocab_size)

    def forward(self, x):
        B, T = x.size()
        x = self.embed(x) + self.pos_embed[:, :T, :]
        semantic_scores = self.semantic_scorer(x)  # shape: (B, T)
        self.latest_layer_usage = torch.zeros_like(semantic_scores)

        for i, layer in enumerate(self.layers):
            threshold = i / len(self.layers)
            keep_mask = (semantic_scores > threshold)  # shape: (B, T)
            self.latest_layer_usage += keep_mask.float()

            # Hard skipping: only process tokens above threshold
            if keep_mask.any():
                mask_expanded = keep_mask.unsqueeze(-1).expand_as(x)
                x_new = layer(x.clone())  # clone to prevent in-place errors
                x = torch.where(mask_expanded, x_new, x)

        x = self.ln_final(x)
        logits = self.head(x)
        return logits

# -------------------------------
# Baseline Transformer (No Sparsity)
# -------------------------------
class BaselineTransformer(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_layers):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embed_dim)
        self.pos_embed = nn.Parameter(torch.randn(1, 512, embed_dim))
        self.layers = nn.ModuleList([MiniTransformerBlock(embed_dim) for _ in range(num_layers)])
        self.ln_final = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, vocab_size)

    def forward(self, x):
        B, T = x.size()
        x = self.embed(x) + self.pos_embed[:, :T, :]
        for layer in self.layers:
            x = layer(x)
        x = self.ln_final(x)
        logits = self.head(x)
        return logits

# -------------------------------
# Benchmark Function
# -------------------------------
def benchmark_model(model, tokens, use_cuda=False):
    if use_cuda:
        model = model.cuda()
        tokens = tokens.cuda()
        torch.cuda.reset_peak_memory_stats()
    else:
        torch.set_num_threads(1)

    start = time.time()
    with torch.no_grad():
        output = model(tokens)
    end = time.time()

    mem = torch.cuda.max_memory_allocated() / 1e6 if use_cuda else None
    layer_usage = getattr(model, "latest_layer_usage", None)
    return {
        "time_sec": end - start,
        "output_shape": output.shape,
        "max_memory_MB": mem,
        "avg_layers_per_token": layer_usage.mean().item() if layer_usage is not None else "N/A"
    }

# -------------------------------
# Main Test Script
# -------------------------------
if __name__ == "__main__":
    vocab_size = 5000
    embed_dim = 64
    num_layers = 6
    use_cuda = torch.cuda.is_available()

    tokens = torch.randint(0, vocab_size, (2, 20))  # batch of 2, 20 tokens

    print("=== Sparse Depth Transformer ===")
    sparse_model = SparseDepthTransformer(vocab_size, embed_dim, num_layers)
    sparse_stats = benchmark_model(sparse_model, tokens, use_cuda)
    print(sparse_stats)

    print("\n=== Baseline Transformer ===")
    baseline_model = BaselineTransformer(vocab_size, embed_dim, num_layers)
    baseline_stats = benchmark_model(baseline_model, tokens, use_cuda)
    print(baseline_stats)


=== Sparse Depth Transformer ===
{'time_sec': 0.004933357238769531, 'output_shape': torch.Size([2, 20, 5000]), 'max_memory_MB': 23.094784, 'avg_layers_per_token': 3.575000047683716}

=== Baseline Transformer ===
{'time_sec': 0.003687143325805664, 'output_shape': torch.Size([2, 20, 5000]), 'max_memory_MB': 27.005952, 'avg_layers_per_token': 'N/A'}
