## Importing Libraries

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from typing import Optional, List, Tuple

# PyTorch
import torch

# Hugging Face
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, TextStreamer
from sentence_transformers import SentenceTransformer
from datasets import Dataset

# Langchain
from langchain_huggingface import ChatHuggingFace
from langchain_community.document_loaders import CSVLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter

In [None]:
pd.set_option("display.max_colwidth", None)

## Device

In [None]:
# Device setup
device = (
    "cuda:0" if torch.cuda.is_available() else # Nvidia GPU
    "mps" if torch.backends.mps.is_available() else # Apple Silicon GPU
    "cpu"
)
print(f"Device = {device}")

In [None]:
# Flash Attention Implementation
if device == "cuda:0":
    if torch.cuda.get_device_capability()[0] >= 8: # Ampere, Ada, or Hopper GPUs
        attn_implementation = "flash_attention_2"
        torch_dtype = torch.bfloat16
    else:
        attn_implementation = "eager"
        torch_dtype = torch.float16
else:
    attn_implementation = "eager"
    torch_dtype = torch.float32
print(f"Attention Implementation = {attn_implementation}")

## Hyperparameters

In [None]:
################################################################################
# Tokenizer parameters
################################################################################
max_length=8192
padding="do_not_pad"  # "max_length", "longest", "do_not_pad"
truncation=True

################################################################################
# Generation parameters
################################################################################
num_return_sequences=1
max_new_tokens=1024
do_sample=True  # True for sampling, False for greedy decoding
temperature=0.6
top_p=0.9
repetition_penalty=1.1

################################################################################
# bitsandbytes parameters
################################################################################
load_in_4bit=True
bnb_4bit_compute_dtype=torch_dtype
bnb_4bit_quant_type="nf4"  # "nf4", #fp4"
bnb_4bit_use_double_quant=True

################################################################################
# Retriever parameters
################################################################################
top_k = 5
chunk_size = 100  # The maximum number of characters in a chunk
chunk_overlap = 20  # The number of characters to overlap between chunks
add_start_index = True  # If `True`, includes chunk's start index in metadata
strip_whitespace = True  # If `True`, strips whitespace from the start and end of every document
MARKDOWN_SEPARATORS = [
    "\n#{1,6} ",
    "```\n",
    "\n\\*\\*\\*+\n",
    "\n---+\n",
    "\n___+\n",
    "\n\n",
    "\n",
    " ",
    "",
]

## Model

In [None]:
# Model ID
model_id = "PathFinderKR/Waktaverse-Llama-3-KO-8B-Instruct"

In [None]:
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id)
streamer = TextStreamer(tokenizer)

In [None]:
# Quantization
quantization_config = BitsAndBytesConfig(
    load_in_4bit=load_in_4bit,
    bnb_4bit_compute_dtype=bnb_4bit_compute_dtype,
    bnb_4bit_quant_type=bnb_4bit_quant_type,
    bnb_4bit_use_double_quant=bnb_4bit_use_double_quant
)

In [None]:
# Load model
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map=device,
    attn_implementation=attn_implementation,
    torch_dtype=torch_dtype,
    quantization_config=quantization_config
)

## Documents

In [None]:
# Load documents
loader = CSVLoader()
knowledge_base = loader.load("data/knowledge_base.csv")

In [None]:
text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=chunk_size,
    chunk_overlap=chunk_overlap,
    add_start_index=add_start_index,
    strip_whitespace=strip_whitespace,
    markdown_separators=MARKDOWN_SEPARATORS
)

In [None]:
docs_processed = []
for doc in RAW_KNOWLEDGE_BASE:
    docs_processed += text_splitter.split_documents([doc])

## Retriever

## Inference

In [None]:
def prompt_template(system, user):
    return (
        "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n"
        f"{system}<|eot_id|>"
        
        "<|start_header_id|>user<|end_header_id|>\n\n"
        f"{user}<|eot_id|>"
        
        "<|start_header_id|>assistant<|end_header_id|>\n\n"
    )

In [None]:
def generate_response(system ,user):
    prompt = prompt_template(system, user)
    
    input_ids = tokenizer.encode(
        prompt,
        max_length=max_length,
        padding=padding,
        truncation=truncation,
        add_special_tokens=True,
        return_tensors="pt"
    ).to(device)
    
    outputs = model.generate(
        input_ids=input_ids,
        pad_token_id=tokenizer.eos_token_id,
        num_return_sequences=num_return_sequences,
        max_new_tokens=max_new_tokens,
        do_sample=do_sample,
        temperature=temperature,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        streamer=streamer
    )

    return tokenizer.decode(outputs[0], skip_special_tokens=False)

In [None]:
system_prompt = 

In [None]:
user_prompt = (
    "Answer the following question based only on the provided context"
    f"문맥: {context}\n"
    f"질문: {question}"
)