In [2]:
import torch

from colt5_attention import (
    ConditionalRoutedFeedForward,
    ConditionalRoutedAttention,
    ConditionalRoutedTransformerBlock
)


# Single

In [None]:
# mock input, say it is 32768 length

tokens = torch.randn(2, 32768, 512)
mask = torch.ones(2, 32768).bool()  # can handle variable lengthed sequences

# feedforward

ff = ConditionalRoutedFeedForward(
    dim = 512,
    light_ff_mult = 0.5,      # hidden dimension ratio of light branch
    heavy_ff_mult = 4,        # hidden dimension ratio of heavy branch
    num_heavy_tokens = 1024   # heavy branch receives only 1024 routed tokens of 32768
)

ff_out = ff(tokens, mask = mask)  # (2, 32768, 512) - light and heavy branch summed

# attention

attn = ConditionalRoutedAttention(
    dim = 512,
    light_dim_head = 64,       # attention head dimension of light branch
    light_heads = 8,           # number of attention heads for light branch
    light_window_size = 128,   # local attention receptive field for light
    heavy_dim_head = 64,       # attention head dimension of heavy branch
    heavy_heads = 8,           # number of attention heads for heavy branch
    num_heavy_tokens_q = 1024, # heavy branch receives only 1024 routed tokens of 32768
    num_heavy_tokens_kv = 1024 # heavy branch receives only 1024 routed tokens of 32768
)

block = ConditionalRoutedTransformerBlock(
    dim = 512,
    light_dim_head = 64,
    light_heads = 8,
    light_window_size = 128,
    heavy_dim_head = 64,
    heavy_heads = 8,
    light_ff_mult = 0.5,
    heavy_ff_mult = 4,
    num_heavy_ff_tokens = 1024,
    num_heavy_attn_tokens_q = 1024,
    num_heavy_attn_tokens_kv = 1024
)

In [2]:
#time the for loop
import time
start = time.time()
for i in range(100):
    block_out = block(tokens, mask = mask) # (2, 32768, 512)
end = time.time()

print("Time taken for 100 iterations: ", end - start)

Time taken for 100 iterations:  117.16679191589355


input 10x smaller, loop 10x more

In [3]:
# mock input, say it is 32768 length

tokens = torch.randn(2, 3276, 512)
mask = torch.ones(2, 3276).bool()  # can handle variable lengthed sequences

# feedforward

ff = ConditionalRoutedFeedForward(
    dim = 512,
    light_ff_mult = 0.5,      # hidden dimension ratio of light branch
    heavy_ff_mult = 4,        # hidden dimension ratio of heavy branch
    num_heavy_tokens = 1024   # heavy branch receives only 1024 routed tokens of 32768
)

ff_out = ff(tokens, mask = mask)  # (2, 32768, 512) - light and heavy branch summed

# attention

attn = ConditionalRoutedAttention(
    dim = 512,
    light_dim_head = 64,       # attention head dimension of light branch
    light_heads = 8,           # number of attention heads for light branch
    light_window_size = 128,   # local attention receptive field for light
    heavy_dim_head = 64,       # attention head dimension of heavy branch
    heavy_heads = 8,           # number of attention heads for heavy branch
    num_heavy_tokens_q = 102, # heavy branch receives only 1024 routed tokens of 32768
    num_heavy_tokens_kv = 102 # heavy branch receives only 1024 routed tokens of 32768
)

block = ConditionalRoutedTransformerBlock(
    dim = 512,
    light_dim_head = 64,
    light_heads = 8,
    light_window_size = 128,
    heavy_dim_head = 64,
    heavy_heads = 8,
    light_ff_mult = 0.5,
    heavy_ff_mult = 4,
    num_heavy_ff_tokens = 102,
    num_heavy_attn_tokens_q = 102,
    num_heavy_attn_tokens_kv = 102
)

In [4]:
#time the for loop
import time
start = time.time()
for i in range(1000):   
    block_out = block(tokens, mask = mask) # (2, 3276, 512)
end = time.time()

print("Time taken for 1000 iterations: ", end - start)

Time taken for 1000 iterations:  116.82632446289062


In [3]:
# mock input, say it is 32768 length

tokens = torch.randn(2, 327, 512)
mask = torch.ones(2, 327).bool()  # can handle variable lengthed sequences

# feedforward

ff = ConditionalRoutedFeedForward(
    dim = 512,
    light_ff_mult = 0.5,      # hidden dimension ratio of light branch
    heavy_ff_mult = 4,        # hidden dimension ratio of heavy branch
    num_heavy_tokens = 32   # heavy branch receives only 1024 routed tokens of 32768
)

ff_out = ff(tokens, mask = mask)  # (2, 32768, 512) - light and heavy branch summed

# attention

attn = ConditionalRoutedAttention(
    dim = 512,
    light_dim_head = 64,       # attention head dimension of light branch
    light_heads = 8,           # number of attention heads for light branch
    light_window_size = 128,   # local attention receptive field for light
    heavy_dim_head = 64,       # attention head dimension of heavy branch
    heavy_heads = 8,           # number of attention heads for heavy branch
    num_heavy_tokens_q = 32, # heavy branch receives only 1024 routed tokens of 32768
    num_heavy_tokens_kv = 32 # heavy branch receives only 1024 routed tokens of 32768
)

block = ConditionalRoutedTransformerBlock(
    dim = 512,
    light_dim_head = 64,
    light_heads = 8,
    light_window_size = 128,
    heavy_dim_head = 64,
    heavy_heads = 8,
    light_ff_mult = 0.5,
    heavy_ff_mult = 4,
    num_heavy_ff_tokens = 32,
    num_heavy_attn_tokens_q = 32,
    num_heavy_attn_tokens_kv = 32
)

In [4]:
#time the for loop
import time
start = time.time()
for i in range(10000):   
    block_out = block(tokens, mask = mask) # (2, 327, 512)
end = time.time()

print("Time taken for 1000 iterations: ", end - start)

Time taken for 1000 iterations:  132.3905553817749


In [6]:
# mock input, say it is 32768 length

tokens = torch.randn(2, 109, 512)
mask = torch.ones(2, 109).bool()  # can handle variable lengthed sequences

# feedforward

ff = ConditionalRoutedFeedForward(
    dim = 512,
    light_ff_mult = 0.5,      # hidden dimension ratio of light branch
    heavy_ff_mult = 4,        # hidden dimension ratio of heavy branch
    num_heavy_tokens = 10   # heavy branch receives only 1024 routed tokens of 32768
)

ff_out = ff(tokens, mask = mask)  # (2, 32768, 512) - light and heavy branch summed

# attention

attn = ConditionalRoutedAttention(
    dim = 512,
    light_dim_head = 64,       # attention head dimension of light branch
    light_heads = 8,           # number of attention heads for light branch
    light_window_size = 128,   # local attention receptive field for light
    heavy_dim_head = 64,       # attention head dimension of heavy branch
    heavy_heads = 8,           # number of attention heads for heavy branch
    num_heavy_tokens_q = 10, # heavy branch receives only 1024 routed tokens of 32768
    num_heavy_tokens_kv = 10 # heavy branch receives only 1024 routed tokens of 32768
)

block = ConditionalRoutedTransformerBlock(
    dim = 512,
    light_dim_head = 64,
    light_heads = 8,
    light_window_size = 128,
    heavy_dim_head = 64,
    heavy_heads = 8,
    light_ff_mult = 0.5,
    heavy_ff_mult = 4,
    num_heavy_ff_tokens = 10,
    num_heavy_attn_tokens_q = 10,
    num_heavy_attn_tokens_kv = 10
)

In [7]:
#time the for loop
import time
start = time.time()
for i in range(30000):   
    block_out = block(tokens, mask = mask) # (2, 327, 512)
end = time.time()

print("Time taken for 1000 iterations: ", end - start)

Time taken for 1000 iterations:  247.19424271583557


# Batch

In [6]:
# mock input, say it is 32768 length

tokens = torch.randn(20, 3276, 512)
mask = torch.ones(20, 3276).bool()  # can handle variable lengthed sequences

# feedforward

ff = ConditionalRoutedFeedForward(
    dim = 512,
    light_ff_mult = 0.5,      # hidden dimension ratio of light branch
    heavy_ff_mult = 4,        # hidden dimension ratio of heavy branch
    num_heavy_tokens = 102   # heavy branch receives only 1024 routed tokens of 32768
)

ff_out = ff(tokens, mask = mask)  # (2, 32768, 512) - light and heavy branch summed

# attention

attn = ConditionalRoutedAttention(
    dim = 512,
    light_dim_head = 64,       # attention head dimension of light branch
    light_heads = 8,           # number of attention heads for light branch
    light_window_size = 128,   # local attention receptive field for light
    heavy_dim_head = 64,       # attention head dimension of heavy branch
    heavy_heads = 8,           # number of attention heads for heavy branch
    num_heavy_tokens_q = 102, # heavy branch receives only 1024 routed tokens of 32768
    num_heavy_tokens_kv = 102 # heavy branch receives only 1024 routed tokens of 32768
)

block = ConditionalRoutedTransformerBlock(
    dim = 512,
    light_dim_head = 64,
    light_heads = 8,
    light_window_size = 128,
    heavy_dim_head = 64,
    heavy_heads = 8,
    light_ff_mult = 0.5,
    heavy_ff_mult = 4,
    num_heavy_ff_tokens = 102,
    num_heavy_attn_tokens_q = 102,
    num_heavy_attn_tokens_kv = 102
)

In [7]:
#time the for loop
import time
start = time.time()
for i in range(10):
    block_out = block(tokens, mask = mask) # (2, 32768, 512)
end = time.time()

print("Time taken for 100 iterations: ", end - start)

Time taken for 100 iterations:  12.182886838912964


In [10]:
# mock input, say it is 32768 length

tokens = torch.randn(200, 327, 512)
mask = torch.ones(200, 327).bool()  # can handle variable lengthed sequences

# feedforward

ff = ConditionalRoutedFeedForward(
    dim = 512,
    light_ff_mult = 0.5,      # hidden dimension ratio of light branch
    heavy_ff_mult = 4,        # hidden dimension ratio of heavy branch
    num_heavy_tokens = 10   # heavy branch receives only 1024 routed tokens of 32768
)

ff_out = ff(tokens, mask = mask)  # (2, 32768, 512) - light and heavy branch summed

# attention

attn = ConditionalRoutedAttention(
    dim = 512,
    light_dim_head = 64,       # attention head dimension of light branch
    light_heads = 8,           # number of attention heads for light branch
    light_window_size = 128,   # local attention receptive field for light
    heavy_dim_head = 64,       # attention head dimension of heavy branch
    heavy_heads = 8,           # number of attention heads for heavy branch
    num_heavy_tokens_q = 10, # heavy branch receives only 1024 routed tokens of 32768
    num_heavy_tokens_kv = 10 # heavy branch receives only 1024 routed tokens of 32768
)

block = ConditionalRoutedTransformerBlock(
    dim = 512,
    light_dim_head = 64,
    light_heads = 8,
    light_window_size = 128,
    heavy_dim_head = 64,
    heavy_heads = 8,
    light_ff_mult = 0.5,
    heavy_ff_mult = 4,
    num_heavy_ff_tokens = 10,
    num_heavy_attn_tokens_q = 10,
    num_heavy_attn_tokens_kv = 10
)

In [11]:
#time the for loop
import time
start = time.time()
for i in range(10):
    block_out = block(tokens, mask = mask) # (2, 32768, 512)
end = time.time()

print("Time taken for 100 iterations: ", end - start)

Time taken for 100 iterations:  13.222659587860107


In [12]:
# mock input, say it is 32768 length

tokens = torch.randn(2000, 32, 512)
mask = torch.ones(2000, 32).bool()  # can handle variable lengthed sequences

# feedforward

ff = ConditionalRoutedFeedForward(
    dim = 512,
    light_ff_mult = 0.5,      # hidden dimension ratio of light branch
    heavy_ff_mult = 4,        # hidden dimension ratio of heavy branch
    num_heavy_tokens = 3  # heavy branch receives only 1024 routed tokens of 32768
)

ff_out = ff(tokens, mask = mask)  # (2, 32768, 512) - light and heavy branch summed

# attention

attn = ConditionalRoutedAttention(
    dim = 512,
    light_dim_head = 64,       # attention head dimension of light branch
    light_heads = 8,           # number of attention heads for light branch
    light_window_size = 128,   # local attention receptive field for light
    heavy_dim_head = 64,       # attention head dimension of heavy branch
    heavy_heads = 8,           # number of attention heads for heavy branch
    num_heavy_tokens_q = 3, # heavy branch receives only 1024 routed tokens of 32768
    num_heavy_tokens_kv = 3 # heavy branch receives only 1024 routed tokens of 32768
)

block = ConditionalRoutedTransformerBlock(
    dim = 512,
    light_dim_head = 64,
    light_heads = 8,
    light_window_size = 128,
    heavy_dim_head = 64,
    heavy_heads = 8,
    light_ff_mult = 0.5,
    heavy_ff_mult = 4,
    num_heavy_ff_tokens = 3,
    num_heavy_attn_tokens_q = 3,
    num_heavy_attn_tokens_kv = 3
)

In [13]:
#time the for loop
import time
start = time.time()
for i in range(10):
    block_out = block(tokens, mask = mask) # (2, 32768, 512)
end = time.time()

print("Time taken for 100 iterations: ", end - start)

Time taken for 100 iterations:  18.901021242141724
