In [1]:
# THis notebook is to show the EnhancedBrain input, output, and state of mind at once.
# I'll start with a rough version and make it cleaner over time

In [3]:
from visual_transformer import *
from visual_transformer.enhanced_model import *

In [4]:
from IPython.display import display
from ipywidgets import widgets

import numpy as np
import matplotlib.pyplot as plt

import time


In [12]:
device = torch.device('cuda:0')

In [5]:
def pre_imshow_numpy(torch_img, imshow=False):
    clean = torch_img.detach().cpu()
    right_order = torch.permute(clean, (1, 2, 0))
    array = right_order.numpy()
    plt.imshow(array)
    return array

In [6]:
output = widgets.Output(layout={'border': '1px solid black'})

In [7]:
model = EnhancedAgentBrain()

In [9]:
# Some other elements to use

from game import *

game_settings = BIG_tool_use_advanced_2_5
game_settings.gameSize = 224 # for compatibility with brain's expected size
G = discreteGame(game_settings)

########

vocab_size = 10000
# tokenizer.save_model(".", "tokenizer/eng_sentences_tokenizer_vc10000")
tokenizer = ByteLevelBPETokenizer(
    "./text_pretraining_tokenizer/eng_sentences_tokenizer_vc10000_v2-vocab.json",
    "./text_pretraining_tokenizer/eng_sentences_tokenizer_vc10000_v2-merges.txt",
)   
tokenizer._tokenizer.post_processor = BertProcessing(
    ("</s>", tokenizer.token_to_id("</s>")),
    ("<s>", tokenizer.token_to_id("<s>")),
)   
tokenizer.enable_truncation(max_length=32)
tokenizer.enable_padding()

In [10]:
T = widgets.Textarea(
    value='Hello World',
    placeholder='Type something',
    description='String:',
    disabled=False
)

In [13]:
def get_image(game=None):
    if game is None:
        game = G
    img = torch.tensor(G.getData()).unsqueeze(0)
    img = torch.permute(img, (0, 3, 1, 2)).contiguous().to(device)
    return img

In [15]:
# modify this for other inputs
# inp_tensor = get_image()
inp_tensor = None

In [28]:
tt = torch.tensor(tokenizer.encode("<s>Hello World!</s>").ids).unsqueeze(0).contiguous().to(device)
tt

tensor([[   0,   32,   87,   34, 5411, 4226,    5,   32,   19,   87,   34,    2]],
       device='cuda:0')

In [29]:
tokenizer.decode(tt[0][1:-1].cpu().numpy())

'<s>Hello World!</s>'

In [31]:
"<s>Hello World!</s>".find('</s>')

15

In [20]:
@output.capture()
def display_innards(b):
    print("Canvases:\n")
    if model.canvases.is_empty():
        print("################\nCanvases object is empty, nothing to show\n################")
    else:
        for i in range(model.canvases.num_canvases):
            print(f"##########\nCanvas {i}:\n")
            plt.imshow(pre_imshow_numpy(model.canvases.tw[i][:1]))
            plt.show()
            

In [23]:
@output.capture()
def forward_wrapper(b):
    output.clear_output()
    if inp_tensor is None:
        print("input tensor is None; using input from the game\n\n")
        local_tensor = get_image()
    else:
        print("using global variable inp_tensor as image input\n\n")
        local_tensor = inp_tensor
    tt = torch.tensor(tokenizer.encode(T.value).ids).unsqueeze(0).contiguous().to(device)
    _, recon = model(tt, local_tensor, create_context=True, ret_images=True)
    print("output image:\n")
    plt.imshow(pre_imshow_numpy(recon[:1]))
    plt.show()
    display_innards(b)

In [32]:
# wrapper for extending input text without changing what is displayed, much

temp=1.0
temp_eps = 1e-4

@output.capture()
def extend_wrapper(b):
    output.clear_output()
    if inp_tensor is None:
        print("input tensor is None; using input from the game\n\n")
        local_tensor = get_image()
    else:
        print("using global variable inp_tensor as image input\n\n")
        local_tensor = inp_tensor

    # in this case, we abridge the end-of-sentence token in order to continue the extension
    inp_ids = tokenizer.encode(T.value).ids[:-1]
    if len(inp_ids) > 32:
        inp_ids = inp_ids[-32:] # the rest should be in memory
    tt = torch.tensor(inp_ids).unsqueeze(0).contiguous().to(device)
    logits, recon = model(tt, local_tensor, return_full=False, create_context=True, ret_images=True)

    s = tt.size()
    output = torch.zeros((s[0], s[1] +1), dtype = torch.long, device = seed.device)
    output[:, :-1] += tt
    
    preds = model.select(logits, temp, ret_all, temp_eps)
    output[:, -1] += preds

    T.value = tokeniser.decode(output[0][1:].cpu().numpy()) # update the string, cut off start token
    print("updated input string (alse in text box):\n")
    print(T.value)
    
    print("output image:\n")
    plt.imshow(pre_imshow_numpy(recon[:1]))
    plt.show()
    display_innards(b)

In [33]:
@output.capture()
def soft_reset_wrapper(b):
    print("soft reset (removing internal gradients)\n")
    model.soft_reset()

@output.capture()
def reset_wrapper(b):
    print("hard reset (clering memory and canvaases)\n")
    model.reset()

In [36]:
max_len = 32 # max len of the model input. Make this a text box or a selector or a knob.

@output.capture()
def generate_wrapper(b):
    # while the string is not too long nor contains the stop codon
    while (T.value.find('</s>') == -1) and (len(tokenizer.encode(T.value).ids[:-1]) < max_len):
        extend_wrapper(b)