In [53]:
import torch
import torch.nn as nn

In [54]:
class config:
    embedding_dim = 100
    context_size = 100
    

In [55]:
class WordEmbedding(nn.Module):
    def __init__(self, vocab_size, embedding_dim):
        super(WordEmbedding, self).__init__()
        self.embed = nn.Embedding(vocab_size, embedding_dim)
        
    def forward(self, x):
        return self.embed(x)

In [56]:
class PositionalEncoding(nn.Module):
    def __init__(self, embedding_dim, max_len=5000):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, embedding_dim)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, embedding_dim, 2).float() * (-torch.log(torch.tensor(10000.0)) / embedding_dim))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

In [57]:
class Attention(nn.Module):
    def __init__(self,embedding_dim):
        super(Attention, self).__init__()
        self.wq = nn.Linear(embedding_dim, embedding_dim)
        self.wk = nn.Linear(embedding_dim, embedding_dim)
        self.wv = nn.Linear(embedding_dim, embedding_dim)
        self.wo = nn.Linear(embedding_dim, embedding_dim)

        self.residual_dropout = nn.Dropout(0.1)
        self.attention_dropout = nn.Dropout(0.1)

        mask = torch.full(( 1, config.context_size, config.context_size), float("-inf"))
        mask = torch.triu(mask, diagonal=1)
        self.register_buffer("mask",mask)
    
    def forward(self, x):

        seq_len = x.size(1)

        q = self.wq(x)
        k = self.wk(x)
        v = self.wv(x)
        
        scores = torch.matmul(q, k.transpose(1, 2))
        scores = scores / (q.size(-1) ** 0.5)
        scores = scores + self.mask[:, :seq_len, :seq_len]
        scores = torch.softmax(scores, dim=-1).type_as(q)
        scores = self.attention_dropout(scores)
        attention = torch.matmul(scores, v)

        attention = self.wo(attention)
        return self.residual_dropout(attention)

In [58]:
class FeedForward(nn.Module):
    def __init__(self,embedding_dim,hidden_dim):
        super(FeedForward, self).__init__()
        self.ff1 = nn.Linear(embedding_dim, hidden_dim)
        self.ff2 = nn.Linear(embedding_dim, hidden_dim)
        self.ff3 = nn.Linear(hidden_dim, embedding_dim)
        self.dropout = nn.Dropout(0.1)

    def forward(self, x):
        return self.dropout(self.ff3(torch.relu(self.ff1)*self.ff2))

In [59]:
class TransformerBlock(nn.Module):
    def __init__(self,embedding_dim,hidden_dim):
        super(TransformerBlock, self).__init__()
        self.attention = Attention(embedding_dim)
        self.feedforward = FeedForward(embedding_dim,hidden_dim)
        self.layer_norm1 = nn.LayerNorm(embedding_dim)
        self.layer_norm2 = nn.LayerNorm(embedding_dim)

    def forward(self, x):
        x = x + self.attention(self.layer_norm1(x))
        x = x + self.feedforward(self.layer_norm2(x))
        return x

In [60]:
class Transformer(nn.Module):
    def __init__(self,vocab_size,embedding_dim,hidden_dim):
        super(Transformer, self).__init__()
        self.word_embedding = WordEmbedding(vocab_size,embedding_dim)
        self.positional_encoding = PositionalEncoding(embedding_dim)
        self.layers = nn.ModuleList()
        for _ in range(6):
            self.layers.append(TransformerBlock(embedding_dim,hidden_dim))
        self.layer_norm = nn.LayerNorm(embedding_dim)
        self.out = nn.Linear(embedding_dim, vocab_size,bias=False)

    def forward(self, x):
        x = self.word_embedding(x)
        x = self.positional_encoding(x)
        for layer in self.layers:
            x = layer(x)
        x = self.layer_norm(x)
        return self.out(x)

In [61]:
jarvis = Transformer(100,100,100)

In [62]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(jarvis.parameters(), lr=0.001)

In [63]:
def train(epoch,dataloader):
    for i in range(epoch):
        for context, target in dataloader:
            optimizer.zero_grad()
            output = jarvis(context)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            print(loss.item())