In [None]:
import os
import sys

sys.argv += "-log tb -name tst -reset 1 -state_size 64 -load_pretrained_model ../checkpoints/gru_64_with_input.pth -var_analysis.no_input 0".split(" ")

# Pretend we are in the main directory
os.chdir("..")

In [None]:
from main import initialize
import torch
import torch.nn.functional as F

In [None]:
%matplotlib inline
from matplotlib import pyplot as plt
plt.rcParams['figure.dpi'] = 150
plt.rcParams['savefig.dpi'] = 150

In [None]:
# Note: checkpoints have all arguments saved
helper, task = initialize()

In [None]:
def plot_states(bi, enc_states, dec_states, data, diff=False):
    # plt.figure()


    # plt.imshow(all_states[:, bi].T.cpu().numpy(), aspect="auto", cmap="viridis")

    all_states = torch.cat([enc_states[:data["in_len"][bi]], dec_states[:data["out_len"][bi]]], dim=0)
    all_states = all_states.permute(0,2,1,3).flatten(2)

    diff_states = (all_states - F.pad(all_states[:-1], (0, 0, 0, 0, 1, 0), value=0))

    in_data = data["in"][:data["in_len"][bi], bi]
    out_data = data["out"][:data["out_len"][bi], bi]
    in_data = task.train_set.in_vocabulary(in_data.cpu().numpy().tolist())
    out_data = task.train_set.out_vocabulary(out_data.cpu().numpy().tolist())

    fig, axs = plt.subplots(2)
    plt.axes(axs[0])
    plt.imshow(all_states[:, bi].T.cpu().numpy(), aspect="auto", cmap="viridis")
    plt.xticks(range(len(in_data) + len(out_data)), in_data + ["S"] + out_data[:-1])

    plt.axes(axs[1])
    plt.imshow(diff_states[:, bi].T.cpu().numpy(), aspect="auto", cmap="viridis")
    plt.xticks(range(len(in_data) + len(out_data)), in_data + ["S"] + out_data[:-1])

    plt.show()

In [None]:
def encode_with_state(self, inp: torch.Tensor, in_len: torch.Tensor) -> torch.Tensor:
    x = self.embedding(inp.long())

    state = None
    states = []

    for i in range(x.shape[0]):
        _, new_state = self.rnn(x[i:i+1], state)
        state = torch.where((i < in_len).view(1, -1, 1), new_state, state) if state is not None else new_state
        states.append(state)

    states = torch.stack(states, dim=0)
    return states[-1], states

def decode_with_state(self, encoded_state: torch.Tensor, outp: torch.Tensor, out_len: torch.Tensor) -> torch.Tensor:
    if self.no_input:
        outp = torch.full_like(outp, self.no_input_token)

    x = F.pad(outp[:-1], (0, 0, 1, 0), value=self.sos_token)
    x = self.embedding(x.long())

    out_seq = []
    state = encoded_state
    states = []

    for i in range(x.shape[0]):
        out, state = self.rnn(x[i:i+1], state)
        states.append(state)
        out_seq.append(out)

    out = torch.cat(out_seq, dim=0)
    states = torch.stack(states, dim=0)
    return self.fc(out), states

task.model.encode_with_state = encode_with_state.__get__(task.model)
task.model.decode_with_state = decode_with_state.__get__(task.model)

In [None]:
def create_input(input: str) -> torch.Tensor:
    a = task.train_set.in_vocabulary(input)
    inp = torch.tensor(a, device=helper.device).unsqueeze(1)
    in_len = torch.tensor([len(a)], device=helper.device)

    return inp, in_len

In [None]:
def run_model(input: str):
    inp, in_len = create_input(input)

    task.set_eval()
    with torch.no_grad():
        encoded_state, states = task.model.encode_with_state(inp, in_len)
        out, dec_states = task.model.decode_with_state(encoded_state, inp, in_len)

        out = out.argmax(dim=-1).squeeze(1)
        print(task.train_set.out_vocabulary(out.cpu().numpy()))
        print((out==inp.squeeze(1)).int().cpu().numpy())

        plot_states(0, states, dec_states, {"in": inp, "in_len": in_len, "out": inp, "out_len": in_len})

In [None]:
run_model("b b b b c")

In [None]:
run_model("b a a b b b")


In [None]:
def plot_diff(i1: str, i2: str):
    inputs = [create_input(i1), create_input(i2)]
    states = []

    task.set_eval()
    with torch.no_grad():
        for i, (inp, in_len) in enumerate(inputs):
            encoded_state, enc_states = task.model.encode_with_state(inp, in_len)
            out, dec_states = task.model.decode_with_state(encoded_state, inp, in_len)

            all_states = torch.cat([enc_states, dec_states], dim=0).squeeze(-2)
            all_states = all_states.permute(0,2,1).flatten(1)
            states.append(all_states)

            out = out.argmax(dim=-1).squeeze(1)
            print(task.train_set.out_vocabulary(out.cpu().numpy()))
            print((out==inp.squeeze(1)).cpu().numpy())


    plt.figure()
    plt.imshow((states[0]-states[1]).T.cpu().numpy(), aspect="auto", cmap="viridis", vmin=-1, vmax=1)

    labels = [f"{a}/{b}" for a, b in zip(i1.split(),i2.split())]
    xtok = labels + ["S"] + labels[:-1]
    plt.xticks(range(len(xtok)), xtok)
    plt.colorbar()

In [None]:
plot_diff("a a a a a", "a a b a a")

In [None]:
plot_diff("a a a c a", "a a b a a")

In [None]:
run_model("a a a a a")

In [None]:
run_model("b b b b b")

In [None]:
run_model("a b c d c b a")

In [None]:
plot_diff("a b c d a a a a", "c a d b a a a a")