In [56]:
import torch
import torch.nn as nn
import math
import torch.nn.functional as F
from tqdm import tqdm
from google.colab import files

In [3]:
with open("ptb.train.txt", 'r') as f:
    lines = f.readlines()

In [4]:
def get_tokens():
  tokens = [list(line) for line in lines]
  return tokens

token = get_tokens()

In [5]:
def flatten(tokens):
  return [items for i in tokens for items in i]

tokens = flatten(token)
print(len(tokens))

5101619


In [6]:
def unique_char(tokens):
  uniq_tokens = []
  for i in tokens:
    if i not in uniq_tokens:
      uniq_tokens.append(i)
  return uniq_tokens


uniq_tokens = unique_char(tokens)
print(len(uniq_tokens))

50


In [7]:
vocab = {}
for e, char in enumerate(uniq_tokens):
  vocab[char] = e

In [8]:
numerical = [vocab[char] for char in tokens]

In [9]:
seq_length = 64
num_samples = (len(numerical) - 1) // seq_length
dataset = torch.tensor(numerical[:num_samples * seq_length]).reshape(num_samples, seq_length)
dataset.shape

torch.Size([79712, 64])

In [10]:
batch_size = 32
num_batches = len(dataset) // batch_size
train_iter = dataset[:num_batches * batch_size].reshape((num_batches, batch_size, seq_length))
train_iter.shape

torch.Size([2491, 32, 64])

In [11]:
labels = torch.tensor(numerical[1:num_samples * seq_length + 1]).reshape(num_batches, batch_size, seq_length)
labels.shape

torch.Size([2491, 32, 64])

In [12]:
def textify(embedding):
    result = ""
    for idx in embedding:
        result += uniq_tokens[int(idx)]
    return result

In [13]:
print(textify(train_iter[10, 3]))
print(textify(labels[10, 3]))

ter business appears to depend heavily on the creativity and <un
er business appears to depend heavily on the creativity and <unk


In [14]:
class AddNorm(nn.Module):
    def __init__(self, d_model, dropout):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        self.ln = nn.LayerNorm(d_model)

    def forward(self, x, y):
        return self.ln(self.dropout(y) + x)

In [15]:
class FeedForward(nn.Module):
    def __init__(self, ffn_hiddens, d_model):
        super().__init__()
        self.lin1 = nn.Linear(d_model, ffn_hiddens)
        self.act = nn.ReLU()
        self.lin2 = nn.Linear(ffn_hiddens, d_model)
    
    def forward(self, x):
        return self.lin2(self.act(self.lin1(x)))

In [63]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads, dropout):
        super().__init__()
        self.d_k = d_model // num_heads
        self.num_heads = num_heads
        self.d_model = d_model
        self.key = nn.Linear(d_model, d_model, bias=False)
        self.query = nn.Linear(d_model, d_model, bias=False)
        self.value = nn.Linear(d_model, d_model, bias=True)
        self.output = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)
        self.scale = 1 / math.sqrt(self.d_k)
        
    def forward(self, k, q, v, mask=None):
        batch_size, seq_len = q.shape[0], q.shape[1]
        q = self.query(q)
        k = self.key(k)
        v = self.value(v)
        
        Q = q.view(batch_size, seq_len, self.num_heads, self.d_k).permute(0, 2, 1, 3)
        K = k.view(batch_size, seq_len, self.num_heads, self.d_k).permute(0, 2, 1, 3)
        V = v.view(batch_size, seq_len, self.num_heads, self.d_k).permute(0, 2, 1, 3)
    
        scores = (Q @ K.permute(0, 1, 3, 2)) * self.scale
        if mask is not None:
            scores = scores.masked_fill(mask[:, :, :seq_len, :seq_len] == 0, float('-inf'))
        attn = torch.softmax(scores, dim=-1)
        x = self.dropout(attn) @ V
        x = x.permute(0, 2, 1, 3).contiguous()
        x = x.view(batch_size, -1, self.d_model)
        x = self.output(x)
        return x

In [31]:
class DecoderBlock(nn.Module):
    def __init__(self, d_model, num_heads, ffn_hiddens, dropout):
        super().__init__()
        # MultiheadAttention -> AddNorm -> FFN -> AddNorm
        self.attention = MultiHeadAttention(d_model, num_heads, dropout)
        self.addnorm1 = AddNorm(d_model, dropout)
        self.ffn = FeedForward(ffn_hiddens, d_model)
        self.addnorm2 = AddNorm(d_model, dropout)

    def forward(self, x, mask):
        x = self.addnorm1(x, self.attention(x, x, x, mask=mask))
        x = self.addnorm2(x, self.ffn(x))
        return x

In [69]:
class GPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.seq_len = config.seq_len
        self.d_model = config.d_model
        self.embedding = nn.Embedding(config.vocab_size, self.d_model)
        self.pos_encoding = nn.Embedding(config.seq_len, config.d_model)
        self.dec_blocks = nn.Sequential(*[DecoderBlock(self.d_model, config.num_heads, config.ffn_hiddens, 
                                                       config.dropout_prob) for _ in range(config.num_blocks)])
        self.lin_head = nn.Linear(self.d_model, self.d_model)
        self.mask = None
        self.apply(self.init_weights)

    def init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
        elif isinstance(module, nn.LayerNorm):
            torch.nn.init.zeros_(module.bias)
            torch.nn.init.ones_(module.weight)

    def subsequent_mask(self, seq_len):
        # Mask data from future time steps
        mask = torch.tril(torch.ones(seq_len, seq_len)).to(torch.bool).view(1, 1, seq_len, seq_len)
        return mask

    def configure_optimizers(self, train_config):
        decay = set()
        no_decay = set()
        whitelist_weight_modules = (torch.nn.Linear, )
        blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
        for mn, m in self.named_modules():
            for pn, p in m.named_parameters():
                fpn = '%s.%s' % (mn, pn) if mn else pn
                if pn.endswith('bias'):
                    no_decay.add(fpn)
                elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
                    decay.add(fpn)
                elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
                    no_decay.add(fpn)

        param_dict = {pn: p for pn, p in self.named_parameters()}
        inter_params = decay & no_decay
        union_params = decay | no_decay

        optim_groups = [
            {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": train_config.weight_decay},
            {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
        ]
        optimizer = torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas)
        return optimizer

    def forward(self, x, target=None):
        if self.mask is None:
            self.mask = self.subsequent_mask(self.seq_len)
        pos = torch.arange(0, x.shape[1], dtype=torch.long).unsqueeze(0)
        x = self.embedding(x) * math.sqrt(self.d_model)
        x = self.pos_encoding(pos) + x
        for blk in self.dec_blocks:
            x = blk(x, mask=self.mask)
        x = self.lin_head(x)
        loss = None
        if target is not None:
            loss = F.cross_entropy(x.view(-1, x.size(-1)), target.view(-1), ignore_index=-1)
        return x, loss

In [19]:
class GPTConfig:
    d_model: int = 512
    vocab_size: int = 50
    seq_len: int = 64
    dropout_prob: float = 0.1
    ffn_hiddens: int = 48
    num_blocks: int = 6
    num_heads: int = 8
    weight_decay: float = 0.1
    grad_norm_clip: float = 1.0
    num_epochs: int = 10
    learning_rate: float = 3e-4
    betas: tuple = (0.9, 0.95)

In [None]:
config = GPTConfig()
net = GPT(config)
optimizer = net.configure_optimizers(config)

net.train()
for epoch in range(config.num_epochs):
    epoch_loss = 0
    for x, y in tqdm(zip(train_iter, labels)):
        optimizer.zero_grad()
        y_hat, loss = net(x, y)
        loss.backward()
        epoch_loss += loss
        print(loss)
        torch.nn.utils.clip_grad_norm_(net.parameters(), config.grad_norm_clip)
        optimizer.step()
    
    print(epoch_loss)