In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import sampler

import torchvision.datasets as dset
import torchvision.transforms as T
import torch.nn.functional as F

import numpy as np
import math

USE_GPU = True
dtype = torch.float32

if USE_GPU and torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

print_every = 100
print('using device:', device)

using device: cuda


In [17]:
def compute_qkv(X, W_q, W_k, W_v):
    """
    X: input sequence [batch_size, sequence_length, d_in]
    W_q: query weights [d_in, d_model]
    W_k: key weights [d_in, d_model]
    W_v: value weights [d_in, d_model]
    """
    Q = X @ W_q
    K = X @ W_k
    V = X @ W_v
    return Q, K, V

def self_attention(Q, K, V):
    d_model = Q.size(-1)
    E = (Q @ K.transpose(-1, -2)) / math.sqrt(d_model)
    E = torch.exp(E)
    A = E / torch.sum(E, dim=-1)
    out = A @ V
    return out

In [18]:
X = torch.tensor([[[1, 0], [0, 1]], [[1, 0], [0, 1]]], dtype=float)
W_q = torch.tensor([[1, 0], [0, 1]], dtype=float)
W_k = torch.tensor([[1, 0], [0, 1]], dtype=float)
W_v = torch.tensor([[1, 2], [3, 4]], dtype=float)

print(X.shape)

Q, K, V = compute_qkv(X, W_q, W_k, W_v)
output = self_attention(Q, K, V)

print(output)

torch.Size([2, 2, 2])
tensor([[[1.6605, 2.6605],
         [2.3395, 3.3395]],

        [[1.6605, 2.6605],
         [2.3395, 3.3395]]], dtype=torch.float64)


In [35]:
def multi_head_attention(Q, K, V, n_heads):
    B, N, d_model = Q.shape
    d_head = d_model // n_heads

    def split(m):
        m = m.view(B, N, n_heads, d_head)
        return m.permute(0, 2, 1, 3)

    Q = split(Q)
    K = split(K)
    V = split(V)

    E = (Q @ K.transpose(-1, -2)) / torch.sqrt(torch.tensor(d_head))
    A = torch.softmax(E, dim=-1)
    out = A @ V
    out = out.transpose(1, 2)
    out = out.reshape(B, N, d_model)
    return out

In [36]:
X = torch.tensor([[[1, 0], [0, 1]], [[1, 0], [0, 1]]], dtype=float)
W_q = torch.tensor([[1, 0], [0, 1]], dtype=float)
W_k = torch.tensor([[1, 0], [0, 1]], dtype=float)
W_v = torch.tensor([[1, 2], [3, 4]], dtype=float)

Q, K, V = compute_qkv(X, W_q, W_k, W_v)
n_heads = 2

output = multi_head_attention(Q, K, V, n_heads)
print(output)

tensor([[[1.5379, 3.0000],
         [2.0000, 3.4621]],

        [[1.5379, 3.0000],
         [2.0000, 3.4621]]], dtype=torch.float64)


In [45]:
def pos_encoding(data):
    _, N, D = data.shape
    pe = torch.zeros((1, N, D))

    i = torch.arange(N)[:, None]
    pows = torch.pow(10000, -torch.arange(0, D, 2) / D)

    pe[0, :, 0::2] = torch.sin(i * pows)
    pe[0, :, 1::2] = torch.cos(i * pows)
    return data + pe

In [47]:
torch.manual_seed(231)

batch_size = 1
sequence_length = 2
embed_dim = 6
data = torch.randn(batch_size, sequence_length, embed_dim)

output = pos_encoding(data)
print(output)

tensor([[[-1.1106,  1.0014,  1.5280, -0.0778, -0.6964,  1.1455],
         [ 0.8125, -0.4303,  0.4981,  0.7320,  1.1380,  1.5331]]])
