In [2]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import joblib
import matplotlib.pyplot as plt
from flair.data import Sentence
from flair.embeddings import WordEmbeddings
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
%matplotlib inline

device = torch.device("cuda:0")
dtype = torch.float

d_model = 300 # embedding size of fasttext models
output_lang = 'en' # select language

# Test for Multi-Head Attention

In [3]:
batch_size = 16
seq_len = 8
num_heads = 4
e_dim = 6
model_size = 12

print('input')
a = torch.randn(batch_size, seq_len, model_size).to(dtype)
print(a.shape)

print('\nreshape for linear qvk layer')
b = torch.reshape(a, (batch_size*seq_len,model_size))
print(b.shape)
print('\napply linear')
b = torch.randn(batch_size*seq_len, num_heads*3*e_dim).to(dtype)
print(b.shape)

print('\nreshape to batch_size,seqlen,..')
c = torch.reshape(b,(batch_size,seq_len,num_heads*3*e_dim))
print(c.shape)

print('\nsplit into heads and seperate qvk')
d = torch.reshape(c, (batch_size,seq_len,num_heads,3,e_dim))
print(d.shape)

print('\npermute head to front for parallel processing')
d = d.permute(0,2,1,3,4)
print(d.shape)

print('\nextract q, k, v')
q = d[:,:,:,0,:]
k = d[:,:,:,1,:]
v = d[:,:,:,2,:]
print('q',q.shape)
print('k',k.shape)
print('v',v.shape)

print('\nfuse batch and head dim for parallel processing')
q = q.reshape(batch_size*num_heads,seq_len,e_dim)
k = k.reshape(batch_size*num_heads,seq_len,e_dim)
v = v.reshape(batch_size*num_heads,seq_len,e_dim)
print('q',q.shape)
print('k',k.shape)
print('v',v.shape)

print('\ntranspose k')
k = torch.transpose(k, 1, 2)
print('k',k.shape)

print('\nmultiply q and k + softmax')
qk = torch.bmm(q,k)
qk = F.softmax(qk, dim=2)
print('qk',qk.shape)

print('\nmultiply with v')
qkv = torch.bmm(qk,v)
print(qkv.shape)

print('\nreshape to cat heads')
qkv = torch.reshape(qkv, (batch_size, num_heads, seq_len, e_dim))
print(qkv.shape)

print('\ncat all heads')
qkv = qkv.permute(0,2,1,3)
qkv  = torch.reshape(qkv, (batch_size, seq_len, num_heads*e_dim))
print(qkv.shape)

print('\nreshape to multiply with WO')
qkv  = torch.reshape(qkv, (batch_size*seq_len, num_heads*e_dim))
print(qkv.shape)

print('\nmultiply...')
qkv = torch.randn(batch_size*seq_len, model_size).to(dtype)
print(qkv.shape)

print('\nreshape for next layer')
qkv = torch.reshape(qkv, ((batch_size, seq_len, model_size)))
print(qkv.shape)

input
torch.Size([16, 8, 12])

reshape for linear qvk layer
torch.Size([128, 12])

apply linear
torch.Size([128, 72])

reshape to batch_size,seqlen,..
torch.Size([16, 8, 72])

split into heads and seperate qvk
torch.Size([16, 8, 4, 3, 6])

permute head to front for parallel processing
torch.Size([16, 4, 8, 3, 6])

extract q, k, v
q torch.Size([16, 4, 8, 6])
k torch.Size([16, 4, 8, 6])
v torch.Size([16, 4, 8, 6])

fuse batch and head dim for parallel processing
q torch.Size([64, 8, 6])
k torch.Size([64, 8, 6])
v torch.Size([64, 8, 6])

transpose k
k torch.Size([64, 6, 8])

multiply q and k + softmax
qk torch.Size([64, 8, 8])

multiply with v
torch.Size([64, 8, 6])

reshape to cat heads
torch.Size([16, 4, 8, 6])

cat all heads
torch.Size([16, 8, 24])

reshape to multiply with WO
torch.Size([128, 24])

multiply...
torch.Size([128, 12])

reshape for next layer
torch.Size([16, 8, 12])
