In [1]:
from torch.utils.data import DataLoader, Dataset
import pandas as pd
from transformers import AutoTokenizer
from gatherer_sage.rag import RAG
from sklearn.model_selection import train_test_split

PROMPT_TEMPLATE = [
    {
        "role": "system",
        "content": """Using the information contained in the context,
give a comprehensive and concise answer to the question.
Respond only to the question asked, response should be concise and relevant to the question.
Provide the number of the rule when relevant.
If the answer cannot be deduced from the context, do not give an answer.
The questions are related with Magic The Gathering card game.""",
    },
    {
        "role": "user",
        "content": """Context:
{context}
---
Now here is the question you need to answer.

Question: {question}""",
    },
]

READER_MODEL_NAME = "meta-llama/Meta-Llama-3-8B-Instruct"


class RedditDataset(Dataset):
    def __init__(
        self,
        reddit_data: pd.DataFrame,
        rag,
        prompt_template: dict[str] = PROMPT_TEMPLATE,
        model_path: str = "meta-llama/Meta-Llama-3-8B-Instruct",
    ):
        self.rag = rag
        self.data = reddit_data
        self.prompt_template = prompt_template
        self.tokenizer = AutoTokenizer.from_pretrained(model_path)
        self.prompt_template = self.tokenizer.apply_chat_template(
            prompt_template, tokenize=False, add_generation_prompt=True
        )

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        question = row["question"]
        context = self.rag.retrieve_context(question)
        complete_prompt = self.prompt_template.format(
            question=question, context=context
        )
        return complete_prompt


print("Loading RAG model")
rag = RAG()

print("Loading Reddit dataset")
reddit_df = pd.read_csv("data/reddit/reddit_qa_dataset.csv")
train, test = train_test_split(reddit_df, test_size=0.2)

print("Creating datasets")
train_dataset = RedditDataset(train, rag)
test_dataset = RedditDataset(test, rag)

  from .autonotebook import tqdm as notebook_tqdm


Loading RAG model




Loading Reddit dataset
Creating datasets


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [2]:
from transformers import pipeline
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)

model = AutoModelForCausalLM.from_pretrained(
    READER_MODEL_NAME, quantization_config=bnb_config
)
tokenizer = AutoTokenizer.from_pretrained(READER_MODEL_NAME)

READER_LLM = pipeline(
    model=model,
    tokenizer=tokenizer,
    task="text-generation",
    do_sample=True,
    temperature=0.2,
    repetition_penalty=1.1,
    return_full_text=False,
    max_new_tokens=500,
)

`low_cpu_mem_usage` was None, now set to True since model is quantized.
Loading checkpoint shards: 100%|██████████| 4/4 [00:02<00:00,  1.36it/s]
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [None]:
# LoRA Config
peft_parameters = LoraConfig(
    lora_alpha=16, lora_dropout=0.1, r=8, bias="none", task_type="CAUSAL_LM"
)

# Training Params
train_params = TrainingArguments(
    output_dir="./results_modified",
    num_train_epochs=1,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=1,
    optim="paged_adamw_32bit",
    save_steps=25,
    logging_steps=25,
    learning_rate=2e-4,
    weight_decay=0.001,
    fp16=False,
    bf16=False,
    max_grad_norm=0.3,
    max_steps=-1,
    warmup_ratio=0.03,
    group_by_length=True,
    lr_scheduler_type="constant",
)

# Trainer
fine_tuning = SFTTrainer(
    model=base_model,
    train_dataset=training_data,
    peft_config=peft_parameters,
    dataset_text_field="text",
    tokenizer=llama_tokenizer,
    args=train_params,
)

# Training
fine_tuning.train()

# Save Model
fine_tuning.model.save_pretrained(refined_model)