In [2]:
import numpy as onp
import jax.numpy as jnp
from jax import jit, vmap, grad, nn, random


In [3]:
from typing import List, Tuple, Dict, Any


In [4]:
import os 
import urllib
import sys


# using te corpus from Carpathy
file_url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"


def download(url: str) -> None:
    file_name = url.split("/")[-1]
    download_path = file_name

    if os.path.exists(download_path):
        print("Already downloaded!")

    else:
        # ============================================ download
        print("Downloading, sit tight!")

        def _progress(count, block_size, total_size):
            sys.stdout.write(
                f"\r>> Downloading {file_name} {float(count * block_size) / float(total_size) * 100.0}%")
            sys.stdout.flush()

        file_path, _ = urllib.request.urlretrieve(
            url, download_path, _progress)
        print()
        print(
            f"Successfully downloaded {file_name} {os.stat(file_path).st_size} bytes")
        

download(url=file_url)


Already downloaded!


In [5]:
def read_all_text(file_path: str) -> List[str]:
    assert os.path.exists(file_path)

    with open(file_path, "r") as f:
        data = f.readlines()

    # remove new line escape sequences
    data = [d for d in data if d != "\n"]
    data = [d.replace("\n", "") for d in data]

    return data


all_text = read_all_text("input.txt")


In [6]:
all_text[:5]


['First Citizen:',
 'Before we proceed any further, hear me speak.',
 'All:',
 'Speak, speak.',
 'First Citizen:']

In [7]:
def tokenize_text(text: str) -> list:
    s = text.split(" ")
    return s


def batch_tokenize_text(texts: List[str]) -> List[List[str]]:
    return [tokenize_text(text) for text in texts]


The ":" punctuation mark is a problem. The corpus isn't exactly cleaned for a typical ngram purpose.

In [8]:
tokenize_text(all_text[2])


['All:']

In [9]:
tokenized_sents = batch_tokenize_text(all_text)
tokenized_sents[:1]


[['First', 'Citizen:']]

In [10]:
from tqdm.auto import tqdm


def create_vocab(tokenized_sentences=tokenized_sents):
    vocabulary = dict()  # word to count mapping

    for _, sentence in tqdm(enumerate(tokenized_sentences), total=len(tokenized_sentences)):
        for token in sentence:
            if token in vocabulary.keys():
                vocabulary[token] += 1.0
            else:
                vocabulary[token] = 0.0

    vocabulary["[OOV]"] = 0.0
    return vocabulary


vocabulary = create_vocab()


  0%|          | 0/32777 [00:00<?, ?it/s]

In [11]:
vocab_size = len(list(vocabulary.keys()))
vocab_size


25672

In [12]:
from functools import reduce


def get_total_word_count(vocabulary):
    total = reduce(lambda start, values: start +
                   jnp.sum(values), vocabulary.values(), 0)
    return total


total_tokens = get_total_word_count(vocabulary)
print(total_tokens)


176998.0


In [13]:
def unigram_probabilities(vocabulary, total_tokens, vocab_size, smoothing="laplace"):
    probabilities = dict()  # unigram -> probability
    for k, v in vocabulary.items():
        if smoothing:
            probabilities[k] = (v + 1) / (float(total_tokens) + vocab_size)

    return probabilities


unigram_probs = unigram_probabilities(vocabulary, total_tokens, vocab_size)


In [14]:
def get_sent_probs(sentence, unigram_probs=unigram_probs) -> jnp.ndarray:
    if not isinstance(sentence, list):
        # tokenize
        tokens = tokenize_text(sentence)
    else:
        tokens = sentence

    sentence_probs = []
    for tok in tokens:
        if tok in unigram_probs.keys():
            sentence_probs.append(unigram_probs[tok])
        else:
            sentence_probs.append(unigram_probs["[OOV]"])

    return jnp.array(sentence_probs)


In [15]:
def perplexity(prob):
    product_prob = jnp.prod(prob)
    return jnp.power(product_prob, (-1 / prob.shape[0]))


In [16]:
test_inputs = [
    "First Citizen:",
    "Manners maketh a man.",
    "Where's the ghost, Othello?"
]


def evaluate(test_inputs=test_inputs) -> None:
    for _, ti in tqdm(enumerate(test_inputs), total=len(test_inputs)):
        probs = get_sent_probs(ti)
        p = perplexity(probs)

        print(f"Input : {ti}\nPerplexity : {p}\nProbabilities : {probs}\n")


evaluate(test_inputs)


  0%|          | 0/3 [00:00<?, ?it/s]

Input : First Citizen:
Perplexity : 1335.49560546875
Probabilities : [0.00115952 0.00048354]

Input : Manners maketh a man.
Perplexity : 12557.001953125
Probabilities : [4.9341293e-06 4.9341293e-06 1.2878078e-02 1.2828737e-04]

Input : Where's the ghost, Othello?
Perplexity : 9271.228515625
Probabilities : [6.9077811e-05 2.6826862e-02 1.4802388e-05 4.9341293e-06]



In [67]:
# jax prng key for random
n_gens = 5
seed = int(onp.random.randint(1, 2000))
master_key = random.PRNGKey(seed)
master_key, *subkeys = random.split(master_key, n_gens + 1)


@jit
def sample_words(key, probabilities):
    logits = random.categorical(key, probabilities, axis=-1, shape=[15, ])
    return logits


def generate(key, n_tokens, unigram_probs=unigram_probs):    
    words = list()
    probs = list()
    indexes = list()

    for k, v in unigram_probs.items():
        words.append(k)
        probs.append(v)

    assert len(words) == len(probs)
    
    probs = jnp.array(probs)
    sampled = sample_words(key, probs)
    indexes = sampled.tolist()
    
        
    return " ".join(words[int(i)] for i in indexes)


# ======================
for idx in range(n_gens):
    out = generate(subkeys[idx], 15)
    print(out)


parts mutinous see: limit. little, returns. Caliban! constrains shreds slackness. buds malign thaw weather, counsel;
sacrament, shrill-voiced widower They, changed, treasons blindly pains teeth, dissentious health. Buckingham. spite. births: estates,--
virtue: Angelo? Scrivener: unaccustom'd point lamentation; strew grave, wherefore? zeal, exchange! on! defied disturbing fringed
their thee; high dugs spell, GLOUCESTER: fairest-boding virtue! people. Account lord's patience! king holiday, fill.
vouch'd, ean: Contend summer's friends,' whet heavy? here,--this, ordnance forgiven solace: urging excellence; bad! Became
