In [1]:
#| default_exp model_wrapper

In [2]:
#| export
from dataclasses import dataclass
from typing import Dict, Any, Union
from transformers.models.llama import LlamaModel
from llama_memorizing_transformers.memorizing_block import MemorizingLlamaDecoderLayer
from llama_memorizing_transformers.memory_collection import BaseMemoryCollection, CosineKnnMemoryCollection
from llama_memorizing_transformers.context_choice import BaseContextChoice, ContextChoiceConstant, ContextChoiceLinear

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
import os
import time
import torch
from llama_4bit_wrapper import import_llama

In [4]:
#| export
def replace_llama_layer_with_memory(model: LlamaModel,
                                    layer_index: int,
                                    context: BaseContextChoice,
                                    memory: BaseMemoryCollection) -> LlamaModel:
    original_layer = model.layers[layer_index]
    new_layer = MemorizingLlamaDecoderLayer(
        module=original_layer,
        context_choice=context.to(model.device),
        memory=memory,
        device=model.device
    )
    model.layers[layer_index] = new_layer
    model._memorizing_patch = True
    return model

In [5]:
if not os.path.exists("../vicuna-13b-GPTQ-4bit-128g"):
    !git clone "https://huggingface.co/anon8231489123/vicuna-13b-GPTQ-4bit-128g"
    !mv "vicuna-13b-GPTQ-4bit-128g" ..

In [6]:
_, _, load_llama_model_4bit_low_ram, _, model_to_half, _, _, _, AMPWrapper = import_llama(
    use_flash_attention=True,
    use_xformers=False,
    autograd_4bit_cuda=False,
    autograd_4bit_triton=True,
)

Using Triton implementation.


In [7]:
model, tokenizer = load_llama_model_4bit_low_ram(
    config_path="../vicuna-13b-GPTQ-4bit-128g/",
    model_path="../vicuna-13b-GPTQ-4bit-128g/vicuna-13b-4bit-128g.safetensors",
    groupsize=128,
    is_v1_model=False,
)
model.cuda()

Loading Model ...


The safetensors archive passed at ../vicuna-13b-GPTQ-4bit-128g/vicuna-13b-4bit-128g.safetensors does not contain metadata. Make sure to save your model with the `save_pretrained` method. Defaulting to 'pt' metadata.


Loaded the model in 3.36 seconds.


LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32001, 5120, padding_idx=0)
    (layers): ModuleList(
      (0-39): 40 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Autograd4bitQuantLinear()
          (k_proj): Autograd4bitQuantLinear()
          (v_proj): Autograd4bitQuantLinear()
          (o_proj): Autograd4bitQuantLinear()
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Autograd4bitQuantLinear()
          (down_proj): Autograd4bitQuantLinear()
          (up_proj): Autograd4bitQuantLinear()
          (act_fn): SiLUActivation()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
  )
  (lm_head): Linear(in_features=5120, out_features=32001, bias=False)
)

In [8]:
model.model = replace_llama_layer_with_memory(
    model.model,
    21,
    ContextChoiceLinear(model.config.num_attention_heads, model.config.hidden_size),
    CosineKnnMemoryCollection(1024, 10),
)

In [9]:
model.model.device

device(type='cuda', index=0)

In [10]:
model.model.layers[21].scaler.device

device(type='cuda', index=0)

In [11]:
model_to_half(model)

Converted as Half.


In [12]:
wrapper = AMPWrapper(model)
wrapper.apply_generate()

In [13]:
prompt = '''I think the meaning of life is'''
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=False)
batch = {k: v.cuda() for k, v in batch.items()}

In [14]:
def test_generate():
    start = time.time()
    with torch.no_grad():
        generated = model.generate(inputs=batch["input_ids"],
                                do_sample=True,
                                use_cache=False,
                                repetition_penalty=1.1,
                                max_new_tokens=128,
                                temperature=0.9,
                                top_p=0.95,
                                top_k=40,
                                return_dict_in_generate=True,
                                output_attentions=False,
                                output_hidden_states=False,
                                output_scores=False)
    result_text = tokenizer.decode(generated['sequences'].cpu().tolist()[0])
    end = time.time()

    print(result_text)
    print(end - start)


In [15]:
test_generate()

I think the meaning of life is of:, forO01,2 of4A andA2 O of
i 
DR
O ( and and
.o, of
 Be,(C with A SX, with of
O X or of
LL, of with for for of
 Cl for of of with, and and
 of) for of
 for for
 for
 on for for for BeO D for and for and ofO from and
 of with if for of for
 and fromfor of of for of forX of forS for withCT for for for
 for for for for of if of for
44.89495253562927


In [16]:
#| hide
import nbdev; nbdev.nbdev_export()