In [None]:
import os

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchtext
import wandb
from torchtext.data import RawField, ReversibleField, LabelField
from torchtext.datasets import WikiText2

In [None]:
config = {
            'embed_dim': 512,
            'key_dim': 64,

    
            'batch_size': 400,
            #'dataset': 'imagenette2-320',
            'datadir': '/home/apower/data/text/wikitext-2,
            'dropout': 0.1,
            'init_gain': 5,
            'initializer': None,
            'learning_rate': 0.1,
            'load_workers': os.cpu_count(), 
            'max_epochs': 1000,
            'optimizer': 'SGD',
            'random_seed': 1,
            'training_loops': 4,
            'cuda_device_ids': [0, 1, 2],
num_hidden_nodes = 300
         }




device = 'cuda'
learning_rate = 0.1
embedding_dim = 300
bptt_len = 5
data_dir = 


In [1]:
ls

my-language-model.ipynb  my-transformer.ipynb  rnn.py  Untitled.ipynb  [0m[01;34mwandb[0m/


In [None]:
class SelfAttention(nn.Module):
    def __init__(self, embedding_dim=512, attention_heads=8):
        super().__init__()
        k_d = embedding_dim / attention_heads
        self.Wq = torch.randn((attention_heads, embedding_dim, k_d))
        self.Wk = torch.randn((attention_heads, embedding_dim, k_d))
        self.Wv = torch.randn((attention_heads, embedding_dim, k_d))
        self.softmax = nn.Softmax(dim=2)
    
    def forward(self, in_vectors):
        # in_vectors.shape = (number of vectors, embedding_dimension)
        queries = torch.matmul(in_vectors, self.Wq) #shape = (heads, num vectors, k_d)
        keys = torch.matmul(in_vectors, self.Wk) #shape = (heads, num vectors, k_d)
        values = torch.matmul(in_vectors, self.Wv) #shape = (heads, num vectors, k_d)
        k_d = keys.size[2]

        scores = torch.matmul(queries, torch.transpose(keys, 1,2)) #shape = (heads, num vectors, num vectors)
        normalized_scores = self.softmax(scores / torch.sqrt(k_d)) #shape = (heads, num vectors, num vectors)
        Zi = torch.matmul(normalized_scores, values)  #shape = (heads, num vectors, k_d)
        Z = torch.squeeze(torch.cat(torch.split(Zi, 1, dim=0), 2)) #shape = (num vectors, embedding_dim)

        return Z  # shape = (num vectors, embedding_dim)

In [None]:
class EncoderBlock(nn.Module):
    def __init__(self, embedding_dim=512, attention_heads=8):
        super().__init__()
        
        self.attention = SelfAttention(embedding_dim=embedding_dim, attention_heads=attention_heads)
        self.ffnn = torch.nn.Linear(embedding_dim, embedding_dim)

    def forward(self, in_vectors):
        a1 = nn.LayerNorm(in_vectors + self.attention(in_vectors))
        a2 = nn.LayerNorm(a1 + self.ffnn(a1))
        return a2


In [None]:
class Encoder(nn.Module):
    def __init__(self, embedding_dim=512, attention_heads=8, num_blocks=6):
        super().__init__()
        
        self.embedding = torch.randn(vocab_len, embedding_dim)
        
        blocks = []
        for i in range(num_blocks):
            blocks.append(EncoderBlock(embedding_dim=embedding_dim, attention_heads=attention_heads))
        self.blocks = nn.Sequential(*blocks)
    
    def forward(self, tokens):
        vectors = blah blah tokens  # FIXME: embed tokens
        positional_vectors = blah blah position  #FIXME: positional offsets
        return self.blocks(vectors)
    

In [None]:
class EncoderDecoderAttention():
    def __init__(self, embedding_dim=512, attention_heads=8):
        super().__init__()
        k_d = embedding_dim / attention_heads
        self.Wq = torch.randn((attention_heads, embedding_dim, k_d))
        self.Wk = torch.randn((attention_heads, embedding_dim, k_d))
        self.Wv = torch.randn((attention_heads, embedding_dim, k_d))
        self.softmax = nn.Softmax(dim=2)
    
    def forward(self, in_vectors, encoder_vectors):
        # in_vectors.shape = (number of vectors, embedding_dimension)
        queries = torch.matmul(in_vectors, self.Wq) #shape = (heads, num vectors, k_d)
        keys = torch.matmul(encoder_vectors, self.Wk) #shape = (heads, num vectors, k_d)
        values = torch.matmul(encoder_vectors, self.Wv) #shape = (heads, num vectors, k_d)
        k_d = keys.size[2]

        scores = torch.matmul(queries, torch.transpose(keys, 1,2)) #shape = (heads, num vectors, num vectors)
        normalized_scores = self.softmax(scores / torch.sqrt(k_d)) #shape = (heads, num vectors, num vectors)
        Zi = torch.matmul(normalized_scores, values)  #shape = (heads, num vectors, k_d)
        Z = torch.squeeze(torch.cat(torch.split(Zi, 1, dim=0), 2)) #shape = (num vectors, embedding_dim)

        return Z  # shape = (num vectors, embedding_dim)

In [None]:
class DecoderBlock(nn.Module):
    def __init__(self, embedding_dim=512, attention_heads=8):
        super().__init__()
        
        # FIXME: mask out future self-attention
        self.self_attention = SelfAttention(embedding_dim=embedding_dim, attention_heads=attention_heads)
        self.enc_attention = EncoderDecoderAttention(embedding_dim=embedding_dim, attention_heads=attention_heads)
        self.ffnn = torch.nn.Linear(embedding_dim, embedding_dim)

    def forward(self, (in_vectors, encoder_vectors)):
        a1 = nn.LayerNorm(in_vectors + self.self_attention(in_vectors))
        a2 = nn.LayerNorm(a1 + self.enc_attention(a1, encoder_vectors))
        a3 = nn.LayerNorm(a2 + self.ffnn(a1))
        return (a3, encoder_vectors)


In [None]:
class Decoder(nn.Module):
    def __init__(self, embedding_dim=512, attention_heads=8, num_blocks=6):
        super().__init__()
                
        blocks = []
        for i in range(num_blocks):
            blocks.append(DecoderBlock(embedding_dim=embedding_dim, attention_heads=attention_heads))
        self.blocks = nn.Sequential(*blocks)
    
    def forward(self, encoder_vectors):
        return self.blocks((encoder_vectors, encoder_vectors))
    

In [None]:
class Transformer():
    def __init__(self):
        super().__init__()
        self.encode = Encoder(embedding_dim=512, attention_heads=8, num_blocks=6)
        self.decode = Decoder(embedding_dim=512, attention_heads=8, num_blocks=6)

    def embed(self, words):
        ...
        return embedded_vectors
    
    def forward(self, words):
        embedded_vectors = self.embed(words)
        return self.decode(self.encode(embedded_vectors))
        