In [None]:
!pip install -U pip && \
  pip install -U datasets evaluate sentencepiece transformers wandb scikit-learn nltk rouge-score accelerate && \
  pip install -U flash-attn --no-build-isolation optimum && \
  pip install -U bitsandbytes>=0.39.0 accelerate>=0.20.0

In [None]:
!(cd Lucrare-de-dizertatie-2024/ && PYTHONPATH=. python3 dizertatie/main.py)

In [None]:
import sys
sys.path.append('/root/Lucrare-de-dizertatie-2024/')

In [None]:
import dizertatie
import pathlib
from dizertatie.configs.common import PROJECT_SEED
from dizertatie.dataset.dataset import DatasetConfig, load

DATA_PATH = pathlib.Path('/root/Lucrare-de-dizertatie-2024/data')

ro_sent = load(DatasetConfig(
    shuffle_seed=PROJECT_SEED,
    subsample_size=None,
    path=DATA_PATH
), 'RoSent')

In [None]:
ro_sent

## Help links
* https://huggingface.co/mistralai/Mixtral-8x7B-v0.1/discussions/36#65b8d5cf23d948d884d19645
* https://huggingface.co/docs/transformers/perf_infer_gpu_one
* https://huggingface.co/docs/transformers/perf_train_gpu_one
* https://huggingface.co/docs/transformers/perf_train_gpu_many
* https://huggingface.co/docs/transformers/big_models#low-memory-loading
* https://github.com/Hannibal046/Awesome-LLM?tab=readme-ov-file
* https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2
* https://discuss.huggingface.co/t/model-inference-on-tokenized-dataset/14820
* https://chat.lmsys.org/

* https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B
* https://huggingface.co/openchat/openchat-3.5-0106
* https://huggingface.co/TheBloke/claude2-alpaca-13B-GGUF
* https://huggingface.co/TheBloke/Wizard-Vicuna-30B-Uncensored-GPTQ
* https://huggingface.co/berkeley-nest/Starling-LM-7B-alpha
* https://huggingface.co/state-spaces/mamba-2.8b

* https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig

* https://huggingface.co/docs/transformers/chat_templating
* https://huggingface.co/docs/transformers/perf_infer_gpu_one
* https://huggingface.co/docs/transformers/perf_torch_compile
* https://huggingface.co/docs/transformers/llm_tutorial#generate-text

* https://medium.com/@mayvic/llm-multi-gpu-batch-inference-with-accelerate-edadbef3e239
* https://huggingface.co/docs/accelerate/usage_guides/distributed_inference

In [None]:
!pip install -U flash-attn --no-build-isolation optimum

In [None]:
!pip install -U bitsandbytes>=0.39.0 accelerate>=0.20.0 optimum onnxruntime onnx

In [None]:
MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.2"

In [None]:
# not supported
# import torch
# from optimum.onnxruntime import ORTModelForCausalLM

# model = ORTModelForCausalLM.from_pretrained(
#   MODEL_NAME,
#   export=True,
#   provider="CUDAExecutionProvider",
# )
# model = torch.compile(model)

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, padding_side="left")
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME, device_map="auto",
    # attn_implementation="flash_attention_2",
    # torch_dtype=torch.bfloat16,
)
model = torch.compile(model)

In [None]:
!free -h && nvidia-smi

In [None]:
model.device

In [None]:
tokenizer.pad_token = tokenizer.eos_token

In [None]:
import html
import re

N = 270
M = 360

def make_template(text):
    text = prep_text(text)
    text = tokenizer.encode(text, padding=False, truncation=True, max_length=N)
    text = tokenizer.decode(text, skip_special_tokens=True)

    return [{
        'role': 'user', 'content': f"""You are a helpful professional translator. You will be prompted with texts to translate. You will respond only with the translation.
You will receive prompts with the format: "Translate from Romanian to English: [Romanian text]".
You will respond with: "Translation: [English text].
Translate from Romanian to English: {text}"""
    }]

def prep_text(x):
    x = re.sub(r'\s+', html.unescape(x).replace('\\', '\\\\'), ' ').strip()[:30719].replace('\\\\', '\\').strip().replace('\n', ' ')
    return "un produs interesant, nici bun, nici rau" if x == "" else x

def mistral_tokenize(examples):
    templates = list(map(make_template, examples['text_ro']))

    result = list(
        map(
            lambda x: tokenizer.apply_chat_template(x, tokenize=False).replace('<s>', '', 1),
            templates
        )
    )
    result = tokenizer(result, padding='max_length', truncation=False, max_length=M)
    # print(tokenizer.batch_decode(result['input_ids']))

    examples['input_ids'] = result['input_ids']
    examples['attention_mask'] = result['attention_mask']

    return examples

ro_sent_tokenized = ro_sent.map(mistral_tokenize, batched=True).remove_columns(['text_ro', 'target']).with_format('torch')

In [None]:
[x[-5:] for x in ro_sent_tokenized[:10]['input_ids']], ro_sent_tokenized['input_ids'].shape

In [None]:
ro_sent_tokenized

In [None]:
import gc
# del model
gc.collect()
torch.cuda.empty_cache()


In [None]:
# !mkdir -p mistral_ro_sent

In [None]:
for x in torch.utils.data.DataLoader(ro_sent_tokenized, batch_size = 86, shuffle=False):
    print(x['id'])
    print(tokenizer.decode(x['input_ids'][0]))
    break

In [None]:
%%time

import tqdm
import json

loader = torch.utils.data.DataLoader(
    ro_sent_tokenized.remove_columns(['id']), batch_size = 1,
    shuffle=False,
    pin_memory=True,
    num_workers=4
)

SEP_TOKEN = '[/INST]'
ANS_PREFIX = 'Translation:'

with torch.no_grad():
    for batch_idx, batch in enumerate(tqdm.tqdm(loader)):
        inputs = {k: v.cuda() for k, v in batch.items()}

        generated_ids = model.generate(
            **inputs, max_new_tokens=int(N+0.2), pad_token_id=tokenizer.pad_token_id, do_sample=True,
            temperature=0.7, top_p=1 # settings from https://chat.lmsys.org/
        )
        decoded = tokenizer.batch_decode(generated_ids)

        for i, v in enumerate(decoded):
            start = v.index('<s>')
            stop = v.index('</s>', start)
            v = v[start:stop].replace('<s>', '').replace('</s>', '')

            separator = v.index(SEP_TOKEN)
            prompt = v[:separator].replace('[INST]', '').strip()
            answer = v[separator+len(SEP_TOKEN):].strip()
            if answer.startswith(ANS_PREFIX):
                answer = answer[len(ANS_PREFIX):]

            try:
                answer = answer[:answer.index(ANS_PREFIX)].strip()
            except:
                pass

            answer = answer.strip()

            # print("Prompt:", prompt)
            # print("###")
            # print("Answer:", answer)
            # print("====================")

            decoded[i] = answer

        with open(f'mistral_ro_sent/batch_{batch_idx}', 'w') as f:
            json.dump(decoded, f)

        del inputs
        if (batch_idx)%50==0:
            gc.collect()
            torch.cuda.empty_cache()
        break

In [None]:
del inputs

In [None]:
!python infer_mistral.py

In [None]:
!nvidia-smi