### CS 5223 - Grad Project
Tushaar Gangavarapu - TG352 and Lucas Molter - LM865

Professor: David Bindel

May 19th 2023

Cornell University

#### Introduction

The backbone of recent (and important) NLP developments has been the Attention matrix, first mentioned at the 2017 paper "Attention is All You Need". To explain it very briefly, this approach has allowed NLP models represent words as a weighted average 

### togepi

toeplitz-based generative pretraining

In [1]:
!pip install torchinfo

Collecting torchinfo
  Downloading torchinfo-1.8.0-py3-none-any.whl (23 kB)
Installing collected packages: torchinfo
Successfully installed torchinfo-1.8.0


---

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torchinfo import summary

import math
from prettytable import PrettyTable

from dataclasses import dataclass

In [3]:
def print_params(module, print_vals=True):
    params_table = PrettyTable(['module', 'num_params', 'requires_grad'])
    total_trainable_params = 0
    for name, param in module.named_parameters():
        params_table.add_row([name, param.numel(), param.requires_grad])
        if param.requires_grad:
            total_trainable_params = total_trainable_params + param.numel()
    print(params_table)
    if total_trainable_params > 1e6:
        print(f'total trainable params: {(total_trainable_params / 1e6):0.2f}M')
    else:
        print(f'total trainable params: {total_trainable_params}')

#### config

In [4]:
@dataclass
class TestTogepiConfig:
    # embedding
    vocab_size = 10  # includes special tokens ([PAD], [MASK], [CLS], [SEP]) 
    padding_idx = 0
    max_position_embeddings = 7  # includes proxy for padding token; max_length = max_position_embeddings - 1
    pad_position = 0
    num_token_types = 3  # includes padding token type
    pad_token_type = 0
    embedding_dim = 4
    embedding_dropout_proba = 0.1
    
    # attention
    causal_attn = True  # for generative pre-training
    num_attn_heads = 2
    attn_actn = 'gelu'
    sparse_dens = 0.3
    attn_dropout_proba = 0.1

test_config = TestTogepiConfig()
test_config.vocab_size

10

#### embedding

In [5]:
class Embedding(nn.Module):
    def __init__(self, config):
        super().__init__()

        self._padding_idx = config.padding_idx
        self._pad_position = config.pad_position
        self._pad_token_type = config.pad_token_type

        self.tok_emb = nn.Embedding(num_embeddings=config.vocab_size, embedding_dim=config.embedding_dim, padding_idx=config.padding_idx)
        self.pos_emb = nn.Embedding(num_embeddings=config.max_position_embeddings, embedding_dim=config.embedding_dim, padding_idx=config.pad_position)
        self.type_emb = nn.Embedding(num_embeddings=config.num_token_types, embedding_dim=config.embedding_dim, padding_idx=config.pad_token_type)

        nn.init.xavier_uniform_(self.tok_emb.weight.data)
        self.tok_emb.weight.data[self._padding_idx] = torch.zeros(config.embedding_dim)
        nn.init.xavier_uniform_(self.pos_emb.weight.data)
        self.tok_emb.weight.data[self._pad_position] = torch.zeros(config.embedding_dim)
        nn.init.xavier_uniform_(self.type_emb.weight.data)
        self.tok_emb.weight.data[self._pad_token_type] = torch.zeros(config.embedding_dim)

        self.layer_norm = nn.LayerNorm(normalized_shape=config.embedding_dim, eps=1e-12)
        self.dropout = nn.Dropout(p=config.embedding_dropout_proba)

    def forward(self, input_ids, token_type_ids=None, padding_mask=None):
        # input_ids: (batch_size, max_length)
        # padding_mask: (batch_size, max_length)
        max_length = input_ids.shape[1]
        # assert(max_length == self.pos_emb.num_embeddings - 1)
        if padding_mask is None:
            # 1: no pad, 0: pad
            padding_mask = torch.where(input_ids == self._padding_idx, 0, 1)

        # position_ids: (batch_size, max_length)
        # assert(self._pad_position == 0)
        position_ids = torch.arange(max_length, dtype=torch.long, device=input_ids.device) + 1  # assuming zero is reserved for pad position
        position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
        position_ids = position_ids.masked_fill(padding_mask == 0, self._pad_position)

        # token_type_ids: (batch_size, max_length)
        if token_type_ids is None:
            # assert(self._pad_token_type == 0)
            token_type_ids = torch.ones_like(input_ids)  # assuming zero is reserved for pad position
        token_type_ids = token_type_ids.masked_fill(padding_mask == 0, self._pad_token_type)
        
        token_embeddings = self.tok_emb(input_ids)
        position_embeddings = self.pos_emb(position_ids)
        token_type_embeddings = self.type_emb(token_type_ids)
        
        return self.dropout(self.layer_norm(token_embeddings + position_embeddings + token_type_embeddings))

test_input_ids = torch.tensor([[1, 2, 3, 4, 0, 0], [3, 4, 5, 6, 7, 8]])
test_emb_obj = Embedding(test_config)
test_emb = test_emb_obj(test_input_ids)
print_params(test_emb_obj)
test_emb, test_emb.shape

+-------------------+------------+---------------+
|       module      | num_params | requires_grad |
+-------------------+------------+---------------+
|   tok_emb.weight  |     40     |      True     |
|   pos_emb.weight  |     28     |      True     |
|  type_emb.weight  |     12     |      True     |
| layer_norm.weight |     4      |      True     |
|  layer_norm.bias  |     4      |      True     |
+-------------------+------------+---------------+
total trainable params: 88


(tensor([[[-1.7468,  1.0223, -0.1759,  0.9004],
          [ 0.7343,  1.2877, -0.0000, -1.6012],
          [-1.7097,  1.3248, -0.1124,  0.4973],
          [ 1.1672,  1.0530, -1.0768, -1.1434],
          [ 0.8496,  0.2151, -1.8746,  0.8099],
          [ 0.8496,  0.2151, -1.8746,  0.8099]],
 
         [[-1.5472,  1.3251, -0.5070,  0.7291],
          [ 0.5841,  1.3160, -0.2221, -1.6780],
          [-1.1318, -0.9334,  1.6043,  0.4609],
          [ 0.5997,  1.4710, -1.4034, -0.6673],
          [-0.0000,  1.4685, -0.9210,  0.6706],
          [-1.6175,  0.0000,  0.6469, -0.3610]]], grad_fn=<MulBackward0>),
 torch.Size([2, 6, 4]))

#### attention

In [6]:
class MultiHeadAttention(nn.Module):
    def __init__(self, config):
        super().__init__()

        assert(config.embedding_dim % config.num_attn_heads == 0)

        self._num_heads = config.num_attn_heads
        self._per_head_dim = config.embedding_dim // config.num_attn_heads
        max_length = config.max_position_embeddings - 1

        self.wq = nn.Linear(in_features=config.embedding_dim, out_features=config.embedding_dim)
        self.wk = nn.Linear(in_features=config.embedding_dim, out_features=config.embedding_dim)
        self.wv = nn.Linear(in_features=config.embedding_dim, out_features=config.embedding_dim)
        nn.init.xavier_normal_(self.wq.weight.data)
        nn.init.xavier_normal_(self.wk.weight.data)
        nn.init.xavier_normal_(self.wv.weight.data)

        self._causal = config.causal_attn
        if config.causal_attn:
            self.register_buffer('causal_attn_mask', torch.tril(torch.ones(max_length, max_length)).view(1, 1, max_length, max_length))

        self.wo = nn.Linear(in_features=config.embedding_dim, out_features=config.embedding_dim)
        nn.init.xavier_normal_(self.wo.weight.data)

        self.layer_norm = nn.LayerNorm(normalized_shape=config.embedding_dim, eps=1e-12)
        self.dropout = nn.Dropout(p=config.attn_dropout_proba)
        self.softmax = nn.Softmax(dim=-1)
    
    def _extend_padding_mask(self, padding_mask, embeddings):
        # padding_mask: (batch_size, max_length)
        if padding_mask is None:
            padding_mask = torch.ones(embeddings.shape[0], embeddings.shape[1])

        extended_padding_mask = padding_mask.unsqueeze(1).unsqueeze(2)
        extended_padding_mask = extended_padding_mask.to(dtype=embeddings.dtype)  # amp/fp16 compatibility
        extended_padding_mask = (1 - extended_padding_mask) * -1e4
        return extended_padding_mask

    def forward(self, embeddings, padding_mask=None):
        batch_size = embeddings.shape[0]
        max_length = embeddings.shape[1]
        embedding_dim = embeddings.shape[2]

        # embeddings: (batch_size, max_length, embedding_dim)
        # attn_mask: 1 = non-pad, 0 = pad
        # projected_*: (batch_size, max_length, num_heads * per_head_dim)
        projected_query = self.wq(embeddings)
        projected_key = self.wk(embeddings)
        projected_value = self.wv(embeddings)

        sliced_projected_query = projected_query.view(batch_size, max_length, self._num_heads, self._per_head_dim).permute(0, 2, 1, 3)
        sliced_projected_key_tr = projected_query.view(batch_size, max_length, self._num_heads, self._per_head_dim).permute(0, 2, 3, 1)
        sliced_projected_value = projected_query.view(batch_size, max_length, self._num_heads, self._per_head_dim).permute(0, 2, 1, 3)

        # attn_mat: (batch_size, num_heads, max_length, max_length)
        # attn_mat: QK' / sqrt(d)
        # attn_mask: set [pad] tok attn values to -inf
        attn_mat = torch.matmul(sliced_projected_query, sliced_projected_key_tr) / np.power(embedding_dim, 0.5)
        attn_mat = attn_mat + self._extend_padding_mask(padding_mask=padding_mask, embeddings=embeddings)
        if self._causal:
            attn_mat.masked_fill_(self.causal_attn_mask[:, :, :max_length, :max_length] == 0, -1e4)
        # attn_probs: (batch_size, num_heads, max_length, max_length)
        attn_probs = self.softmax(attn_mat)
        attn_probs = self.dropout(attn_probs)

        # ctx_vectors: (batch_size, num_heads, max_length, per_head_dim)
        #    .permute: (batch_size, max_length, num_heads, per_head_dim)
        #    .view   : (batch_size, max_length, num_heads * per_head_dim)
        ctx_vectors = torch.matmul(attn_probs, sliced_projected_value).permute(0, 2, 1, 3).contiguous().view(batch_size, max_length, -1)
        attn_output = self.wo(ctx_vectors)
        attn_output = self.dropout(attn_output)

        return self.layer_norm(attn_output + embeddings), attn_probs

test_mha_obj = MultiHeadAttention(test_config)
test_padding_mask = torch.tensor([[1, 1, 1, 1, 0, 0], [1, 1, 1, 1, 1, 1]])
test_mha_emb, test_mha_filters = test_mha_obj(test_emb, padding_mask=test_padding_mask)
print_params(test_mha_obj)
test_mha_emb, test_mha_emb.shape

+-------------------+------------+---------------+
|       module      | num_params | requires_grad |
+-------------------+------------+---------------+
|     wq.weight     |     16     |      True     |
|      wq.bias      |     4      |      True     |
|     wk.weight     |     16     |      True     |
|      wk.bias      |     4      |      True     |
|     wv.weight     |     16     |      True     |
|      wv.bias      |     4      |      True     |
|     wo.weight     |     16     |      True     |
|      wo.bias      |     4      |      True     |
| layer_norm.weight |     4      |      True     |
|  layer_norm.bias  |     4      |      True     |
+-------------------+------------+---------------+
total trainable params: 88


(tensor([[[-1.2223,  0.5883, -0.6726,  1.3067],
          [-0.2960,  1.3438,  0.3586, -1.4064],
          [-1.4819, -0.0752,  0.2372,  1.3199],
          [-0.1796,  1.5546, -0.1372, -1.2378],
          [-1.0302,  0.7668, -0.9444,  1.2078],
          [-1.0423,  0.9359, -0.9548,  1.0612]],
 
         [[-1.1850, -0.0741, -0.3184,  1.5776],
          [-0.9913,  1.0017,  0.9983, -1.0087],
          [ 1.1699, -1.5675,  0.4169, -0.0194],
          [-1.1786,  1.5858, -0.1229, -0.2843],
          [-1.2998,  1.1983, -0.6086,  0.7100],
          [ 0.0582, -1.6466,  0.9034,  0.6850]]],
        grad_fn=<NativeLayerNormBackward0>),
 torch.Size([2, 6, 4]))

In [7]:
class TogepiMultiHeadAttention(nn.Module):
    def __init__(self, config):
        super().__init__()

        assert(config.embedding_dim % config.num_attn_heads == 0)

        self._num_heads = config.num_attn_heads
        self._per_head_dim = config.embedding_dim // config.num_attn_heads
        max_length = config.max_position_embeddings - 1  # one position reserved for pad position
        self._training_max_length = max_length

        # out_features: (num_heads * per_head_dim)
        self.pre_proj = nn.Linear(in_features=config.embedding_dim, out_features=config.embedding_dim)
        self.pre_sparse_proj = nn.Linear(in_features=config.embedding_dim, out_features=config.embedding_dim)
        nn.init.xavier_normal_(self.pre_proj.weight.data)
        nn.init.xavier_normal_(self.pre_sparse_proj.weight.data)

        # randomly initialize point-spread functions, one per head
        # psf: [tok_weight, [tok_-1_weights, tok_-2_weight, ...], [..., tok_+2_weight, tok_+1_weight]]
        self.toeplitz_psfs = nn.Parameter(torch.randn(self._num_heads, 2 * max_length - 1, self._per_head_dim))
        self.attn_actn = F.gelu if config.attn_actn == 'gelu' else F.relu
        self.post_conv_proj = nn.Linear(in_features=config.embedding_dim, out_features=config.embedding_dim)
        nn.init.xavier_normal_(self.toeplitz_psfs.data)
        nn.init.xavier_normal_(self.post_conv_proj.weight.data)
        
        num_nonzero = int(max_length * max_length * config.sparse_dens)
        sparse_idxs = torch.randint(0, max_length, (num_nonzero, 2))
        sparse_vals = torch.randn(num_nonzero)
        self.sparse = nn.Parameter(torch.sparse_coo_tensor(sparse_idxs.t(), sparse_vals.abs(), size=(max_length, max_length)).to_dense())

        self._causal = config.causal_attn
        if config.causal_attn:
            # causal_psf_mask: ignore the tokens appearing ahead of the current token.
            self.register_buffer('causal_psf_mask', torch.tensor([1] + [1] * (max_length - 1) + [0] * (max_length - 1)).unsqueeze(0).unsqueeze(2))
            self.register_buffer('causal_sparse_mask', torch.tril(torch.ones(max_length, max_length)))
        
        self.layer_norm = nn.LayerNorm(normalized_shape=config.embedding_dim, eps=1e-12)
        self.dropout = nn.Dropout(p=config.attn_dropout_proba)

    def forward(self, embeddings, padding_mask=None, softmax_psf_weights=True):
        # embeddings: (batch_size, max_length, embedding_dim)
        # padding_mask: (batch_size, max_length)
        batch_size = embeddings.shape[0]
        max_length = embeddings.shape[1]
        embedding_dim = embeddings.shape[2]

        # expanded_padding_mask: (batch_size, max_length, 1)
        # 1: no pad, 0: pad
        expanded_padding_mask = None
        if padding_mask is not None:
            expanded_padding_mask = padding_mask.unsqueeze(2)

        # pre_proj_emb: (batch_size, max_length, num_heads * per_head_dim)
        pre_proj_emb = self.pre_proj(embeddings)
        if padding_mask is not None:
            pre_proj_emb.masked_fill_(expanded_padding_mask == 0, 0)
        # padded_embeddings: (batch_size, 2 * max_length - 1, embedding_dim)
        # F.pad: pad=(padding_left, padding_right, padding_top, padding_bottom)
        pre_proj_padded_embeddings = F.pad(pre_proj_emb, pad=(0, 0, 0, max_length - 1), mode='constant')
        # pre_proj_padded_embeddings: (batch_size, num_heads, 2 * max_length - 1, per_head_dim)
        pre_proj_padded_embeddings = pre_proj_padded_embeddings.view(batch_size, 2 * max_length - 1, self._num_heads, self._per_head_dim).permute(0, 2, 1, 3)

        psfs_weights = self.toeplitz_psfs.data
        if self._causal:
            if self._training_max_length == max_length:
                psfs_weights.masked_fill_(self.causal_psf_mask == 0, 0)
            else:
                # at inference time, the max_length changes per prompt
                causal_psf_mask = torch.tensor([1] + [1] * (max_length - 1) + [0] * (max_length - 1)).unsqueeze(0).unsqueeze(2)
                psfs_weights.masked_fill_(causal_psf_mask == 0, 0)
        if softmax_psf_weights:
            psfs_weights = F.softmax(psfs_weights, dim=1)
        psfs_fft = torch.fft.fftn(psfs_weights, dim=(1, 2))
        emb_fft = torch.fft.fftn(pre_proj_padded_embeddings, dim=(2, 3))
        # conv_output: (batch_size, num_heads, max_length, per_head_dim)
        conv_output = torch.real(torch.fft.ifftn(psfs_fft * emb_fft, dim=(2, 3))[:, :, :max_length, :])
        # conv_output: (batch_size, max_length, num_heads * per_head_dim)
        conv_output = self.attn_actn(conv_output).permute(0, 2, 1, 3).reshape(batch_size, max_length, -1)
        conv_emb = self.post_conv_proj(conv_output)
        
        
        sparse_data = self.sparse.data
        if self._causal:
            sparse_data.masked_fill_(self.causal_sparse_mask[:max_length, :max_length] == 0, 0)
        pre_sparse_emb = self.pre_sparse_proj(pre_proj_emb)
        if padding_mask is not None:
            pre_sparse_emb.masked_fill_(expanded_padding_mask == 0, 0)
        sparse_emb = torch.matmul(sparse_data, pre_sparse_emb)

        togepi_emb = self.dropout(conv_emb + sparse_emb)
        return self.layer_norm(togepi_emb + embeddings)
        
test_togepi_mha_obj = TogepiMultiHeadAttention(test_config)
test_padding_mask = torch.tensor([[1, 1, 1, 1, 0, 0], [1, 1, 1, 1, 1, 1]])
test_togepi_mha_emb = test_togepi_mha_obj(test_emb, padding_mask=test_padding_mask)
print_params(test_togepi_mha_obj)
test_togepi_mha_emb, test_togepi_mha_emb.shape

+------------------------+------------+---------------+
|         module         | num_params | requires_grad |
+------------------------+------------+---------------+
|     toeplitz_psfs      |     44     |      True     |
|         sparse         |     36     |      True     |
|    pre_proj.weight     |     16     |      True     |
|     pre_proj.bias      |     4      |      True     |
| pre_sparse_proj.weight |     16     |      True     |
|  pre_sparse_proj.bias  |     4      |      True     |
| post_conv_proj.weight  |     16     |      True     |
|  post_conv_proj.bias   |     4      |      True     |
|   layer_norm.weight    |     4      |      True     |
|    layer_norm.bias     |     4      |      True     |
+------------------------+------------+---------------+
total trainable params: 148


(tensor([[[-1.6801,  0.7827,  0.1640,  0.7334],
          [ 0.3760,  1.4663, -0.8449, -0.9974],
          [-1.5970,  1.1562,  0.1317,  0.3091],
          [ 0.9828,  0.9007, -1.4156, -0.4679],
          [ 0.7972,  0.1035, -1.6645,  0.7637],
          [-0.1368,  0.4456, -1.5210,  1.2122]],
 
         [[-1.4156,  1.1995, -0.4087,  0.6248],
          [ 0.2022,  1.5489, -0.7595, -0.9916],
          [-0.6960, -1.0655,  1.5246,  0.2369],
          [ 0.3313,  1.1090, -1.6210,  0.1807],
          [ 0.1517,  1.2346, -1.5568,  0.1705],
          [-1.0827,  0.1097, -0.5957,  1.5687]]],
        grad_fn=<NativeLayerNormBackward0>),
 torch.Size([2, 6, 4]))

---
#### speed tests

##### *sparse vs. dense*

In [37]:
def create_sparse_mat(sparse_dens=0.3, max_length=512):
    num_nonzero = int(max_length * max_length * sparse_dens)
    sparse_idxs = torch.randint(0, max_length, (num_nonzero, 2))
    sparse_vals = torch.randn(num_nonzero)
    return torch.sparse_coo_tensor(sparse_idxs.t(), sparse_vals.abs(), size=(max_length, max_length))

def create_emb(batch_size=32, max_length=512, embedding_dim=768):
    return torch.randn(batch_size, max_length, embedding_dim)

def sparse_matmul(sparse_mat, emb):
    # sparse_mat: (max_length, max_length) 
    batch_size, max_length, embedding_dim = emb.shape
    return torch.sparse.mm(sparse_mat, emb.permute(1, 0, 2).reshape(max_length, -1)).view(max_length, batch_size, -1).permute(1, 0, 2)

def sparse_to_dense_matmul(sparse_mat, emb):
    # sparse_mat: (max_length, max_length)
    return torch.matmul(sparse_mat.to_dense(), emb)

def dense_matmul(dense_mat, emb):
    return torch.matmul(dense_mat, emb)

In [10]:
sparse_dens = 0.3
max_length = 512
batch_size = 32 
embedding_dim = 768

sparse_mat = create_sparse_mat(sparse_dens=sparse_dens, max_length=max_length)
dense_mat = sparse_mat.to_dense()
emb = create_emb(batch_size=batch_size, max_length=max_length, embedding_dim=embedding_dim)

In [11]:
%%timeit
sparse_matmul(sparse_mat, emb)

317 ms ± 2.04 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [12]:
%%timeit
sparse_to_dense_matmul(sparse_mat, emb)

25 ms ± 199 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [13]:
%%timeit
dense_matmul(dense_mat, emb)

24.9 ms ± 215 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


##### *bert-attention vs. togepi-attention*

In [38]:
# https://aclanthology.org/2021.emnlp-main.831.pdf
@dataclass
class SpeedTestConfig:
    # embedding
    vocab_size = 30522
    padding_idx = 0
    max_position_embeddings = 1024 + 1 #L
    pad_position = 0
    num_token_types = 3
    pad_token_type = 0
    embedding_dim = 2048 #d
    embedding_dropout_proba = 0.1
    
    # attention
    causal_attn = True  # for generative pre-training
    num_attn_heads = 16
    attn_actn = 'gelu'
    sparse_dens = 0.3
    attn_dropout_proba = 0.1

    # training
    batch_size = 64
    
test_speed_config = SpeedTestConfig()
test_speed_config.vocab_size

30522

In [39]:
test_input_ids = torch.randint(low=0, high=test_speed_config.max_position_embeddings - 1, 
                               size=(test_speed_config.batch_size, test_speed_config.max_position_embeddings - 1))
test_input_ids.shape

torch.Size([64, 1024])

In [40]:
test_emb_obj = Embedding(test_speed_config)
test_emb = test_emb_obj(test_input_ids)
print_params(test_emb_obj)
test_emb.shape

+-------------------+------------+---------------+
|       module      | num_params | requires_grad |
+-------------------+------------+---------------+
|   tok_emb.weight  |  62509056  |      True     |
|   pos_emb.weight  |  2099200   |      True     |
|  type_emb.weight  |    6144    |      True     |
| layer_norm.weight |    2048    |      True     |
|  layer_norm.bias  |    2048    |      True     |
+-------------------+------------+---------------+
total trainable params: 64.62M


torch.Size([64, 1024, 2048])

In [41]:
test_mha_obj = MultiHeadAttention(test_speed_config)
print_params(test_mha_obj)

test_togepi_mha_obj = TogepiMultiHeadAttention(test_speed_config)
print_params(test_togepi_mha_obj)

+-------------------+------------+---------------+
|       module      | num_params | requires_grad |
+-------------------+------------+---------------+
|     wq.weight     |  4194304   |      True     |
|      wq.bias      |    2048    |      True     |
|     wk.weight     |  4194304   |      True     |
|      wk.bias      |    2048    |      True     |
|     wv.weight     |  4194304   |      True     |
|      wv.bias      |    2048    |      True     |
|     wo.weight     |  4194304   |      True     |
|      wo.bias      |    2048    |      True     |
| layer_norm.weight |    2048    |      True     |
|  layer_norm.bias  |    2048    |      True     |
+-------------------+------------+---------------+
total trainable params: 16.79M
+------------------------+------------+---------------+
|         module         | num_params | requires_grad |
+------------------------+------------+---------------+
|     toeplitz_psfs      |  4192256   |      True     |
|         sparse         |  104

In [42]:
%%timeit
test_mha_emb, test_mha_filters = test_mha_obj(test_emb)

1min 14s ± 6.16 s per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [43]:
%%timeit
test_togepi_mha_emb = test_togepi_mha_obj(test_emb)

16 s ± 419 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [None]:
# https://jakevdp.github.io/PythonDataScienceHandbook/01.07-timing-and-profiling.html
!pip install memory_profiler
%load_ext memory_profiler

In [70]:
%memit
test_mha_emb, test_mha_filters = test_mha_obj(test_emb)

peak memory: 3453.45 MiB, increment: 0.00 MiB


In [69]:
%%memit
test_togepi_mha_emb = test_togepi_mha_obj(test_emb)

peak memory: 4036.32 MiB, increment: 1178.54 MiB
