In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import numpy as np
import color
from color import magenta, green, red

In [3]:
class SelfAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(SelfAttention, self).__init__()
        
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads 
        
        assert (self.head_dim * heads == embed_size), "Embed size needs to be divisible by heads"
        
        # Explanation: Ensures that 'embed_size' is divisible by 'heads'. This is because the embedding is cut up into chunks and fed into identical but seperate attention heads. 
        # Each head sees a reduced dimension of the embedding which is concatonated at the end to form the final full form. This was better than just one single headed attention
        # according to the "Attention is all you need" paper.
        
        self.query_weights = nn.Linear(self.head_dim, self.head_dim, bias=False) # The query needs to be head_dim x head_dim because it is multiplied by the key which is head_dim x head_dim
        self.key_weights = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.value_weights = nn.Linear(self.head_dim, self.head_dim, bias=False)
        
        self.fc_out = nn.Linear(embed_size, embed_size)
    
    def forward(self, queries, keys, values, mask, testing_mode=False): # Actual Queries, Keys and Values are passed in here, not the same as weight matrices
        
        # queries, keys, values have shape: (num_examples, seq_length, embed_size)

        num_examples = queries.shape[0]
        value_len, key_len, query_len = values.shape[1], keys.shape[1], queries.shape[1]
        
        # Split embedding into self.heads pieces
        # queries, keys, values have a new shape: (num examples, seq length, num heads, head dimension)
        
        queries = queries.reshape(num_examples, query_len, self.heads, self.head_dim)
        keys = keys.reshape(num_examples, key_len, self.heads, self.head_dim)
        values = values.reshape(num_examples, value_len, self.heads, self.head_dim)
        
        queries = self.query_weights(queries)
        keys = self.key_weights(keys)
        values = self.value_weights(values)
        
        # Size should be: [batch size, seq length, num heads, head dimension]
        if testing_mode:
            print(magenta('Testing Self Attention:')) 
            if queries.shape[0] == num_examples and queries.shape[1] == query_len and queries.shape[2] == self.heads and queries.shape[3] == self.head_dim: 
                print('Size of query is', green('correct'))
            else:
                print('Size of query is', red('incorrect'))
                print(queries.shape, red('does not match'), [num_examples, query_len, self.heads, self.head_dim])
        
        # Matmul Q and K
        # queries_dot_values shape: (num examples, num heads, query_len, key_len)
        # nqhd, nkhd -> nhqk (n: number of examples, q: query length, k: key length, h: number of heads, d: head dimension)
        
        queries_dot_values = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])                                           # einsum is a more efficient way of doing matrix multiplication
        queries_dot_values = queries_dot_values / np.sqrt(self.head_dim)
        
        if mask is not None:
            queries_dot_values = queries_dot_values.masked_fill(mask == 0, float("-1e20"))                              # Masking out the padded values, use -1e20 because softmax will make it close to 0
        
        attention = torch.softmax(queries_dot_values, dim=-1)                                                           # dim=-1 means the last dimension
        out = torch.einsum("nhqk, nlhd -> nqhd", [attention, values]).reshape(num_examples, query_len, self.embed_size) # multiply attention by values and reshape to original shape with embed length
        
        # Size should be:[num examples, seq length, embed size]
        if testing_mode: 
            if out.shape[0] == num_examples and out.shape[1] == value_len and out.shape[2] == self.embed_size: 
                print('Size of output is', green('correct'))
            else:
                print('Size of output is', red('incorrect'))
                print(out.shape, red('does not match'), [num_examples, value_len, self.embed_size])
        

In [4]:
# TESTING SELF ATTENTION LAYERS

embed_size = 512
heads = 8
seq_length = 10
batch_size = 4

queries = torch.rand(batch_size, seq_length, embed_size)
keys = torch.rand(batch_size, seq_length, embed_size)
values = torch.rand(batch_size, seq_length, embed_size)

self_attention = SelfAttention(embed_size, heads)
self_attention.forward(queries, keys, values, None, testing_mode=True)

# Output should be of shape [batch size, seq length, num heads, head dimension]

[10;10;35mTesting Self Attention:[0m
Size of query is [10;10;32mcorrect[0m
Size of output is [10;10;32mcorrect[0m


In [5]:
# Experimenting with linear layers

linear_layer = nn.Linear(5, 4, bias=False, dtype=float)
weights = torch.arange(20, dtype=float).reshape(4, 5)

new_layer = linear_layer(weights)

print('linear layer\n', linear_layer.weight)
print('linear layer bias\n',linear_layer.bias) # should be none

print('\nweight matrix\n',weights)
print('new_layer layer\n' ,new_layer)

linear layer
 Parameter containing:
tensor([[ 0.3166,  0.4056,  0.3759, -0.2346,  0.0049],
        [ 0.0269,  0.0996, -0.3733,  0.0618,  0.2073],
        [ 0.2639, -0.0843, -0.3288,  0.2310, -0.3859],
        [ 0.2381, -0.2076, -0.1744, -0.3358, -0.2447]], dtype=torch.float64,
       requires_grad=True)
linear layer bias
 None

weight matrix
 tensor([[ 0.,  1.,  2.,  3.,  4.],
        [ 5.,  6.,  7.,  8.,  9.],
        [10., 11., 12., 13., 14.],
        [15., 16., 17., 18., 19.]], dtype=torch.float64)
new_layer layer
 tensor([[  0.4730,   0.3675,  -1.5926,  -2.5428],
        [  4.8147,   0.4786,  -3.1133,  -6.1650],
        [  9.1564,   0.5898,  -4.6340,  -9.7872],
        [ 13.4982,   0.7009,  -6.1547, -13.4095]], dtype=torch.float64,
       grad_fn=<MmBackward0>)


In [6]:
# Figuring out .split()
# .split takes arguments: split_size_or_sections, in this case we want to split the embedding into 3 parts
arrange_1to12 = torch.arange(12)
arrange_0to3, arrange_4to7, arrange_8to11 = arrange_1to12.split(4)[0], arrange_1to12.split(4)[1], arrange_1to12.split(4)[2]

print(magenta('original tensor\n'), arrange_1to12)
print(magenta('\n3 partitions of tensor using.split:\n'),arrange_0to3,'\n' ,arrange_4to7,'\n' ,arrange_8to11)

[10;10;35moriginal tensor
[0m tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11])
[10;10;35m
3 partitions of tensor using.split:
[0m tensor([0, 1, 2, 3]) 
 tensor([4, 5, 6, 7]) 
 tensor([ 8,  9, 10, 11])


In [7]:
# Experimenting with torch reshape

num_examples = 10
seq_length = 5
embed_size = 10
heads = 5

x = torch.arange(num_examples*seq_length*embed_size).reshape(num_examples, seq_length, embed_size)
print(magenta('10 sentences, of 5 words, each with a 10 integer embedding\n'))
print(x.shape)

print(magenta('\nReshaping into 5 heads with a dimensionality of 2 instead of 10\n'))

x = x.reshape(num_examples, seq_length, heads, embed_size // heads)
print(x.shape) # Numbers are arranged in a weird way, but does that really matter on intialization? probably not

[10;10;35m10 sentences, of 5 words, each with a 10 integer embedding
[0m
torch.Size([10, 5, 10])
[10;10;35m
Reshaping into 5 heads with a dimensionality of 2 instead of 10
[0m
torch.Size([10, 5, 5, 2])
