In [39]:
import torch
import torch.nn as nn
from torch.nn import functional as F
import math
from typing import Tuple

class Embedding(nn.Module):

    def __init__(self,
                 config,
                 vocab_size):
        """
            Embedding generates learnable representation of an input sequence which encodes
            contextual, semantic meaning for each word.
            Params:
                d_model(int): specifies the embedding dimension for each token/word
                vocab_size(int): number of embeddings that would be needed. # of unique words
                max_seq_len(int): the maximum sequence length of an input sequence. Used for generation positional encoding
                dropout(float): probability of dropout applied on the final embedding output
        """

        super().__init__()
        self.vocab_size = vocab_size
        self.token_embedding_table = nn.Embedding(num_embeddings=vocab_size,
                                                  embedding_dim=config["d_model"])
        self.position_embedding_table = nn.Embedding(num_embeddings=config["context_length"],
                                                     embedding_dim=config["d_model"])
        self.dropout = nn.Dropout(p=config["dropout"])

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x => [B, S]
        B, S = x.shape
        token_emb = self.token_embedding_table(x) # [B, S, D]

        pos_emb = self.position_embedding_table(torch.arange(S, device=device)).unsqueeze(0) # [1, S, D]
        out = self.dropout(token_emb+pos_emb)
        return self.dropout(out)



class AttentionHead(nn.Module):

    def __init__(self,
                 config) -> None:

        super().__init__()

        self.d_model = config["d_model"]
        self.head_dim = config["head_dim"]

        self.query = nn.Linear(self.d_model, self.head_dim)
        self.key = nn.Linear(self.d_model, self.head_dim)
        self.value = nn.Linear(self.d_model, self.head_dim)
        self.dropout = nn.Dropout(p=config["dropout"])

    def forward(self,
                query: torch.Tensor,
                key: torch.Tensor,
                value: torch.Tensor,
                mask=None) -> torch.Tensor:

        # query => [B, Q, D]
        # key => [B, K, D]
        # value => [B, K, D]

        q = self.query(query) # B, Q, HEAD_DIM
        k = self.key(key) # B, K, HEAD_DIM
        v = self.value(value) # B, K, HEAD_DIM

        weights = q @ k.transpose(1, 2) # B, Q, K
        if mask is not None:
            weights = weights.masked_fill(mask==0, value=float("-inf"))
        weights = F.softmax(weights/math.sqrt(self.head_dim), dim=-1)
        out = weights @ v # [B, Q, K] x [B, K, HEAD_DIM] => [B, Q, HEAD_DIM]
        return self.dropout(out)


class MultiHeadAttention(nn.Module):

    def __init__(self,
                 config) -> None:

         super().__init__()
         self.sa_heads = nn.ModuleList([AttentionHead(config) for _ in range(config["n_heads"])])
         self.proj = nn.Linear(config["d_model"], config["d_model"])
         self.dropout = nn.Dropout(p=config["dropout"])

    def forward(self,
                query: torch.Tensor,
                key: torch.Tensor,
                value: torch.Tensor,
                mask=None) -> torch.Tensor:

        out = torch.cat([h(query, key, value, mask) for h in self.sa_heads], dim=-1)
        out = self.proj(out)
        return self.dropout(out)


class FeedForward(nn.Module):

    def __init__(self,
                 config):

        super().__init__()
        d_model = config["d_model"]
        self.net = nn.Sequential(
            nn.Linear(d_model, d_model*4),
            nn.ReLU(),
            nn.Linear(d_model*4, d_model),
            nn.Dropout(p=config["dropout"])
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:

        x = self.net(x)
        return x


class GPTDecoderBlock(nn.Module):

    def __init__(self, config) -> None:

        super().__init__()
        self.mha = MultiHeadAttention(config)
        self.ff = FeedForward(config)
        self.ln_1 = nn.LayerNorm(normalized_shape=config["d_model"])
        self.ln_2 = nn.LayerNorm(normalized_shape=config["d_model"])

    def forward(self, x: torch.Tensor, mask=None) -> torch.Tensor:

        x = x + self.mha(self.ln_1(x), self.ln_1(x), self.ln_1(x), mask)
        x = x + self.ff(self.ln_2(x))
        return x

class GPTDecoder(nn.Module):

    def __init__(self, config) -> None:

        super().__init__()
        self.blocks = nn.ModuleList([GPTDecoderBlock(config) for _ in range(config["n_decoders"])])

    def forward(self, x: torch.Tensor, mask=None) -> torch.Tensor:

        for block in self.blocks:
            x = block(x, mask)
        return x

class PoemGPT(nn.Module):

    def __init__(self, config, vocab_size) -> None:

        super().__init__()
        self.context_length = config["context_length"]
        self.embedding = Embedding(config, vocab_size)
        self.gpt = GPTDecoder(config)
        self.lm_head = nn.Linear(config["d_model"], vocab_size)

    def forward(self,
                x: torch.Tensor,
                targets: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]:

        B, S = x.shape
        # x -> [B, S], targets -> [B, S]
        x = self.embedding(x) # B, S, D_MODEL
        mask = create_causal_mask(S)

        x = self.gpt(x, mask) # B, S, D_MODEL
        logits = self.lm_head(x) # B, S, VOCAB_SIZE

        if targets is None:
            loss = None
        else:
            logits = logits.view(B*S, -1)
            targets = targets.view(-1)
            loss = F.cross_entropy(logits, targets)
        return logits, loss


    def generate(self, x:torch.Tensor=None, max_new_tokens: int=500) -> torch.Tensor:

        if x is None:
            x = torch.zeros((1, 1), dtype=torch.long, device=device) # B, S

        for _ in range(max_new_tokens):
            preds, _ = self(x[:, -self.context_length:])# B, S, VOCAB_SIZE
            preds = preds[:, -1, :] # B, VOCAB_SIZE
            probs = F.softmax(preds, dim=-1)
            x_next = torch.multinomial(input=probs, num_samples=1) # B, 1
            x = torch.cat((x, x_next), dim=1) # B, S+1

        return x


def create_causal_mask(sz):
    mask = torch.ones((sz, sz), device=device)
    mask = torch.tril(mask)
    return mask

In [40]:
import torch
import torch.nn as nn
from torch.nn import functional as F
import numpy as np
import random
from typing import Tuple, Dict
import json


device = "cuda" if torch.cuda.is_available() else "cpu"

with open("./poem_gpt_config.json", "r") as f:
    config = json.load(f)

with open("./input.txt", "r", encoding="utf-8") as f:
    data = f.read()


chars = sorted(list(set(data)))
vocab_size = len(chars)
stoi = {ch:i for i,ch in enumerate(chars)}
itos = {i:ch for i,ch in enumerate(chars)}

encode = lambda s: [stoi[ch] for ch in s]
decode = lambda l: "".join([itos[i] for i in l])

data = torch.tensor(encode(data))

n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]


def get_random_batch(split: str="train") -> Tuple[torch.Tensor, torch.Tensor]:

    data = train_data if split=="train" else val_data

    batch_size = config["batch_size"]
    block_size = config["block_size"]

    idxs = torch.randint(0, len(data)-block_size, size=(batch_size,))
    x_batch = torch.stack([data[i:i+block_size] for i in idxs])
    y_batch = torch.stack([data[i+1:i+block_size+1] for i in idxs])

    x_batch, y_batch = x_batch.to(device), y_batch.to(device)
    return x_batch, y_batch


@torch.no_grad()
def eval_model() -> Dict[str, float]:
    losses = {}
    poem_gpt.eval()

    for split in ["train", "val"]:
        data = train_data if split=="train" else val_data
        loss = 0
        for iter in range(config["eval_iters"]):
            x_batch, y_batch = get_random_batch(split)
            _, l_ = poem_gpt(x_batch, y_batch)
            loss += l_.item()

        losses[split] = loss/config["eval_iters"]

    poem_gpt.train()
    return losses


def train_poem_gpt():

    for iter in range(config["train_iters"]):

        if iter%config["eval_interval"]==0:
            losses = eval_model()
            print(f"iter {iter} train_loss: {losses['train']} val_loss: {losses['val']}")

        x_batch, y_batch = get_random_batch()
        _, loss = poem_gpt(x_batch, y_batch)

        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()


poem_gpt = PoemGPT(config, vocab_size)
poem_gpt = poem_gpt.to(device)
optimizer = torch.optim.AdamW(params=poem_gpt.parameters(),
                              lr=config["learning_rate"])

train_poem_gpt()

iter 0 train_loss: 4.687802615165711 val_loss: 4.681776196956634
iter 500 train_loss: 2.1590825736522676 val_loss: 2.215250650644302
iter 1000 train_loss: 1.8130074471235276 val_loss: 1.9667431920766831
iter 1500 train_loss: 1.6225154322385789 val_loss: 1.8205379575490952
iter 2000 train_loss: 1.4965563523769378 val_loss: 1.7162453413009644
iter 2500 train_loss: 1.4177560430765153 val_loss: 1.668158510327339
iter 3000 train_loss: 1.3559211957454682 val_loss: 1.6206841945648194
iter 3500 train_loss: 1.335439379811287 val_loss: 1.6167381221055985
iter 4000 train_loss: 1.2825062787532806 val_loss: 1.5865983229875564
iter 4500 train_loss: 1.2475514125823974 val_loss: 1.5702516168355942


In [42]:
print(decode(poem_gpt.generate(max_new_tokens=500).cpu().numpy()[0][1:]))


PRINCE PEY:
Very well! I will be naked to take when.

MARIANA:
Now sword? and who have friend out of the place
Artend I of your vental, and am not pratise
He disture friends in Playets to the comprove.

LUCIO:
By you women, I do put on cominition.
And whils woratil orp we moforey any
Moris: y maingive, os wacking, waras, t, ouns
Olyonesig? fagonad! cin, t, s,
I; fo, foce lelinsts!-t tate nope's; war!
Trimply titaved ps, ge
Fingeivedy, bequbupe, po.
Trhokeringe at bous ot: bys; wined iooualy; on


In [43]:
torch.save(obj=poem_gpt.state_dict(),
           f=open("./weights/poem_gpt_weights.pt", "wb"))