# Instructions
**This notebook contains two custom finetuned models called 2.7B-horni and 2.7B-horni-ln which are NOT official EleutherAI models or affiliated with them in any way beyond using their model as a basis. You can also select the original EleutherAI/gpt-neo-2.7B model in the model selection's dropdown menu.**

Go through each cell in this notebook one by one, take a look at the options and descriptions and then press the play button to the left of it. You can skip the optional one. Don't skip any of the others. After running the "Play" cell, a small form will appear underneath, which you can use to actually play. Check out this [guide](https://gitlab.com/nolialsea/novelaicontributions/-/blob/master/Guides/Using%20finetuneanon%27s%20Colab.md) for details.

To reset the state of your game, run the "Setup" cell again. Closing the notebook will lose your progress, so if you want to keep your story, use the "history" action, copy out your story to a text editor. You can also copy out your author's note and memory from the output of the "info" action.

The most reliable way of loading the models is to store them in your google drive. This notebook will automatically download the models from a google drive in the model setup step. If this succeeds, you can then copy it into your drive in the optional following step. You can also download the files yourself and upload them to your drive yourself.

* [2.7B-horni](https://mega.nz/file/6BNykLJb#B6gxK3TnCKBpeOF1DJMXwaLc_gcTcqMS0Lhzr1SeJmc) [[Google Drive](https://drive.google.com/file/d/1-Jj_hlyNCQxuSnK7FFBXREGnRSMI5MoF/view?usp=sharing)] 5GB, for NSFW styled output
* [2.7B-horni-ln](https://mega.nz/file/rQcWCTZR#tCx3Ztf_PMe6OtfgI95KweFT5fFTcMm7Nx9Jly_0wpg) [[Google Drive](https://drive.google.com/file/d/1M1JY459RBIgLghtWDRDXlD4Z5DAjjMwg/view?usp=sharing)] 5GB, for light novel styled output
* Torrent: magnet:?xt=urn:btih:31d956ff4a248dcf914b1b7e474cbac02d70d6a4&dn=gtp-neo-horni&tr=http%3A%2F%2Fopenbittorrent.com%3A80%2Fannounce

Anon says about Google Drive: If you run into quotas, create a folder in your own google drive, add a shortcut to the file in question in said folder, and then download the folder itself instead of the file directly.

If you have an GTX card with 8GB or more, you may be able to run this notebook locally. If you integrate the models in clover edition, use the transformers fork that's installed in the setup cell of this notebook to avoid OOM errors and use model.generate() to be able to use repetition_penalty_range and repetition_penalty_slope. For running locally, remove the map_device arguments from torch.load calls.


In [None]:
#@title Setup
#@markdown Run this for setting up dependencies or resetting actions
!pip install git+https://github.com/finetuneanon/transformers@gpt-neo-dungeon-localattention2
#!wget -c http://ftp.us.debian.org/debian/pool/main/m/megatools/megatools_1.11.0~git20200404-1_amd64.deb -O megatools.deb
#!dpkg -i megatools.deb
!pip install gdown
!nvidia-smi

import os

from transformers import GPTNeoForCausalLM, AutoTokenizer
import tarfile
import codecs
import torch
import subprocess
import copy

from IPython.display import HTML, display
import ipywidgets as widgets

def set_css():
  display(HTML('''
  <style>
    pre {
        white-space: pre-wrap;
    }
  </style>
  '''))

try:
  initialized += 1
except:
  get_ipython().events.register('pre_run_cell', set_css)
  tail_free_sampling, top_k, top_p, temperature, number_generated_tokens, repetition_penalty, repetition_penalty_range, repetition_penalty_slope, number_show_last_actions = 0.95, 60, 0.9, 0.8, 40, 1.25, 300, 3.33, 15
  prevent_square_brackets, prevent_angle_brackets, prevent_curly_brackets = True, True, True
  enable_top_k, enable_top_p, enable_tfs = False, False, True
  bad_words_ids = None
  initialized = 0

actions = []
memory = ("", torch.zeros((1, 0)).long())
lmi = ["", torch.zeros((1, 0)).long()]
an = ("", torch.zeros((1, 0)).long())
an_depth = 3
history = None

In [None]:
#@title Model setup
#@markdown horni was finetuned for one epoch on about 800MB worth of random blocks of text from literotica. Do not use the horni model if you dislike NSFW outputs. horni-ln uses horni as a base and was finetuned for one epoch on 579MB of text from a light novel dataset.

print("Setting up model, this will take a few minutes")

model_name = "2.7B-horni" #@param ["2.7B-horni", "2.7B-horni-ln", "EleutherAI/gpt-neo-2.7B"]
model_gdrive = "/content/drive/MyDrive/gpt-neo-2.7B-horni.tar" #@param {type:"string"}
use_gdrive = False #@param {type:"boolean"}
#@markdown If you download errors, the google drive downloads might be over their daily download quota. In that case, right-click, select "interrupt execution", download the checkpoint from mega yourself, upload to your google drive, tick use_gdrive and put the correct filename, e.g. `gpt-neo-2.7B-horni-ln.tar` and restart the cell.
#@markdown
#@markdown Warnings about certain attention bias parameters being uninitialized or about the google drive already having been mounted can be ignored.

custom_models = ["2.7B-horni", "2.7B-horni-ln"]

if use_gdrive:
  from google.colab import drive
  drive.mount('/content/drive')

#model_types = {"2.7B-horni": "https://mega.nz/file/6BNykLJb#B6gxK3TnCKBpeOF1DJMXwaLc_gcTcqMS0Lhzr1SeJmc",
#               "2.7B-horni-ln": "https://mega.nz/file/rQcWCTZR#tCx3Ztf_PMe6OtfgI95KweFT5fFTcMm7Nx9Jly_0wpg"}
model_types = {"2.7B-horni": "https://drive.google.com/uc?id=1-Jj_hlyNCQxuSnK7FFBXREGnRSMI5MoF",
               "2.7B-horni-ln": "https://drive.google.com/uc?id=1M1JY459RBIgLghtWDRDXlD4Z5DAjjMwg"}

model = None
tokenizer = None
pipeline = None
checkpoint = None

if not os.path.isdir("gpt-neo-"+model_name) and model_name in custom_models:
  if use_gdrive:
    tar = tarfile.open(model_gdrive, "r")
  else:
    model_url = model_types[model_name]
    print("Downloading:", model_url)
    #!megadl $model_url --no-ask-password
    !gdown $model_url
    tar = tarfile.open(model_name + ".tar", "r")
  tar.extractall()
  tar.close()

if model_name in custom_models:
  checkpoint = torch.load("gpt-neo-" + model_name + "/pytorch_model.bin", map_location="cpu")
  model = GPTNeoForCausalLM.from_pretrained("gpt-neo-" + model_name, state_dict=checkpoint).half().to("cpu").eval()
  for k in list(checkpoint.keys()):
    del checkpoint[k]
  del checkpoint
else:
  from transformers.file_utils import cached_path, WEIGHTS_NAME, hf_bucket_url
  archive_file = hf_bucket_url(model_name, filename=WEIGHTS_NAME)
  resolved_archive_file = cached_path(archive_file)
  checkpoint = torch.load(resolved_archive_file, map_location="cpu")
  for k in checkpoint.keys():
    checkpoint[k] = checkpoint[k].half()
  model = GPTNeoForCausalLM.from_pretrained(model_name, state_dict=checkpoint).half().to("cpu").eval()
  for k in list(checkpoint.keys()):
    del checkpoint[k]
  del checkpoint
tokenizer = AutoTokenizer.from_pretrained("gpt2")
"""
if torch.cuda.get_device_properties(0).total_memory > 15000 * 1024 * 1024:
  print("Big GPU detected, using fp32")
  model = model.float()
"""    

In [None]:
#@title Copy downloaded model to google drive (optional)
#@markdown If the model checkpoint was downloaded automatically in the previous step, you can copy it to your google drive here for more reliable access in the future
gdrive_target = "/content/drive/MyDrive/gpt-neo-2.7B-horni.tar" #@param {type:"string"}
copy_model_file = False #@param {type:"boolean"}

if copy_model_file:
  from google.colab import drive
  drive.mount('/content/drive')
  model_tar = '/content/' + model_name + ".tar"
  !cp -v $model_tar $gdrive_target

In [None]:
#@title Sampling settings (DO NOT SKIP)
#@markdown You can modify sampling settings here. Don't forget to run the cell again after changing. The number of generated tokens is subtracted from the context window size, don't set it high.
tail_free_sampling = 0.95 #@param {type:"number"}
top_k = 60 #@param {type:"number"}
top_p = 0.9 #@param {type:"number"}
temperature =  0.8#@param {type:"number"}
number_generated_tokens =  40#@param {type:"integer"}
repetition_penalty = 1.25 #@param {type:"number"}
repetition_penalty_range = 512 #@param {type:"number"}
repetition_penalty_slope = 3.33 #@param {type:"number"}
number_show_last_actions = 15 #@param {type:"integer"}

#@markdown If tail free sampling is enabled, top_p and top_k should probably not be used.
enable_tfs = True #@param {type:"boolean"}
enable_top_k = False #@param {type:"boolean"}
enable_top_p = False #@param {type:"boolean"}

if not enable_tfs:
  tail_free_sampling = None
if not enable_top_k:
  top_k = None
if not enable_top_p:
  top_p = None

#@markdown Temperatures seem to give results different from those in AID, so play around with it. Even 0.5 can give good results.

In [None]:
#@title Prevent tokens like [], <> and {} from being generated
#thanks STARSTRUCK

prevent_square_brackets = True #@param {type:"boolean"}
prevent_angle_brackets = True #@param {type:"boolean"}
prevent_curly_brackets = True #@param {type:"boolean"}

vocab = tokenizer.get_vocab()
vocab_keys = vocab.keys()
bad_keys = list()
find_keys = lambda char : [key for key in vocab_keys if key.find(char) != -1]

if prevent_square_brackets:
  bad_keys.extend(find_keys("["))
  bad_keys.extend(find_keys("]"))

if prevent_angle_brackets:
  bad_keys.extend(find_keys("<"))
  bad_keys.extend(find_keys(">"))

if prevent_curly_brackets:
  bad_keys.extend(find_keys("{"))
  bad_keys.extend(find_keys("}"))

bad_words_ids = list()
bad_keys_final = list()
for key in bad_keys:
  if key == "<|endoftext|>" or key in bad_keys_final:
    continue
  bad_id = vocab[key]
  bad_words_ids.append([bad_id])
  bad_keys_final.append(key)

if len(bad_words_ids) < 1:
  bad_words_ids = None

#print(f"Bad keys: {bad_keys_final} (Count: {len(bad_keys)})")
#print(f"Bad ids: {bad_words_ids}")

In [None]:
#MODIFY FORWARD FUNCTION TO SPLIT MODEL
number_of_parts = 32 #CAN BE 2 or 4 as well
from transformers import GPTNeoModel
from transformers.modeling_outputs import BaseModelOutputWithPast
def new_forward(
    self,
    input_ids=None,
    past_key_values=None,
    attention_mask=None,
    token_type_ids=None,
    position_ids=None,
    head_mask=None,
    inputs_embeds=None,
    use_cache=None,
    output_attentions=None,
    output_hidden_states=None,
    return_dict=None,
):
    global number_of_parts
    output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
    output_hidden_states = (
        output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
    )
    use_cache = use_cache if use_cache is not None else self.config.use_cache
    return_dict = return_dict if return_dict is not None else self.config.use_return_dict

    if input_ids is not None and inputs_embeds is not None:
        raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
    elif input_ids is not None:
        input_shape = input_ids.size()
        input_ids = input_ids.view(-1, input_shape[-1])
        batch_size = input_ids.shape[0]
    elif inputs_embeds is not None:
        input_shape = inputs_embeds.size()[:-1]
        batch_size = inputs_embeds.shape[0]
    else:
        raise ValueError("You have to specify either input_ids or inputs_embeds")

    if token_type_ids is not None:
        token_type_ids = token_type_ids.view(-1, input_shape[-1])
    if position_ids is not None:
        position_ids = position_ids.view(-1, input_shape[-1])

    if past_key_values is None:
        past_length = 0
        past_key_values = tuple([None] * len(self.h))
    else:
        past_length = past_key_values[0][0].size(-2)
    if position_ids is None:
        device = input_ids.device if input_ids is not None else inputs_embeds.device
        position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
        position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])

    # Attention mask.
    if attention_mask is not None:
        assert batch_size > 0, "batch_size has to be defined and > 0"
        global_attention_mask = attention_mask.view(batch_size, -1)
        # We create a 3D attention mask from a 2D tensor mask.
        # Sizes are [batch_size, 1, 1, to_seq_length]
        # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
        # this attention mask is more simple than the triangular masking of causal attention
        # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
        global_attention_mask = global_attention_mask[:, None, None, :]

        # Since global_attention_mask is 1.0 for positions we want to attend and 0.0 for
        # masked positions, this operation will create a tensor which is 0.0 for
        # positions we want to attend and -10000.0 for masked positions.
        # Since we are adding it to the raw scores before the softmax, this is
        # effectively the same as removing these entirely.
        global_attention_mask = global_attention_mask.to(dtype=self.dtype)  # fp16 compatibility
        global_attention_mask = (1.0 - global_attention_mask) * -10000.0
    else:
        global_attention_mask = None

    # Prepare head mask if needed
    # 1.0 in head_mask indicate we keep the head
    # attention_probs has shape bsz x num_headss x N x N
    # head_mask has shape n_layer x batch x num_headss x N x N
    head_mask = self.get_head_mask(head_mask, self.config.num_layers)

    if inputs_embeds is None:
        inputs_embeds = self.wte(input_ids)
    position_embeds = self.wpe(position_ids)
    hidden_states = inputs_embeds + position_embeds

    if token_type_ids is not None:
        token_type_embeds = self.wte(token_type_ids)
        hidden_states = hidden_states + token_type_embeds

    hidden_states = self.drop(hidden_states)

    output_shape = input_shape + (hidden_states.size(-1),)

    presents = () if use_cache else None
    all_self_attentions = () if output_attentions else None
    all_hidden_states = () if output_hidden_states else None

    for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
        if number_of_parts == 2:
            if i == 0:
                cudastreams = {}
                for j in range(0,16):
                    cudastreams[j] = torch.cuda.Stream()
                    for param1,param2 in zip(self.h[j].parameters(),self.h[j+16].parameters()):
                        param1.data = param2.data
                        
                for j in range(0,16):
                    with torch.cuda.stream(cudastreams[j]):
                        for param1,param2 in zip(self.h[j].parameters(),self.extrastorage[j].parameters()):
                            param1.data.copy_(param2.data, non_blocking=True)
                        self.h[j].to("cuda", non_blocking=True)
                        
                torch.cuda.synchronize()
                del cudastreams
                
            if i == 16:
                cudastreams = {}
                for j in range(16,32):
                    cudastreams[j] = torch.cuda.Stream()
                    for param1,param2 in zip(self.h[j].parameters(),self.h[j-16].parameters()):
                        param1.data = param2.data
                        
                for j in range(16,32):  
                    with torch.cuda.stream(cudastreams[j]):
                        for param1,param2 in zip(self.h[j].parameters(),self.extrastorage[j].parameters()):
                            param1.data.copy_(param2.data, non_blocking=True)
                            pass
                        self.h[j].to("cuda", non_blocking=True)
                torch.cuda.synchronize()
                del cudastreams
                
        if number_of_parts == 4:
            if i == 0:
                cudastreams = {}
                for j in range(0,8):
                    cudastreams[j] = torch.cuda.Stream()
                    for param1,param2 in zip(self.h[j].parameters(),self.h[j+24].parameters()):
                        param1.data = param2.data
                for j in range(0,8):
                    with torch.cuda.stream(cudastreams[j]):
                        for param1,param2 in zip(self.h[j].parameters(),self.extrastorage[j].parameters()):
                            param1.data.copy_(param2.data, non_blocking=True)
                        self.h[j].to("cuda", non_blocking=True)
                torch.cuda.synchronize()
                del cudastreams
                
            if i == 8:
                cudastreams = {}
                for j in range(8,16):
                    cudastreams[j] = torch.cuda.Stream()
                    for param1,param2 in zip(self.h[j].parameters(),self.h[j-8].parameters()):
                        param1.data = param2.data
                for j in range(8,16):
                    with torch.cuda.stream(cudastreams[j]):
                        for param1,param2 in zip(self.h[j].parameters(),self.extrastorage[j].parameters()):
                            param1.data.copy_(param2.data, non_blocking=True)
                        self.h[j].to("cuda", non_blocking=True)
                torch.cuda.synchronize()
                del cudastreams
                    
            if i == 16:
                cudastreams = {}
                for j in range(16,24):
                    cudastreams[j] = torch.cuda.Stream()
                    for param1,param2 in zip(self.h[j].parameters(),self.h[j-8].parameters()):
                        param1.data = param2.data
                for j in range(16,24):
                    with torch.cuda.stream(cudastreams[j]):
                        for param1,param2 in zip(self.h[j].parameters(),self.extrastorage[j].parameters()):
                            param1.data.copy_(param2.data, non_blocking=True)
                        model.transformer.h[j].to("cuda", non_blocking=True)
                torch.cuda.synchronize()
                del cudastreams
                
            if i == 24:
                cudastreams = {}
                for j in range(24,32):
                    cudastreams[j] = torch.cuda.Stream()
                    for param1,param2 in zip(self.h[j].parameters(),self.h[j-8].parameters()):
                        param1.data = param2.data
                for j in range(24,32):
                    with torch.cuda.stream(cudastreams[j]):
                        for param1,param2 in zip(self.h[j].parameters(),self.extrastorage[j].parameters()):
                            param1.data.copy_(param2.data, non_blocking=True)
                        self.h[j].to("cuda", non_blocking=True)
                torch.cuda.synchronize()
                del cudastreams
                
        if number_of_parts == 32:
            
            if i == 0:
                for param1,param2 in zip(self.h[i].parameters(),self.h[31].parameters()):
                    param1.data = param2.data
                    
                for param1,param2 in zip(self.h[0].parameters(),self.extrastorage[0].parameters()):
                    param1.data = param2.data.to("cuda", non_blocking=True)
                self.h[0].to("cuda", non_blocking=True)
                    
                    
            if i >= 1:
                for param1,param2 in zip(self.h[i].parameters(),self.h[i-1].parameters()):
                    param1.data = param2.data
                    
                for param1,param2 in zip(self.h[i].parameters(),self.extrastorage[i].parameters()):
                    param1.data.copy_(param2.data, non_blocking=True)
                self.h[i].to("cuda", non_blocking=True)
        
        attn_type = self.config.attention_layers[i]
        attn_mask = global_attention_mask if attn_type == "global" else attention_mask

        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        if getattr(self.config, "gradient_checkpointing", False) and self.training:

            if use_cache:
                logger.warning(
                    "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
                    "`use_cache=False`..."
                )
                use_cache = False

            def create_custom_forward(module):
                def custom_forward(*inputs):
                    # None for past_key_value
                    return module(*inputs, use_cache, output_attentions)

                return custom_forward

            outputs = torch.utils.checkpoint.checkpoint(
                create_custom_forward(block),
                hidden_states,
                None,
                attn_mask,
                head_mask[i],
            )
        else:
            outputs = block(
                hidden_states,
                layer_past=layer_past,
                attention_mask=attn_mask,
                head_mask=head_mask[i],
                use_cache=use_cache,
                output_attentions=output_attentions,
            )

        hidden_states = outputs[0]
        if use_cache is True:
            presents = presents + (outputs[1],)

        if output_attentions:
            all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)

    hidden_states = self.ln_f(hidden_states)

    hidden_states = hidden_states.view(*output_shape)
    # Add last hidden state
    if output_hidden_states:
        all_hidden_states = all_hidden_states + (hidden_states,)

    if not return_dict:
        return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)

    return BaseModelOutputWithPast(
        last_hidden_state=hidden_states,
        past_key_values=presents,
        hidden_states=all_hidden_states,
        attentions=all_self_attentions,
    )
GPTNeoModel.forward = new_forward
model.eval().half().to("cpu")
model.transformer.wte.to("cuda")
model.transformer.wpe.to("cuda")
model.transformer.ln_f.to("cuda")
model.lm_head.to("cuda")
torch.cuda.empty_cache()
for param in model.transformer.wte.parameters():
    param.requires_grad = False
for param in model.transformer.wpe.parameters():
    param.requires_grad = False
for i in range(32):
    for param in model.transformer.h[i].parameters():
        param.requires_grad = False
for param in model.transformer.ln_f.parameters():
    param.requires_grad = False
for param in model.lm_head.parameters():
    param.requires_grad = False
setattr(model.transformer,"extrastorage",None)
model.transformer.extrastorage = copy.deepcopy(model.transformer.h)
smalltensor = torch.tensor(0).to("cuda")
for j in range(32):
    for param1 in model.transformer.h[j].parameters():
        param1.data = smalltensor
gc.collect()
torch.cuda.empty_cache()
model.transformer.extrastorage.to("cpu")
for i in range(32):
    for param in model.transformer.extrastorage[i].parameters():
        param.requires_grad = False
        param.data.pin_memory()
gc.collect()
torch.cuda.empty_cache()
if number_of_parts == 2:
    for j in range(16,32):
        for param1,param2 in zip(model.transformer.h[j].parameters(),model.transformer.extrastorage[j].parameters()):
            param1.data = param2.data.to("cuda", non_blocking=True)
        model.transformer.h[j].to("cuda", non_blocking=True)  
    print("number_of_parts = 4" )
    
if number_of_parts == 4:
    for j in range(24,32):
        for param1,param2 in zip(model.transformer.h[j].parameters(),model.transformer.extrastorage[j].parameters()):
            param1.data = param2.data.to("cuda", non_blocking=True)
        model.transformer.h[j].to("cuda", non_blocking=True)  
    print("number_of_parts = 4" )
    
if number_of_parts == 32:
    for param1,param2 in zip(model.transformer.h[31].parameters(),model.transformer.extrastorage[31].parameters()):
        param1.data = param2.data.to("cuda", non_blocking=True)
    model.transformer.h[31].to("cuda", non_blocking=True)  
    print("number_of_parts = 32" )

In [None]:
#@title Basic sampling

#@markdown Use this cell if you just want to sample from the model in a free form way.

basic_prompt = "The rays of the evening sun falling in through the window bathed the room in a soft, warm light" #@param {type:"string"}

ids = tokenizer(basic_prompt, return_tensors="pt").input_ids.to("cpu")
n_ids = ids.shape[1]
if n_ids < 1:
  n_ids = 1
  ids = torch.tensor([[tokenizer.eos_token_id]])
max_length = n_ids + number_generated_tokens
torch.cuda.empty_cache()
basic_output = model.generate(
    ids.long().cuda(),
    do_sample=True,
    min_length=max_length,
    max_length=max_length,
    temperature=temperature,
    tfs = tail_free_sampling,
    top_k = top_k,
    top_p = top_p,
    repetition_penalty = repetition_penalty,
    repetition_penalty_range = repetition_penalty_range,
    repetition_penalty_slope = repetition_penalty_slope,
    use_cache=True,
    bad_words_ids=bad_words_ids,
    pad_token_id=tokenizer.eos_token_id
).long().to("cpu")
torch.cuda.empty_cache()

print(tokenizer.decode(basic_output[0]))

# Using the play function

If your prompt starts with a letter, try putting a space or newline in front.

* **generate** adds your prompt as an action and generates more output
* **continue** generates more output
* **edit** copies the last action into the prompt field and sets action to replace
* **replace** replaces the last output with the prompt and generates more, use this to edit
* **info** outputs LMI and memory
* **history** outputs all actions so far
* **memory** sets memory to the text in the prompt field
* **authorsnote** sets author's note to the text in the prompt field
* **andepth** sets the depth of the author's note to the number in the prompt
* **tokenize** tokenizes the text in the prompt field and outputs the number of tokens

In [None]:
#@title Play

action_type = "generate"
prompt = ""
need_refresh = True

action_types = ["generate", "continue", "edit", "replace", "undo", "retry", "memory", "authorsnote", "andepth", "info", "history", "tokenize"]

def assemble():
  remaining = (2048 - number_generated_tokens + 1) - memory[1].shape[1] - an[1].shape[1]
  n_actions = len(actions)
  n_ctx = 0
  back_i = n_actions
  for i in range(n_actions):
      i_action = n_actions - i - 1
      n_tok = actions[i_action][1].shape[1]
      if remaining > n_ctx + n_tok:
        n_ctx += n_tok
        back_i = i_action
      else:
        break
  lmi[0], lmi[1] = memory[0], memory[1]
  start = False
  if n_actions - back_i - 1 < an_depth:
    start = True
  while back_i < n_actions:
    if start or n_actions - back_i - 1 == an_depth:
      lmi[0] += an[0]
      lmi[1] = torch.cat([lmi[1].cpu(), an[1].cpu()], 1).long()
      start = False
    lmi[0] += actions[back_i][0]
    lmi[1] = torch.cat([lmi[1].cpu(), actions[back_i][1].cpu()], 1).long()
    back_i += 1

def clear_output():
  with out:
    IPython.display.clear_output()

def set_action(change):
  global action_type
  action_type = change.new

def set_prompt(change):
  global prompt
  prompt = change.new

@torch.no_grad()
def play(do_action=None):
  global memory, need_refresh, an, an_depth, action_type, history
  an_updated = False
  memory_updated = False
  if do_action is not None:
    action = do_action
    action_type = do_action
  else:
    action = action_type
  with out:
    if prompt in action_types:
      action == prompt
    else:
      if action == "edit":
        if len(actions) > 0:
          input.value = actions[-1][0]
        else:
          input.value = ""
        dropdown.value = "replace"
        return
      if action == "replace":
        if len(actions) > 0:
          actions.pop()
        need_refresh = True
        action = "generate"
      if action == "generate":
        text = prompt
        if len(text) > 0:
          for line in text.splitlines(True):
            tokens = tokenizer(line, return_tensors="pt").input_ids.to("cpu")
            actions.append((line, tokens))
        action = "continue"
      if action == "info":
        clear_output()
        print("LMI: " + lmi[0])
        print("LMI tokens: " + str(lmi[1].shape[1]))
        print("Memory: " + memory[0])
        print("Author's note: " + an[0])
        print("Author's note depth: " + str(an_depth))
        need_refresh = True
      if action == "history":
        clear_output()
        print("".join([action[0] for action in actions]), end="")
        need_refresh = False
      if action == "retry":
        if len(actions) > 0:
          actions.pop()
        need_refresh = True
        action = "continue"
      if action == "undo":
        if len(actions) > 0:
          actions.pop()
        assemble()
        clear_output()
        print("".join([action[0] for action in actions[-number_show_last_actions:]]), end="")
        need_refresh = False
      if action == "memory":
        if prompt == "":
          memory = ("", torch.zeros((1, 0)).long())
          text = ""
        else:
          text = codecs.decode(prompt + "\n", "unicode-escape")
          tokens = tokenizer(text, return_tensors="pt").input_ids.to("cpu")
          memory = (text, tokens)
        clear_output()
        print("Memory: " + text)
        memory_updated = True
      if action == "authorsnote":
        if prompt == "":
          an = ("", torch.zeros((1, 0)).long())
          text = ""
        else:
          text = "\n[Author's note: " + codecs.decode(prompt, "unicode-escape") + "]\n"
          tokens = tokenizer(text, return_tensors="pt").input_ids.to("cpu")
          an = (text, tokens)
        clear_output()
        print("Author's note: " + text)
        an_updated = True
      if action == "andepth":
        clear_output()
        try:
          an_depth = int(codecs.decode(prompt + "\n", "unicode-escape"))
        except:
          pass
        print("Author's note depth: " + str(an_depth))
        an_updated = True
      if action == "tokenize":
        text = codecs.decode(prompt, "unicode-escape")
        tokens = tokenizer(text, return_tensors="pt").input_ids.to("cpu")
        clear_output()
        print("Tokens: " + str(tokens.shape[1]))
        print(tokens[0])
        need_refresh = True
      if action == "continue":
        assemble()
        ids = lmi[1].cuda()
        n_ids = ids.shape[1]
        if n_ids < 1:
          n_ids = 1
          ids = torch.tensor([[tokenizer.eos_token_id]])
        max_length = number_generated_tokens + n_ids
        #ids[:, :] = 13
        torch.cuda.empty_cache()
        clear_output()
        gen_tokens = model.generate(
            ids.long().cuda(),
            do_sample=True,
            min_length=max_length,
            max_length=max_length,
            temperature=temperature,
            tfs = tail_free_sampling,
            top_k = top_k,
            top_p = top_p,
            repetition_penalty = repetition_penalty,
            repetition_penalty_range = repetition_penalty_range,
            repetition_penalty_slope = repetition_penalty_slope,
            use_cache=True,
            bad_words_ids=bad_words_ids,
            pad_token_id=tokenizer.eos_token_id
        ).long()
        stop_tokens = [0, 13, 30, 526, 764, 1701, 2474, 5145, 5633]
        for i in reversed(range(len(gen_tokens[0]))):
          if i < n_ids:
            gen_tokens = gen_tokens[0]
            break
          if gen_tokens[0][i] in stop_tokens:
            gen_tokens = gen_tokens[0][:i+1]
            break
        gen_text = tokenizer.decode(gen_tokens[n_ids:])
        if len(gen_text) > 0:
          actions.append((gen_text, gen_tokens[n_ids:].unsqueeze(0).cpu()))
        print("".join([action[0] for action in actions[-number_show_last_actions:]]), end="")
        torch.cuda.empty_cache()
        need_refresh = False
    if history is not None:
      if history:
        with out_history:
          IPython.display.clear_output()
          print("".join([action[0] for action in actions]), end="")
        with out_history2:
          IPython.display.clear_output()
      else:
        with out_history2:
          IPython.display.clear_output()
          print("".join([action[0] for action in actions]), end="")
        with out_history:
          IPython.display.clear_output()
      if an_updated:
        with out_an:
          IPython.display.clear_output()
          if len(an[0]) > 0:
            print("AN depth: " + str(an_depth) + "\n" + an[0], end="")
      if memory_updated:
        with out_memory:
          IPython.display.clear_output()
          print(memory[0], end="")
      history = not history

import ipywidgets as widgets
import IPython.display
out = widgets.Output(layout={'border': '1px solid black', "width": "1280px"})
dropdown = widgets.Dropdown(options=action_types, value=action_type, description='Action:', disabled=False)
dropdown.observe(set_action, 'value')
button = widgets.Button(description='[selected action]', disabled=False)
button.on_click(lambda _: play(dropdown.value))
generate_button = widgets.Button(description='Generate', disabled=False)
generate_button.on_click(lambda _: play("generate"))
edit_button = widgets.Button(description='Edit', disabled=False)
edit_button.on_click(lambda _: play("edit"))
retry_button = widgets.Button(description='Retry', disabled=False)
retry_button.on_click(lambda _: play("retry"))
continue_button = widgets.Button(description='Continue', disabled=False)
continue_button.on_click(lambda _: play("continue"))
undo_button = widgets.Button(description='Undo', disabled=False)
undo_button.on_click(lambda _: play("undo"))
dropdown_hbox = widgets.HBox([dropdown, button])
hbox = widgets.HBox([generate_button, edit_button, retry_button, continue_button, undo_button])
input = widgets.Textarea(value='', placeholder='', description='Input:', disabled=False, rows=4, layout={"width": "1280px"})
input.observe(set_prompt, 'value')

display(out, dropdown_hbox, hbox, input)

In [None]:
#@title History
#@markdown Run this cell to have an auto-updating full listing of the current story.

history = True
out_history = widgets.Output(layout={'border': '1px solid black', "width": "1280px"})
out_history2 = widgets.Output(layout={'border': '1px solid black', "width": "1280px"})
out_memory = widgets.Output(layout={'border': '1px solid black', "width": "1280px"})
out_an = widgets.Output(layout={'border': '1px solid black', "width": "1280px"})
display(out_history, out_history2, out_memory, out_an)

In [None]:
#@title Attention display
#@markdown If you don't know what this is, just ignore. attn_head_combination selects the operation used to combine layer and attention head results.

visualizer_prompt = "Before his eyes was an orange cat with stripes. Running his fingers through its soft fur, he admired" #@param {type:"string"}
max_attentions =  8#@param {type:"integer"}
attn_head_combination = "mean" #@param ["mean", "max"]
use_lmi_as_prompt = False #@param {type:"boolean"}

if use_lmi_as_prompt:
  ids = lmi[1]
else:
  ids = tokenizer(visualizer_prompt, return_tensors="pt").input_ids.to("cpu")

n_ids = ids.shape[1]
if n_ids < 1:
  n_ids = 1
  ids = torch.tensor([[tokenizer.eos_token_id]])

max_length = n_ids + number_generated_tokens
torch.cuda.empty_cache()
basic_output = model.generate(
    ids.long().cuda(),
    do_sample=True,
    min_length=max_length,
    max_length=max_length,
    temperature=temperature,
    tfs = tail_free_sampling,
    top_k = top_k,
    top_p = top_p,
    repetition_penalty = repetition_penalty,
    repetition_penalty_range = repetition_penalty_range,
    repetition_penalty_slope = repetition_penalty_slope,
    use_cache=True,
    bad_words_ids=bad_words_ids,
    pad_token_id=tokenizer.eos_token_id,
    return_dict_in_generate=True,
    output_attentions=True
)
torch.cuda.empty_cache()
attentions = basic_output["attentions"]
basic_output = basic_output["sequences"].cpu()

print("Prompt: " + visualizer_prompt)
print()

def combine(attentions):
  if attn_head_combination == "mean":
    attentions = attentions.mean(0)
  else:
    attentions = attentions.max(0)[0]
  return attentions

import torch.nn.functional as F
import numpy as np
#gen_tokens x layers x tensor (1 x heads x n_ids square)
torch.set_printoptions(sci_mode=False)
for i in range(number_generated_tokens):
  layer_attn = []
  for j in range(0,len(attentions[i])):
    layer_attn.append(combine(attentions[i][j][0,:,-1].float().cpu()))
  layer_attn = torch.stack(layer_attn)
  token_attn = combine(layer_attn)
  prob, topk = token_attn.topk(max_attentions)
  top_tokens = []
  for top in topk:
    decoded = tokenizer.decode(torch.tensor([basic_output[0][top]]))
    top_tokens.append(f"{decoded!r}@{top} ({token_attn[top]:.4f})")
  print("Token: " + repr(tokenizer.decode(torch.tensor([basic_output[0][n_ids + i]]))))
  print(", ".join(top_tokens))
  print()

print("Output: " + tokenizer.decode(basic_output[0]))