In [None]:
import sys

# Hint:
#   I'm use `sys.executable` because python in venv on Windows 10 (Yeap...)

print(f'python: {sys.executable}')
!{sys.executable} -V

In [None]:
!{sys.executable} -m pip install -qq --upgrade pip
!{sys.executable} -m pip install -qq numpy==1.24.1 datasets nltk matplotlib
!{sys.executable} -m pip install -qq torch -f https://download.pytorch.org/whl/torch_stable.html

In [None]:
import torch
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
import numpy as np

from typing import List
from datasets import load_dataset
from collections import Counter

from tqdm import tqdm

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
dataset = load_dataset('santhosh/english-malayalam-names')

In [None]:
long_en = dataset['train']['en']

In [None]:
def select_uniq_freq_names(names: List[str]):
    # Select unique names
    names_counter = Counter()
    for name_and_other in names:
        words = name_and_other.split()
        if len(words) == 0:
            continue

        words = words[0].split('.')
        if len(words) == 0:
            continue

        names_counter.update(words[:1])

    return [name for name, freq in names_counter.items() if freq > 1]

In [None]:
class Tokenizer:
    def __init__(self, lines: List[str]):
        self.id2token, self.token2id = self.train(lines)

        self.bos = self.id2token[-4]
        self.eos = self.id2token[-3]
        self.unk_id = len(self.id2token) - 1
        self.pad_id = len(self.id2token) - 2
        self.eos_id = len(self.id2token) - 3
        self.bos_id = len(self.id2token) - 4

    def train(self, uniq_names: List[str]):
        tokens_counter = Counter()
        for name in uniq_names:
            tokens_counter.update(name)  # Names -> symbols

        id2token = [token for token in tokens_counter.keys()]
        id2token.extend(['<bos>', '<eos>', '<pad>', '<unk>'])
        token2id = {token: i for i, token in enumerate(id2token)}

        for i in range(len(id2token)):
            assert token2id[id2token[i]] == i

        return id2token, token2id

    def as_matrix(self, x: List[str]) -> torch.Tensor:
        tokens = [
            torch.LongTensor(
                [self.bos_id] +
                list(
                    map(lambda tt: self.token2id.get(tt, self.unk_id), t)
                ) +
                [self.eos_id]
            ) for t in x]
        return pad_sequence(tokens, batch_first=True, padding_value=self.pad_id)

    def as_bos_and_indexes(self, x: str) -> torch.Tensor:
        return torch.LongTensor([self.bos_id] + list(
            map(lambda tt: self.token2id.get(tt, self.unk_id), x)
        ))

    def decode(self, x: torch.Tensor) -> List[str]:
        assert x.ndim == 2, 'pass batch of ids'

        tokens = x.tolist()
        lines = [self._decode_line(t) for t in tokens]
        return lines

    def _decode_line(self, tokens: List[int]) -> str:
        answer = ""
        for token_id in tokens:
            answer += f'{self.id2token[token_id]} ' if token_id != self.pad_id else ''
        return answer.strip()

    def __len__(self):
        return len(self.id2token)

In [None]:
from collections import namedtuple

ClassifierOutput = namedtuple('ClassifierOutput', ['loss', 'logits'])


class GenNetwork(torch.nn.Module):
    def __init__(self, tokenizer):
        super().__init__()
        self.tokenizer = tokenizer
        self.embeddings = torch.nn.Embedding(
            num_embeddings=len(tokenizer),
            embedding_dim=32,
        )
        self.gru = torch.nn.GRU(
            input_size=32,
            hidden_size=64,
            num_layers=1,
            batch_first=True,
        )
        self.classifier_head = torch.nn.Linear(64, len(tokenizer))
        self.cross = torch.nn.CrossEntropyLoss(ignore_index=tokenizer.pad_id)

    def _compute_lm_loss(self, predictions: torch.Tensor, input_ids: torch.Tensor) -> torch.Tensor:
        predictions = predictions[:, :-1, :].contiguous()
        labels = input_ids[:, 1:].contiguous()
        return self.cross(predictions.view(-1, predictions.size(2)), labels.view(-1))

    def forward(self, input_ids, labels=None):
        emb = self.embeddings(input_ids)
        lstm_res, _ = self.gru(emb)
        class_res = self.classifier_head(lstm_res)

        if labels is None:
            output = ClassifierOutput(loss=None, logits=class_res)
        else:
            loss = self._compute_lm_loss(class_res, labels)
            output = ClassifierOutput(loss=loss, logits=class_res)

        return output

In [None]:
def iterate_batches(vocab, x_data, shuffle: bool = False, batch_size: int = 32, device=device):
    x_data = np.array(x_data)

    indices = np.arange(len(x_data))
    if shuffle:
        np.random.shuffle(indices)

    for i in range(0, len(indices), batch_size):
        batch_idx = indices[i:i+batch_size]
        batch_x = x_data[batch_idx]

        input_ids = vocab.as_matrix(batch_x).to(device)
        yield {'input_ids': input_ids, 'labels': input_ids}

# https://stackoverflow.com/questions/5283649/plot-smooth-line-with-pyplot
def smooth(scalars: List[float], weight: float) -> List[float]:  # Weight between 0 and 1
    last = scalars[0]  # First value in the plot (first timestep)
    smoothed = list()
    for point in scalars:
        smoothed_val = last * weight + (1 - weight) * point  # Calculate smoothed value
        smoothed.append(smoothed_val)                        # Save it
        last = smoothed_val                                  # Anchor the last smoothed value
        
    return smoothed

In [None]:
# For fast system testing
is_low_dataset = False
en = long_en[:20000] if is_low_dataset else long_en

In [None]:
uniq_en_names = select_uniq_freq_names(en)
vocab = Tokenizer(uniq_en_names)

In [None]:
net = GenNetwork(vocab).to(device)
optimizer = torch.optim.Adam(net.parameters())
net

In [None]:
is_visual_enable = True

if is_visual_enable:
    import matplotlib.pyplot as plt
    from IPython.display import clear_output

In [None]:
history = []

batch_size = 64
num_epoch = 40

if is_visual_enable:
    progressbar = tqdm(total=num_epoch * len(uniq_en_names) //
                       batch_size, desc='training')
net.train()
for epoch in range(num_epoch):
    it = enumerate(iterate_batches(net.tokenizer, uniq_en_names,
                                   batch_size=batch_size, shuffle=True))
    for idx, batch in it:
        model_out = net(**batch)

        optimizer.zero_grad()
        loss = model_out.loss
        loss.backward()
        history.append(loss.item())

        optimizer.step()
        if is_visual_enable:
            progressbar.update(1)
            if (idx + 1) % 100 == 0:
                clear_output(wait=True)
                plt.plot(np.log(smooth(history, 0.95)), label='loss')
                plt.show()

In [None]:
@torch.no_grad()
def generate_name_greedy_decoding(net, first_sym='', max_len=30, device=device):
    name = first_sym
    for _ in range(max_len):
        name_indexes = net.tokenizer.as_matrix([name]).to(device)
        output = net(name_indexes)
        logits = output.logits[0][len(name)]
        probs = F.softmax(logits, dim=-1)
        idx = torch.argmax(probs).item()
        if idx == net.tokenizer.eos_id:
            break

        symb = net.tokenizer.id2token[idx]
        name += symb

    return name


@torch.no_grad()
def generate_name_casual_sampling(net, first_sym='', temp=1, max_len=30, device=device):
    name = first_sym
    for _ in range(max_len):
        name_indexes = net.tokenizer.as_matrix([name]).to(device)

        output = net(name_indexes)
        logits = output.logits[0][len(name)]

        probs = F.softmax(logits / temp, dim=-1).detach().cpu().data.numpy()
        idx = np.random.choice(len(logits), p=probs)
        if idx == net.tokenizer.eos_id:
            break

        symb = net.tokenizer.id2token[idx]
        name += symb

    return name


@torch.no_grad()
def generate_name_top_p(net, top_p, first_sym='', temp=1, max_len=30, device=device):
    name = first_sym
    for _ in range(max_len):
        name_indexes = net.tokenizer.as_matrix([name]).to(device)

        output = net(name_indexes)
        logits = output.logits[0][len(name)]

        probs = F.softmax(logits / temp, dim=-1).detach().cpu().data.numpy()
        prob_args = np.argsort(-probs)

        prob_acc = 0
        for last_idx in range(len(prob_args)):
            prob_acc += probs[prob_args[last_idx]]
            if prob_acc >= top_p:
                break
        last_idx = min(last_idx + 1, len(prob_args))

        probs = probs[prob_args[:last_idx]] / prob_acc
        arg_idx = np.random.choice(last_idx, p=probs)
        idx = prob_args[arg_idx]
        if idx == net.tokenizer.eos_id:
            break

        symb = net.tokenizer.id2token[idx]
        name += symb

    return name


@torch.no_grad()
def generate_name_top_k(net, top_k: int, first_sym='', temp=1, max_len=30, device=device):
    top_k = min(top_k, len(net.tokenizer))
    name = first_sym
    for _ in range(max_len):
        name_indexes = net.tokenizer.as_matrix([name]).to(device)

        output = net(name_indexes)
        logits = output.logits[0][len(name)]

        probs = F.softmax(logits / temp, dim=-1).detach().cpu().data.numpy()
        prob_args = np.argsort(-probs)
        probs = probs[prob_args[:top_k]]
        probs /= np.sum(probs)

        arg_idx = np.random.choice(top_k, p=probs)
        idx = prob_args[arg_idx]
        if idx == net.tokenizer.eos_id:
            break

        symb = net.tokenizer.id2token[idx]
        name += symb

    return name

## Greedy decoding

In [None]:
for i in range(5):
    print(generate_name_greedy_decoding(net, 'A'),
          generate_name_greedy_decoding(net, 'Ai'),
          generate_name_greedy_decoding(net, 'Lu'),
          generate_name_greedy_decoding(net, 'M'))

## Casual sampling

In [None]:
for i in range(5):
    print(generate_name_casual_sampling(net, first_sym='Muh', temp=1))

## Top-p

In [None]:
for i in range(5):
    print(generate_name_top_p(net, 0.8, first_sym='M', temp=1))

## Top-k

In [None]:
for i in range(5):
    print(generate_name_top_k(net, 20, first_sym='Lucy', temp=1))