# Medium Article Generator
**Author:** Matheus Oliveira de Souza - *Artificial Intelligence Developer*

**References:**
- [Attention Is All You Need](https://proceedings.neurips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf)
- [Language Models are Unsupervised Multitask Learners](https://cdn.openai.com/better-language-models/language_models_are_unsupervised_multitask_learners.pdf)

**Dataset:**
- [190k+ Medium Articles](https://www.kaggle.com/datasets/fabiochiusano/medium-articles)

---

# [1] Notebook setup

## [1.1] Import

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# OpenAI Tokenizer
import tiktoken

import os
import re
import csv
import json
import nltk
import numpy as np
from collections import Counter
from nltk.corpus import stopwords
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction

import matplotlib.pyplot as plt
from tqdm import tqdm

import gc
from typing import Union, Generator

## [1.2] Root path

In [None]:
ROOT_PATH = os.path.abspath(path=os.path.dirname(p="."))

# [2] Utils

## [2.1] Tokenizer

### [2.1.1] TikToken

In [None]:
class TikTokenizer:

    cl100k_base = tiktoken.get_encoding(encoding_name="cl100k_base")
    SOS = "<|sos|>"
    EOS = "<|eos|>"
    UNK = "<|unk|>"
    enc = tiktoken.Encoding(
        name="cl100k_im",
        pat_str=cl100k_base._pat_str,
        mergeable_ranks=cl100k_base._mergeable_ranks,
        special_tokens={
            **cl100k_base._special_tokens,
            SOS: 100264,
            EOS: 100265,
            UNK: 100266,
        }
    )

    @classmethod
    def vocab_size(cls) -> int:
        """Return the size of vocabulary."""
        return cls.enc.max_token_value

    @classmethod
    def encode(cls, text: str) -> list[int]:
        """Convert a sequence of characters into a sequence of numbers/indices."""
        return cls.enc.encode(text=text, allowed_special={cls.SOS, cls.EOS, cls.UNK})

    @classmethod
    def decode(cls, tokens: list[int]) -> str:
        """Convert a sequence of numbers/indices into a sequence of characters."""
        return cls.enc.decode(tokens=tokens)

### [2.1.2] Custom Tokenizer

In [None]:
class Tokenizer:

    def __init__(self, vocab: dict[int, str], lookup_vocab: dict[int, int]) -> None:
        self.vocab = {int(k): v for k, v in vocab.items()}
        self.lookup = {int(k): int(v) for k, v in lookup_vocab.items()}
        self.unk = self.lookup.get(TikTokenizer.encode(text=TikTokenizer.UNK)[0])

    def __len__(self) -> int:
        """Return the size of vocabulary."""
        return len(self.vocab)

    def encode(self, text: str) -> list[int]:
        """Convert a sequence of characters into a sequence of numbers/indices."""
        tik_tokens = TikTokenizer.encode(text=text)
        custom_tokens = [self.lookup.get(tk, self.unk) for tk in tik_tokens]
        return custom_tokens

    def decode(self, tokens: list[int], apply_join: bool = True) -> str:
        """Convert a sequence of numbers/indices into a sequence of characters."""
        string = [self.vocab.get(tk) for tk in tokens]
        if apply_join:
            return "".join(string)
        return string

## [2.2] Functions

In [None]:
def save_losses(losses: list[float], dir_path: str, filename: str) -> None:
    """Save losses to a file."""
    with open(file=os.path.join(dir_path, filename), mode="w") as json_buffer:
        json.dump(obj={"losses": losses}, fp=json_buffer)

def save_accuracies(accuracies: list[float], dir_path: str, filename: str) -> None:
    """Save accuracies to a file."""
    with open(file=os.path.join(dir_path, filename), mode="w") as json_buffer:
        json.dump(obj={"accuracies": accuracies}, fp=json_buffer)

def save_train_stats(
        best_loss: float,
        curr_iter: int,
        last_save: int,
        overfitting: int,
        last_train_loss: float,
        last_valid_loss: float,
        dir_path: str,
        filename: str
    ) -> None:
    """Save train stats."""
    with open(file=os.path.join(dir_path, filename), mode="w") as json_buffer:
        json.dump(obj={"best_loss": best_loss, "curr_iter": curr_iter, "last_save": last_save, "overfitting": overfitting, "last_train_loss": last_train_loss, "last_valid_loss": last_valid_loss}, fp=json_buffer)

def save_vocab(vocab: dict, dir_path: str, filename: str) -> None:
    """Save vocab."""
    with open(file=os.path.join(dir_path, filename), mode="w") as json_buffer:
        json.dump(obj=vocab, fp=json_buffer)

def load_json(file_path: str) -> dict:
    """Load json file."""
    with open(file=file_path, mode="r") as json_buffer:
        return json.load(fp=json_buffer)

# [3] Dataset

## [3.1] Load dataset

In [None]:
with open(file=os.path.join(ROOT_PATH, "medium_articles.csv"), mode="r") as file_buffer:
  reader = csv.reader(file_buffer)
  next(reader)
  dataset = list(reader)

In [None]:
print(f"{len(dataset)=}")

In [None]:
title_article_dataset = [[row[0], row[1], row[-1]] for row in dataset]
del dataset
gc.collect()

In [None]:
full_dataset = "".join(f"{TikTokenizer.SOS}{title}\n\n\t{text}{TikTokenizer.EOS}" for title, text, _ in title_article_dataset)

## [3.2] Clear dataset

### [3.2.1] Unique characters

In [None]:
dataset_characters = "".join(sorted(list(set(full_dataset))))
print(f"{dataset_characters=}")
print(f"{len(dataset_characters)=}")

#### [3.2.1.1] Removing non UTF-8 characters

In [None]:
non_utf8_characters_filter = re.compile(pattern=r"[^\x00-\x7F]")

In [None]:
non_utf8_characters_list = non_utf8_characters_filter.findall(string=full_dataset)
non_utf8_characters = "".join(list(set(non_utf8_characters_list)))
print(f"{len(non_utf8_characters_list)=}")
print(f"{non_utf8_characters=}")

In [None]:
clean_full_dataset = non_utf8_characters_filter.sub(repl="", string=full_dataset)
del full_dataset, non_utf8_characters, non_utf8_characters_list
gc.collect()

In [None]:
clean_dataset_characters = "".join(sorted(list(set(clean_full_dataset))))
print(f"{clean_dataset_characters=}")
print(f"{len(clean_dataset_characters)=}")

### [3.2.2] Checking topics

In [None]:
title_filter = re.compile(pattern=rf"{TikTokenizer.EOS}")

In [None]:
title_list = title_filter.split(string=clean_full_dataset)

In [None]:
nltk.download(info_or_id="stopwords")
stop_words = list(set(stopwords.words("english"))) + [
    ":",
    "?",
    "(",
    ")",
    "-",
    ".",
    "&",
    ",",
    "—",
    "’",
    "…",
    "0", "1", "2", "3", "4", "5", "6", "7", "8", "9",
    ""
]

#### [3.2.2.1] Removing stopword from titles

In [None]:
clean_title_list = []
for content in tqdm(iterable=title_list):
    title = content.replace(TikTokenizer.SOS, "").split(sep="\n\n\t")
    if len(title) != 2:
        continue
    title = title[0]
    tokens = TikTokenizer.encode(text=title)
    str_tokens = [TikTokenizer.decode(tokens=[tk]).strip() for tk in tokens]
    clean_title_list.append(" ".join(stk for stk in str_tokens if stk not in stop_words and len(stk) > 3))

In [None]:
count_title_words = Counter(" ".join(clean_title_list).split(sep=" "))
count_title_words.pop("")
top_count_title_words = count_title_words.most_common(n=50)

In [None]:
plt.figure(figsize=(15, 5))
for word, count in top_count_title_words:
    plt.bar(x=word, height=count)
    plt.xticks(rotation=80)
plt.show()

### [3.2.3] Removing links

In [None]:
link_pattern = re.compile(pattern=r"https?://\S+")

In [None]:
clean_full_dataset = link_pattern.sub(repl="", string=clean_full_dataset)

In [None]:
print(f"{len(clean_full_dataset)=} characters")

### [3.2.4] Removing bad articles

In [None]:
splitted_dataset = title_filter.split(string=clean_full_dataset)
print(f"{len(splitted_dataset)=}")

#### [3.2.4.1] Removing empty articles

In [None]:
line_filter = re.compile(pattern=r"(\n){2,}")

In [None]:
valid_text = []
for content in tqdm(iterable=splitted_dataset):
    f = content.replace(TikTokenizer.SOS, "").split(sep="\n\n\t")
    if len(f) != 2:
        continue
    title, article = f
    article = line_filter.sub(repl="\n", string=article)
    article = re.sub(pattern=r" {1,}\n", repl="\n", string=article)
    article = article.strip() + "\n"
    if len(article) > 1000: # character limit of article
        title = title.strip()
        valid_text.append([title, article])
print(f"{len(valid_text)=}")

### [3.2.5] Checking article sizes

In [None]:
print(f"Longest : {max(len(article) for title, article in valid_text)}")
print(f"Shortest : {min(len(article) for title, article in valid_text)}")

### [3.2.6] Final dataset

In [None]:
clean_full_dataset = "".join(list(map(lambda x: TikTokenizer.SOS + "\t\n\n".join(x) + TikTokenizer.EOS, valid_text))).strip()
print(f"{len(clean_full_dataset)=} characters")

## [3.3] Free memory

In [None]:
del valid_text, count_title_words, top_count_title_words, clean_title_list, title_list
gc.collect()

# [4] Model

## [4.1] Positional Encoding

In [None]:
class PositionalEncoding(nn.Module):

    def __init__(self, context: int, emb_dim: int) -> None:
        super().__init__()

        even_i = torch.arange(start=0, end=emb_dim, step=2).float()
        odd_i = torch.arange(start=1, end=emb_dim, step=2).float() - 1

        even_denom = torch.pow(10_000, exponent=even_i / emb_dim)
        odd_denom = torch.pow(10_000, exponent=odd_i / emb_dim)

        pos = torch.arange(end=context).float().reshape(shape=[context, 1])

        even = torch.sin(pos / even_denom)
        odd = torch.cos(pos / odd_denom)

        self.register_buffer(name="pe", tensor=torch.cat(tensors=[even, odd], dim=1).expand(1, -1, -1))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, T, D = x.size()
        x_pe = x + self.pe[:,:T,:]
        return x_pe

## [4.2] Feed Forward

In [None]:
class FeedForward(nn.Module):

    def __init__(self, emb_dim: int, ff_dim: int, dropout_rate: float = 0.2) -> None:
        super().__init__()
        self.linear_1 = nn.Linear(in_features=emb_dim, out_features=ff_dim)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p=dropout_rate)
        self.linear_2 = nn.Linear(in_features=ff_dim, out_features=emb_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        z1 = self.linear_1(x)
        a1 = self.relu(z1)
        a1 = self.dropout(a1)
        z2 = self.linear_2(a1)
        return z2

## [4.3] Attention

In [None]:
class Attention(nn.Module):

    def __init__(self, emb_dim: int, head_dim: int, context: int, dropout_rate: float) -> None:
        super().__init__()

        self.query = nn.Linear(in_features=emb_dim, out_features=head_dim, bias=False)
        self.key = nn.Linear(in_features=emb_dim, out_features=head_dim, bias=False)
        self.value = nn.Linear(in_features=emb_dim, out_features=head_dim, bias=False)
        self.dropout = nn.Dropout(p=dropout_rate)

        ones = torch.ones(size=[context, context], dtype=torch.float)
        self.register_buffer(name="mask", tensor=torch.tril(input=ones))

    def forward(self, x: torch.Tensor) -> torch.Tensor:

        B, T, C = x.size()
        Q = self.query(x)
        K = self.key(x)
        V = self.value(x)

        QK = Q @ K.transpose(-2, -1) * C**-0.5
        attention = QK.masked_fill(self.mask[:T,:T] == 0, float("-inf"))
        attention = F.softmax(input=attention, dim=-1)

        attention = self.dropout(attention)

        out = attention @ V

        return out


class MultiAttention(nn.Module):

    def __init__(self, emb_dim: int, head_dim: int, context: int, dropout_rate: float) -> None:
        super().__init__()
        n_heads = emb_dim // head_dim
        self.attention = nn.ModuleList(modules=[Attention(emb_dim=emb_dim, head_dim=head_dim, context=context, dropout_rate=dropout_rate) for _ in range(n_heads)])
        self.linear = nn.Linear(in_features=emb_dim, out_features=emb_dim)
        self.dropout = nn.Dropout(p=dropout_rate)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        out = torch.cat(tensors=[attention(x) for attention in self.attention], dim=-1)
        out = self.linear(out)
        out = self.dropout(out)
        return out

## [4.4] Decoder

In [None]:
class DecoderLayer(nn.Module):

    def __init__(self, emb_dim: int, head_dim: int, context: int, ff_dim: int, dropout_rate: float) -> None:
        super().__init__()
        self.attention = MultiAttention(emb_dim=emb_dim, head_dim=head_dim, context=context, dropout_rate=dropout_rate)
        self.feed_forward = FeedForward(emb_dim=emb_dim, ff_dim=ff_dim, dropout_rate=dropout_rate)
        self.norm_1 = nn.LayerNorm(normalized_shape=emb_dim)
        self.norm_2 = nn.LayerNorm(normalized_shape=emb_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x_norm = self.norm_1(x)
        attention = self.attention(x_norm)
        attention = attention + x

        attention_norm = self.norm_2(attention)
        ff = self.feed_forward(attention_norm)
        ff = ff + attention

        return ff


class Decoder(nn.Module):

    def __init__(self, n_layers: int, decoder: DecoderLayer) -> None:
        super().__init__()
        self.layers = nn.Sequential(*[decoder for _ in range(n_layers)])

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.layers(x)

## [4.5] Article Generator

In [None]:
class ArticleGenerator(nn.Module):

    def __init__(
            self,
            n_layers: int,
            vocab_size: int,
            emb_dim: int,
            head_dim: int,
            context: int,
            ff_dim: int,
            dropout_rate: float,
            device: str,
            tokenizer: Tokenizer
        ) -> None:
        super().__init__()
        self.ctx = context
        self.eos = vocab_size - 1
        self.emb = nn.Embedding(num_embeddings=vocab_size, embedding_dim=emb_dim)
        self.pe = PositionalEncoding(context=context, emb_dim=emb_dim)
        self.layers = nn.Sequential(*[DecoderLayer(emb_dim=emb_dim, head_dim=head_dim, context=context, ff_dim=ff_dim, dropout_rate=dropout_rate) for _ in range(n_layers)])
        self.norm = nn.LayerNorm(normalized_shape=emb_dim)
        self.out = nn.Linear(in_features=emb_dim, out_features=vocab_size)

        self.dev = device
        self.tokenizer = tokenizer

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.emb(x)
        x = self.pe(x)
        x = self.layers(x)
        x = self.norm(x)
        x = self.out(x)
        return x

    @torch.no_grad()
    def predict_next_token(self, x: torch.Tensor) -> torch.Tensor:
        self.eval()
        x = self(x)
        token = x.argmax(dim=-1)[:,-1]
        return token

    def generate(self, text: str = "<|sos|>", max_len: int = None) -> Generator[str, None, None]:
        max_len = float("inf") if max_len is None else max_len
        count = 0
        x = torch.tensor(data=self.tokenizer.encode(text=text), requires_grad=False).unsqueeze(dim=0).to(device=self.dev)
        assert x.ndim == 2
        while count < max_len:
            if x.size(dim=1) > self.ctx:
                x = x[:,1:] # ignoring first token of window context
            token = self.predict_next_token(x)
            yield self.tokenizer.decode(tokens=[token.item()])
            x = torch.cat(tensors=[x, token.unsqueeze(dim=0)], dim=1)
            count += 1

# [5] Training

## [5.1] Hyper parameters

In [None]:
# model
N_LAYERS = 6
EMB_DIM = 768
HEAD_DIM = 64
FF_DIM = EMB_DIM * 4
DROPOUT = 0.05
WINDOW_CONTEXT = 512

# train
SEED = 1234
EPOCHS = 1000
EARLY_STOP = 10
TRAIN_ITERATIONS = 400
VALID_ITERATIONS = 200
TRAIN_SPLIT = 0.8
LR = 2e-4
WEIGHT_DECAY = 0
BATCH_SIZE = 24

WEIGHT_FOLDER = os.path.join(ROOT_PATH, "exp")
if not os.path.exists(path=WEIGHT_FOLDER):
    os.mkdir(path=WEIGHT_FOLDER)

## [5.2] Dataset

In [None]:
class ArticleDataset(Dataset):

    def __init__(self, articles: list[int], context: int, n_iter: int, batch_size: int) -> None:
        super().__init__()

        self.len = int(n_iter * batch_size)
        self.x = articles
        self.ctx = context
        self.limit = len(articles) - self.ctx - 1

    def __len__(self) -> int:
        return self.len

    def __getitem__(self, index: int) -> Union[torch.Tensor, torch.Tensor]:

        index_content = np.random.randint(low=0, high=self.limit, size=1).item()

        x = self.x[index_content:index_content+self.ctx]
        y = self.x[index_content+1:index_content+self.ctx+1]

        np_x = np.asarray(a=x, dtype=np.int64)
        np_y = np.asarray(a=y, dtype=np.int64)

        t_x = torch.tensor(data=np_x, requires_grad=False).long()
        t_y = torch.tensor(data=np_y, requires_grad=False).long()

        return t_x, t_y

### [5.2.1] Defining tokenizer

In [None]:
unique_tokens = sorted(list(set(TikTokenizer.encode(text=clean_full_dataset + TikTokenizer.UNK))))
print(f"{len(unique_tokens)=}")

In [None]:
custom_vocab = {}
vocab_mapper = {}
for unique_tk in tqdm(iterable=unique_tokens):
    str_token = TikTokenizer.decode(tokens=[unique_tk])
    idx = len(custom_vocab)
    custom_vocab[idx] = str_token
    vocab_mapper[unique_tk] = idx

del unique_tokens
gc.collect()

In [None]:
save_vocab(vocab=custom_vocab, dir_path=WEIGHT_FOLDER, filename="vocab.json")
save_vocab(vocab=vocab_mapper, dir_path=WEIGHT_FOLDER, filename="mapper.json")

In [None]:
tokenizer = Tokenizer(vocab=custom_vocab, lookup_vocab=vocab_mapper)

In [None]:
clean_full_tk_dataset = tokenizer.encode(text=clean_full_dataset)
del clean_full_dataset
gc.collect()

### [5.2.2] Splitting dataset

In [None]:
train_split = int(len(clean_full_tk_dataset) * TRAIN_SPLIT)
print(f"{train_split=}")

In [None]:
train_set = clean_full_tk_dataset[:train_split]
valid_set = clean_full_tk_dataset[train_split:]
del clean_full_tk_dataset
gc.collect()

### [5.2.3] Creating dataset

In [None]:
print(f"{len(tokenizer)=}")

In [None]:
print(f"{len(train_set)=}")
print(f"{len(valid_set)=}")

In [None]:
train_dataset = ArticleDataset(articles=train_set, context=WINDOW_CONTEXT, n_iter=TRAIN_ITERATIONS, batch_size=BATCH_SIZE)
valid_dataset = ArticleDataset(articles=valid_set, context=WINDOW_CONTEXT, n_iter=VALID_ITERATIONS, batch_size=BATCH_SIZE)

In [None]:
for _ in range(2):
    for x, y in train_dataset:
        print(f"{x.shape=}")
        print(f"{y.shape=}")
        print("-" * 50)
        print(f"x: {tokenizer.decode(tokens=x.tolist())}")
        print(f"\ny: {tokenizer.decode(tokens=y.tolist())}")

        print("=" * 100)
        break

In [None]:
train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE)
valid_loader = DataLoader(dataset=valid_dataset, batch_size=BATCH_SIZE)

## [5.3] Device

In [None]:
device = torch.device(device="cuda" if torch.cuda.is_available() else "cpu")
print(f"{device=}")

## [5.4] Instanciating model

In [None]:
model = ArticleGenerator(
    n_layers=N_LAYERS,
    vocab_size=len(tokenizer),
    emb_dim=EMB_DIM,
    head_dim=HEAD_DIM,
    context=WINDOW_CONTEXT,
    ff_dim=FF_DIM,
    dropout_rate=DROPOUT,
    device=device,
    tokenizer=tokenizer
)
model = nn.DataParallel(module=model) # multiple GPUs
gc.collect()

In [None]:
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
buffer_size = sum(p.nelement() * p.element_size() for p in model.buffers())
model_size = (n_parameters + buffer_size) / 1024**2
print(f"{n_parameters=}")
print(f"{model_size=:.2f}Mb")

In [None]:
def init_weights(module: nn.Module):
    if isinstance(module, nn.Linear):
        nn.init.xavier_normal_(tensor=module.weight)
        if module.bias is not None:
            nn.init.zeros_(tensor=module.bias)
    elif isinstance(module, nn.Embedding):
        nn.init.xavier_normal_(tensor=module.weight)

model.apply(init_weights)
model.to(device=device)

### [5.4.1] Model loss function

In [None]:
def model_metric(yhat: torch.Tensor, y: torch.Tensor, tokenizer: Tokenizer):

    batch_size, ctx, _ = yhat.size()

    base_yhat = yhat.view(batch_size * ctx, -1)
    base_y = y.view(-1)

    loss = F.cross_entropy(input=base_yhat, target=base_y)

    pred = yhat.argmax(dim=-1)
    pred_tokens = [tokenizer.decode(tokens=tokens, apply_join=False) for tokens in pred.tolist()]
    true_tokens = [tokenizer.decode(tokens=tokens, apply_join=False) for tokens in y.tolist()]

    accuracies = [sentence_bleu(references=[true], hypothesis=pred, smoothing_function=SmoothingFunction().method1) for pred, true in zip(pred_tokens, true_tokens)]
    acc = sum(accuracies) / batch_size

    return loss, acc

### [5.4.2] Checking loss

In [None]:
expected_loss = -np.log(1 / len(tokenizer)) # all classes has the same probability to be predicted.
print(f"{expected_loss=}")

model.eval()
with torch.no_grad():
    for x, y in valid_loader:
        x, y = x.to(device=device), y.to(device=device)
        print(f"{x.shape=}")
        print(f"{y.shape=}")
        yhat = model(x)
        loss, acc = model_metric(yhat=yhat, y=y, tokenizer=tokenizer)
        print(f"{loss=}")
        print(f"{acc=}")
        break

## [5.5] Train config

In [None]:
if False: # continue training
    stats = load_json(file_path=os.path.join(WEIGHT_FOLDER, "stats.json"))
    curr_iter = stats.get("curr_iter")
    last_save = stats.get("last_save")
    best_loss = stats.get("best_loss")
    last_train_loss = stats.get("last_train_loss")
    last_valid_loss = stats.get("last_valid_loss")
    overfitting = stats.get("overfitting")
    train_losses = load_json(file_path=os.path.join(WEIGHT_FOLDER, "train_losses.json")).get("losses")
    valid_losses = load_json(file_path=os.path.join(WEIGHT_FOLDER, "valid_losses.json")).get("losses")
    train_accuracies = load_json(file_path=os.path.join(WEIGHT_FOLDER, "train_accuracies.json")).get("accuracies")
    valid_accuracies = load_json(file_path=os.path.join(WEIGHT_FOLDER, "valid_accuracies.json")).get("accuracies")
    model.load_state_dict(torch.load(f=os.path.join(WEIGHT_FOLDER, "last_weights.pt"), map_location=device))
else:
    curr_iter = 1
    last_save = 0
    best_loss = np.inf

    last_train_loss = 0
    last_valid_loss = 0
    overfitting = 0

    train_losses = []
    valid_losses = []
    train_accuracies = []
    valid_accuracies = []

## [5.6] Optimizer

In [None]:
optimizer = optim.Adam(params=model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)

## [5.7] Train loop

In [None]:
while curr_iter < EPOCHS + 1:

    print(f"EPOCH {curr_iter}/{EPOCHS} | Overfitting {overfitting}/{EARLY_STOP} | Best valid loss {best_loss} | Last savement : {last_save}")
    save_train_stats(
        best_loss=best_loss,
        curr_iter=curr_iter,
        last_save=last_save,
        overfitting=overfitting,
        last_train_loss=last_train_loss,
        last_valid_loss=last_valid_loss,
        dir_path=WEIGHT_FOLDER,
        filename="stats.json"
    )

    if overfitting == EARLY_STOP:
        break

    model.train()
    train_loss = 0
    train_acc = 0
    train_tqdm = tqdm(iterable=train_loader)
    for x, y in train_tqdm:

        x, y = x.to(device=device), y.to(device=device)
        yhat = model(x)

        loss, acc = model_metric(yhat=yhat, y=y, tokenizer=tokenizer)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        curr_loss = loss.cpu().item()
        curr_acc = acc.cpu().item()
        train_losses.append(curr_loss)
        train_accuracies.append(curr_acc)
        train_loss += curr_loss
        train_acc += curr_acc

        train_tqdm.set_description(desc=f"loss : {curr_loss} - accuracy : {acc}")
    train_loss /= TRAIN_ITERATIONS
    train_acc /= TRAIN_ITERATIONS
    print(f"[train] loss : {train_loss} - accuracy : {train_acc}")

    model.eval()
    valid_tqdm = tqdm(iterable=valid_loader)
    valid_loss = 0
    valid_acc = 0
    with torch.no_grad():
        for x, y in valid_tqdm:

            x, y = x.to(device=device), y.to(device=device)
            yhat = model(x)

            loss, acc = model_metric(yhat=yhat, y=y, tokenizer=tokenizer)

            curr_loss = loss.cpu().item()
            curr_acc = acc.cpu().item()
            valid_losses.append(curr_loss)
            valid_accuracies.append(curr_acc)
            valid_loss += curr_loss
            valid_acc += curr_acc

            valid_tqdm.set_description(desc=f"loss : {curr_loss} - accuracy : {acc}")
        valid_loss /= VALID_ITERATIONS
        valid_acc /= VALID_ITERATIONS
        print(f"[valid] loss : {valid_loss} - accuracy : {valid_acc}")

    if best_loss > valid_loss:
        best_loss = valid_loss
        last_save = curr_iter
        torch.save(obj=model.state_dict(), f=os.path.join(WEIGHT_FOLDER, "weights.pt"))
        overfitting = 0
    elif last_train_loss > train_loss and valid_loss > last_valid_loss:
        overfitting += 1
    elif overfitting:
        overfitting -= 1

    curr_iter += 1
    last_train_loss = train_loss
    last_valid_loss = valid_loss

    save_losses(losses=train_losses, dir_path=WEIGHT_FOLDER, filename="train_losses.json")
    save_losses(losses=valid_losses, dir_path=WEIGHT_FOLDER, filename="valid_losses.json")
    save_accuracies(accuracies=train_accuracies, dir_path=WEIGHT_FOLDER, filename="train_accuracies.json")
    save_accuracies(accuracies=valid_accuracies, dir_path=WEIGHT_FOLDER, filename="valid_accuracies.json")
    torch.save(obj=model.state_dict(), f=os.path.join(WEIGHT_FOLDER, "last_weights.pt"))
    print("=" * 100)

# [6] Test

## [6.1] Loading model

In [None]:
custom_vocab = load_json(file_path=os.path.join(WEIGHT_FOLDER, "vocab.json"))
vocab_mapper = load_json(file_path=os.path.join(WEIGHT_FOLDER, "mapper.json"))

In [None]:
tokenizer = Tokenizer(vocab=custom_vocab, lookup_vocab=vocab_mapper)

In [None]:
device = torch.device(device="cuda" if torch.cuda.is_available() else "cpu")
print(f"{device=}")

In [None]:
model = ArticleGenerator(
    n_layers=N_LAYERS,
    vocab_size=len(tokenizer),
    emb_dim=EMB_DIM,
    head_dim=HEAD_DIM,
    context=WINDOW_CONTEXT,
    ff_dim=FF_DIM,
    dropout_rate=DROPOUT,
    device=device,
    tokenizer=tokenizer
)
model = nn.DataParallel(module=model)
gc.collect()

In [None]:
model.load_state_dict(torch.load(f=os.path.join(WEIGHT_FOLDER, "weights.pt"), map_location=device))

## [6.2] Generation

In [None]:
for i, token in enumerate(iterable=model.module.generate(max_len=100)):
    print(token, end="")
    if i % 30 == 0: # break line
        print()