In [None]:
import os
import sys

# Ensure no W&B logging will be performed
# sys.argv += "-log tb -name tst -task gru_repeat -reset 1 -state_size 64 -load_pretrained_model ../checkpoints/gru_64-70k.pth -var_analysis.no_input 1".split(" ")
sys.argv += "-log tb -name tst -task gru_repeat -reset 1 -state_size 64 -load_pretrained_model ../checkpoints/gru_64_with_input.pth -var_analysis.no_input 0".split(" ")
# sys.argv += "-log tb -name tst -task gru_repeat -reset 1 -state_size 64 -load_pretrained_model ../checkpoints/gru_no_r-100k.pth -var_analysis.no_input 0".split(" ")
# sys.argv += "-log tb -name tst -task gru_repeat -reset 1 -state_size 64 -load_pretrained_model save/gru_digit_store/checkpoint/model-70000.pth -var_analysis.no_input 1".split(" ")
#
# sys.argv += "-log tb -name tst -task gru_repeat -reset 1 -state_size 64 -load_pretrained_model ../checkpoints/gru_no_z-100k.pth -var_analysis.no_input 0".split(" ")


# Pretend we are in the main directory
os.chdir("..")

In [None]:
from main import initialize
import torch
import torch.nn.functional as F
from typing import Optional

In [None]:
%matplotlib inline
from matplotlib import pyplot as plt
plt.rcParams['figure.dpi'] = 200
plt.rcParams['savefig.dpi'] = 200

In [None]:
# Note: checkpoints have all arguments saved
helper, task = initialize()
task.create_data_fetcher()

In [None]:
diter = iter(task.train_loader)

In [None]:
data = next(diter)
data= task.prepare_data(data)

In [None]:
def get_overwrite(force, var, i, newval):
    if var in force:
        return force[var][i]
    return newval

def gru_run(model: torch.nn.GRU, x: torch.Tensor, h: Optional[torch.Tensor], lengths: torch.Tensor,
            force = {}):
    assert model.num_layers == 1
    wir, wiz, win = model.weight_ih_l0.chunk(3, dim=0)
    whr, whz, whn = model.weight_hh_l0.chunk(3, dim=0)
    bir, biz, bin = model.bias_ih_l0.chunk(3, dim=0)
    bhr, bhz, bhn = model.bias_hh_l0.chunk(3, dim=0)

    if h is None:
        h = torch.zeros(x.shape[1], wir.shape[0], device=x.device, dtype=x.dtype)

    rs = []
    zs = []
    ns = []
    hs = []

    for i in range(x.shape[0]):
        r = torch.sigmoid(F.linear(x[i], wir, bir) + F.linear(h, whr, bhr))
        z = torch.sigmoid(F.linear(x[i], wiz, biz) + F.linear(h, whz, bhz))
        r = get_overwrite(force, "r", i, r)
        z = get_overwrite(force, "z", i, z)
        z = z.masked_fill((i >= lengths)[..., None], 1.0)
        n = torch.tanh(F.linear(x[i], win, bin) + r * F.linear(h, whn, bhn))
        n = get_overwrite(force, "n", i, n)
        h = (1 - z) * n + z * h

        zs.append(z)
        hs.append(h)
        rs.append(r)
        ns.append(n)
        # print(h.shape)
        # out, state = model(x[i:i+1], h.unsqueeze(0))
        # h = h.squeeze(0)
        # hs.append(h)

    return torch.stack(hs, dim=0), torch.stack(rs, dim=0), torch.stack(zs, dim=0), torch.stack(ns, dim=0)


def encode_with_state(self, inp: torch.Tensor, in_len: torch.Tensor, force={}) -> torch.Tensor:
    x = self.embedding(inp.long())

    state = None

    states, rs, zs, ns = gru_run(self.rnn, x, state, in_len, force)

    return states[-1], (states, rs, zs, ns)

def decode_with_state(self, encoded_state: torch.Tensor, outp: torch.Tensor, out_len: torch.Tensor, force={}) -> torch.Tensor:
    if self.no_input:
        outp = torch.full_like(outp, self.no_input_token)

    x = F.pad(outp[:-1], (0, 0, 1, 0), value=self.sos_token)
    x = self.embedding(x.long())

    states, rs, zs, ns = gru_run(self.rnn, x, encoded_state, out_len, force)
    return self.fc(states), (states, rs, zs, ns)

task.model.encode_with_state = encode_with_state.__get__(task.model)
task.model.decode_with_state = decode_with_state.__get__(task.model)

gru_run_orig = gru_run

In [None]:
def run_model(data, enc_force={}, dec_force={}):
    task.set_eval()
    with torch.no_grad():
        state, enc_states = task.model.encode_with_state(data["in"], data["in_len"], force=enc_force)
        out, dec_states = task.model.decode_with_state(state, data["out"], data["out_len"], force=dec_force)

    return out, (enc_states, dec_states)

In [None]:
out, states = run_model(data)

In [None]:
# Verify if the network is repeating correctly
ok_mask = (data["out"] == out.argmax(dim=-1)) | (torch.arange(data["out"].shape[0], device=out.device)[:, None] >= data["out_len"][None])
seq_ok = ok_mask.all(dim=0)
seq_ok.float().mean()

In [None]:
def plot_states(data, state, bi=0):
    enc_states, dec_states = state
    # plt.figure()
    enc_states = (s[:data["in_len"][bi], bi] for s in enc_states)
    dec_states = (s[:data["out_len"][bi], bi] for s in dec_states)

    states, rs, zs, ns = (torch.cat([es, ds], dim=0) for es, ds in zip(enc_states, dec_states))

    in_data = data["in"][:data["in_len"][bi], bi]
    out_data = data["out"][:data["out_len"][bi], bi]
    in_data = task.train_set.in_vocabulary(in_data.cpu().numpy().tolist())
    out_data = task.train_set.out_vocabulary(out_data.cpu().numpy().tolist())

    ticks = range(len(in_data) + len(out_data)), in_data + ["S"] + out_data[:-1]

    fig, axs = plt.subplots(2,2, figsize=(10, 10))
    plt.axes(axs[0,0])
    plt.title("h[t]")
    plt.imshow(states.T.cpu().numpy(), aspect="auto", cmap="viridis", vmin=-1, vmax=1)
    plt.xticks(*ticks)

    plt.axes(axs[0,1])
    plt.title("z[t]")
    plt.imshow(zs.T.cpu().numpy(), aspect="auto", cmap="viridis", vmin=0, vmax=1)
    plt.xticks(*ticks)

    plt.axes(axs[1,0])
    plt.title("rs[t]")
    plt.imshow(rs.T.cpu().numpy(), aspect="auto", cmap="viridis", vmin=0, vmax=1)
    plt.xticks(*ticks)

    plt.axes(axs[1,1])
    plt.title("ns[t]")
    plt.imshow(ns.T.cpu().numpy(), aspect="auto", cmap="viridis", vmin=-1, vmax=1)
    plt.xticks(*ticks)

    # plt.show()
    return fig

In [None]:
plot_states(data, states, 1)


In [None]:
plot_states(data, states, 2)

In [None]:
plot_states(data, states, 3)

In [None]:
def create_input(input: str) -> torch.Tensor:
    a = task.train_set.in_vocabulary(input)
    inp = torch.tensor(a, device=helper.device).unsqueeze(1)
    in_len = torch.tensor([len(a)], device=helper.device)

    return inp, in_len

In [None]:
def run_model_on_str(input: str, enc_force={}, dec_force={}):
    inp, in_len = create_input(input)
    data = {"in": inp, "in_len": in_len, "out": inp, "out_len": in_len}
    out, states = run_model(data, enc_force, dec_force)

    out = out.argmax(dim=-1).squeeze(1)
    print(task.train_set.out_vocabulary(out.cpu().numpy()))
    print((out==inp.squeeze(1)).int().cpu().numpy())

    return plot_states(data, states), (data, states)

In [None]:
run_model_on_str("a b 1 c a d f a")

In [None]:
run_model_on_str("4 c a f g 6 1 3")

In [None]:
os.makedirs("states", exist_ok=True)
for i in range(10):
    seq = torch.randint(0, len(task.train_set.in_vocabulary), (8,))
    inp = " ".join(task.train_set.in_vocabulary(seq.cpu().numpy()))
    fig, _ = run_model_on_str(inp)
    plt.tight_layout()
    fig.savefig(f"states/{i}.png")
    plt.close(fig)
    del fig

In [None]:
_, s1 = run_model_on_str("a b 1 c a d f a")

In [None]:
run_model_on_str("4 c a f g 6 1 3", {"z":})

In [None]:
_, s1 = run_model_on_str("a c f d 1 g 3")


In [None]:
_, s1 = run_model_on_str("a a a a a a a")


In [None]:
_, s1 = run_model_on_str("a b 1 c a d f 2")

In [None]:
def plot_gate(data, state, bi=0):
    enc_states, dec_states = state
    # plt.figure()
    enc_states = (s[:data["in_len"][bi], bi] for s in enc_states)
    dec_states = (s[:data["out_len"][bi], bi] for s in dec_states)

    states, rs, zs, ns = (torch.cat([es, ds], dim=0) for es, ds in zip(enc_states, dec_states))

    in_data = data["in"][:data["in_len"][bi], bi]
    out_data = data["out"][:data["out_len"][bi], bi]
    in_data = task.train_set.in_vocabulary(in_data.cpu().numpy().tolist())
    out_data = task.train_set.out_vocabulary(out_data.cpu().numpy().tolist())

    ticks = range(len(in_data) + len(out_data)), in_data + ["S"] + out_data[:-1]

    fig, axs = plt.subplots(1,1, figsize=[3.2, 3])
    plt.axes(axs)
    plt.title("$z_t$")
    plt.imshow(1-zs.T.cpu().numpy(), aspect="auto", cmap="viridis", vmin=0, vmax=1)
    plt.xticks(*ticks)
    plt.colorbar()
    plt.xlabel("$i_t$")
    plt.ylabel("$z_t [j]$")
    plt.yticks([],[])

    # plt.show()
    return fig

In [None]:
fig = plot_gate(*s1)
fig.tight_layout()
fig.savefig("gate.pdf", bbox_inches="tight")