In [None]:
from torch.utils.data import DataLoader, Dataset
import pandas as pd
from transformers import AutoTokenizer
from gatherer_sage.rag import RAG

PROMPT_TEMPLATE = [
    {
        "role": "system",
        "content": """Using the information contained in the context,
give a comprehensive and concise answer to the question.
Respond only to the question asked, response should be concise and relevant to the question.
Provide the number of the rule when relevant.
If the answer cannot be deduced from the context, do not give an answer.
The questions are related with Magic The Gathering card game.""",
    },
    {
        "role": "user",
        "content": """Context:
{context}
---
Now here is the question you need to answer.

Question: {question}""",
    },
]

READER_MODEL_NAME = "meta-llama/Meta-Llama-3-8B-Instruct"


class RedditDataset(Dataset):
    def __init__(
        self,
        reddit_data: pd.DataFrame,
        rag,
        prompt_template=PROMPT_TEMPLATE,
        model_path="meta-llama/Meta-Llama-3-8B-Instruct",
    ):
        self.rag = rag
        self.data = reddit_data
        self.prompt_template = prompt_template
        self.tokenizer = AutoTokenizer.from_pretrained(model_path)
        self.prompt_template = self.tokenizer.apply_chat_template(
            prompt_template, tokenize=False, add_generation_prompt=True
        )

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        question = row["question"]
        context = self.rag.retrieve_context(question)
        complete_prompt = self.prompt_template.format(
            question=question, context=context
        )
        return complete_prompt


rag = RAG()

In [None]:
from transformers import pipeline
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)

model = AutoModelForCausalLM.from_pretrained(
    READER_MODEL_NAME, quantization_config=bnb_config
)
tokenizer = AutoTokenizer.from_pretrained(READER_MODEL_NAME)

READER_LLM = pipeline(
    model=model,
    tokenizer=tokenizer,
    task="text-generation",
    do_sample=True,
    temperature=0.2,
    repetition_penalty=1.1,
    return_full_text=False,
    max_new_tokens=500,
)