In [1]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
probs = np.load("ngram_probs.npy")
probs.shape

In [None]:
assert probs.shape == (27, 27, 27, 27)
reshaped = probs.reshape(27**2, 27**2)
plt.figure(figsize=(6, 6))
plt.imshow(reshaped, cmap='hot', interpolation='nearest')
plt.axis('off')

In [12]:
import imageio
import os
import IPython
import warnings

# Suppress all warnings
warnings.filterwarnings('ignore')
def random_u32(state):
    # xorshift rng: https://en.wikipedia.org/wiki/Xorshift#xorshift.2A
    # doing & 0xFFFFFFFFFFFFFFFF is the same as cast to uint64 in C
    # doing & 0xFFFFFFFF is the same as cast to uint32 in C
    state[0] ^= (state[0] >> 12) & 0xFFFFFFFFFFFFFFFF
    state[0] ^= (state[0] << 25) & 0xFFFFFFFFFFFFFFFF
    state[0] ^= (state[0] >> 27) & 0xFFFFFFFFFFFFFFFF
    return ((state[0] * 0x2545F4914F6CDD1D) >> 32) & 0xFFFFFFFF

def random_f32(state):
    # random float32 in [0,1)
    return (random_u32(state) >> 8) / 16777216.0

def sample_discrete(probs_, coinf):
    # sample from a discrete distribution
    cdf = 0.0
    for i, prob in enumerate(probs_):
        cdf += prob
        if coinf < cdf:
            return i
    return len(probs_) - 1  # in case of rounding errors

train_text = open('../data/train.txt', 'r').read()
assert all(c == '\n' or ('a' <= c <= 'z') for c in train_text)
uchars = sorted(list(set(train_text))) # unique characters we see in the input
char_to_token = {c: i for i, c in enumerate(uchars)}
token_to_char = {i: c for i, c in enumerate(uchars)}
seq_len = 4
rng_state = [1337]
tape = []
visual_images = './visual_images/'
os.makedirs(visual_images, exist_ok=True)

# Initialize imageio writer for GIF creation
images = []
gen_chars = []
# Simulate iterations
iterations = 20
probs__list = probs[0][0][-1]
for iter_idx in range(iterations):
    IPython.display.clear_output()
    plt.figure(figsize=(8, 6))
    
    # Example probabilities (replace with actual reshaped probabilities)
    probs_ = probs__list
    y_pos = np.arange(len(uchars))
    plt.bar(y_pos, probs_, align='center', alpha=0.5)
    plt.xticks(y_pos, uchars)
    plt.ylabel('Probability')
    plt.title(f'Iteration {iter_idx + 1}: Probability Distribution of Characters')
    
    # Save plot to images list
    img_path = f'./{visual_images}/iteration_{iter_idx + 1}.png'
    plt.savefig(img_path)
    images.append(imageio.imread(img_path))
    
    plt.close()
    
    # Sample the next token
    coinf = random_f32(rng_state)
    probs__list = probs_
    next_token = sample_discrete(probs__list, coinf)
    next_char = token_to_char[next_token]
    
    # Update tape
    tape.append(next_token)
    if len(tape) > seq_len - 1:
        tape = tape[1:]
    gen_chars.append(next_char)
    plt.figure(figsize=(8, 6))
    plt.bar(y_pos, probs_, align='center', alpha=0.5)
    plt.xticks(y_pos, uchars)
    plt.ylabel('Probability')
    plt.title(f'Iteration {iter_idx + 1}: Selection of Characters: "{"".join(gen_chars)}"')
    plt.axvline(x=next_token, color='r', linestyle='--', linewidth=1, label=f'Selected: {next_char}')
    plt.legend()
    plt.text(next_token, max(probs_) * 0.8, f'Selected: {next_char}', ha='center', va='bottom', color='blue', fontsize=12)
    # Save plot to images list
    img_path = f'./{visual_images}/iteration_selection_{iter_idx + 1}.png'
    plt.savefig(img_path)
    images.append(imageio.imread(img_path))
    
    # print(f"Iteration {iter_idx + 1}: Selected character '{next_char}'")
    plt.close()
    

# Save images as a GIF
gif_path = 'iterations.gif'
imageio.mimsave(gif_path, images, duration=1000)  # duration in seconds between frames

# Clean up temporary image files
for img_path in os.listdir(visual_images):
    if img_path.endswith('.png'):
        os.remove(visual_images+img_path)
# Delete the directory
os.rmdir(visual_images)
print(f'Animated GIF saved as {gif_path}')
IPython.display.clear_output()
