In [1]:
import collections
import itertools
import functools
import math
import random

import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np
from torch.utils.data import DataLoader
from tqdm import tqdm_notebook

import json

In [7]:
class Encoder(nn.Module):
    def __init__(self, src_vocab_size, embed_size, num_layers, heads, device, forward_expansion, dropout, max_length):
        super(Encoder, self).__init__()

        self.embed_size         = embed_size
        self.device             = device

        self.word_embedding     = nn.Embedding(src_vocab_size, embed_size)
        self.position_embedding = nn.Embedding(max_length, embed_size)
        self.layers             = nn.ModuleList([TransformerBlock(embed_size, heads, dropout=dropout, forward_expansion=forward_expansion) for _ in range(num_layers)])
        self.dropout            = nn.Dropout(dropout)

    def forward(self, x, mask):

        N, seq_length = x.shape
        positions     = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)
        out           = self.dropout((self.word_embedding(x) + self.position_embedding(positions)))

        # In the Encoder the query, key, value are all the same, it's in the
        # decoder this will change. This might look a bit odd in this case.
        for layer in self.layers:
            out = layer(out, out, out, mask)

        return out

In [85]:
class SelfAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(SelfAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        assert (
            self.head_dim * heads == embed_size
        ), "Embedding size needs to be divisible by heads"

        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_out = nn.Linear(heads * self.head_dim, embed_size)

    def forward(self, values, keys, query, mask):
        # Get number of training examples
        N = query.shape[0]

        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]
        print(value_len, key_len, query_len)

        '''print('V0', values.shape)
        print('K0', keys.shape)
        print('Q0', query.shape)'''

        # Split the embedding into self.heads different pieces
        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        query = query.reshape(N, query_len, self.heads, self.head_dim)
        '''print('V1', values.shape)
        print('K1', keys.shape)
        print('Q1', query.shape)'''

        values = self.values(values)  # (N, value_len, heads, head_dim)
        keys = self.keys(keys)  # (N, key_len, heads, head_dim)
        queries = self.queries(query)  # (N, query_len, heads, heads_dim)

        '''print('V2', values.shape)
        print('K2', keys.shape)
        print('Q2', query.shape)'''

        # Einsum does matrix mult. for query*keys for each training example
        # with every other training example, don't be confused by einsum
        # it's just how I like doing matrix multiplication & bmm

        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
        # queries shape: (N, query_len, heads, heads_dim),
        # keys shape: (N, key_len, heads, heads_dim)
        # energy: (N, heads, query_len, key_len)

        # Mask padded indices so their weights become 0
        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))

        # Normalize energy values similarly to seq2seq + attention
        # so that they sum to 1. Also divide by scaling factor for
        # better stability
        attention = torch.softmax(energy / (self.embed_size ** (1 / 2)), dim=3)
        # attention shape: (N, heads, query_len, key_len)

        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
            N, query_len, self.heads * self.head_dim
        )
        # attention shape: (N, heads, query_len, key_len)
        # values shape: (N, value_len, heads, heads_dim)
        # out after matrix multiply: (N, query_len, heads, head_dim), then
        # we reshape and flatten the last two dimensions.

        out = self.fc_out(out)
        # Linear layer doesn't modify the shape, final shape will be
        # (N, query_len, embed_size)

        return out

In [6]:
class TransformerBlock(nn.Module):
    def __init__(self, embed_size, heads, dropout, forward_expansion):
        super(TransformerBlock, self).__init__()
        self.attention = SelfAttention(embed_size, heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)

        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, forward_expansion * embed_size),
            nn.ReLU(),
            nn.Linear(forward_expansion * embed_size, embed_size),
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, value, key, query, mask):
        attention = self.attention(value, key, query, mask)

        # Add skip connection, run through normalization and finally dropout
        x = self.dropout(self.norm1(attention + query))
        forward = self.feed_forward(x)
        out = self.dropout(self.norm2(forward + x))
        return out

In [58]:
class DecoderBlock(nn.Module):
    def __init__(self, embed_size, heads, forward_expansion, dropout, device):
        super(DecoderBlock, self).__init__()

        self.norm              = nn.LayerNorm(embed_size)
        self.attention         = SelfAttention(embed_size, heads=heads)
        self.transformer_block = TransformerBlock(embed_size, heads, dropout, forward_expansion)
        self.dropout           = nn.Dropout(dropout)

    def forward(self, x, value, key, src_mask, trg_mask):

        attention = self.attention(x, x, x, trg_mask)
        print("decAtt", attention.shape)
        query     = self.dropout(self.norm(attention + x))
        print("decQue", query.shape)
        out       = self.transformer_block(value, key, query, src_mask)
        print("decOut", out.shape)

        return out

In [74]:
class Decoder(nn.Module):
    def __init__(self, trg_vocab_size, embed_size, num_layers, heads, forward_expansion, dropout, device, max_length):
        super(Decoder, self).__init__()
        self.device = device

        self.word_embedding     = nn.Embedding(trg_vocab_size, embed_size)
        self.position_embedding = nn.Embedding(max_length, embed_size)

        self.layers  = nn.ModuleList([DecoderBlock(embed_size, heads, forward_expansion, dropout, device) for _ in range(num_layers)])
        self.fc_out  = nn.Linear(embed_size, trg_vocab_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, enc_out, src_mask, trg_mask):
        print('xx', x.shape)
        N, seq_length = x.shape
        positions     = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)
        print('posi', positions.shape)
        x             = self.dropout((self.word_embedding(x) + self.position_embedding(positions)))
        print('xx', x.shape)

        for layer in self.layers:
            print('Newblock')
            x = layer(x, enc_out, enc_out, src_mask, trg_mask)

        out = self.fc_out(x)

        return out

In [61]:
class Transformer(nn.Module):
    def __init__(self, src_vocab_size, trg_vocab_size, src_pad_idx, trg_pad_idx, embed_size=512,
                 num_layers=2, forward_expansion=2, heads=2, dropout=0, device="cpu", max_length=9):

        super(Transformer, self).__init__()

        self.encoder = Encoder(src_vocab_size, embed_size, num_layers, heads, device, forward_expansion, dropout, max_length)
        self.decoder = Decoder(trg_vocab_size, embed_size, num_layers, heads, forward_expansion, dropout, device, max_length)

        self.src_pad_idx = src_pad_idx
        self.trg_pad_idx = trg_pad_idx
        self.device      = device

    def make_src_mask(self, src):
      
        src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)
        # (N, 1, 1, src_len)
        return src_mask.to(self.device)

    def make_trg_mask(self, trg):

        N, trg_len = trg.shape
        trg_mask = torch.tril(torch.ones((trg_len, trg_len))).expand(N, 1, trg_len, trg_len)

        return trg_mask.to(self.device)

    def forward(self, src, trg):

        src_mask = self.make_src_mask(src)
        trg_mask = self.make_trg_mask(trg)
        # print('x', src.shape)

        enc_src  = self.encoder(src, src_mask)
        # print('encoder', enc_src.shape)

        print('===== DECODER =====')
        out      = self.decoder(trg, enc_src, src_mask, trg_mask)
        print('decoder', out.shape)

        return out

In [86]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

model = LanguageModel(vocab_size=10, max_seq_length=9, dim=512, n_layers=2, pad_token_id=0).to(device)

x   = torch.tensor([[1, 5, 6, 4, 3, 9, 5, 2, 0], [1, 8, 7, 3, 4, 5, 6, 7, 2]]).to(device)
trg = torch.tensor([[1, 7, 4, 3, 5, 9, 2, 0], [1, 5, 6, 2, 4, 7, 6, 2]]).to(device)

out = model(x)
print(out.shape) 

cpu
9 9 9
9 9 9
torch.Size([2, 9, 10])


In [75]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

x   = torch.tensor([[1, 5, 6, 4, 3, 9, 5, 2, 0], [1, 8, 7, 3, 4, 5, 6, 7, 2]]).to(device)
trg = torch.tensor([[1, 7, 4, 3, 5, 9, 2, 0], [1, 5, 6, 2, 4, 7, 6, 2]]).to(device)

src_pad_idx    = 0
trg_pad_idx    = 0
src_vocab_size = 10
trg_vocab_size = 10

print(trg[:, :-1])

model = Transformer(src_vocab_size, trg_vocab_size, src_pad_idx, trg_pad_idx, device=device).to(device)
out = model(x, trg[:, :-1])
print(out.shape)

cpu
tensor([[1, 7, 4, 3, 5, 9, 2],
        [1, 5, 6, 2, 4, 7, 6]])
9 9 9
9 9 9
===== DECODER =====
xx torch.Size([2, 7])
posi torch.Size([2, 7])
xx torch.Size([2, 7, 512])
Newblock
7 7 7
decAtt torch.Size([2, 7, 512])
decQue torch.Size([2, 7, 512])
9 9 7
decOut torch.Size([2, 7, 512])
Newblock
7 7 7
decAtt torch.Size([2, 7, 512])
decQue torch.Size([2, 7, 512])
9 9 7
decOut torch.Size([2, 7, 512])
decoder torch.Size([2, 7, 10])
torch.Size([2, 7, 10])


In [3]:
def make_src_mask(src):
    src_mask = (src != 0).unsqueeze(1).unsqueeze(2)
        # (N, 1, 1, src_len)
    return src_mask

make_src_mask(x)

tensor([[[[ True,  True,  True,  True,  True,  True,  True,  True, False]]],


        [[[ True,  True,  True,  True,  True,  True,  True,  True,  True]]]])

In [4]:
def make_src_mask(src):
    src_mask = (src != 0)
        # (N, 1, 1, src_len)
    return src_mask[:, None, None, :]

make_src_mask(x)

tensor([[[[ True,  True,  True,  True,  True,  True,  True,  True, False]]],


        [[[ True,  True,  True,  True,  True,  True,  True,  True,  True]]]])

# Minha modificação

In [83]:
class LanguageModel(torch.nn.Module):

    def __init__(self, vocab_size: int, max_seq_length: int, dim: int, n_layers: int, pad_token_id: int):
        """
        Implements the Self-attention, decoder-only."

        Args:
            vocab_size (int): Size of the input vocabulary.
            max_seq_length (int): Size of the sequence to consider as context for prediction.
            dim (int): Dimension of the embedding layer for each word in the context.
            n_layers (int): number of self-attention layers.
            pad_token_id (int): id of the pad token that will be ignored in the attention.
        """
        # Escreva seu código aqui.

        super().__init__()

        self.vocab_size     = vocab_size
        self.max_seq_length = max_seq_length
        self.dim            = dim
        self.n_layers       = n_layers
        self.pad_token_id   = pad_token_id
        self.n_heads        = 2

        self.embedding_layer       = nn.Embedding(vocab_size,     dim, padding_idx=pad_token_id)
        self.positional_embeddings = nn.Embedding(max_seq_length, dim, padding_idx=pad_token_id)

        self.attention1  = SelfAttention(self.dim, self.n_heads)
        self.attention2  = SelfAttention(self.dim, self.n_heads)
        
        self.linear1 = nn.Linear(self.dim, self.vocab_size, bias=False)
        self.dropout = nn.Dropout(p=0.2)

        self.norm0 = nn.LayerNorm(dim)
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)

        self.feed_forward = nn.Sequential(
            nn.Linear(dim, 2*dim),
            nn.ReLU(),
            nn.Linear(2*dim, dim),
        )

    def triMask(self, inputs):

        N, tri_len = inputs.shape
        return torch.tril(torch.ones((tri_len, tri_len))).expand(N, 1, tri_len, tri_len)

    def forward(self, inputs):
        """
        Args:
            inputs is a LongTensor of shape (batch_size, max_seq_length)
            
        Returns:
            logits of shape (batch_size, vocab_size)
        """
        # Escreva seu código aqui.

        mask = inputs != self.pad_token_id
        j = self.triMask(inputs)

        x = self.dropout(self.embedding_layer(inputs) + self.positional_embeddings.weight) # B, L, D
        a = self.attention1(x, x, x, j)
        q = self.dropout(self.norm0(a + x))

        a = self.attention2(x, x, q, mask[:, None, None, :])

        x = self.dropout(self.norm1(a + q))
        forward = self.feed_forward(x)
        x = self.dropout(self.norm2(forward + x))


        o = self.linear1(x)
                  
        return o

In [42]:
class SelfAttentionmm(nn.Module):
    def __init__(self, dim, max_seq_length, n_heads, pad_token_id):
        super(SelfAttentionmm, self).__init__()

        self.pad_token_id   = pad_token_id
        self.max_seq_length = max_seq_length

        self.n_heads        = n_heads
        self.dim            = dim
        self.D_k            = dim//n_heads
        
        self.W_q = torch.nn.Linear(self.dim, self.dim, bias=False) # D, D
        self.W_k = torch.nn.Linear(self.dim, self.dim, bias=False) # D, D
        self.W_v = torch.nn.Linear(self.dim, self.dim, bias=False) # D, D
        self.W_o = torch.nn.Linear(self.dim, self.dim, bias=False) # D, D

        self.feed_forward = torch.nn.Sequential(torch.nn.LayerNorm(self.dim, eps=1e-6),
                                                nn.Linear(self.dim, self.dim),
                                                torch.nn.ReLU(),
                                                nn.Linear(self.dim, self.dim),
                                                nn.Dropout(p=0.2),
                                                torch.nn.LayerNorm(self.dim, eps=1e-6))


    def attention(self, Q, K, V, mask):   

        '''
        1 torch.Size([5, 2, 9, 9])
        2 torch.Size([5, 9])
        3 torch.Size([5, 2, 9, 9])
        4 torch.Size([5, 2, 9, 9])
        5 torch.Size([5, 2, 9, 128])
        6 torch.Size([5, 9, 2, 128])
        7 torch.Size([5, 9, 256])
        8 torch.Size([5, 9, 256])
        '''
        
        scores = torch.matmul(Q, K.transpose(-1, -2))/math.sqrt(self.D_k) # B, HEADS, L, L -> 1
        # print(scores)
        mask_expanded = mask.expand_as(scores)                        # B, HEADS, L, L -> 3
        
        scores.masked_fill_(~mask_expanded, float('-inf'))                # B, HEADS, L, L
        # print(scores)
        probs = F.softmax(scores, dim=-1)                                 # B, HEADS, L, L -> 4

        E = torch.matmul(probs, V)                                        # B, HEADS, L, D//HEADS -> 5
        E = E.transpose(1,2).contiguous()                                 # B, L, HEADS, D//HEADS -> 6
        E = E.reshape(mask.shape[0], self.max_seq_length, self.dim)       # B, L, D -> 7
        E = self.W_o(E)                                                   # B, L, D -> 8

        return E
        
    def forward(self, x, inputs, tri_mask):

        mask = inputs != self.pad_token_id

        print('xSelf', x.shape) # B, L, D

        q = self.W_q(x).reshape(mask.shape[0], self.max_seq_length, self.n_heads, self.D_k).transpose(1,2) # B, HEADS, L, D//HEADS -> torch.Size([5, 2, 9, 128])
        k = self.W_k(x).reshape(mask.shape[0], self.max_seq_length, self.n_heads, self.D_k).transpose(1,2) # B, HEADS, L, D//HEADS
        v = self.W_v(x).reshape(mask.shape[0], self.max_seq_length, self.n_heads, self.D_k).transpose(1,2) # B, HEADS, L, D//HEADS

        y = self.attention(q, k, v, mask, tri_mask)   # B, L, D
        y = self.W_o(x)                               # B, L, D


        # print('x', x.shape)
        # print('inputs', inputs.shape)
        X = x + y
        

        return self.feed_forward(X)

In [None]:
src_pad_idx    = 0
trg_pad_idx    = 0
src_vocab_size = 10
trg_vocab_size = 10

print(trg[:, :-1])

model = Transformer(src_vocab_size, trg_vocab_size, src_pad_idx, trg_pad_idx, device=device).to(device)
out = model(x, trg[:, :-1])
print(out.shape)

# Attention

In [54]:
class SelfAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(SelfAttention, self).__init__()

        self.embed_size = embed_size
        self.heads      = heads
        self.head_dim   = embed_size // heads

        assert (
            self.head_dim * heads == embed_size
        ), "Embedding size needs to be divisible by heads"

        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_out = nn.Linear(heads * self.head_dim, embed_size)

    def forward(self, values, keys, query, mask):
        # Get number of training examples
        N = query.shape[0]

        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]
        print(value_len, key_len, query_len)

        # Split the embedding into self.heads different pieces
        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        query = query.reshape(N, query_len, self.heads, self.head_dim)

        values = self.values(values)  # (N, value_len, heads, head_dim)
        keys = self.keys(keys)  # (N, key_len, heads, head_dim)
        queries = self.queries(query)  # (N, query_len, heads, heads_dim)

        # Einsum does matrix mult. for query*keys for each training example
        # with every other training example, don't be confused by einsum
        # it's just how I like doing matrix multiplication & bmm

        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
        # queries shape: (N, query_len, heads, heads_dim),
        # keys shape: (N, key_len, heads, heads_dim)
        # energy: (N, heads, query_len, key_len)

        # Mask padded indices so their weights become 0
        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))

        # Normalize energy values similarly to seq2seq + attention
        # so that they sum to 1. Also divide by scaling factor for
        # better stability
        attention = torch.softmax(energy / (self.embed_size ** (1 / 2)), dim=3)
        # attention shape: (N, heads, query_len, key_len)

        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
            N, query_len, self.heads * self.head_dim
        )
        # attention shape: (N, heads, query_len, key_len)
        # values shape: (N, value_len, heads, heads_dim)
        # out after matrix multiply: (N, query_len, heads, head_dim), then
        # we reshape and flatten the last two dimensions.

        out = self.fc_out(out)
        # Linear layer doesn't modify the shape, final shape will be
        # (N, query_len, embed_size)

        return out

In [55]:
a = torch.rand((2, 9))
m = torch.rand((2, 9, 512))

teste = SelfAttention(embed_size=512, heads=2)
teste(m, m, m, make_src_mask(a))


9 9 9


tensor([[[-0.3335, -0.0593, -0.0261,  ...,  0.1508, -0.2988, -0.0069],
         [-0.3329, -0.0577, -0.0257,  ...,  0.1506, -0.2985, -0.0077],
         [-0.3319, -0.0588, -0.0263,  ...,  0.1503, -0.2989, -0.0079],
         ...,
         [-0.3331, -0.0596, -0.0260,  ...,  0.1512, -0.2997, -0.0085],
         [-0.3335, -0.0592, -0.0259,  ...,  0.1506, -0.2991, -0.0082],
         [-0.3335, -0.0578, -0.0255,  ...,  0.1506, -0.2982, -0.0072]],

        [[-0.2888, -0.1058,  0.1194,  ...,  0.0919, -0.2774, -0.0424],
         [-0.2871, -0.1064,  0.1195,  ...,  0.0916, -0.2778, -0.0417],
         [-0.2875, -0.1054,  0.1196,  ...,  0.0910, -0.2775, -0.0435],
         ...,
         [-0.2874, -0.1060,  0.1191,  ...,  0.0918, -0.2779, -0.0422],
         [-0.2868, -0.1061,  0.1193,  ...,  0.0912, -0.2775, -0.0421],
         [-0.2877, -0.1058,  0.1193,  ...,  0.0915, -0.2777, -0.0426]]],
       grad_fn=<AddBackward0>)

In [57]:
m.reshape(2, 9, 2, 256)

tensor([[[[0.0033, 0.9380, 0.4886,  ..., 0.5543, 0.9401, 0.5505],
          [0.2118, 0.9130, 0.8602,  ..., 0.4050, 0.2701, 0.2778]],

         [[0.0262, 0.5792, 0.0016,  ..., 0.1232, 0.8835, 0.2023],
          [0.3997, 0.9968, 0.8600,  ..., 0.4489, 0.6798, 0.3016]],

         [[0.1478, 0.5773, 0.1737,  ..., 0.6511, 0.0846, 0.4102],
          [0.4746, 0.3968, 0.5910,  ..., 0.8221, 0.6617, 0.5427]],

         ...,

         [[0.5062, 0.2764, 0.7213,  ..., 0.3328, 0.1283, 0.8582],
          [0.0637, 0.4163, 0.8025,  ..., 0.2134, 0.7268, 0.8714]],

         [[0.2920, 0.7578, 0.7532,  ..., 0.7742, 0.6510, 0.9469],
          [0.4326, 0.7127, 0.4063,  ..., 0.7641, 0.9083, 0.4737]],

         [[0.1933, 0.0821, 0.7895,  ..., 0.3318, 0.9747, 0.5574],
          [0.8157, 0.3432, 0.1347,  ..., 0.2997, 0.3680, 0.3851]]],


        [[[0.9211, 0.0728, 0.6304,  ..., 0.3999, 0.7939, 0.5588],
          [0.6313, 0.2807, 0.7249,  ..., 0.8839, 0.0532, 0.8162]],

         [[0.1017, 0.1403, 0.0193,  ..., 0.73

In [None]:
class SelfAttentionmm(nn.Module):
    def __init__(self, dim, max_seq_length, n_heads, pad_token_id):
        super(SelfAttentionmm, self).__init__()

        self.pad_token_id   = pad_token_id
        self.max_seq_length = max_seq_length

        self.n_heads        = n_heads
        self.dim            = dim
        self.D_k            = dim//n_heads
        
        self.W_q = torch.nn.Linear(self.dim, self.dim, bias=False) # D, D
        self.W_k = torch.nn.Linear(self.dim, self.dim, bias=False) # D, D
        self.W_v = torch.nn.Linear(self.dim, self.dim, bias=False) # D, D
        self.W_o = torch.nn.Linear(self.dim, self.dim, bias=False) # D, D

        self.feed_forward = torch.nn.Sequential(torch.nn.LayerNorm(self.dim, eps=1e-6),
                                                nn.Linear(self.dim, self.dim),
                                                torch.nn.ReLU(),
                                                nn.Linear(self.dim, self.dim),
                                                nn.Dropout(p=0.2),
                                                torch.nn.LayerNorm(self.dim, eps=1e-6))


    def attention(self, Q, K, V, mask, tri_mask):   

        '''
        1 torch.Size([5, 2, 9, 9])
        2 torch.Size([5, 9])
        3 torch.Size([5, 2, 9, 9])
        4 torch.Size([5, 2, 9, 9])
        5 torch.Size([5, 2, 9, 128])
        6 torch.Size([5, 9, 2, 128])
        7 torch.Size([5, 9, 256])
        8 torch.Size([5, 9, 256])
        '''
        
        scores = torch.matmul(Q, K.transpose(-1, -2))/math.sqrt(self.D_k) # B, HEADS, L, L -> 1
        # print(scores)
        new_mask      = mask[:, None, None, :] & tri_mask
        mask_expanded = new_mask.expand_as(scores)                        # B, HEADS, L, L -> 3
        
        scores.masked_fill_(~mask_expanded, float('-inf'))                # B, HEADS, L, L
        # print(scores)
        probs = F.softmax(scores, dim=-1)                                 # B, HEADS, L, L -> 4

        E = torch.matmul(probs, V)                                        # B, HEADS, L, D//HEADS -> 5
        E = E.transpose(1,2).contiguous()                                 # B, L, HEADS, D//HEADS -> 6
        E = E.reshape(mask.shape[0], self.max_seq_length, self.dim)       # B, L, D -> 7
        E = self.W_o(E)                                                   # B, L, D -> 8

        return E
        
    def forward(self, x, inputs, tri_mask):

        mask = inputs != self.pad_token_id

        q = self.W_q(x).reshape(mask.shape[0], self.max_seq_length, self.n_heads, self.D_k).transpose(1,2) # B, HEADS, L, D//HEADS -> torch.Size([5, 2, 9, 128])
        k = self.W_k(x).reshape(mask.shape[0], self.max_seq_length, self.n_heads, self.D_k).transpose(1,2) # B, HEADS, L, D//HEADS
        v = self.W_v(x).reshape(mask.shape[0], self.max_seq_length, self.n_heads, self.D_k).transpose(1,2) # B, HEADS, L, D//HEADS

        y = self.attention(q, k, v, mask, tri_mask)   # B, L, D
        y = self.W_o(x)                               # B, L, D


        # print('x', x.shape)
        # print('inputs', inputs.shape)
        X = x + y
        

        return self.feed_forward(X)

In [None]:
class SelfAttention(nn.Module):
    def __init__(self, dim, heads):
        super(SelfAttention, self).__init__()

        self.dim        = dim
        self.heads      = heads
        self.head_dim   = dim // heads

        assert (self.head_dim * heads == dim), "Embedding size needs to be divisible by heads"

        self.Wv = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.Wk = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.Wq = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc = nn.Linear(dim, dim)

    def forward(self, values, keys, query, mask):
        N = query.shape[0]

        v_len, k_len, q_len = values.shape[1], keys.shape[1], query.shape[1]
        print(v_len, k_len, q_len)

        # Split the embedding into self.heads different pieces
        values = values.reshape(N, v_len, self.heads, self.head_dim)
        keys   = keys.reshape(N, k_len, self.heads, self.head_dim)
        query  = query.reshape(N, q_len, self.heads, self.head_dim)

        values  = self.Wv(values)  # (N, value_len, heads, head_dim)
        keys    = self.Wk(keys)    # (N, key_len, heads, head_dim)
        queries = self.Wq(query)   # (N, query_len, heads, heads_dim)

        # Einsum does matrix mult. for query*keys for each training example
        # with every other training example, don't be confused by einsum
        # it's just how I like doing matrix multiplication & bmm

        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
        # queries shape: (N, query_len, heads, heads_dim),
        # keys shape: (N, key_len, heads, heads_dim)
        # energy: (N, heads, query_len, key_len)

        # Mask padded indices so their weights become 0
        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))

        # Normalize energy values similarly to seq2seq + attention
        # so that they sum to 1. Also divide by scaling factor for
        # better stability
        attention = torch.softmax(energy / (self.embed_size ** (1 / 2)), dim=3)
        # attention shape: (N, heads, query_len, key_len)

        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
            N, query_len, self.heads * self.head_dim
        )
        # attention shape: (N, heads, query_len, key_len)
        # values shape: (N, value_len, heads, heads_dim)
        # out after matrix multiply: (N, query_len, heads, head_dim), then
        # we reshape and flatten the last two dimensions.

        out = self.fc_out(out)
        # Linear layer doesn't modify the shape, final shape will be
        # (N, query_len, embed_size)

        return out

In [70]:
word_embedding     = nn.Embedding(10, 2)

k =  torch.ones((2,7), dtype=torch.long)

word_embedding(k)

tensor([[[-0.2938, -0.6883],
         [-0.2938, -0.6883],
         [-0.2938, -0.6883],
         [-0.2938, -0.6883],
         [-0.2938, -0.6883],
         [-0.2938, -0.6883],
         [-0.2938, -0.6883]],

        [[-0.2938, -0.6883],
         [-0.2938, -0.6883],
         [-0.2938, -0.6883],
         [-0.2938, -0.6883],
         [-0.2938, -0.6883],
         [-0.2938, -0.6883],
         [-0.2938, -0.6883]]], grad_fn=<EmbeddingBackward0>)

In [71]:
word_embedding     = nn.Linear(10, 2)

k =  torch.ones((2,7), dtype=torch.long)

word_embedding(k)

RuntimeError: ignored