In [1]:
!pip install torchinfo

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting torchinfo
  Downloading torchinfo-1.7.1-py3-none-any.whl (22 kB)
Installing collected packages: torchinfo
Successfully installed torchinfo-1.7.1


In [2]:
!git clone https://github.com/Taeksu-Kim/LUNA_Linear_Unified_Nested_Attention.git

Cloning into 'LUNA_Linear_Unified_Nested_Attention'...
remote: Enumerating objects: 94, done.[K
remote: Counting objects: 100% (94/94), done.[K
remote: Compressing objects: 100% (91/91), done.[K
remote: Total 94 (delta 52), reused 7 (delta 2), pack-reused 0[K
Unpacking objects: 100% (94/94), done.


In [3]:
cd LUNA_Linear_Unified_Nested_Attention

/content/LUNA_Linear_Unified_Nested_Attention


In [4]:
from LUNA import *
from utils import Config

In [5]:
cd ..

/content


In [6]:
!git clone https://github.com/Taeksu-Kim/Transformer.git

Cloning into 'Transformer'...
remote: Enumerating objects: 88, done.[K
remote: Counting objects: 100% (88/88), done.[K
remote: Compressing objects: 100% (59/59), done.[K
remote: Total 88 (delta 36), reused 41 (delta 15), pack-reused 0[K
Unpacking objects: 100% (88/88), done.


In [7]:
cd  ./Transformer/PyTorch

/content/Transformer/PyTorch


In [8]:
from transformer import TransformerEncoder, Transformer

In [9]:
import random
from tqdm import tqdm

import torch
import torch.nn as nn

from torchinfo import summary

In [10]:
batch_size = 8
embedding_dim =  768
max_input_len = 512
max_dec_len = 32
num_layer = 12
num_att_head = 12
feed_forward_dim = 1024
p_length = 128
vocab_size = 32500
dynamic_projection = False
tie_key_value = False

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [11]:
seq_len = max_input_len

inputs = []

for i in range(batch_size):
  input_len = random.randint(seq_len-20, seq_len-10)
  pad_len = seq_len - input_len
  line = []
  for j in range(input_len):
    line.append(random.randint(1, vocab_size-1))
  line += [0] * pad_len
  inputs.append(line)

inputs = torch.tensor(inputs)

In [12]:
encoder_labels = [0]*int(batch_size/2) + [1]*int(batch_size/2)
encoder_labels = torch.tensor(encoder_labels)

In [13]:
seq_len = max_dec_len

decoder_labels = []

for i in range(batch_size):
  input_len = random.randint(seq_len-20, seq_len-10)
  pad_len = seq_len - input_len
  line = []
  for j in range(input_len):
    line.append(random.randint(1, vocab_size-1))
  line += [0] * pad_len
  decoder_labels.append(line)

decoder_labels = torch.tensor(decoder_labels)

In [14]:
inputs.shape, encoder_labels.shape

(torch.Size([8, 512]), torch.Size([8]))

In [15]:
config_dict = {
'vocab_size': vocab_size,
'd_model': embedding_dim,
'max_enc_len': max_input_len,
'max_dec_len': max_dec_len,
'pad_id': 0,
'use_decoder': True,
'init_std': 0.02,
'norm_eps': 1e-12,
'drop_out_raito': 0.1,
'num_enc_layers': num_layer,
'num_dec_layers': num_layer,
'num_att_heads': num_att_head,
'feed_forward_dim': feed_forward_dim,
'p_length': p_length,
'p_drop_out_raito': 0.3,
'dynamic_projection':dynamic_projection,
'tie_key_value':tie_key_value,
'decoder_only':False
}

config = Config(config_dict)

In [16]:
luna_encoder = Luna_TransformerEncoder(config)

In [17]:
vanillar_encoder = TransformerEncoder(config)

In [18]:
summary(vanillar_encoder, input_data=[inputs])

Layer (type:depth-idx)                                       Output Shape              Param #
TransformerEncoder                                           [8, 512, 768]             --
├─Embedding: 1-1                                             [8, 512, 768]             24,960,000
├─ModuleList: 1-2                                            --                        --
│    └─TransformerEncoderLayer: 2-1                          [8, 512, 768]             --
│    │    └─AddNorm: 3-1                                     [8, 512, 768]             2,363,904
│    │    └─AddNorm: 3-2                                     [8, 512, 768]             1,576,192
│    └─TransformerEncoderLayer: 2-2                          [8, 512, 768]             --
│    │    └─AddNorm: 3-3                                     [8, 512, 768]             2,363,904
│    │    └─AddNorm: 3-4                                     [8, 512, 768]             1,576,192
│    └─TransformerEncoderLayer: 2-3                        

In [19]:
summary(luna_encoder, input_data=[inputs])

Layer (type:depth-idx)                                       Output Shape              Param #
Luna_TransformerEncoder                                      [8, 512, 768]             98,304
├─Embedding: 1-1                                             [8, 512, 768]             24,960,000
├─Dropout: 1-2                                               [8, 512, 768]             --
├─Dropout: 1-3                                               [8, 128, 768]             --
├─ModuleList: 1-4                                            --                        --
│    └─LunaTransformerEncoderLayer: 2-1                      [8, 512, 768]             --
│    │    └─LinearUnifiedNestedAttention: 3-1                [8, 128, 768]             4,134,144
│    │    └─LayerNorm: 3-2                                   [8, 128, 768]             1,536
│    │    └─LayerNorm: 3-3                                   [8, 512, 768]             1,536
│    │    └─PoswiseFeedForward: 3-4                          [8, 512, 

In [20]:
class simple_classfier(nn.Module):
    def __init__(self, encoder, config):
      super(simple_classfier, self).__init__()
      self.encoder = encoder
      self.fc = nn.Linear(config.d_model,2)

    def forward(self, inputs, labels=None):
      logits = self.encoder(inputs)[0]
      logits = self.fc(logits[:,0])

      return (logits,)

In [21]:
vanillar_encoder_classifier = simple_classfier(vanillar_encoder, config)

In [22]:
luna_encoder_classifier = simple_classfier(luna_encoder, config)

In [23]:
def train(model, inputs, labels):
  model.to(device)
  inputs = inputs.to(device)
  labels = labels.to(device)

  optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)

  for epoch in tqdm(range(iter_steps)):
    model.train()
    optimizer.zero_grad()

    with torch.cuda.amp.autocast():
        logits = model(inputs,
                       labels)[0]

        loss_fnc = nn.CrossEntropyLoss()
        if len(logits.size()) == 2:
          loss = loss_fnc(logits, labels)
        else:
          num_class = logits.size()[-1]
          loss = loss_fnc(logits.view(-1, num_class), labels.view(-1))

    loss.backward()
    optimizer.step()

In [24]:
iter_steps = 300

In [25]:
train(vanillar_encoder_classifier, inputs, encoder_labels)

100%|██████████| 300/300 [01:06<00:00,  4.49it/s]


In [26]:
train(luna_encoder_classifier, inputs, encoder_labels)

100%|██████████| 300/300 [00:59<00:00,  5.04it/s]


In [28]:
batch_size = 8
embedding_dim = 768
max_input_len = 512
max_dec_len = 128
num_layer = 12
num_att_head = 12
feed_forward_dim = 1024
p_length = 32
vocab_size = 32500
dynamic_projection = False
tie_key_value = False

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [29]:
config_dict = {
'vocab_size': vocab_size,
'd_model': embedding_dim,
'max_enc_len': max_input_len,
'max_dec_len': max_dec_len,
'pad_id': 0,
'use_decoder': True,
'init_std': 0.02,
'norm_eps': 1e-12,
'drop_out_raito': 0.1,
'num_enc_layers': num_layer,
'num_dec_layers': num_layer,
'num_att_heads': num_att_head,
'feed_forward_dim': feed_forward_dim,
'p_length': p_length,
'p_drop_out_raito': 0.3,
'dynamic_projection':dynamic_projection,
'tie_key_value':tie_key_value,
'decoder_only':False
}

config = Config(config_dict)

In [30]:
seq_len = max_input_len

inputs = []

for i in range(batch_size):
  input_len = random.randint(seq_len-20, seq_len-10)
  pad_len = seq_len - input_len
  line = []
  for j in range(input_len):
    line.append(random.randint(1, vocab_size-1))
  line += [0] * pad_len
  inputs.append(line)

inputs = torch.tensor(inputs)

In [31]:
seq_len = max_dec_len

decoder_labels = []

for i in range(batch_size):
  input_len = random.randint(seq_len-20, seq_len-10)
  pad_len = seq_len - input_len
  line = []
  for j in range(input_len):
    line.append(random.randint(1, vocab_size-1))
  line += [0] * pad_len
  decoder_labels.append(line)

decoder_labels = torch.tensor(decoder_labels)

In [32]:
inputs.shape, decoder_labels.shape

(torch.Size([8, 512]), torch.Size([8, 128]))

In [33]:
luna_transformer = Luna_Transformer(config)

In [34]:
vanillar_transformer = Transformer(config)

In [35]:
summary(vanillar_transformer, input_data=[inputs,decoder_labels])

Layer (type:depth-idx)                                            Output Shape              Param #
Transformer                                                       [8, 128, 32500]           --
├─TransformerEncoder: 1-1                                         [8, 512, 768]             --
│    └─Embedding: 2-1                                             [8, 512, 768]             24,960,000
│    └─ModuleList: 2-2                                            --                        --
│    │    └─TransformerEncoderLayer: 3-1                          [8, 512, 768]             3,940,096
│    │    └─TransformerEncoderLayer: 3-2                          [8, 512, 768]             3,940,096
│    │    └─TransformerEncoderLayer: 3-3                          [8, 512, 768]             3,940,096
│    │    └─TransformerEncoderLayer: 3-4                          [8, 512, 768]             3,940,096
│    │    └─TransformerEncoderLayer: 3-5                          [8, 512, 768]             3,940,096
│ 

In [36]:
summary(luna_transformer, input_data=[inputs,decoder_labels])

Layer (type:depth-idx)                                            Output Shape              Param #
Luna_Transformer                                                  [8, 128, 32500]           --
├─Luna_TransformerEncoder: 1-1                                    [8, 512, 768]             24,576
│    └─Embedding: 2-1                                             [8, 512, 768]             24,960,000
│    └─Dropout: 2-2                                               [8, 512, 768]             --
│    └─Dropout: 2-3                                               [8, 32, 768]              --
│    └─ModuleList: 2-4                                            --                        --
│    │    └─LunaTransformerEncoderLayer: 3-1                      [8, 512, 768]             5,713,408
│    │    └─LunaTransformerEncoderLayer: 3-2                      [8, 512, 768]             5,713,408
│    │    └─LunaTransformerEncoderLayer: 3-3                      [8, 512, 768]             5,713,408
│    │    └─

In [37]:
iter_steps = 300

In [38]:
train(vanillar_transformer, inputs, decoder_labels)

100%|██████████| 300/300 [02:02<00:00,  2.46it/s]


In [39]:
train(luna_transformer, inputs, decoder_labels)

100%|██████████| 300/300 [02:32<00:00,  1.97it/s]
