In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import models.transformer as m
import matplotlib.pyplot as plt
import numpy as np

In [2]:
device = torch.device('cuda:0')

In [3]:
ff = m.FeedForward(512, 2048)
ff.to(device)
print(ff(torch.rand(50, 4, 512).to(device)).size())
print(ff.state_dict().keys())

torch.Size([50, 4, 512])
odict_keys(['linear1.weight', 'linear1.bias', 'linear2.weight', 'linear2.bias'])


In [4]:
pe = m.PositionalEncoder(5000, 512)
pe.to(device)
print(pe(torch.rand((50, 4, 512)).to(device)).size())

torch.Size([50, 4, 512])


In [5]:
an = m.AddNorm(512, 0.1)
an.to(device)
print(an(torch.rand((50, 4, 512)).to(device), torch.rand((50, 4, 512)).to(device)).size())
print(an.state_dict().keys())

torch.Size([50, 4, 512])
odict_keys(['layernorm.weight', 'layernorm.bias'])


In [6]:
el = m.TransformerEncoderLayer(512, 8, 2048)
el.to(device)
print(el(torch.rand(50, 4, 512).to(device)).size())
print(el.state_dict().keys())

torch.Size([50, 4, 512])
odict_keys(['mha.in_proj_weight', 'mha.in_proj_bias', 'mha.out_proj.weight', 'mha.out_proj.bias', 'ff.linear1.weight', 'ff.linear1.bias', 'ff.linear2.weight', 'ff.linear2.bias', 'an1.layernorm.weight', 'an1.layernorm.bias', 'an2.layernorm.weight', 'an2.layernorm.bias'])


In [7]:
encoder = m.TransformerEncoder(512, 8, 4, 2048)
encoder.to(device)
print(encoder(torch.rand(50, 4, 512).to(device)).size())
print(encoder.state_dict().keys())

torch.Size([50, 4, 512])
odict_keys(['encoder_layers.encoder_layer1.mha.in_proj_weight', 'encoder_layers.encoder_layer1.mha.in_proj_bias', 'encoder_layers.encoder_layer1.mha.out_proj.weight', 'encoder_layers.encoder_layer1.mha.out_proj.bias', 'encoder_layers.encoder_layer1.ff.linear1.weight', 'encoder_layers.encoder_layer1.ff.linear1.bias', 'encoder_layers.encoder_layer1.ff.linear2.weight', 'encoder_layers.encoder_layer1.ff.linear2.bias', 'encoder_layers.encoder_layer1.an1.layernorm.weight', 'encoder_layers.encoder_layer1.an1.layernorm.bias', 'encoder_layers.encoder_layer1.an2.layernorm.weight', 'encoder_layers.encoder_layer1.an2.layernorm.bias', 'encoder_layers.encoder_layer2.mha.in_proj_weight', 'encoder_layers.encoder_layer2.mha.in_proj_bias', 'encoder_layers.encoder_layer2.mha.out_proj.weight', 'encoder_layers.encoder_layer2.mha.out_proj.bias', 'encoder_layers.encoder_layer2.ff.linear1.weight', 'encoder_layers.encoder_layer2.ff.linear1.bias', 'encoder_layers.encoder_layer2.ff.linea

In [8]:
decoder = m.TransformerDecoder(512, 8, 4, 2048)
decoder.to(device)
print(decoder(torch.rand(50, 4, 512).to(device), torch.rand(50, 4, 512).to(device)).size())
print(decoder.state_dict().keys())

torch.Size([50, 4, 512])
odict_keys(['decoder_layers.decoder_layer1.mhas1.in_proj_weight', 'decoder_layers.decoder_layer1.mhas1.in_proj_bias', 'decoder_layers.decoder_layer1.mhas1.out_proj.weight', 'decoder_layers.decoder_layer1.mhas1.out_proj.bias', 'decoder_layers.decoder_layer1.mhas2.in_proj_weight', 'decoder_layers.decoder_layer1.mhas2.in_proj_bias', 'decoder_layers.decoder_layer1.mhas2.out_proj.weight', 'decoder_layers.decoder_layer1.mhas2.out_proj.bias', 'decoder_layers.decoder_layer1.an1.layernorm.weight', 'decoder_layers.decoder_layer1.an1.layernorm.bias', 'decoder_layers.decoder_layer1.an2.layernorm.weight', 'decoder_layers.decoder_layer1.an2.layernorm.bias', 'decoder_layers.decoder_layer1.an3.layernorm.weight', 'decoder_layers.decoder_layer1.an3.layernorm.bias', 'decoder_layers.decoder_layer1.ff.linear1.weight', 'decoder_layers.decoder_layer1.ff.linear1.bias', 'decoder_layers.decoder_layer1.ff.linear2.weight', 'decoder_layers.decoder_layer1.ff.linear2.bias', 'decoder_layers.d

In [13]:
transformer = m.Transformer(5000, 512, 8, 6, 2048)
transformer.to(device)
inp = torch.rand(20, 4, 512).to(device)
tar = torch.randint(0, 1, (4, 20)).to(device)
inp_padd = torch.randint(0, 2, (4, 20)).type(torch.uint8).to(device)
tar_padd = m.create_padding_mask_from_data(tar).to(device)
tar_look = m.create_look_ahead_mask(tar.size(1)).to(device)
print(transformer(inp, tar, inp_key_padding_mask=inp_padd, tar_key_padding_mask=tar_padd, mem_key_padding_mask=inp_padd, tar_attn_mask=tar_look).size())
print(transformer.state_dict().keys())

torch.Size([4, 20, 5000])
odict_keys(['encoder.encoder_layers.encoder_layer1.mha.in_proj_weight', 'encoder.encoder_layers.encoder_layer1.mha.in_proj_bias', 'encoder.encoder_layers.encoder_layer1.mha.out_proj.weight', 'encoder.encoder_layers.encoder_layer1.mha.out_proj.bias', 'encoder.encoder_layers.encoder_layer1.ff.linear1.weight', 'encoder.encoder_layers.encoder_layer1.ff.linear1.bias', 'encoder.encoder_layers.encoder_layer1.ff.linear2.weight', 'encoder.encoder_layers.encoder_layer1.ff.linear2.bias', 'encoder.encoder_layers.encoder_layer1.an1.layernorm.weight', 'encoder.encoder_layers.encoder_layer1.an1.layernorm.bias', 'encoder.encoder_layers.encoder_layer1.an2.layernorm.weight', 'encoder.encoder_layers.encoder_layer1.an2.layernorm.bias', 'encoder.encoder_layers.encoder_layer2.mha.in_proj_weight', 'encoder.encoder_layers.encoder_layer2.mha.in_proj_bias', 'encoder.encoder_layers.encoder_layer2.mha.out_proj.weight', 'encoder.encoder_layers.encoder_layer2.mha.out_proj.bias', 'encoder.e

In [68]:
def accuracy_metrics(predict, target):
    matched_matrix = (torch.max(predict, dim=-1)[1] == target)
    mask = (target != 0)
    matched_matrix = matched_matrix * mask
    acc = matched_matrix.type(torch.float).sum() / mask.type(torch.float).sum()
    return acc

In [26]:
tar = torch.randint(0, 2, (2, 4))

In [48]:
tar

tensor([[1, 0, 0, 0],
        [1, 1, 0, 0]])

In [55]:
inp = torch.rand((2, 4, 8))

In [67]:
inp

tensor([[[0.3651, 1.0000, 0.7920, 0.4278, 0.1233, 0.6303, 0.1410, 0.1200],
         [0.0690, 0.8083, 0.9296, 0.5539, 0.1202, 0.9849, 0.7614, 0.8524],
         [0.0000, 0.2732, 0.2572, 0.8573, 0.4406, 0.6854, 0.4603, 0.3153],
         [0.3523, 0.3768, 0.4104, 0.3379, 0.7744, 0.9522, 0.8908, 0.5174]],

        [[0.3890, 1.0000, 0.0547, 0.2634, 0.5774, 0.8788, 0.6619, 0.2587],
         [0.2807, 1.0000, 0.3326, 0.0175, 0.4737, 0.7933, 0.9083, 0.4498],
         [0.7653, 0.2699, 0.0890, 0.0855, 0.9379, 0.1390, 0.3373, 0.9730],
         [0.5138, 0.5583, 0.4108, 0.0065, 0.8243, 0.6241, 0.3070, 0.0610]]])

In [69]:
a = accuracy_metrics(inp, tar)

In [70]:
a

tensor(1.)