In [None]:
import os
import sys
os.environ["TRANSFORMERS_CACHE"] = "/workspace/cache/"
import torch
from torch.nn import DataParallel
from utils.prompter import Prompter
from time import time
from time import perf_counter
from peft import PeftModel

from transformers import LlamaForCausalLM, LlamaTokenizer, GenerationConfig, BitsAndBytesConfig

In [None]:
device_map = {'model.embed_tokens': 0,
 'model.layers.0': 0,
 'model.layers.1': 0,
 'model.layers.2': 0,
 'model.layers.3': 0,
 'model.layers.4': 0,
 'model.layers.5': 0,
 'model.layers.6': 0,
 'model.layers.7': 0,
 'model.layers.8': 0,
 'model.layers.9': 0,
 'model.layers.10': 0,
 'model.layers.11': 0,
 'model.layers.12': 0,
 'model.layers.13': 0,
 'model.layers.14': 0,
 'model.layers.15': 0,
 'model.layers.16': 0,
 'model.layers.17': 0,
 'model.layers.18': 0,
 'model.layers.19': 0,
 'model.layers.20': 0,
 'model.layers.21': 0,
 'model.layers.22': 0,
 'model.layers.23': 0,
 'model.layers.24': 0,
 'model.layers.25': 0,
 'model.layers.26': 0,
 'model.layers.27': 0,
 'model.layers.28': 0,
 'model.layers.29': 0,
 'model.layers.30': 1,
 'model.layers.31': 1,
 'model.layers.32': 1,
 'model.layers.33': 1,
 'model.layers.34': 1,
 'model.layers.35': 1,
 'model.layers.36': 1,
 'model.layers.37': 1,
 'model.layers.38': 1,
 'model.layers.39': 1,
 'model.layers.40': 1,
 'model.layers.41': 1,
 'model.layers.42': 1,
 'model.layers.43': 1,
 'model.layers.44': 1,
 'model.layers.45': 1,
 'model.layers.46': 1,
 'model.layers.47': 1,
 'model.layers.48': 1,
 'model.layers.49': 1,
 'model.layers.50': 1,
 'model.layers.51': 1,
 'model.layers.52': 1,
 'model.layers.53': 1,
 'model.layers.54': 1,
 'model.layers.55': 1,
 'model.layers.56': 1,
 'model.layers.57': 1,
 'model.layers.58': 1,
 'model.layers.59': 1,
 'model.layers.60': 1,
 'model.layers.61': 1,
 'model.layers.62': 1,
 'model.layers.63': 1,
 'model.layers.64': 1,
 'model.layers.65': 1,
 'model.layers.66': 1,
 'model.layers.67': 1,
 'model.layers.68': 1,
 'model.layers.69': 1,
 'model.layers.70': 1,
 'model.layers.71': 1,
 'model.layers.72': 1,
 'model.layers.73': 1,
 'model.layers.74': 1,
 'model.layers.75': 1,
 'model.layers.76': 1,
 'model.layers.77': 1,
 'model.layers.78': 1,
 'model.layers.79': 1,
 'model.norm': 1,
 'lm_head': 1}

In [None]:
tokenizer = LlamaTokenizer.from_pretrained("timdettmers/guanaco-65b-merged")

In [None]:
model = LlamaForCausalLM.from_pretrained("decapoda-research/llama-7b-hf",
       load_in_4bit=True,
        device_map='auto',
        torch_dtype=torch.bfloat16,
        quantization_config=BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=torch.bfloat16,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type='nf4'
        ),
    )

#model = PeftModel.from_pretrained(
#            model,
#            "timdettmers/guanaco-7b",
#            torch_dtype=torch.float16)

In [None]:
model.eval()
if torch.__version__ >= "2" and sys.platform != "win32":
    model = torch.compile(model)

In [None]:

#prompter = Prompter("sum")

In [None]:
instruction = "Résume ce texte issue d'un cours de droit en conservant les dates, les abréviations et les principes importants."
input="""Il arrive également que la garde juridique soit la conséquence de la loi. C’est notamment le cas des tuteurs chargés de gérer le mode de vie de l’enfant placé sous tutelle :
Chambre criminelle, 28 mars 2000 : Un enfant de 14 ans est placé sous la tutelle de son beau-père après avoir perdu ses deux parents et en manipulant une arme l’enfant cause la mort d’un camarade. La cour d’appel écarte la faute de surveillance du beau-père mais retient sa responsabilité sur le fondement de l’article 1242 alinéa 1er en sa qualité de tuteur. La Cour de cassation rejette le pourvoi."""
prompt = prompter.generate_prompt(instruction, input)

In [None]:
inputs = tokenizer("They", return_tensors="pt")
input_ids = inputs.input_ids.to("cuda")

In [None]:
torch.set_num_threads(14)

In [None]:
generation_config = GenerationConfig(
    temperature=0.1,
    top_p=0.1,
    use_cache=False
)
now = time()
with torch.no_grad():
    generation_output = model.generate(
        input_ids=input_ids,
        #generation_config=generation_config,
        return_dict_in_generate=True,
        output_scores=True,
        max_new_tokens=100,
    )
duration = time() - now
try:
    s = generation_output.sequences[0]
except:
    s = generation_output[0]

tks = (s.shape[0] - input_ids.shape[1])/duration
print(f"{tks} tokens/s")
print(f"{1/tks} tokens/s")
output = tokenizer.decode(s)
print(output)
#res = prompter.get_response(output)
#print(res)

In [None]:
t1_start = perf_counter()
logits = model(input_ids).logits[:, -1, :]
t1_stop = perf_counter()
print(f"{1/(t1_stop - t1_start)} tokens/s")
print(torch.cuda.max_memory_allocated() / 1e9)

In [None]:
logits.max()

In [None]:
tokenizer.decode(generation_output.sequences[1], skip_special_tokens=True, clean_up_tokenization_spaces=False)

In [None]:
generation_output.sequences.shape