Initialize repository, copy weights from Google drive.

In [None]:
import os
import sys
from google.colab import drive

# Mount google drive.
drive.mount('/gdrive')

#@markdown Location of tokenizer.
tokenizer_loc = '/gdrive/My Drive/tokenizer.model' #@param {type:"string"}

# @markdown Location of directory containing model weights / parameters.
weight_loc = '/gdrive/My Drive/7B/' #@param {type:"string"}

!pip install fairscale
!pip install sentencepiece
!git clone https://github.com/facebookresearch/llama.git

sys.path.insert(0, '/content/llama/')

!nvidia-smi

The 7B checkpoint is too large to fit into RAM. Run this cell if you need to split the 7B checkpoint. Will save the results to your 7B directory so you should only ever need to run this cell once. You may need to restart the runtime afterward.

In [None]:
import torch

checkpoint = torch.load(os.path.join(weight_loc, 'consolidated.00.pth'),
                        map_location="cuda")

d1 = dict(list(checkpoint.items())[:len(checkpoint)//2])
torch.save(d1, os.path.join(weight_loc, 'consolidated.00.00.pth'))
del(d1)

d2 = dict(list(checkpoint.items())[len(checkpoint)//2:])
torch.save(d2, os.path.join(weight_loc, 'consolidated.00.01.pth'))
del(d2)

del(checkpoint)

Include that one anon's additional sampling methods so we have Kobold parameters like repetition penalty, tfs, etc.

In [None]:
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the GNU
# General Public License version 3.

from typing import List

import torch

from llama.tokenizer import Tokenizer
from llama.model import Transformer

class LLaMA:
    def __init__(self, model: Transformer, tokenizer: Tokenizer):
        self.model = model
        self.tokenizer = tokenizer

    def generate(
        self,
        prompts: List[str],
        max_gen_len: int,
        temperature: float = 0.8,
        top_p: float = 0.95,
        tfs: float = 1.0,
        typical: float = 1.0,
        penalty_range: float = 1024,
        penalty_slope: float = 0.7,
        penalty: float = 1.1
    ) -> List[str]:
        bsz = len(prompts)
        params = self.model.params
        assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)

        prompt_tokens = [self.tokenizer.encode(x, bos=True, eos=False)
                         for x in prompts]

        min_prompt_size = min([len(t) for t in prompt_tokens])
        max_prompt_size = max([len(t) for t in prompt_tokens])

        total_len = min(params.max_seq_len, max_gen_len + max_prompt_size)

        tokens = torch.full((bsz, total_len),
                            self.tokenizer.pad_id).cuda().long()
        for k, t in enumerate(prompt_tokens):
            tokens[k, : len(t)] = torch.tensor(t).long()
        input_text_mask = tokens != self.tokenizer.pad_id
        start_pos = min_prompt_size
        prev_pos = 0
        for cur_pos in range(start_pos, total_len):
            input_ids = tokens[:, prev_pos:cur_pos]
            logits = self.model.forward(input_ids, prev_pos)
            if temperature > 0:

                next_token_scores = sample_top_p_actual(input_ids, logits,
                                                        top_p)
                next_token_scores = sample_tail_free(input_ids,
                                                     next_token_scores, tfs)
                next_token_scores = sample_typical(input_ids, next_token_scores,
                                                   typical)
                next_token_scores = sample_temperature(input_ids,
                                                       next_token_scores,
                                                       temperature)
                next_token_scores = sample_advanced_repetition_penalty(input_ids,
                                                                       next_token_scores,
                                                                       penalty_range,
                                                                       penalty_slope,
                                                                       penalty)

                next_token_scores = torch.nn.functional.softmax(next_token_scores,
                                                                dim=-1)
                next_token = torch.multinomial(next_token_scores,
                                               num_samples=1).squeeze(1)
            else:
                next_token = torch.argmax(logits, dim=-1)
            next_token = next_token.reshape(-1)
            # only replace token if prompt has already been generated
            next_token = torch.where(
                input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
            )
            tokens[:, cur_pos] = next_token
            prev_pos = cur_pos

        decoded = []
        for i, t in enumerate(tokens.tolist()):
            # cut to max gen len
            t = t[: len(prompt_tokens[i]) + max_gen_len]
            # cut to eos tok if any
            try:
                t = t[: t.index(self.tokenizer.eos_id)]
            except ValueError:
                pass
            decoded.append(self.tokenizer.decode(t))
        return decoded

# taken from Kobold and transformers so this stuff is AGPL I guess
def sample_temperature(input_ids, scores, tempt):
    scores = scores / tempt
    return scores

def sample_typical(input_ids, scores, typical, filter_value = -float("Inf"),
                   min_tokens_to_keep = 1):
    if filter_value >= 1.0:
        return scores

    probs = scores.softmax(dim=-1)
    log_probs = probs.log()

    neg_entropy = (probs * log_probs).nansum(dim=-1, keepdim=True)

    entropy_deviation = (neg_entropy - log_probs).abs()

    _, sorted_indices = torch.sort(entropy_deviation)
    sorted_logits = probs.gather(-1, sorted_indices)
    sorted_indices_to_remove = sorted_logits.cumsum(dim=-1) >= typical
    sorted_indices_to_remove = sorted_indices_to_remove.roll(1, dims=-1)

    min_tokens_to_keep = max(min_tokens_to_keep, 1)
    # Keep at least min_tokens_to_keep
    sorted_indices_to_remove[..., : min_tokens_to_keep] = 0

    indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
    scores = scores.masked_fill(indices_to_remove, filter_value)
    return scores    

def sample_top_p_actual(input_ids, scores, top_p, filter_value = -float("Inf"),
                        min_tokens_to_keep = 1):
    sorted_logits, sorted_indices = torch.sort(scores, descending=False)
    cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)

    # Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
    sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
    if min_tokens_to_keep > 1:
        # Keep at least min_tokens_to_keep
        sorted_indices_to_remove[..., -min_tokens_to_keep :] = 0

    # scatter sorted tensors to original indexing
    indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices,
                                                         sorted_indices_to_remove)
    scores = scores.masked_fill(indices_to_remove, filter_value)
    return scores

def sample_advanced_repetition_penalty(input_ids, scores, penalty_range,
                                       penalty_slope, penalty):
    penalty_range = int(penalty_range)
    clipped_penalty_range = min(input_ids.shape[-1], penalty_range)

    if penalty != 1.0:
        if penalty_range > 0:
            if clipped_penalty_range < input_ids.shape[1]:
                input_ids = input_ids[..., -clipped_penalty_range:]

            if penalty_slope != 0:
                _penalty = (torch.arange(penalty_range, dtype=scores.dtype,
                                         device=scores.device)/(penalty_range - 1)) * 2. - 1
                _penalty = (penalty_slope * _penalty) / (1 + torch.abs(_penalty) * (penalty_slope - 1))
                _penalty = 1 + ((_penalty + 1) / 2).unsqueeze(0) * (penalty - 1)
                penalty = _penalty[..., -clipped_penalty_range:]

        score = torch.gather(scores, 1, input_ids)
        score = torch.where(score <= 0, score * penalty, score / penalty)
        scores.scatter_(1, input_ids, score)

        return scores    

def sample_top_a(input_ids, scores, top_a, filter_value = -float("Inf"),
                 min_tokens_to_keep = 1):
    if filter_value >= 1.0:
        return scores

    sorted_logits, sorted_indices = torch.sort(scores, descending=True)
    probs = sorted_logits.softmax(dim=-1)

    # Remove tokens with probability less than top_a*(max(probs))^2 (token with 0 are kept)
    probs_max = probs[..., 0, None]
    sorted_indices_to_remove = probs < probs_max * probs_max * top_a

    if min_tokens_to_keep > 1:
        # Keep at least min_tokens_to_keep
        sorted_indices_to_remove[..., : min_tokens_to_keep] = 0

    indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices,
                                                         sorted_indices_to_remove)
    scores = scores.masked_fill(indices_to_remove, filter_value)
    return scores    

def sample_tail_free(input_ids, scores, tfs, filter_value = -float("Inf"),
                     min_tokens_to_keep = 1):
    if filter_value >= 1.0:
        return scores
    sorted_logits, sorted_indices = torch.sort(scores, descending=True)
    probs = sorted_logits.softmax(dim=-1)

    # Compute second derivative normalized CDF
    d2 = probs.diff().diff().abs()
    normalized_d2 = d2 / d2.sum(dim=-1, keepdim=True)
    normalized_d2_cdf = normalized_d2.cumsum(dim=-1)

    # Remove tokens with CDF value above the threshold (token with 0 are kept)
    sorted_indices_to_remove = normalized_d2_cdf > tfs

    # Centre the distribution around the cutoff as in the original implementation of the algorithm
    sorted_indices_to_remove = torch.cat(
        (
            torch.zeros(scores.shape[0], 1, dtype=torch.bool,
                        device=scores.device),
            sorted_indices_to_remove,
            torch.ones(scores.shape[0], 1, dtype=torch.bool,
                       device=scores.device),
        ),
        dim=-1,
    )

    if min_tokens_to_keep > 1:
        # Keep at least min_tokens_to_keep
        sorted_indices_to_remove[..., : min_tokens_to_keep] = 0

    indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices,
                                                         sorted_indices_to_remove)
    scores = scores.masked_fill(indices_to_remove, filter_value)
    return scores

Load model.

In [None]:
from typing import Tuple
import os
import sys
import torch
import time
import json

from pathlib import Path

from fairscale.nn.model_parallel.initialize import initialize_model_parallel

from llama.model import ModelArgs, Transformer
from llama.tokenizer import Tokenizer


os.environ['RANK'] = '0'
os.environ['WORLD_SIZE'] = '1'
os.environ['MP'] = '1'
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '2223'


def setup_model_parallel() -> Tuple[int, int]:
    local_rank = int(os.environ.get("LOCAL_RANK", -1))
    world_size = int(os.environ.get("WORLD_SIZE", -1))

    torch.distributed.init_process_group("gloo")
    initialize_model_parallel(world_size)
    torch.cuda.set_device(local_rank)

    # seed must be the same in all processes
    torch.manual_seed(1)
    return local_rank, world_size


'''
def load(ckpt_dir: str, tokenizer_path: str, local_rank: int, world_size: int,
         max_seq_len: int, max_batch_size: int) -> LLaMA:
    start_time = time.time()
    checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
    assert (
        world_size == len(checkpoints)
    ), f"Loading a checkpoint for MP={len(checkpoints)} but world size is
        {world_size}"
    ckpt_path = checkpoints[local_rank]
    print("Loading")
    
    checkpoint = torch.load(ckpt_path, map_location="cpu")
    with open(Path(ckpt_dir) / "params.json", "r") as f:
        params = json.loads(f.read())

    model_args: ModelArgs = ModelArgs(max_seq_len=max_seq_len,
                                      max_batch_size=max_batch_size,
                                      **params)
    tokenizer = Tokenizer(model_path=tokenizer_path)
    model_args.vocab_size = tokenizer.n_words
    torch.set_default_tensor_type(torch.cuda.HalfTensor)
    model = Transformer(model_args).cuda().half()
    torch.set_default_tensor_type(torch.FloatTensor)
    model.load_state_dict(checkpoint, strict=False)

    generator = LLaMA(model, tokenizer)
    print(f"Loaded in {time.time() - start_time:.2f} seconds")
    return generator
'''


def load(ckpt_dir: str, tokenizer_path: str, local_rank: int, world_size: int,
         max_seq_len: int, max_batch_size: int) -> LLaMA:
    start_time = time.time()
    
    print("Loading")
    with open(Path(ckpt_dir) / "params.json", "r") as f:
        params = json.loads(f.read())

    model_args: ModelArgs = ModelArgs(max_seq_len=max_seq_len,
                                      max_batch_size=max_batch_size,
                                      **params)
    tokenizer = Tokenizer(model_path=tokenizer_path)
    model_args.vocab_size = tokenizer.n_words
    torch.set_default_tensor_type(torch.cuda.HalfTensor)
    model = Transformer(model_args).cuda().half()
    torch.set_default_tensor_type(torch.FloatTensor)

    checkpoint_paths = [os.path.join(ckpt_dir, 'consolidated.00.00.pth'),
                        os.path.join(ckpt_dir, 'consolidated.00.01.pth')]
    
    for checkpoint_path in checkpoint_paths:
        checkpoint = torch.load(checkpoint_path, map_location='cpu')
        model.load_state_dict(checkpoint, strict=False)
        del checkpoint

    generator = LLaMA(model, tokenizer)
    print(f"Loaded in {time.time() - start_time:.2f} seconds")
    return generator

# @markdown Context size. Can be up to 2048, but Colab GPU doesn't always play well with high values.
max_seq_len = 1024 # @param {type:"number"}
max_batch_size = 1

local_rank, world_size = setup_model_parallel()
if local_rank > 0:
    sys.stdout = open(os.devnull, 'w')

generator = load(weight_loc, tokenizer_loc, local_rank, world_size,
                 max_seq_len, max_batch_size)
tokenizer = generator.tokenizer

Main GUI. If you change the presets, you'll have to reload the cell for the changes to take effect.

In [1]:
import ipywidgets as widgets
from IPython.display import display
import time

max_gen_len = 64 #@param {type:"number"}
temperature = 0.8 #@param {type:"number"}
top_p = 0.95 #@param {type:"number"}
tfs = 1.0 #@param {type:"number"}
typical = 1.0 #@param {type:"number"}
penalty_range = 1024 #@param {type:"number"}
penalty_slope = 0.7 #@param {type:"number"}
penalty = 1.1 #@param {type:"number"}

input_text_area = widgets.Textarea(placeholder='Enter a prompt...',
                                   layout=widgets.Layout(width='1200px',
                                                         height='600px'))
send_button = widgets.Button(description='Send')
undo_button = widgets.Button(description='Undo')
redo_button = widgets.Button(description='Redo')
retry_button = widgets.Button(description='Retry')
memory_button = widgets.ToggleButton(description='Memory')

hbox = widgets.HBox([input_text_area,
                     widgets.VBox([send_button, undo_button, redo_button,
                                  retry_button, memory_button])])
output = widgets.Output()

undo_button.disabled = True
redo_button.disabled = True
retry_button.disabled = True

listen_for_updates = False
cur_outputs = []
cur_outputs_idx = -1
memory_text = ''
input_text = ''

def generate():
    # When creating the context, first, place the full memory followed by a
    # newline.
    #
    # Next, taking the last (max_seq_len-1-max_gen_len-len(mem)) tokens,
    # place these tokens in the context.
    
    if memory_text:
        mem_tokenized = tokenizer.encode(memory_text + '\n', bos=False, eos=False)
    else:
        mem_tokenized = []
    
    inp_tokenized = tokenizer.encode(input_text_area.value, bos=False, eos=False)
    num_inp_tokens = max(max_seq_len-1-max_gen_len-len(mem_tokenized), 0)

    if num_inp_tokens > 0:
        tokenized = mem_tokenized + inp_tokenized[-num_inp_tokens:]
    elif len(mem_tokenized) > 0:
        num_mem_tokens = max_seq_len-1-max_gen_len
        tokenized = mem_tokenized[-num_mem_tokens:]
    else:
        tokenized = []
    
    detokenized = tokenizer.decode(tokenized)
    output = generator.generate([detokenized],
                                max_gen_len=max_gen_len,
                                temperature=temperature,
                                top_p=top_p,
                                tfs=tfs,
                                typical=typical,
                                penalty_range=penalty_range,
                                penalty_slope=penalty_slope,
                                penalty=penalty)

    num_characters = len(output) - len(detokenized) - 1
    return output[0][-num_characters:]

    '''
    tokenized = tokenizer.encode(input_text_area.value, bos=True, eos=False)
    detokenized = tokenizer.decode(tokenized[-(max_seq_len-1-max_gen_len):])
    output = generator.generate([detokenized],
                                max_gen_len=max_gen_len,
                                temperature=temperature,
                                top_p=top_p,
                                tfs=tfs,
                                typical=typical,
                                penalty_range=penalty_range,
                                penalty_slope=penalty_slope,
                                penalty=penalty)
    num_characters = len(output) - len(detokenized) - 1
    return output[0][-num_characters:]
    '''

def on_update_input_text_area(change):
    global listen_for_updates, cur_outputs, cur_outputs_idx

    if listen_for_updates:
        cur_outputs = []
        cur_outputs_idx = -1
        undo_button.disabled = True
        redo_button.disabled = True
        retry_button.disabled = True

def send():
    global listen_for_updates, cur_outputs, cur_outputs_idx

    input_text_area.disabled = True
    memory_button.disabled = True
    listen_for_updates = False

    generation = generate()
    input_text_area.value += generation
    cur_outputs_idx += 1
    cur_outputs = cur_outputs[:cur_outputs_idx]
    cur_outputs.append(generation)

    undo_button.disabled = False
    redo_button.disabled = True
    retry_button.disabled = False
    listen_for_updates = True
    memory_button.disabled = False
    input_text_area.disabled = False

def undo():
    global listen_for_updates, cur_outputs, cur_outputs_idx

    listen_for_updates = False
    num_chars = len(cur_outputs[cur_outputs_idx])
    input_text_area.value = input_text_area.value[:-num_chars]
    cur_outputs_idx -= 1

    if cur_outputs_idx == -1:
        undo_button.disabled = True
        retry_button.disabled = True
    if len(cur_outputs) > 0:
        redo_button.disabled = False

    listen_for_updates = True

def redo():
    global listen_for_updates, cur_outputs, cur_outputs_idx

    listen_for_updates = False
    input_text_area.value += cur_outputs[cur_outputs_idx+1]
    cur_outputs_idx += 1

    if cur_outputs_idx == len(cur_outputs) - 1:
        redo_button.disabled = True
    if len(cur_outputs) > 0:
        undo_button.disabled = False
        retry_button.disabled = False

    listen_for_updates = True

def send_button_clicked(b):
    send()

def undo_button_clicked(b):
    undo()

def redo_button_clicked(b):
    redo()

def retry_button_clicked(b):
    undo()
    send()

def memory_button_clicked(b):
    global listen_for_updates, cur_outputs, cur_outputs_idx, memory_text, \
           input_text
    if memory_button.value:
        listen_for_updates = False
        send_button.disabled = True
        undo_button.disabled = True
        redo_button.disabled = True
        retry_button.disabled = True
        input_text = input_text_area.value
        input_text_area.value = memory_text
    else:
        memory_text = input_text_area.value
        input_text_area.value = input_text
        input_text = ''
        send_button.disabled = False
        undo_button.disabled = cur_outputs_idx < 0
        redo_button.disabled = cur_outputs_idx >= len(cur_outputs) - 1
        retry_button.disabled = undo_button.disabled
        listen_for_updates = True

send_button.on_click(send_button_clicked)
undo_button.on_click(undo_button_clicked)
redo_button.on_click(redo_button_clicked)
retry_button.on_click(retry_button_clicked)
memory_button.observe(memory_button_clicked, names='value')
input_text_area.observe(on_update_input_text_area, names='value')

display(hbox, output)

HBox(children=(Textarea(value='', layout=Layout(height='600px', width='1200px'), placeholder='Enter a prompt..…

Output()