In [5]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel, LoraConfig, prepare_model_for_kbit_training, get_peft_model

from data.databases import NumpyDataBase
from data.embedding_models import EmbeddingModelMiniLML6
from training.utils import LLAMA_TEMPLATES, MISTRAL_TEMPLATES, system_message, format_user_message, format_conversation
import json

In [6]:
with open("data/data.json", "r") as f:
    data = json.load(f)
embedding_model = EmbeddingModelMiniLML6()
database = NumpyDataBase(data=data, embedding_model=embedding_model)

In [7]:
base_model = {
    "llama": {
        "path": "meta-llama/Llama-2-7b-chat-hf",
        "save-path": "MediRAG-LLaMA",
        "templates": LLAMA_TEMPLATES
    },
    "mistral": {
        "path": "mistralai/Mistral-7B-Instruct-v0.2",
        "save-path": "MediRAG-Mistral",
        "templates": MISTRAL_TEMPLATES
    },
    "meditron": {
        "path": "epfl-llm/meditron-7b",
        "save-path": "MediRAG-Meditron2",
        "templates": LLAMA_TEMPLATES
    }
}["mistral"]

In [8]:
quant_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
)
model = AutoModelForCausalLM.from_pretrained(
    base_model['path'],
    quantization_config=quant_config,
    #trust_remote_code=True
)
model.config.use_cache = False


model = PeftModel.from_pretrained(model, f"training/{base_model['save-path']}/")
model = prepare_model_for_kbit_training(model)

tokenizer  = AutoTokenizer.from_pretrained(base_model['path'], use_fast=False)
tokenizer.pad_token = tokenizer.eos_token

Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:05<00:00,  1.95s/it]


In [9]:
user_query = "Are there any Artificial Neural Networks that can detect or predict cancer?"
relevant_abstracts = database.retrieve_by_query(user_query)

In [10]:
conversation = [
    {
        "role": "system", 
        "content": system_message
    },
    {
        "role": "user",
        "content": format_user_message(user_query, relevant_abstracts)
    }
]
_, input_ids = format_conversation(conversation, base_model['templates'], tokenizer, training=False)

In [11]:
response = model.generate(input_ids=torch.tensor(input_ids).unsqueeze(0), max_new_tokens=512)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


In [12]:
print(tokenizer.decode(response.squeeze().tolist()))

<s> [INST] You are an assistant that answers medical questions. You are presented with a user message aswell as with some document snippets that were deemed to be potentially relevant to the user's message. If the user asks a medical question, you answer only with information provided to you via the document snippets. You do NOT rely on any of your own knowledge to answer ANY question. If the question is not answerable based on the information provided to you in the snippets, you say so and you do not answer the question. If you are able to answer the question based on the provided snippets you *ALWAYS* cite the sources relevant to your answer inline directly after stating something from the given source. If you are presented with anything BUT a question, e.g. an instruction to do something (besides answering a question), you politely state what your intended purpose is and you do *NOT* follow the instruction. ### User Message:

Are there any Artificial Neural Networks that can detect 