In [1]:
import torch
from torch import nn
from torch.nn import functional as F
import matplotlib.pyplot as plt
import numpy as np

In [286]:
from torch.utils.data import IterableDataset, DataLoader, Dataset

class VocabDataset(Dataset):
    def __init__(self, file_patch: str, seq_len: int):
        f = open(file_patch, 'r')
        corpus = f.read()
        f.close()
        self.seq_len = seq_len
        self.vocab = sorted(list(set(corpus)))
        self.vocab_size = len(self.vocab)
        self.stoi = dict(zip(self.vocab, range(self.vocab_size)))

        self.encoder = lambda c: [self.stoi[ci] for ci in c]
        self.decoder = lambda x: [self.vocab[xi] for xi in x]
        self.data = torch.tensor(self.encoder(corpus))

    def __getitem__(self, index):
        x = self.data[index: index + self.seq_len]
        y = self.data[index+1: index + self.seq_len+1]
        return x, y

    def __len__(self):
        return len(self.data) - self.seq_len

def preprocess_batch(batch: torch.Tensor):
    # batch (batch_size, seq_len+1)
    x = batch.reshape(batch.shape[0], 1, batch.shape[1])
    x = x.repeat_interleave(batch.shape[1]-1, dim=1)
    x = x[:, :, :-1].reshape(-1, x.shape[1])
    y = batch[:, 1:].reshape(-1)
    # tril = torch.tril(torch.ones_like(x)).reshape(-1, x.shape[1])
    
    # output: (batch_size, seq_len), (batch_size, ), (batch_size, seq_len)
    return x, y

vocab = VocabDataset("../datasets/shakespeare.txt", seq_len=8)
loader = DataLoader(vocab, batch_size=4, shuffle=False)

for i, (x, y) in enumerate(loader):
    
    # print(batch.shape)
    # x, y = preprocess_batch(batch)
    print(x.shape, y.shape)
    print(vocab.decoder(x[0]), vocab.decoder(y[0]))
    
    if i == 0: break


torch.Size([4, 8]) torch.Size([4, 8])
['F', 'i', 'r', 's', 't', ' ', 'C', 'i'] ['i', 'r', 's', 't', ' ', 'C', 'i', 't']


In [292]:
class DotProductAttention(nn.Module):
    
    def __init__(self, dropout_rate: float=0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout_rate)
        self.attention_weight = None
        
    def forward(self, query: torch.Tensor, keys: torch.Tensor, vals: torch.Tensor):
        # query/keys: (batch_size, seq_len, d)
        # vals: (batch_size, seq_len, key_dim)
        _, _, seq_len, d = query.shape
        
        # mask: (1, 1, seq_len, seq_len)
        mask = torch.tril(torch.ones((seq_len, seq_len))) #.reshape(1, 1, seq_len, seq_len)
        
        presoftmax = query @ keys.transpose(-2, -1) / d**0.5
        presoftmax = presoftmax.masked_fill(mask == 0, float('-inf'))
        self.attention_weight = F.softmax(presoftmax, dim=-1)
        # self.attention_weight = F.softmax(torch.bmm(query, keys.transpose(-2, -1)) / d, dim=-1)
        # out: (batch_size, seq_len, key_dim)
        return self.dropout(self.attention_weight @ vals)

q = torch.randn(4, 1, 8, 2)
k = torch.randn(4, 1, 8, 2)
v = torch.randn(4, 1, 8, 2)

attn = DotProductAttention()
out = attn(q, k, v)
out.shape, attn.attention_weight.shape

(torch.Size([4, 1, 8, 2]), torch.Size([4, 1, 8, 8]))

In [293]:
class MultiHeadAttention(nn.Module):
    
    def __init__(self, n_heads: int, n_hidden: int, n_out: int, dropout_rate: float=0.1, bias=False):
        super().__init__()
        self.n_heads = n_heads
        self.n_hidden = n_hidden
        self.W_q = nn.LazyLinear(n_hidden, bias=bias)
        self.W_k = nn.LazyLinear(n_hidden, bias=bias)
        self.W_v = nn.LazyLinear(n_hidden, bias=bias)
        self.W_o = nn.LazyLinear(n_out)
        self.attention = DotProductAttention(dropout_rate)
    
    def transpose_QKV(self, X: torch.Tensor):
        X = X.reshape(*X.shape[:2], self.n_heads, -1)
        X = X.permute(0, 2, 1, 3) # (batch_size, n_heads, seq_len, n_hidden/n_heads)
        return X
    
    def forward(self, X: torch.Tensor):
        
        Q = self.transpose_QKV(self.W_q(X))
        K = self.transpose_QKV(self.W_k(X))
        V = self.transpose_QKV(self.W_v(X))
        # Q, K, V: (batch_size, n_heads, seq_len, n_hidden/n_heads)
        
        out = self.attention(Q, K, V)
        out = out.reshape(out.shape[0], out.shape[2], -1) # (batch_size, seq_len, n_hidden*n_heads)
        
        return self.W_o(out)
        
mha = MultiHeadAttention(n_heads=5, n_hidden=20, n_out=16)
X = torch.randn(16, 8, 2)
mha(X).shape

torch.Size([16, 8, 16])

In [294]:
class TransformerDecoderBlock(nn.Module):
    
    def __init__(self, vocab_size, n_heads=5, n_hidden=20, n_out=16, ffn_n_hidden=16, dropout=0.1):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, n_out)
        self.pos_emb = nn.Embedding(vocab_size, n_out)
        
        self.mha = MultiHeadAttention(n_heads, n_hidden, n_out, dropout)
        self.norm1 = nn.LayerNorm(n_out)
        self.ffn = nn.Sequential(nn.LazyLinear(ffn_n_hidden), nn.ReLU(), nn.LazyLinear(n_out))
        self.norm2 = nn.LayerNorm(n_out)
        
    def forward(self, X: torch.Tensor):
        i = torch.arange(X.shape[1])
        
        X = self.emb(X) + self.pos_emb(i)
        
        X = self.norm1(X + self.mha(X))
        X = self.norm2(X + self.ffn(X))
        return X

tdb = TransformerDecoderBlock(100, n_out=2)
net = nn.Sequential(tdb, nn.LazyLinear(100), nn.Softmax(dim=-1))
X = torch.ones((16, 8)).type(torch.int32)
net(X).shape

torch.Size([16, 8, 100])

In [313]:
dataset = VocabDataset("../datasets/shakespeare.txt", seq_len=8)

train_indices = list(range(0, int(len(dataset) * 0.8)))
val_indices = list(range(0, int(len(dataset) * 0.2)))

train_set = torch.utils.data.Subset(dataset, train_indices)
val_set = torch.utils.data.Subset(dataset, val_indices)

train_loader = DataLoader(train_set, batch_size=128, shuffle=False)
val_loader = DataLoader(val_set, batch_size=128, shuffle=False)

tdb = TransformerDecoderBlock(
    vocab_size=dataset.vocab_size,
    n_heads=5,
    n_hidden=20,
    n_out=2,
    ffn_n_hidden=50
)
net = nn.Sequential(tdb, nn.LazyLinear(dataset.vocab_size), nn.Softmax(dim=-1))

optim = torch.optim.Adam(net.parameters(), lr=0.1)

for _ in range(10):
    
    for i, (x, y) in enumerate(train_loader):
    
        # print(x.shape, y.shape)
        y_pred = net(x)
        y_pred = y_pred.transpose(1, 2)
        
        loss = F.cross_entropy(y_pred, y)
        
        optim.zero_grad()
        loss.backward()
        optim.step()
    
    print(loss.item())


4.137979030609131
3.99735426902771
3.99735426902771
3.99735426902771
3.99735426902771


KeyboardInterrupt: 