In [19]:
import torch
from torch.nn.functional import softmax
from torch import nn

In [3]:
Q = torch.ones((2,100,64))
K = torch.ones((2,90,64))
V = torch.ones((2,90,64))


def attention(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor):
    '''
    Should return the results of self-attention (see the "Self-Attention in Detail" section of the Illustrated Transformer).

    With this function, you can ignore masking.

    Q: shape (batch, target sequence length, embedding dim)
    K: shape (batch, source sequence length, embedding dim)
    V: shape (batch, source sequence length, embedding dim)
    softmax(Q KT/sqrt(d_k))V

    Return: shape (same as Q if embedding dim same. batch, target sequence length, output embedding dim)
    '''
    sqrt_d_k = torch.sqrt(torch.tensor(K.shape[-1]))
    query_key = torch.bmm(Q,torch.transpose(K,1,2))
    # print(f"{query_key.shape=} {sqrt_d_k=}")
    result =torch.bmm(softmax(query_key/sqrt_d_k,dim=2), V)
    return result

attention(Q, K, V).shape


torch.Size([2, 100, 64])

In [15]:




Q = torch.ones((2,20,64))
K = torch.ones((2,10,64))
V = torch.ones((2,10,64))

def masked_attention(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor):
    '''
    Should return the results of self-attention.

    You should implement masking for this function. See "The Decoder Side" for an explanation of masking.

    Q: shape (batch, target sequence length, embedding dim)
    K: shape (batch, source sequence length, embedding dim)
    V: shape (batch, source sequence length, embedding dim)
    I = Q K.T
    I.shape = target_len x source_len
    softmax((I+mask)/sqrt(d_k))V

    Return: shape (same as Q if embedding dim same. batch, target sequence length, output embedding dim)
    '''
    sqrt_d_k = torch.sqrt(torch.tensor(K.shape[-1]))
    target_seq_len = torch.tensor(Q.shape[1])
    source_seq_len = torch.tensor(K.shape[1])
    triangular = torch.triu(torch.ones((target_seq_len, source_seq_len), dtype=torch.bool), diagonal=1)
    # print(triangular)

    query_key = torch.bmm(Q, torch.transpose(K,1,2))
    masked_query_key = torch.where(triangular, -torch.inf, query_key)
    # print(masked_query_key.shape, query_key.shape, triangular.shape)
    result =torch.bmm(softmax((masked_query_key)/sqrt_d_k,dim=2), V)
    return result



result = masked_attention(Q, K, V)
print(result.shape)

torch.Size([2, 20, 10]) torch.Size([2, 20, 10]) torch.Size([20, 10])
torch.Size([2, 20, 64])


In [5]:
# from matplotlib import pyplot as plt

In [6]:
# plt.imshow(triangular.detach().numpy())

In [21]:
Q = torch.ones((2,20,4*64))
K = torch.ones((2,10,4*64))
V = torch.ones((2,10,4*64))
num_heads = 4

def multihead_masked_attention(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, num_heads: int):
    '''
    Implements multihead masked attention on the matrices Q, K and V.

    Q: shape (batch, seq, nheads*headsize)
    K: shape (batch, seq, nheads*headsize)
    V: shape (batch, seq, nheads*headsize)
    '''
    # do the reshape
    
    batch, target_seq_len = Q.shape[0:2]
    source_seq_len = K.shape[1] 
    head_size = int(Q.shape[-1]/num_heads)
    sqrt_d_k = torch.sqrt(torch.tensor(head_size))
    # new_shape = (batch, target_seq_len, num_heads, head_size)
    Q = torch.reshape(Q, (batch, target_seq_len, num_heads, head_size))
    K = torch.reshape(K, (batch, source_seq_len, num_heads, head_size))
    V = torch.reshape(V, (batch, source_seq_len, num_heads, head_size))
    # generate mask
    triangular = torch.triu(torch.ones((target_seq_len, source_seq_len), dtype=torch.bool), diagonal=1)
    
    query_key = torch.einsum("abcd,aecd->acbe", Q, K)
    masked_query_key = torch.where(triangular, -torch.inf, query_key)
    masked_query_key = softmax((masked_query_key)/sqrt_d_k,dim=1)
    result = torch.einsum("abcd, adbe-> acbe", query_key, V)
    result = torch.reshape(result, (batch, target_seq_len, num_heads * head_size))
    return result



result = multihead_masked_attention(Q, K, V, num_heads=num_heads)
print(result.shape)


torch.Size([2, 20, 256])


In [41]:
class MultiheadMaskedAttention(nn.Module):
    W_QKV: nn.Linear
    W_O: nn.Linear

    def __init__(self, hidden_size: int, num_heads: int):
        super().__init__()
        self.num_heads = num_heads
        self.hidden_size = hidden_size
        self.W_QKV = nn.Linear(hidden_size*3, num_heads*hidden_size*3)
        self.W_O = nn.Linear(num_heads*hidden_size, hidden_size)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        '''
        x: shape (batch, seq, hidden_size)

        Return: shape (batch, seq, hidden_size)
        '''
        x = x.repeat((1,1,3)) # repeat trice along dim 2
        Q, K, V = torch.split(self.W_QKV(x), num_heads*self.hidden_size, 2)
        print(f"{Q.shape=} {K.shape=} {V.shape=}")
        Z = multihead_masked_attention(Q, K, V, num_heads=self.num_heads)
        print(f"{Z.shape=}")
        Z = self.W_O(Z)
        return Z

# num_heads=4
# x = torch.ones((2,10,hidden_size:=64))       
# mma = MultiheadMaskedAttention(hidden_size=hidden_size, num_heads=num_heads)
# mma(x).shape



Q.shape=torch.Size([2, 10, 256]) K.shape=torch.Size([2, 10, 256]) V.shape=torch.Size([2, 10, 256])
Z.shape=torch.Size([2, 10, 256])


torch.Size([2, 10, 64])

In [23]:
x = torch.ones((1,2,3))


In [34]:
x = x.repeat((1,1,3))
torch.split()