In [1]:
import numpy as onp
import jax.numpy as jnp
from jax import jit, vmap, grad, nn, random
import os
import urllib
import sys
from typing import List, Tuple, Dict, Any


In [2]:
# using te corpus from Carpathy
file_url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"


def download(url: str) -> None:
    file_name = url.split("/")[-1]
    download_path = file_name

    if os.path.exists(download_path):
        print("Already downloaded!")

    else:
        # ============================================ download
        print("Downloading, sit tight!")

        def _progress(count, block_size, total_size):
            sys.stdout.write(
                f"\r>> Downloading {file_name} {float(count * block_size) / float(total_size) * 100.0}%")
            sys.stdout.flush()

        file_path, _ = urllib.request.urlretrieve(
            url, download_path, _progress)
        print()
        print(
            f"Successfully downloaded {file_name} {os.stat(file_path).st_size} bytes")
        

download(url=file_url)


Already downloaded!


In [3]:
def read_all_text(file_path: str) -> List[str]:
    assert os.path.exists(file_path)

    with open(file_path, "r") as f:
        data = f.readlines()

    # remove new line escape sequences
    data = [d for d in data if d != "\n"]
    data = [d.replace("\n", "") for d in data]

    return data


all_text = read_all_text("input.txt")


In [4]:
all_text[:5]


['First Citizen:',
 'Before we proceed any further, hear me speak.',
 'All:',
 'Speak, speak.',
 'First Citizen:']

In [5]:
def tokenize_text(text: str) -> list:
    s = text.split(" ")
    return s


def batch_tokenize_text(texts: List[str]) -> List[List[str]]:
    return [tokenize_text(text) for text in texts]


The ":" punctuation mark is a problem. The corpus isn't exactly cleaned for a typical ngram purpose.

In [6]:
tokenize_text(all_text[2])


['All:']

In [7]:
tokenized_sents = batch_tokenize_text(all_text)
tokenized_sents[:1]


[['First', 'Citizen:']]

In [8]:
from tqdm.auto import tqdm


def create_vocab(tokenized_sentences=tokenized_sents):
    vocabulary = dict()  # word to count mapping

    for _, sentence in tqdm(enumerate(tokenized_sentences), total=len(tokenized_sentences)):
        for token in sentence:
            if token in vocabulary.keys():
                vocabulary[token] += 1.0
            else:
                vocabulary[token] = 0.0

    vocabulary["[OOV]"] = 0.0
    return vocabulary


vocabulary = create_vocab()


  0%|          | 0/32777 [00:00<?, ?it/s]

In [9]:
vocab_size = len(list(vocabulary.keys()))
vocab_size


25672

In [10]:
from functools import reduce


def get_total_word_count(vocabulary):
    total = reduce(lambda start, values: start +
                   jnp.sum(values), vocabulary.values(), 0)
    return total


total_tokens = get_total_word_count(vocabulary)
print(total_tokens)


176998.0


In [11]:
def unigram_probabilities(vocabulary, total_tokens, vocab_size, smoothing="laplace"):
    probabilities = dict()  # unigram -> probability
    for k, v in vocabulary.items():
        if smoothing:
            probabilities[k] = (v + 1) / (float(total_tokens) + vocab_size)

    return probabilities


unigram_probs = unigram_probabilities(vocabulary, total_tokens, vocab_size)


In [12]:
def get_sent_probs(sentence, unigram_probs=unigram_probs) -> jnp.ndarray:
    if not isinstance(sentence, list):
        # tokenize
        tokens = tokenize_text(sentence)
    else:
        tokens = sentence

    sentence_probs = []
    for tok in tokens:
        if tok in unigram_probs.keys():
            sentence_probs.append(unigram_probs[tok])
        else:
            sentence_probs.append(unigram_probs["[OOV]"])

    return jnp.array(sentence_probs)


In [13]:
def perplexity(prob):
    product_prob = jnp.prod(prob)
    return jnp.power(product_prob, (-1 / prob.shape[0]))


In [14]:
test_inputs = [
    "First Citizen:",
    "Manners maketh a man.",
    "Where's the ghost, Othello?"
]


def evaluate(test_inputs=test_inputs) -> None:
    for _, ti in tqdm(enumerate(test_inputs), total=len(test_inputs)):
        probs = get_sent_probs(ti)
        p = perplexity(probs)

        print(f"Input : {ti}\nPerplexity : {p}\nProbabilities : {probs}\n")


evaluate(test_inputs)


  0%|          | 0/3 [00:00<?, ?it/s]

Input : First Citizen:
Perplexity : 1335.49560546875
Probabilities : [0.00115952 0.00048354]

Input : Manners maketh a man.
Perplexity : 12557.001953125
Probabilities : [4.9341293e-06 4.9341293e-06 1.2878078e-02 1.2828737e-04]

Input : Where's the ghost, Othello?
Perplexity : 9271.228515625
Probabilities : [6.9077811e-05 2.6826862e-02 1.4802388e-05 4.9341293e-06]



In [15]:
from tqdm.auto import trange
from functools import partial

"""
https://math.stackexchange.com/questions/966466/what-is-the-difference-between-multinomial-and-categorical-distribution

"""
# not having a static shape can throw jit errors
# specify trials as a static_argnum
@partial(jit, static_argnums=2)
def categorical_trial(key, probabilities, trials):
    logits = random.categorical(key, probabilities, axis=-1, shape=[trials, ])
    return logits


def generate(n_tokens, unigram_probs=unigram_probs, trials=50):
    # randomise seed on each call
    seed = int(onp.random.randint(1, 2000))
    # jax prng key for random
    master_key = random.PRNGKey(seed)
    # one trial, key for each token in n_tokens
    # from each trial get the argmax
    master_key, *subkeys = random.split(master_key, n_tokens + 1)
    
    # collect words and probabilities    
    words = list()
    probs = list()


    for k, v in unigram_probs.items():
        words.append(k)
        probs.append(v)

    assert len(words) == len(probs)
    
    # convert to array
    probs = jnp.array(probs)
    
    # vmap is faster than a loop
    # convert subkeys to array first
    subkeys = jnp.array(subkeys)
    
    sampled = vmap(categorical_trial, in_axes=[0, None, None])(subkeys, probs, trials)
    sampled = jnp.argmax(sampled, axis=-1)
    
    
    # word indexes
    indexes = sampled.tolist()
    return " ".join(words[int(i)] for i in indexes)


In [16]:
for _ in range(10):
    print(generate(15))
    print()


Resolved. than to enemy than chief the a and Speak, all famish? price. All: people.

people. chief Let our our Is't know't, know't. to enemy resolved. you a All: All:

All: further, the chief chief price. and chief rather at proceed than know corn know't,

price. are people. we'll We know't, is and Marcius resolved die kill a hear First

the him, famish? to know't, our You and proceed Citizen: kill than resolved. We know

enemy Is't hear to Marcius Speak, him, know't. kill Let All: further, Is't speak. to

people. resolved. enemy chief Caius people. enemy resolved and know't. we'll than resolved a the

we'll Citizen: resolved famish? any rather proceed enemy corn to are First corn proceed First,

Marcius die people. us know't, know have at a than Before First Resolved. enemy resolved

All: to own Citizen: chief any enemy you is you We a to famish? You

