# Training VanillaTCN on Shakespeare dataset for character-level prediction 

In [1]:
import time

import numpy as np
from tcn import VanillaTCN

In [30]:
# --- hyperparameters ---
# training
learning_rate = 0.001
rho_1, rho_2 = 0.9, 0.999  # Adam params
dropout_input_p_keep = 1
dropout_hidden_p_keep = 1

batch_size = 5
max_epochs = 100
early_stop_rel_tol = 0
data_size = 30  # limit data size for faster training, -1 for full data
# number of backprop steps = O(batch_size * data_size * epochs)

# model
copies = 2  # copies of kernel-dilation list, so final_depth = depth * copies, see below
depth = 3
kernel_size = 3
dilation_size = 2
hidden_size = 20
# num parameters = O(copies * depth * kernel_size * hidden_size^2)

# generation
seed_text = "To be, or not to be"
generation_length = 50

In [21]:
# --- data ---
file = "input.txt"
with open(file, "r") as f:
    text = f.read()
chars = ''.join(set(text))
char_to_idx = {ch: i for i, ch in enumerate(chars)}
idx_to_char = {i: ch for i, ch in enumerate(chars)}
vocab_size = len(chars)
text = text[:data_size]

In [22]:
# --- initialize model params --- 
dilations = copies * [dilation_size**i for i in range(depth-1, -1, -1)]
kernel_sizes = copies * [kernel_size for _ in range(depth)]
hidden_sizes = [vocab_size] + [hidden_size for _ in range(copies * depth-1)]

After including Adam the model stopped learning very well. I thought this might be due to poor weight initialization so here I do a quick hyperparameter search over initial weight scales to see how it affects gradient magnitudes. 

In [23]:
# --- weight initialization hyperparameter search ---
minimum_weight_scale = 0.01
maximum_weight_scale = 100
runs = 20
scale_grad = {}

for run_idx in range(runs):
    weight_scale = np.exp(np.random.uniform(np.log(minimum_weight_scale), np.log(maximum_weight_scale)))
    
    # generate some data
    model = VanillaTCN(input_size=vocab_size, dilations=dilations,
                   kernel_sizes=kernel_sizes, hidden_sizes=hidden_sizes,
                   input_p_keep=dropout_input_p_keep, hidden_p_keep=dropout_hidden_p_keep)
    T_f = model.T_f
    inputs_batch = np.zeros((batch_size, vocab_size, T_f), dtype=np.float32)
    targets_batch = np.zeros((batch_size, vocab_size), dtype=np.float32)
    i = np.random.randint(T_f, len(text) - batch_size - 1)
    for b in range(batch_size):
        for t in range(T_f):
            inputs_batch[b, char_to_idx[text[i - T_f + t + b]], t] = 1
        targets_batch[b, char_to_idx[text[i + b]]] = 1
    
    # get gradients
    model.train_minibatch(inputs_batch, targets_batch,
                          rho_1=rho_1, rho_2=rho_2,
                          learning_rate=learning_rate)
    grads = model.dw
    scale_grad[weight_scale] = [(np.mean(np.abs(g)), np.std(np.abs(g))) for g in grads]
    print(weight_scale, [(np.mean(np.abs(g)), np.std(np.abs(g))) for g in grads])

0.026104827896618176 [(0.003359679, 0.008232633), (0.0044815876, 0.006521489), (0.0032658903, 0.0045701764), (0.003278421, 0.0033071951), (0.0034511718, 0.0028297934), (0.00061457267, 0.0016448857)]
0.05276683829540094 [(0.0038342564, 0.0078014946), (0.008722995, 0.010834441), (0.005951725, 0.007367826), (0.005474142, 0.005208902), (0.003949897, 0.003381496), (0.00067000854, 0.0017621431)]
1.6983074659463868 [(0.002761631, 0.008531206), (0.007836726, 0.010102911), (0.0043202764, 0.004953173), (0.0032497128, 0.0033437093), (0.0025253652, 0.0023581134), (0.0004771137, 0.0013385761)]
0.054060586007368336 [(0.003611302, 0.008964293), (0.006593332, 0.008332121), (0.005032525, 0.0062540867), (0.004380226, 0.0053018574), (0.0027632103, 0.0027429617), (0.00048393215, 0.0013281257)]
0.10393266587460642 [(0.0031713026, 0.008267327), (0.0068579405, 0.009661908), (0.005112388, 0.006214905), (0.003232685, 0.0037204595), (0.0023775585, 0.002088501), (0.00056070293, 0.0016528649)]
71.03761078350907 [

We see that the gradients (in particular note the standard deviation) are actually surprisingly insensitive to initial weight scale. So we just fix it to 0.1 for the rest of the training.

In [24]:
# model init
dilations = copies * [dilation_size**i for i in range(depth-1, -1, -1)]
kernel_sizes = copies * [kernel_size for _ in range(depth)]
hidden_sizes = [vocab_size] + [hidden_size for _ in range(copies * depth-1)]
model = VanillaTCN(input_size=vocab_size, dilations=dilations,
                   kernel_sizes=kernel_sizes, hidden_sizes=hidden_sizes,
                   input_p_keep=dropout_input_p_keep, hidden_p_keep=dropout_hidden_p_keep,
                   weight_scale=0.1)
T_f = model.T_f

# --- print info ---
print("Vocab size:", vocab_size)
print(f"proportion ensemble disconnected approx: {(1-dropout_hidden_p_keep)**sum(model.used_node_idx[2]):.0e}")
print("Receptive field:", T_f)
print("Parameters:", sum([w.size for w in model.weights]) + sum([b.size for b in model.biases]))
print(f"Expected number of total backprop evals: {batch_size * (len(text)-T_f) * max_epochs:.2e}")
# (depth-2)*(kernel_size*hidden_size**2+hidden_size)+2*kernel_size*hidden_size*vocab_size+vocab_size+hidden_size

Vocab size: 65
proportion ensemble disconnected approx: 0e+00
Receptive field: 29
Parameters: 12765
Expected number of total backprop evals: 5.50e+03


In [36]:
# --- training loop ---
t0 = time.time()
prev_loss = 1e6
end = len(text) - batch_size - 1
for e in range(max_epochs):
    for i in range(T_f, end+1):
        inputs_batch = np.zeros((batch_size, vocab_size, T_f), dtype=np.float32)
        targets_batch = np.zeros((batch_size, vocab_size), dtype=np.float32)
        for b in range(batch_size):
            for t in range(T_f):
                inputs_batch[b, char_to_idx[text[i - T_f + t + b]], t] = 1
            targets_batch[b, char_to_idx[text[i + b]]] = 1
        model.train_minibatch(inputs_batch, targets_batch,
                                  rho_1=rho_1, rho_2=rho_2,
                                  learning_rate=learning_rate)
    
    loss = model.train_minibatch(inputs_batch, targets_batch,
                                  return_loss=True)
    print(f"Epoch {e+1}/{max_epochs}, Loss (on final batch): {loss:.4f} "
          f"(i.e. avg. prob. assigned: {np.exp(-loss):.4f}) "
          f"time elapsed: {time.time() - t0:.0f}s eta: {(time.time() - t0)/(e+1)*(max_epochs - e - 1)/60:.1f}m")

    # --- generate text ---
    if len(seed_text) < T_f:
        test = " " * (T_f - len(seed_text)) + seed_text
    else:
        test = seed_text[-T_f:]
        print("Warning: seed_text length greater than receptive field size.")
    for i in range(generation_length):
        input_seq = np.zeros((vocab_size, T_f), dtype=np.float32)
        for t in range(T_f):
            input_seq[char_to_idx[test[-T_f + t]], t] = 1
        output_probs = model.forward_pass(input_seq, do_dropout=False)
        next_char_idx = np.random.choice(range(vocab_size), p=output_probs.ravel())
        test += idx_to_char[next_char_idx]

    print("Generated text:")
    print(test, '\n')

    # Stopped learning?
    if abs(prev_loss - loss)/max(loss, prev_loss, 1e-8) < early_stop_rel_tol:
        print("Converged, stopping training")
        break
    prev_loss = loss

29
30
31
32
33
34
Average gradient norm: 0.03242042288184166
Epoch 1/100, Loss (on final batch): 2.8847 (i.e. avg. prob. assigned: 0.0559) time elapsed: 0s eta: 0.4m
Generated text:
          To be, or not to be KC n dnd yyyn-nnnae yye ndyay an daeenedaaydyna a 

29
30
31
32
33
34
Average gradient norm: 0.03244110941886902
Epoch 2/100, Loss (on final batch): 2.8831 (i.e. avg. prob. assigned: 0.0560) time elapsed: 1s eta: 0.5m
Generated text:
          To be, or not to bennanaOf danK-dc  nendW y  se raaeQylyn'yeendfn yan 

29
30
31
32
33
34
Average gradient norm: 0.03248860687017441
Epoch 3/100, Loss (on final batch): 2.8816 (i.e. avg. prob. assigned: 0.0560) time elapsed: 1s eta: 0.6m
Generated text:
          To be, or not to benaynddd ayy:dgdy ay nyyn y fenddfynadyadnn e e Kny 

29
30
31
32
33
34
Average gradient norm: 0.032537173479795456
Epoch 4/100, Loss (on final batch): 2.8800 (i.e. avg. prob. assigned: 0.0561) time elapsed: 1s eta: 0.5m
Generated text:
          To be, or not t

KeyboardInterrupt: 