<a href="https://colab.research.google.com/github/EvolventaAGG/GPTQ-for-LLaMa/blob/main/Copy_of_alpaca_4bit_v2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

First, install the CUDA extensions.

In [None]:
#!apt-get -y update
#!apt-get -y install python3.10-dev
#!python -m pip install --upgrade pip
!git clone https://github.com/qwopqwop200/GPTQ-for-LLaMa.git
%cd 'GPTQ-for-LLaMa'
!git reset --hard 468c47c01b4fe370616747b6d69a2d3f48bab5e4
!python setup_cuda.py install
#!python test_kernel.py

Next, restart the runtime (but don't delete it). We'll need to do that in order for colab to be able to use the quant_cuda CPP extensions.

Afterward, return to this this cell and execute it to clone the repo, install libraries and download your 4 bit LLaMA model of choice.

In [None]:
import sys
import torch
import quant_cuda

!pip install transformers
!pip install sentencepiece
weights_url = 'https://huggingface.co/elinas/alpaca-13b-lora-int4/resolve/main/alpaca-13b-4bit.pt' #@param {type:"string"}
num_params = "13b" #@param ["7b", "13b", "30b", "65b"]
!wget {weights_url}
!pip install git+https://github.com/huggingface/transformers

sys.path.insert(0, '/content/GPTQ-for-LLaMa/')
#!CUDA_VISIBLE_DEVICES=0 python llama_inference.py decapoda-research/llama-13b-hf --wbits 4 --load llama-13b-4bit.pt --text "It was the best of times, it was the worst of times"

Now execute this cell in order to load in the model. Additionally, you can specify your context size (if you're free tier and running 13B, you'll have to keep this pretty low or you may either run out of memory or have ridiculously slow generation times) and a flag denoting whether to load and split the model checkpoint in GPU VRAM before loading (also needed for free tier 13B).

In [None]:
import time

import torch
import torch.nn as nn

from gptq import *
from modelutils import *
from quant import *

from transformers import LlamaTokenizer

DEV = torch.device('cuda:0')
#context_size = 1024 #@param {type:"number"}
split_checkpoint = True #@param {type:"boolean"}

def load_quant(model, checkpoint, wbits):
    from transformers import LlamaConfig, LlamaForCausalLM 
    config = LlamaConfig.from_pretrained(model)
    def noop(*args, **kwargs):
        pass
    torch.nn.init.kaiming_uniform_ = noop 
    torch.nn.init.uniform_ = noop 
    torch.nn.init.normal_ = noop 

    if split_checkpoint:
        print('Splitting checkpoint ...')
        ckpt = torch.load(checkpoint, map_location='cuda')

        d1 = dict(list(ckpt.items())[:len(ckpt)//2])
        torch.save(d1, checkpoint + '0')
        del(d1)

        d2 = dict(list(ckpt.items())[len(ckpt)//2:])
        torch.save(d2, checkpoint + '1')
        del(d2)

        del(ckpt)

    torch.set_default_dtype(torch.half)
    transformers.modeling_utils._init_weights = False
    torch.set_default_dtype(torch.half)
    model = LlamaForCausalLM(config)
    torch.set_default_dtype(torch.float)
    model = model.eval()
    layers = find_layers(model)
    for name in ['lm_head']:
        if name in layers:
            del layers[name]
    make_quant(model, layers, wbits)

    if split_checkpoint:
        print('Loading model ...')
        for i in range(2):
            ckpt = torch.load(checkpoint + str(i))
            model.load_state_dict(ckpt, strict=False)
            del(ckpt)
        print('Done.')

    else:
        ckpt = torch.load(checkpoint)
        print('Loading model ...')
        model.load_state_dict(torch.load(checkpoint))
        print('Done.')

    #model.seqlen = context_size
    return model

model = load_quant('elinas/alpaca-13b-lora-int4','alpaca-13b-4bit.pt', 4).cuda()
model.to(DEV)
tokenizer = LlamaTokenizer.from_pretrained('elinas/alpaca-13b-lora-int4')

Define our token generation functions (both normal and generator).

In [None]:
#@title Samples
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

def gen_next_tokens(model, tokenizer, tokenized, context_len, max_gen_len,
                    mask_id, temperature=0.8, top_p=0.95, tfs=1.0, typical=1.0,
                    penalty_range=1024, penalty_slope=0.7, penalty=1.1,
                    past_key_values=None):
    #tokenized = tokenizer.encode(inp, return_tensors='pt').to(DEV)
    total_len = min(context_len, max_gen_len + tokenized.shape[1])

    tokens = torch.full((1, total_len), mask_id).to(DEV)
    tokens[0, :tokenized.shape[1]] = tokenized[0]

    if past_key_values:
        output_past_key_values = past_key_values

    for cur_id in range(tokenized.shape[1], total_len):
        #print(cur_id - tokenized.shape[1])
        if past_key_values:
            output = model(tokens[:, cur_id-1:cur_id], past_key_values=past_key_values, use_cache=True)
        else:
            output = model(tokens[:, :cur_id], use_cache=True)

        if not past_key_values:
            logits = output.logits[:, cur_id-1, :]
            output_past_key_values = output.past_key_values
        else:
            logits = output.logits[:, 0, :]
        
        past_key_values = output.past_key_values
        input_ids = tokens[:, cur_id-1:cur_id]

        # Apply samplers - do greedy sampling if temperature is 0.
        if temperature > 0:
            next_token_scores = sample_top_p_actual(input_ids, logits,
                                                    top_p)
            next_token_scores = sample_tail_free(input_ids,
                                                 next_token_scores, tfs)
            next_token_scores = sample_typical(input_ids, next_token_scores,
                                               typical)
            next_token_scores = sample_temperature(input_ids,
                                                   next_token_scores,
                                                   temperature)
            next_token_scores = sample_advanced_repetition_penalty(input_ids,
                                                                   next_token_scores,
                                                                   penalty_range,
                                                                   penalty_slope,
                                                                   penalty)

            next_token_scores = torch.nn.functional.softmax(next_token_scores,
                                                            dim=-1)

            next_token = torch.multinomial(next_token_scores,
                                           num_samples=1).squeeze(1)
        else:
            next_token = torch.argmax(logits, axis=-1)[0]

        tokens[0, cur_id] = next_token
        if next_token.item() == tokenizer.eos_token_id:
            return tokens[:, :cur_id], output_past_key_values
        
    return tokens, output_past_key_values

def stm_next_tokens(model, tokenizer, tokenized, context_len, max_gen_len,
                    mask_id, temperature=0.8, top_p=0.95, tfs=1.0, typical=1.0,
                    penalty_range=1024, penalty_slope=0.7, penalty=1.1,
                    past_key_values=None):
    #tokenized = tokenizer.encode(inp, return_tensors='pt').to(DEV)
    total_len = min(context_len, max_gen_len + tokenized.shape[1])

    tokens = torch.full((1, total_len), mask_id).to(DEV)
    tokens[0, :tokenized.shape[1]] = tokenized[0]

    if past_key_values:
        output_past_key_values = past_key_values

    for cur_id in range(tokenized.shape[1], total_len):
        #print(cur_id - tokenized.shape[1])
        if past_key_values:
            output = model(tokens[:, cur_id-1:cur_id], past_key_values=past_key_values, use_cache=True)
        else:
            output = model(tokens[:, :cur_id], use_cache=True)

        if not past_key_values:
            logits = output.logits[:, cur_id-1, :]
            output_past_key_values = output.past_key_values
        else:
            logits = output.logits[:, 0, :]
        
        past_key_values = output.past_key_values
        input_ids = tokens[:, cur_id-1:cur_id]

        # Apply samplers - do greedy sampling if temperature is 0.
        if temperature > 0:
            next_token_scores = sample_top_p_actual(input_ids, logits,
                                                    top_p)
            next_token_scores = sample_tail_free(input_ids,
                                                 next_token_scores, tfs)
            next_token_scores = sample_typical(input_ids, next_token_scores,
                                               typical)
            next_token_scores = sample_temperature(input_ids,
                                                   next_token_scores,
                                                   temperature)
            next_token_scores = sample_advanced_repetition_penalty(input_ids,
                                                                   next_token_scores,
                                                                   penalty_range,
                                                                   penalty_slope,
                                                                   penalty)

            next_token_scores = torch.nn.functional.softmax(next_token_scores,
                                                            dim=-1)

            next_token = torch.multinomial(next_token_scores,
                                           num_samples=1).squeeze(1)
        else:
            next_token = torch.argmax(logits, axis=-1)[0]

        tokens[0, cur_id] = next_token
        yield next_token, None

        if next_token.item() == tokenizer.eos_token_id:
            yield None, output_past_key_values
            return
    
    yield None, output_past_key_values
    return

# taken from Kobold and transformers so this stuff is AGPL I guess
def sample_temperature(input_ids, scores, tempt):
    scores = scores / tempt
    return scores

def sample_typical(input_ids, scores, typical, filter_value = -float("Inf"),
                   min_tokens_to_keep = 1):
    if filter_value >= 1.0:
        return scores

    probs = scores.softmax(dim=-1)
    log_probs = probs.log()

    neg_entropy = (probs * log_probs).nansum(dim=-1, keepdim=True)

    entropy_deviation = (neg_entropy - log_probs).abs()

    _, sorted_indices = torch.sort(entropy_deviation)
    sorted_logits = probs.gather(-1, sorted_indices)
    sorted_indices_to_remove = sorted_logits.cumsum(dim=-1) >= typical
    sorted_indices_to_remove = sorted_indices_to_remove.roll(1, dims=-1)

    min_tokens_to_keep = max(min_tokens_to_keep, 1)
    # Keep at least min_tokens_to_keep
    sorted_indices_to_remove[..., : min_tokens_to_keep] = 0

    indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
    scores = scores.masked_fill(indices_to_remove, filter_value)
    return scores    

def sample_top_p_actual(input_ids, scores, top_p, filter_value = -float("Inf"),
                        min_tokens_to_keep = 1):
    sorted_logits, sorted_indices = torch.sort(scores, descending=False)
    cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)

    # Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
    sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
    if min_tokens_to_keep > 1:
        # Keep at least min_tokens_to_keep
        sorted_indices_to_remove[..., -min_tokens_to_keep :] = 0

    # scatter sorted tensors to original indexing
    indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices,
                                                         sorted_indices_to_remove)
    scores = scores.masked_fill(indices_to_remove, filter_value)
    return scores

def sample_advanced_repetition_penalty(input_ids, scores, penalty_range,
                                       penalty_slope, penalty):
    penalty_range = int(penalty_range)
    clipped_penalty_range = min(input_ids.shape[-1], penalty_range)

    if penalty != 1.0:
        if penalty_range > 0:
            if clipped_penalty_range < input_ids.shape[1]:
                input_ids = input_ids[..., -clipped_penalty_range:]

            if penalty_slope != 0:
                _penalty = (torch.arange(penalty_range, dtype=scores.dtype,
                                         device=scores.device)/(penalty_range - 1)) * 2. - 1
                _penalty = (penalty_slope * _penalty) / (1 + torch.abs(_penalty) * (penalty_slope - 1))
                _penalty = 1 + ((_penalty + 1) / 2).unsqueeze(0) * (penalty - 1)
                penalty = _penalty[..., -clipped_penalty_range:]

        score = torch.gather(scores, 1, input_ids)
        score = torch.where(score <= 0, score * penalty, score / penalty)
        scores.scatter_(1, input_ids, score)

        return scores    

def sample_top_a(input_ids, scores, top_a, filter_value = -float("Inf"),
                 min_tokens_to_keep = 1):
    if filter_value >= 1.0:
        return scores

    sorted_logits, sorted_indices = torch.sort(scores, descending=True)
    probs = sorted_logits.softmax(dim=-1)

    # Remove tokens with probability less than top_a*(max(probs))^2 (token with 0 are kept)
    probs_max = probs[..., 0, None]
    sorted_indices_to_remove = probs < probs_max * probs_max * top_a

    if min_tokens_to_keep > 1:
        # Keep at least min_tokens_to_keep
        sorted_indices_to_remove[..., : min_tokens_to_keep] = 0

    indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices,
                                                         sorted_indices_to_remove)
    scores = scores.masked_fill(indices_to_remove, filter_value)
    return scores    

def sample_tail_free(input_ids, scores, tfs, filter_value = -float("Inf"),
                     min_tokens_to_keep = 1):
    if filter_value >= 1.0:
        return scores
    sorted_logits, sorted_indices = torch.sort(scores, descending=True)
    probs = sorted_logits.softmax(dim=-1)

    # Compute second derivative normalized CDF
    d2 = probs.diff().diff().abs()
    normalized_d2 = d2 / d2.sum(dim=-1, keepdim=True)
    normalized_d2_cdf = normalized_d2.cumsum(dim=-1)

    # Remove tokens with CDF value above the threshold (token with 0 are kept)
    sorted_indices_to_remove = normalized_d2_cdf > tfs

    # Centre the distribution around the cutoff as in the original implementation of the algorithm
    sorted_indices_to_remove = torch.cat(
        (
            torch.zeros(scores.shape[0], 1, dtype=torch.bool,
                        device=scores.device),
            sorted_indices_to_remove,
            torch.ones(scores.shape[0], 1, dtype=torch.bool,
                       device=scores.device),
        ),
        dim=-1,
    )

    if min_tokens_to_keep > 1:
        # Keep at least min_tokens_to_keep
        sorted_indices_to_remove[..., : min_tokens_to_keep] = 0

    indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices,
                                                         sorted_indices_to_remove)
    scores = scores.masked_fill(indices_to_remove, filter_value)
    return scores

Main GUI.

In [None]:
import ipywidgets as widgets
from IPython.display import display
import time
from enum import Enum

context_size = 1024 #@param {type:"number"}
max_gen_len = 160 #@param {type:"number"}
temperature = 0.7 #@param {type:"number"}
top_p = 0.6 #@param {type:"number"}
tfs = 0.5 #@param {type:"number"}
typical = 0.2 #@param {type:"number"}
penalty_range = 1024 #@param {type:"number"}
penalty_slope = 0.3 #@param {type:"number"}
penalty = 1.1 #@param {type:"number"}
output_streaming = True #@param {type:"boolean"}

input_text_area = widgets.Textarea(placeholder='Enter a prompt...',
                                   layout=widgets.Layout(width='500px',
                                                         height='600px'))
model.seqlen = context_size
send_button = widgets.Button(description='Send')
undo_button = widgets.Button(description='Undo')
redo_button = widgets.Button(description='Redo')
retry_button = widgets.Button(description='Retry')
prev_retry_button = widgets.Button(description='Previous Retry')
memory_button = widgets.ToggleButton(description='Memory')
context_button = widgets.ToggleButton(description='Context')

hbox = widgets.HBox([input_text_area,
                     widgets.VBox([send_button, undo_button, redo_button,
                                  retry_button, prev_retry_button, memory_button,
                                  context_button])])
output = widgets.Output()

Mode = Enum('Mode', ['INPUT', 'MEMORY', 'GENERATING', 'CONTEXT'])

class State:
    def __init__(self, pos, mode):
        self.pos = pos
        self.mode = mode
        self.mem = ''
        self.saved_input = ''

class Position:
    def __init__(self):
        self.pred = None
        self.succs = []
        self.succ_idx = -1
        self.text = ''
        self.past_key_values = None

init_pos = Position()
cur_state = State(init_pos, Mode.INPUT)

def build_context():
    # When creating the context, first, place the full memory followed by a
    # newline.
    #
    # Next, taking the last (max_seq_len-1-max_gen_len-len(mem)) tokens,
    # place these tokens in the context.

    if cur_state.mem:
        mem_tokenized = tokenizer.encode(cur_state.mem + '\n', return_tensors='pt')[0].tolist()
    else:
        mem_tokenized = []
    
    inp_tokenized = tokenizer.encode(input_text_area.value, return_tensors='pt')[0].tolist()
    num_inp_tokens = max(model.seqlen-1-max_gen_len-len(mem_tokenized), 0)

    if num_inp_tokens > 0:
        tokenized = mem_tokenized + inp_tokenized[-num_inp_tokens:]
    elif len(mem_tokenized) > 0:
        num_mem_tokens = model.seqlen-1-max_gen_len
        tokenized = mem_tokenized[-num_mem_tokens:]
    else:
        tokenized = []

    detokenized = tokenizer.decode(tokenized)
    return detokenized

def generate():
    # Create the context and send it to the model, update the text area.
    
    gen_context = build_context()
    retokenized = tokenizer.encode(gen_context, return_tensors='pt').to(DEV)
    prev_num_tokens = len(retokenized[0])

    output = ''
    past_key_values = None
    num_characters = 0

    if output_streaming:
        with torch.no_grad():
            out_tokens = retokenized[0].tolist()
            gen = stm_next_tokens(model, tokenizer, retokenized, model.seqlen,
                max_gen_len, 1, temperature=temperature, top_p=top_p, tfs=tfs,
                typical=typical, penalty_range=penalty_range,
                penalty_slope=penalty_slope, penalty=penalty,
                past_key_values=cur_state.pos.past_key_values)
            for tkn, pkv in gen:
                if pkv is not None:
                    past_key_values = pkv
                else:
                    out_tokens.append(tkn.item())
                    output = tokenizer.decode(out_tokens)
                    num_characters = len(output) - len(gen_context) - 1
                    input_text_area.value = cur_state.pos.text + output[-num_characters:]
    else:
        with torch.no_grad():
            output_tokenized, past_key_values = gen_next_tokens(model, tokenizer,
                retokenized, model.seqlen, max_gen_len, 1, temperature=temperature,
                top_p=top_p, tfs=tfs, typical=typical, penalty_range=penalty_range,
                penalty_slope=penalty_slope, penalty=penalty,
                past_key_values=cur_state.pos.past_key_values)
        output = tokenizer.decode(output_tokenized[0].tolist())
        num_characters = len(output) - len(gen_context) - 1
        input_text_area.value = cur_state.pos.text + output[-num_characters:]
    return output[-num_characters:], past_key_values

def on_update_input_text_area(change):
    # Input mode: Destroy all successors in the node list.
    #
    # Memory mode: n/a.
    #
    # Action allowed criterion: state.mode == 'input' or state.mode == 'memory'.

    if cur_state.mode == Mode.INPUT and (cur_state.pos.succs or cur_state.pos.past_key_values) and cur_state.pos.text != input_text_area.value:
        if cur_state.pos.succs:
            del cur_state.pos.succs
            cur_state.pos.succs = []
            cur_state.pos.succ_idx = -1
            update_buttons_visible()
        if cur_state.pos.past_key_values:
            cur_state.pos.past_key_values = None

def send():
    cur_state.pos.text = input_text_area.value
    cur_state.mode = Mode.GENERATING
    update_buttons_visible()
    generation, past_key_values = generate()

    new_succ = Position()
    new_succ.pred = cur_state.pos
    #new_succ.text = input_text_area.value + generation
    new_succ.text = input_text_area.value
    cur_state.pos.succs.append(new_succ)
    cur_state.pos.succ_idx = len(cur_state.pos.succs) - 1
    cur_state.pos.past_key_values = past_key_values
    
    jump_to(new_succ)

    cur_state.mode = Mode.INPUT
    update_buttons_visible()

def send_button_clicked(b):
    # Set text in current node to whatever is in the input area, generate text
    # (setting mode to 'generating' in the meantime), create a new successor at
    # head of list with text, set successor position to it, jump to it.
    #
    # Action allowed criterion: state.mode == 'input'.

    send()
    
def undo_button_clicked(b):
    # Jump to predecessor.
    #
    # Action allowed criterion: state.mode == 'input', state.predecessor != nil.

    jump_to(cur_state.pos.pred)

def redo_button_clicked(b):
    # Jump to current successor.
    #
    # Action allowed criterion: state.mode == 'input', state.successor_list !=
    # nil.

    jump_to(cur_state.pos.succs[cur_state.pos.succ_idx])

def retry_button_clicked(b):
    # Jump to predecessor, then set successor position to next in the list if
    # it exists and jump to it, otherwise send_button_clicked().
    #
    # Action allowed criterion: state.mode == 'input', state.predecessor != nil.

    jump_to(cur_state.pos.pred)

    if cur_state.pos.succ_idx < len(cur_state.pos.succs) - 1:
        cur_state.pos.succ_idx += 1
        jump_to(cur_state.pos.succs[cur_state.pos.succ_idx])
    else:
        send()

def prev_retry_button_clicked(b):
    # Jump to predecessor, then set successor position to prev in the list and
    # jump to it.
    #
    # Action allowed criterion: state.mode == 'input', state.predecessor != nil,
    # state.predecessor.succ_idx > 0.

    jump_to(cur_state.pos.pred)
    cur_state.pos.succ_idx -= 1
    jump_to(cur_state.pos.succs[cur_state.pos.succ_idx])

def memory_button_clicked(b):
    # Input mode: switch modes to 'memory', save current state.
    #
    # Memory mode: switch modes to 'input', save memory, restore current state.
    #
    # Action allowed criterion: state.mode == 'input' or state.mode == 'memory'.

    if cur_state.mode == Mode.INPUT:
        cur_state.mode = Mode.MEMORY
        cur_state.saved_input = input_text_area.value
        input_text_area.value = cur_state.mem
        update_buttons_visible()
    elif cur_state.mode == Mode.MEMORY:
        if cur_state.mem != input_text_area.value:
            apply_to_all_nodes(lambda n: delete_past_key_values(n))
        cur_state.mode = Mode.INPUT
        cur_state.mem = input_text_area.value
        input_text_area.value = cur_state.saved_input
        update_buttons_visible()

def context_button_clicked(b):
    # Input mode: switch modes to 'context', save current state.
    #
    # Context mode: switch mode to 'input', restore current state.
    #
    # Action allowed criterion: state.mode == 'input' or state.mode == 'context'.

    if cur_state.mode == Mode.INPUT:
        cur_state.mode = Mode.CONTEXT
        cur_state.saved_input = input_text_area.value
        input_text_area.value = build_context()
        update_buttons_visible()
    elif cur_state.mode == Mode.CONTEXT:
        cur_state.mode = Mode.INPUT
        input_text_area.value = cur_state.saved_input
        update_buttons_visible()

def jump_to(pos):
    cur_state.pos = pos
    input_text_area.value = pos.text
    update_buttons_visible()

def apply_to_all_nodes(fn):
    root = cur_state.pos
    while root.pred:
        root = root.pred
    
    node_stack = [root]
    while node_stack:
        removed_node = node_stack.pop(0)
        for succ in removed_node.succs:
            node_stack.append(succ)
        fn(removed_node)

def delete_past_key_values(pos):
    pos.past_key_values = None

def update_buttons_visible():
    send_button.disabled = cur_state.mode != Mode.INPUT
    undo_button.disabled = cur_state.mode != Mode.INPUT or not cur_state.pos.pred
    redo_button.disabled = cur_state.mode != Mode.INPUT or not cur_state.pos.succs
    retry_button.disabled = cur_state.mode != Mode.INPUT or not cur_state.pos.pred
    prev_retry_button.disabled = cur_state.mode != Mode.INPUT or not cur_state.pos.pred or not cur_state.pos.pred.succ_idx > 0
    memory_button.disabled = cur_state.mode != Mode.INPUT and cur_state.mode != Mode.MEMORY
    context_button.disabled = cur_state.mode != Mode.INPUT and cur_state.mode != Mode.CONTEXT
    input_text_area.disabled = cur_state.mode == Mode.GENERATING or cur_state.mode == Mode.CONTEXT

send_button.on_click(send_button_clicked)
undo_button.on_click(undo_button_clicked)
redo_button.on_click(redo_button_clicked)
retry_button.on_click(retry_button_clicked)
prev_retry_button.on_click(prev_retry_button_clicked)
memory_button.observe(memory_button_clicked, names='value')
context_button.observe(context_button_clicked, names='value')
input_text_area.observe(on_update_input_text_area, names='value')
update_buttons_visible()

display(hbox, output)