In [1]:
from datetime import datetime

import torch
import torch.nn as nn

import os
import sys

module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

from src.models import BaselineEncoder, Performer, PerformerAttention, InfinityFormerAttention

In [2]:
device = 'cuda'

hidden_size = 32
num_heads = 2
num_layers = 1
dim_feedforward = 4 * 32

bs = 4
L = 2 ** 13

x = torch.randn(bs, L, hidden_size, device=device)
att = torch.randn(bs, L, device=device) > 0.5

In [3]:
attention = nn.MultiheadAttention(hidden_size, num_heads, batch_first=True)
model = BaselineEncoder(attention, hidden_size, num_layers, dim_feedforward).to(device)

start = datetime.now()
print(model(x, att)[0].shape)
print(datetime.now() - start)

torch.Size([4, 8192, 32])
0:00:00.720032


In [4]:
attention = PerformerAttention(hidden_dim=hidden_size, num_heads=num_heads)
model = BaselineEncoder(attention, hidden_size, num_layers, dim_feedforward).to(device)

start = datetime.now()
print(model(x, att)[0].shape)
print(datetime.now() - start)

torch.Size([4, 8192, 32])
0:00:00.029000


In [7]:
config = {
    'head_size': hidden_size  // num_heads, 
    'length': L, 
    'target_len': 70, 
    'attn_func': 'softmax', 
    'attn_num_basis': 100, 
    'attn_drop': 0.1, 
    'infinite_memory': True, 
    'n_layers': num_layers, 
    'n_heads': num_heads, 
    'd_model': hidden_size, 
    'mask': True, 
    'mask_type': 'cnn', 
    'kl_regularizer': True, 
    'sigma_0': 0, 
    'mu_0': 0,
    'share_mask': True,
    'device': 'cpu'
}

attention = InfinityFormerAttention(**config)
model = BaselineEncoder(attention, hidden_size, num_layers, dim_feedforward).to(device)

start = datetime.now()
print(model(x, att)[0].shape)
print(datetime.now() - start)

torch.Size([4, 8192, 32])
0:00:00.157002


In [8]:
config = {
    'dim': hidden_size,
    'depth': 4,
    'heads': num_heads,
    'dim_head': hidden_size
}

start = datetime.now()
Performer(**config).cuda()(x).shape
print(datetime.now() - start)

0:00:00.786000
