In [None]:
#| default_exp model_wrapper

In [None]:
#| export
from dataclasses import dataclass
from typing import Dict, Any, Union
from transformers.models.llama import LlamaModel
from llama_memorizing_transformers.memorizing_block import MemorizingLlamaDecoderLayer
from llama_memorizing_transformers.memory_collection import BaseMemoryCollection, CosineKnnMemoryCollection
from llama_memorizing_transformers.context_choice import BaseContextChoice, ContextChoiceConstant, ContextChoiceLinear

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
import os
import time
import torch
from llama_4bit_wrapper import import_llama

In [None]:
#| export
def replace_llama_layer_with_memory(model: LlamaModel,
                                    layer_index: int,
                                    context: BaseContextChoice,
                                    memory: BaseMemoryCollection) -> LlamaModel:
    original_layer = model.layers[layer_index]
    new_layer = MemorizingLlamaDecoderLayer(
        module=original_layer,
        context_choice=context.to(model.device),
        memory=memory,
    )
    model.layers[layer_index] = new_layer
    model._memorizing_patch = True
    return model

In [None]:
if not os.path.exists("../vicuna-13b-GPTQ-4bit-128g"):
    !git clone "https://huggingface.co/anon8231489123/vicuna-13b-GPTQ-4bit-128g"
    !mv "vicuna-13b-GPTQ-4bit-128g" ..

In [None]:
_, _, load_llama_model_4bit_low_ram, _, model_to_half, _, _, _, AMPWrapper = import_llama(
    use_flash_attention=True,
    use_xformers=False,
    autograd_4bit_cuda=False,
    autograd_4bit_triton=True,
)

Using Triton implementation.


In [None]:
model, tokenizer = load_llama_model_4bit_low_ram(
    config_path="../vicuna-13b-GPTQ-4bit-128g/",
    model_path="../vicuna-13b-GPTQ-4bit-128g/vicuna-13b-4bit-128g.safetensors",
    groupsize=128,
    is_v1_model=False,
)

Loading Model ...


The safetensors archive passed at ../vicuna-13b-GPTQ-4bit-128g/vicuna-13b-4bit-128g.safetensors does not contain metadata. Make sure to save your model with the `save_pretrained` method. Defaulting to 'pt' metadata.


Loaded the model in 3.54 seconds.


In [None]:
model.model = replace_llama_layer_with_memory(
    model.model,
    21,
    ContextChoiceLinear(model.config.num_attention_heads, model.config.hidden_size),
    CosineKnnMemoryCollection(1024, 10),
)

In [None]:
model_to_half(model)

Converted as Half.


In [None]:
wrapper = AMPWrapper(model)
wrapper.apply_generate()

In [None]:
prompt = '''I think the meaning of life is'''
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=False)
batch = {k: v.cuda() for k, v in batch.items()}

In [None]:
def test_generate():
    start = time.time()
    with torch.no_grad():
        generated = model.generate(inputs=batch["input_ids"],
                                do_sample=True,
                                use_cache=False,
                                repetition_penalty=1.1,
                                max_new_tokens=128,
                                temperature=0.9,
                                top_p=0.95,
                                top_k=40,
                                return_dict_in_generate=True,
                                output_attentions=False,
                                output_hidden_states=False,
                                output_scores=False)
    result_text = tokenizer.decode(generated['sequences'].cpu().tolist()[0])
    end = time.time()

    print(result_text)
    print(end - start)


In [None]:
test_generate()

I think the meaning of life is to be found in family, love and creating something positive with it and this is a book show it can be to be about that relationship that problem they whole thing, most concept idea story potential ending has to be biggest issue it story biggest challenge best song scene should have two problems solution first one I think you and movie best reason public and term end of the debate most t answer is to question new yok me next point will main thing could president last word movie same.
 R most important part most big the biggest idea message of the character story could world should be the future in world most people are more most key word point is probably he
45.534175634384155


In [None]:
#| hide
import nbdev; nbdev.nbdev_export()