## Exercício: Modelo de Linguagem com auto-atenção

Este exercício é similar ao da aula passada, mas iremos agora treinar uma rede neural *com auto-atenção* para prever a próxima palavra de um texto, data as palavras anteriores como entrada.

Na camada de auto-atenção, deve-se implementar (vide slide 34):
- Embeddings de posição
- Projeções lineares (WQ, WK, WV, WO)
- Camada de feed forward (2-layer MLP)

Instrucões:
- É necessário fazer duas implementações da camada de auto-atenção: uma usando laços (ineficiente, mas fácil de entender) e outra matricial (eficiente mas difícil de entender). Usar slide 36 como referência.

- Fazer um assert para garantir que o resultado das duas implementações é exatamente igual.

- No treinamento, usar apenas a implementação matricial.

In [314]:
import string
from collections import Counter
from typing import List, Dict, Union, Tuple
import random
import os
import time
import abc

import numpy as np
from numpy.testing import assert_raises, assert_array_equal, assert_array_almost_equal
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import tqdm
import matplotlib.pyplot as plt


In [28]:
vocab_size = 3000 # Quantidade de palavras no vocabulário
context_size = 4 # n palavras de entrada. O target é a próxima palavra
embed_dim = 64 # Tamanho do feature vector de cada palavra
hidden_units = 300 # Quantidade de unidades na camada escondida
epochs = 10 # Quantidade de epochs que serão treinadas
lr = 5e-1 # Taxa de treinamento
weight_decay = 1e-3 # Regularização
batch_size = 32


## Faz download e carrega o dataset

In [2]:
if not os.path.isfile("67724.txt.utf-8"):
    !curl -LO https://www.gutenberg.org/ebooks/67724.txt.utf-8

if not os.path.isfile("67725.txt.utf-8"):
    !curl -LO https://www.gutenberg.org/ebooks/67725.txt.utf-8

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed

  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0
  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0
100   304  100   304    0     0    450      0 --:--:-- --:--:-- --:--:--   452

  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0

 11  364k   11 42594    0     0  31946      0  0:00:11  0:00:01  0:00:10 31946
100  364k  100  364k    0     0   183k      0  0:00:01  0:00:01 --:--:--  495k
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed

  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0
  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0
100   304  100   304    0     0    533      0 --

In [4]:
text = open("67724.txt.utf-8","r", encoding="utf8").read()
text += open("67725.txt.utf-8","r", encoding="utf8").read()

paragraphs = text.split("\n\n")
len(paragraphs)

4969

In [5]:
def clean_text(text:str) -> str:
    '''
    Clean the text, changing upper case and setting numbers to 999
    '''
    
    text = text.lower()
    old_text = text.split()
    new_text = []

    for j in range(len(old_text)):
        word = old_text[j] 
        if word.isdigit():
            word = "999"
        elif len(word) > 1 and word[0] in string.punctuation:
            old_text.insert(j+1, word[1:])
            word = word[0]
        elif word[-1] in string.punctuation and len(word) > 1:
            old_text.insert(j+1, word[:-1])
            old_text.insert(j+2, word[-1])
            
            word = ""
        
        if len(word) > 0:
            new_text.append(word)
    
    return " ".join(new_text)

In [6]:
cleaned_paragraphs = [paragraph.replace("\n", " ") for paragraph in paragraphs if paragraph.strip()]

#Paper:
#ponctuation -> keep (separado das outras palavras, "pontuação," -> "pontuação"+",")
#numeric -> special symbol (colocando todos como 999 para convergir para o mesmo símbolo)
#upper -> lower
#proper nouns -> special symbol (difícil identificar, ignorado)
#rare words -> special symbol (feito na parte de encoding)

for i in range(len(cleaned_paragraphs)):
    cleaned_paragraphs[i] = clean_text(cleaned_paragraphs[i])

print("SAMPLE ----------------")
print(cleaned_paragraphs[0])
print("---------------------")

print(len(cleaned_paragraphs))

SAMPLE ----------------
﻿the project gutenberg ebook of o guarany : romance brazileiro , vol . 999 ( of 999 ) this ebook is for the use of anyone anywhere in the united states and most other parts of the world at no cost and with almost no restrictions whatsoever . you may copy it , give it away or re-use it under the terms of the project gutenberg license included with this ebook or online at www.gutenberg.org . if you are not located in the united states , you
---------------------
4892


In [7]:
del paragraphs, text

## Análise do dataset

In [9]:

def count_words(texts:List[str]) -> Counter:
    word_counts = Counter()
    for text in texts:
        word_counts.update(text.split(" "))
    return word_counts

word_counts = count_words(cleaned_paragraphs)

len(word_counts)

11470

## Criando um vocabulário

In [12]:
most_frequent_words = [word for word, count in word_counts.most_common(vocab_size)]
vocab = {word: i for i, word in enumerate(most_frequent_words, 1)}

In [13]:
def encode_sentence(sentence:Union[str,List[str]], vocab:Dict) -> List[int]:
    if isinstance(sentence, list):
        words = sentence
    else:
        words = sentence.split(" ")
    
    return [vocab.get(word, 0) for word in words]

In [14]:
inverse_vocab = list(vocab.keys())

In [15]:
def decode_sentence(encoding, inverse_vocab):
    result = []

    for encoding_i in encoding:
        if encoding_i == 0:
            result.append("???")
        else:
            result.append(inverse_vocab[encoding_i-1])

    return result

In [16]:
del word_counts, most_frequent_words

## Classe do dataset

In [17]:
def create_sequences(texts:List[str], context_size:int, 
                     vocab:Dict) -> Tuple[List[List[int]], List[int]]:
    '''
    Generates
    '''
    x_all = []
    y_all = []

    for paragraph in texts:
        start = 0
        end = context_size

        paragraph = encode_sentence(paragraph, vocab)

        while end < len(paragraph):
            x = paragraph[start:end]
            y = paragraph[end]

            if not ( 0 in x or 0 == y):
                x_all.append(x)
                y_all.append(y)

            start += 1
            end += 1
    return x_all, y_all

In [19]:
x_all, y_all = create_sequences(cleaned_paragraphs, context_size, vocab)
assert len(x_all) == len(y_all)

In [20]:
#Embaralhando para evitar viés
indexes = list(range(len(x_all)))
random.shuffle(indexes)

x_all = np.array(x_all)
y_all = np.array(y_all)

x_all = x_all[indexes]
y_all = y_all[indexes]

In [21]:
size_all = len(x_all)

cut1 = int(0.6*size_all)
cut2 = int(0.8*size_all)

x_train = x_all[0:cut1]
y_train = y_all[0:cut1]

x_val = x_all[cut1:cut2]
y_val = y_all[cut1:cut2]

x_test = x_all[cut2:]
y_test = y_all[cut2:]

In [22]:
n_train = len(x_train)
n_val = len(x_val)
n_test = len(x_test)

In [23]:
print("Treino:", n_train)
print("Validação:", n_val)
print("Teste:", n_test)

Treino: 27219
Validação: 9073
Teste: 9074


In [24]:
assert n_train+n_val+n_test == size_all

In [25]:
class TextPredictDataset(Dataset):
    def __init__(self, x_data:List[int], y_data:List[int]):
        self._x_data = torch.tensor(x_data)-1
        self._y_data = torch.tensor(y_data, dtype=torch.int64)-1
        
        if len(x_data) != len(y_data):
            raise ValueError(f"x_data and y_data must have same size. ({len(x_data)} ≠ {len(y_data)})")
        
        self._size = len(x_data)

    def __len__(self):
        return self._size

    def __getitem__(self, idx):
        return self._x_data[idx], self._y_data[idx]


In [26]:
train_data = TextPredictDataset(x_train, y_train)
val_data = TextPredictDataset(x_val, y_val)
test_data = TextPredictDataset(x_test, y_test)

In [29]:
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True)

In [30]:
sample_batch = next(iter(train_loader))

## Model

In [161]:
class SingleHeadAttentionBase(torch.nn.Module, abc.ABC):
    def __init__(self, embed_dim:int) -> None:
        super().__init__()
        
        #d_model = dv = dk = embed_dim
        #h = 1

        self.wQ = torch.Tensor(embed_dim, embed_dim)
        self.wK = torch.Tensor(embed_dim, embed_dim)
        self.wV = torch.Tensor(embed_dim, embed_dim)
        self.w0 = torch.Tensor(embed_dim, embed_dim)

        self.wQ = torch.nn.Parameter(self.wQ)
        self.wK = torch.nn.Parameter(self.wK)
        self.wV = torch.nn.Parameter(self.wV)
        self.w0 = torch.nn.Parameter(self.w0)

        self.dk_root = torch.sqrt(torch.tensor(embed_dim, dtype=torch.float32))

        for w in [self.wQ, self.wK, self.wV, self.w0]:
            torch.nn.init.xavier_uniform_(w)
    
    @abc.abstractmethod
    def forward(self, query:torch.Tensor, key:torch.Tensor, value:torch.Tensor) -> torch.Tensor:
        ...
        

In [246]:
class SingleHeadAttention(SingleHeadAttentionBase):

    def __init__(self, embed_dim: int) -> None:
        super().__init__(embed_dim)

    def forward(self, query:torch.Tensor, key:torch.Tensor, value:torch.Tensor) -> torch.Tensor:
        print(self.wQ.shape, query.shape)
        
        Q = query @ self.wQ
        K = key @ self.wK
        V = value @ self.wV


        scores = Q @ K.permute(0,2,1)
        scores /= self.dk_root
        probs = torch.softmax(scores, dim=-1)
        E = probs @ V

        result = E @ self.w0

        return result


In [247]:
class SingleHeadAttentionLoop(SingleHeadAttentionBase):
    def __init__(self, embed_dim: int) -> None:
        super().__init__(embed_dim)


    def forward(self, query:torch.Tensor, key:torch.Tensor, value:torch.Tensor) -> torch.Tensor:
        batch_size = query.shape[0]
        sequence_size = query.shape[1]
        
        result = torch.empty_like(query)
        scores = torch.empty(sequence_size, device=query.device)

        for batch_index in range(batch_size):
            for word_index in range(sequence_size):
                xq = query[batch_index, word_index]
                q = xq @ self.wQ
                
                for key_index in range(sequence_size):
                    xk = key[batch_index][key_index]
                    k = xk @ self.wK
                    score = q @ k.T
                    scores[key_index] = score
                
                probs = torch.softmax(scores, dim=-1)

                e = 0
                for xv, p in zip(value[batch_index], probs):
                    v = xv @ self.wV
                    e += v*p
                
                e = e @ self.w0

                result[batch_index, word_index] = e

        return result
            

In [328]:
test_embed_dim = 5
matrix_version = SingleHeadAttention(test_embed_dim).eval()
loop_version = SingleHeadAttentionLoop(test_embed_dim).eval()
torch_version = torch.nn.MultiheadAttention(test_embed_dim, num_heads=1, bias=False, batch_first=True).eval()

In [329]:
wQ = matrix_version.wQ
wK = wQ
wV = wQ
w0 = wQ

#wQ = torch.nn.Parameter(torch.ones((test_embed_dim, test_embed_dim)))
#wK = torch.nn.Parameter(torch.ones((test_embed_dim, test_embed_dim)))
#wV = torch.nn.Parameter(torch.ones((test_embed_dim, test_embed_dim)))
#w0 = torch.nn.Parameter(torch.ones((test_embed_dim, test_embed_dim)))

#wQ = torch.nn.Parameter(torch.eye(test_embed_dim))
#wK = torch.nn.Parameter(torch.eye(test_embed_dim))
#wV = torch.nn.Parameter(torch.eye(test_embed_dim))
#w0 = torch.nn.Parameter(torch.eye(test_embed_dim))

matrix_version.wQ = wQ
matrix_version.wK = wK
matrix_version.wV = wV
matrix_version.w0 = w0

loop_version.wQ = wQ
loop_version.wK = wK
loop_version.wV = wV
loop_version.w0 = w0

#torch_version.q_proj_weight = wQ
#torch_version.k_proj_weight = wK
#torch_version.v_proj_weight = wV
torch_version.in_proj_weight = torch.nn.Parameter(torch.concat((wQ, wK, wV)))
torch_version.out_proj.weight = w0



In [330]:
test_data = torch.ones((2, 3, test_embed_dim))

In [331]:
test_data = torch.rand(2, 3, test_embed_dim)

In [332]:
result_matrix = matrix_version(test_data, test_data, test_data)
result_loop = loop_version(test_data, test_data, test_data)
result_torch, _ = torch_version(test_data, test_data, test_data, need_weights=False)

torch.Size([5, 5]) torch.Size([2, 3, 5])


In [333]:
result_matrix = result_matrix.detach()
result_loop = result_loop.detach()
result_torch = result_torch.detach()

In [334]:
assert result_matrix.shape == result_torch.shape
assert result_loop.shape == result_torch.shape

In [335]:
assert_array_almost_equal(result_matrix, result_loop)

AssertionError: 
Arrays are not almost equal to 6 decimals

Mismatched elements: 30 / 30 (100%)
Max absolute difference: 0.07272851
Max relative difference: 1.7289221
 x: array([[[-0.737511,  0.418121,  0.302359, -0.0113  , -0.870395],
        [-0.745576,  0.393842,  0.322722,  0.004232, -0.834106],
        [-0.74752 ,  0.365238,  0.344328,  0.020905, -0.786164]],...
 y: array([[[-0.73273 ,  0.462243,  0.268469, -0.037401, -0.943123],
        [-0.75386 ,  0.414645,  0.310067, -0.005806, -0.875633],
        [-0.757841,  0.357597,  0.353191,  0.027468, -0.780088]],...

In [336]:
assert_array_almost_equal(result_matrix, result_torch)

AssertionError: 
Arrays are not almost equal to 6 decimals

Mismatched elements: 30 / 30 (100%)
Max absolute difference: 1.1894937
Max relative difference: 17.108639
 x: array([[[-0.737511,  0.418121,  0.302359, -0.0113  , -0.870395],
        [-0.745576,  0.393842,  0.322722,  0.004232, -0.834106],
        [-0.74752 ,  0.365238,  0.344328,  0.020905, -0.786164]],...
 y: array([[[ 0.061579, -0.509062, -0.083738, -0.38707 ,  0.319099],
        [ 0.097306, -0.442378, -0.020034, -0.456165,  0.321708],
        [ 0.122666, -0.391349,  0.026576, -0.504543,  0.322033]],...

In [316]:
assert_array_almost_equal(result_loop, result_torch)

AssertionError: 
Arrays are not almost equal to 6 decimals

Mismatched elements: 30 / 30 (100%)
Max absolute difference: 0.05720511
Max relative difference: 0.13257349
 x: array([[[0.692497, 0.579445, 0.719626, 0.339527, 0.599161],
        [0.778973, 0.521562, 0.735156, 0.404873, 0.521556],
        [0.82875 , 0.48351 , 0.725223, 0.488702, 0.483208]],...
 y: array([[[0.723061, 0.558729, 0.724087, 0.36514 , 0.572077],
        [0.761729, 0.532796, 0.73083 , 0.394851, 0.537442],
        [0.785138, 0.515186, 0.727293, 0.431497, 0.519029]],...

In [None]:

class LanguageModel(torch.nn.Module):
    """TODO: implementar o modelo de linguagem"""

In [None]:
model = ...

In [None]:
# sample = next(iter(train_loader))
input = sample[0]
target = sample[1]

In [None]:
output = model(input)

In [None]:
output.argmax(dim=1)

tensor([4842, 2163, 7516, 2652, 6373, 7429, 8003, 3759, 1768, 7740, 2595, 1859,
        3189, 8049, 5727, 6132])

In [None]:
target

tensor([   2,    3,    4,   37,    3,  215,   71,  411, 1263,  355,   87, 3653,
         584,  980,    1,    7])

## Training

In [None]:
# Verifica se há uma GPU disponível e define o dispositivo para GPU se possível, caso contrário, usa a CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [None]:
epochs = 10
lr = """TODO""""
criterion = """TODO CrossEntropy""""

optimizer = """TODO: AdamW ou outro""""

model.to(device)

"""TODO: Implemente o loop de treinamento. Em cada época, calcule e imprima a loss no dataset de validação""""

## Avaliação

In [None]:
""" TODO: calcule a perplexidade final no dataset de validação """

## Exemplo de uso

In [None]:
text = ""

def generate_text(model, vocab, text, max_length):
    """TODO: implemente a função para gerar texto até atingir o max_length"""

context = 5
max_length= 10
generate_text(text, max_length)