In [32]:
from datasets import load_dataset

# Load from Huggingface datasets module
data = load_dataset("cnn_dailymail", "3.0.0")
train = data["train"]["highlights"]
valid = data["validation"]["highlights"]

Found cached dataset cnn_dailymail (/Users/vik/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de)
100%|██████████| 3/3 [00:00<00:00, 150.29it/s]


In [33]:
STOP_TOKEN = 2
START_TOKEN = 1
UNK_TOKEN = 0
MAX_INPUT_CHARS = 192

In [34]:
import re

def split_words(text):
    return list(text)

def clean_text(text):
    return re.sub(r'[^a-zA-Z0-9_\-\n ]', '', text)

def compute_tokens(texts, min_count=100):
    token_counts = {}
    for text in texts:
        words = split_words(clean_text(text))
        for word in words:
            if word not in token_counts:
                token_counts[word] = 0
            token_counts[word] += 1

    token_map = {"<unk>": UNK_TOKEN, "<s>": START_TOKEN, "</s>": STOP_TOKEN}
    token_counter = 3
    for k, v in token_counts.items():
        if v >= min_count:
            token_map[k] = token_counter
            token_counter += 1
    return token_map

def tokenize(text, token_map):
    # Add end token to stop generation
    words = split_words(clean_text(text))
    list_tokens = [token_map.get(w, UNK_TOKEN) for w in words] + [STOP_TOKEN]
    tokens = jnp.array(list_tokens)
    # Remove leading zeros to avoid sequence issues
    tokens = jnp.trim_zeros(tokens, 'f')
    return tokens

def reverse_tokenize(tokens, reverse_map):
    words = map(lambda x: reverse_map[x], tokens)
    return " ".join(words)

def encode(tokens, token_map, max_seq_len):
    mat = np.zeros((max_seq_len, len(token_map)))
    for i in range(tokens.shape[0]):
        mat[i,tokens[i]] = 1
    for i in range(tokens.shape[0], max_seq_len):
        # Add padding token if sequence is too short
        mat[i,STOP_TOKEN] = 1
    return mat

train = [t for t in train if len(t) < MAX_INPUT_CHARS]
valid = [v for v in valid if len(v) < MAX_INPUT_CHARS]

token_map = compute_tokens(train + valid, min_count=max(2, len(train) / 1000))
reverse_token_map = {v:k for k,v in token_map.items()}

In [35]:
print(len(token_map))
print(len(train))

68
35508


In [36]:
import numpy as np
import math
import jax.numpy as jnp
from jax import grad, jit, vmap, tree_map, value_and_grad, debug
from statistics import mean

def init_params(layer_conf):
    layers = []
    for i in range(len(layer_conf)):
        if layer_conf[i]["type"] == "input":
            continue
        elif layer_conf[i]["type"] == "encoder":
            np.random.seed(0)
            k = 1/math.sqrt(layer_conf[i]["hidden"])
            input_weights = np.random.rand(layer_conf[i-1]["units"], layer_conf[i]["hidden"]) * 2 * k - k

            hidden_weights = np.random.rand(layer_conf[i]["hidden"], layer_conf[i]["hidden"]) * 2 * k - k
            hidden_bias = np.random.rand(1, layer_conf[i]["hidden"]) * 2 * k - k

            output_weights = np.random.rand(layer_conf[i]["hidden"], layer_conf[i]["output"]) * 2 * k - k
            output_bias = np.random.rand(1, layer_conf[i]["output"]) * 2 * k - k

            layers.append(
                [[input_weights], [hidden_weights, hidden_bias], [output_weights, output_bias]]
            )
    return layers

In [37]:
def encoder_fwd(params, prev_hidden, x_example):
    [i_weight], [h_weight, h_bias], [o_weight, o_bias] = params
    input_x = x_example @ i_weight
    hidden_x = input_x + prev_hidden @ h_weight + h_bias
    # Activation.  tanh avoids outputs getting larger and larger.
    hidden_x = jnp.tanh(hidden_x)

    # Output layer
    output_x = hidden_x @ o_weight + o_bias
    return hidden_x, output_x

In [38]:
def decoder_fwd(params, prev_hidden, context):
    [c_weight], [h_weight, h_bias], [o_weight, o_bias] = params
    input_x = context @ c_weight
    hidden_x = input_x + prev_hidden @ h_weight + h_bias
    # Activation.  tanh avoids outputs getting larger and larger.
    hidden_x = jnp.tanh(hidden_x)

    # Activation
    output_x = hidden_x @ o_weight + o_bias
    return hidden_x, output_x

In [39]:
def softmax(preds):
    tol = 1e-6
    preds = jnp.exp(preds - jnp.max(preds))
    return preds / (jnp.sum(preds) - tol)

def log_loss(encoder_params, prev_hidden, x, y):
    hidden, output = encoder_fwd(encoder_params, prev_hidden, x)
    output = softmax(output)
    tol = 1e-6
    cross_entropy = jnp.multiply(y, jnp.log(output + tol))
    loss = -jnp.sum(cross_entropy)
    return loss, {"hidden": hidden, "output": output}

In [40]:
def add_pytrees(pytree1, pytree2, x_ended):
  return tree_map(lambda pt1, pt2: pt1 + (pt2 * x_ended), pytree1, pytree2)

@jit
def update(params, x, y, lr):
    encoder_params = params[0]
    # Initialize with hidden shape
    prev_hidden = jnp.zeros((1, 512))
    losses = []
    total_grad = None
    for j in range(x.shape[0]):
        [loss, state], grad = value_and_grad(log_loss, has_aux=True)(encoder_params, prev_hidden, x[j,:], y[j,:])
        prev_hidden = state["hidden"]
        losses.append(loss)

        # If the sequence is over, we won't contribute to the gradient.  This happens when we detect a stop token.
        x_ended = (np.argmax(x[j,:]) != STOP_TOKEN).astype(int)
        if total_grad is None:
            total_grad = grad
        else:
            total_grad = add_pytrees(total_grad, grad, x_ended)

    params[0] = tree_map(
        lambda param, grad: param - lr * grad, encoder_params, total_grad
    )
    return params, losses

def batch_update(params, x, y, lr):
    return vmap(update, in_axes=[None, 0, 0, None])(params, x, y, lr)

In [41]:
tokenized = [tokenize(train[z], token_map) for z in range(len(train))]
max_seq_len = max([len(t) for t in tokenized])

train_x = np.stack([encode(tokenized[z], token_map, max_seq_len) for z in range(len(train))])
train_y = np.zeros(train_x.shape)
train_y[:,:-1,:] = train_x[:,1:,:]
train_y[:,-1,STOP_TOKEN] = 1

In [None]:
from tqdm.auto import tqdm
epochs = 100
lr = 1e-3
batch_size = 1

layer_conf = [
    {"type":"input", "units": len(token_map)},
    {"type": "encoder", "hidden": 512, "output": len(token_map)}
]
params = init_params(layer_conf)

In [46]:
for i in range(epochs):
    sequence_len = 7
    print(f"Epoch {i}")
    epoch_loss = []
    for j in tqdm(range(0, len(train), batch_size)):
        batch_inds = list(range(j, j+batch_size))
        seq_x = train_x[j,:,:]
        seq_y = train_y[j,:,:]
        #params, loss = batch_update(params, seq_x, seq_y, lr)
        params, loss = update(params, seq_x, seq_y, lr)
        epoch_loss.append(mean([float(jnp.mean(l)) for l in loss]))

    print(f"Epoch {i} loss: {mean(epoch_loss)}")

Epoch 0


100%|██████████| 35508/35508 [15:39<00:00, 37.78it/s]


Epoch 0 loss: 2.848861174764411
Epoch 1


100%|██████████| 35508/35508 [16:04<00:00, 36.81it/s]


Epoch 1 loss: 2.8117060742297233
Epoch 2


100%|██████████| 35508/35508 [15:40<00:00, 37.75it/s]


Epoch 2 loss: 2.793207994907865
Epoch 3


100%|██████████| 35508/35508 [15:57<00:00, 37.09it/s]


Epoch 3 loss: 2.8062659272911934
Epoch 4


  1%|▏         | 492/35508 [00:13<15:32, 37.55it/s]


KeyboardInterrupt: 

In [47]:
gen_length = 60
start_text = train[1]
tokenized = tokenize(start_text, token_map)
encoded = encode(tokenized, token_map, max_seq_len)

prev_hidden = jnp.zeros((1, 512))
out_seq = []
out_text = ""
for j in range(encoded.shape[0]):
    hidden, output = encoder_fwd(params[0], prev_hidden, encoded[j,:])
    prev_hidden = hidden
    output = softmax(output)
    out_ind = int(np.argmax(output))
    out_seq.append(out_ind)
    out_text += reverse_token_map[out_ind] + ""

In [48]:
out_text

'rosident \natt aays \nhry Mtownaitl bectee ooncer \ntd aitgowrmbsf toosi wptoet ry oes been artieec por sudteow auys \nwuow daaseng tf Tutoember 10 sotl be tercess d ty tirdd\natocgndat          '