In [59]:
import torch
import torch.nn as nn
import numpy as np
import os
import gensim
from gensim.utils import simple_preprocess

In [60]:
import nltk
from nltk import sent_tokenize
nltk.download('punkt_tab')

[nltk_data] Downloading package punkt_tab to
[nltk_data]     C:\Users\cyborg\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


True

In [61]:
story = []
for filename in os.listdir('data'):
  with open(os.path.join('data', filename), encoding='latin-1') as f:
    corpus = f.read()
  raw_sent = sent_tokenize(corpus)
  for sent in raw_sent:
    story.append(simple_preprocess(sent))

In [62]:
len(story)

145020

In [63]:
model = gensim.models.Word2Vec(
    window=10,
    min_count=2,
    workers=4
)

In [64]:
model.build_vocab(story)

In [65]:
model.train(story, total_examples=model.corpus_count, epochs=model.epochs)

(6571778, 8628190)

In [66]:
model.wv.most_similar('daenerys')

[('stormborn', 0.8051601648330688),
 ('targaryen', 0.7532665729522705),
 ('unburnt', 0.7423449754714966),
 ('myrcella', 0.723058819770813),
 ('princess', 0.7199079990386963),
 ('elia', 0.6827012896537781),
 ('dorne', 0.6826499104499817),
 ('martell', 0.6766705513000488),
 ('queen', 0.6484774947166443),
 ('aegon', 0.6469336748123169)]

In [67]:
class Head(nn.Module):
    
    def __init__(self, num_features):
        super().__init__();

        self.wq = nn.Linear(num_features, num_features, bias=False)
        self.wk = nn.Linear(num_features, num_features, bias=False)
        self.wv = nn.Linear(num_features, num_features, bias=False)

    def forward(self, x):

        q = self.wq(x)
        k = self.wk(x)
        v = self.wv(x)

        energy = torch.matmul(q, k.T) * x.shape[1] ** -0.5
        energy.tril_()
        energy[energy==0] = float('-inf')
        # energy = energy.masked_fill(self.tril)
        # mask = torch.full((energy.shape[0], energy.shape[1]), float('-inf'))
        # mask = torch.triu(mask, diagonal=1)
        
        attention = torch.softmax(energy, dim=-1)

        out = torch.matmul(attention, v)

        return out

In [68]:
class MultiHeadAttention(nn.Module):

    def __init__(self, num_heads, num_features):
        super().__init__()

        self.heads = nn.ModuleList([Head(num_features=num_features) for _ in range(num_heads)])
        self.wo = nn.Linear(num_features * num_heads, num_features, bias=False)
        
    def forward(self, x):

        out = torch.cat([h(x) for h in self.heads], dim=-1)

        out = self.wo(out)
        
        return out


In [69]:
class FeedForward(nn.Module):

    def __init__(self, num_features):
        super().__init__()

        self.model = nn.Sequential(
            nn.Linear(num_features, 2048),
            nn.ReLU(),
            nn.Linear(2048, num_features),
        )


    def forward(self, x):

        out = self.model(x)

        return out

In [70]:
class PositionalEncodeing(nn.Module):

    def __init__(self, d_model, total_tokens):
        super().__init__()

        self.pe = torch.zeros((total_tokens, d_model))
        self.d_model = d_model
        self.total_tokens = total_tokens

    def forward(self, x):

        for pos in range(self.total_tokens):

            for i in range(self.d_model// 2):

                theta = torch.tensor(pos / (10000 ** ((2 * i) / self.d_model)))

                self.pe[pos, 2 * i] = torch.sin(theta) 
                self.pe[pos, 2 * + 1] = torch.cos(theta)

        x = x + self.pe 

        return x

In [71]:
class Block(nn.Module):

    def __init__(self, num_heads, num_features, total_tokens, x):
        super().__init__()

        self.token_embedding = x
        self.mha = MultiHeadAttention(num_heads, num_features)
        self.pe = PositionalEncodeing(num_features, total_tokens)
        self.ffwd = FeedForward(num_features)
        self.ln = nn.LayerNorm(num_features)

    def forward(self, x):

        x = self.pe(x) + x
        ox = x
        x = self.mha(x) + ox
        normed_x = self.ln(x)
        x = self.ffwd(normed_x) + normed_x
        x = self.ln(x) 

        return x

In [72]:
x = torch.tensor(model.wv.get_normed_vectors())

block = Block(2, x.shape[1], x.shape[0], x)
x = block(x)

In [None]:
torch.set_printoptions(threshold=float('inf'))
print(x)