## Importing Libraries

In [None]:
import os
from dotenv import load_dotenv
import locale
import pandas as pd
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from typing import Optional, List, Tuple

# PyTorch
import torch

# Hugging Face
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, TextStreamer
from sentence_transformers import SentenceTransformer
from datasets import Dataset

# Langchain
from langchain.document_loaders import GitHubIssuesLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceHubEmbeddings
#from langchain_community.document_loaders.csv_loader import CSVLoader

In [None]:
# Set locale to UTF-8
locale.getpreferredencoding = lambda: 'UTF-8'
# Set pandas display options
pd.set_option("display.max_colwidth", None)

## Login to GitHub

In [None]:
load_dotenv()
access_token = os.getenv("GITHUB_TOKEN")
loader = GitHubIssuesLoader(
    repo="huggingface/transformers",
    access_token=access_token,  # ADD YOUR GITHUB TOKEN HERE
    include_prs=False,
    state="all"
)

## Device

In [None]:
# Device setup
device = (
    "cuda:0" if torch.cuda.is_available() else # Nvidia GPU
    "mps" if torch.backends.mps.is_available() else # Apple Silicon GPU
    "cpu"
)
print(f"Device = {device}")

In [None]:
# Flash Attention Implementation
if device == "cuda:0":
    if torch.cuda.get_device_capability()[0] >= 8: # Ampere, Ada, or Hopper GPUs
        attn_implementation = "flash_attention_2"
        torch_dtype = torch.bfloat16
    else:
        attn_implementation = "eager"
        torch_dtype = torch.float16
else:
    attn_implementation = "eager"
    torch_dtype = torch.float32
print(f"Attention Implementation = {attn_implementation}")

## Hyperparameters

In [None]:
################################################################################
# Tokenizer parameters
################################################################################
max_length=8192
padding="do_not_pad"  # "max_length", "longest", "do_not_pad"
truncation=True

################################################################################
# Generation parameters
################################################################################
num_return_sequences=1
max_new_tokens=1024
do_sample=True  # True for sampling, False for greedy decoding
temperature=0.6
top_p=0.9
repetition_penalty=1.1

################################################################################
# bitsandbytes parameters
################################################################################
load_in_4bit=True
bnb_4bit_compute_dtype=torch_dtype
bnb_4bit_quant_type="nf4"  # "nf4", #fp4"
bnb_4bit_use_double_quant=True

################################################################################
# Retriever parameters
################################################################################
top_k=5
chunk_size=100  # The maximum number of characters in a chunk
chunk_overlap=20  # The number of characters to overlap between chunks
add_start_index=True  # If `True`, includes chunk's start index in metadata
strip_whitespace=True  # If `True`, strips whitespace from the start and end of every document
MARKDOWN_SEPARATORS = [
    "\n#{1,6} ",
    "```\n",
    "\n\\*\\*\\*+\n",
    "\n---+\n",
    "\n___+\n",
    "\n\n",
    "\n",
    " ",
    "",
]

## Model

In [None]:
# Model ID
model_id = "PathFinderKR/Waktaverse-Llama-3-KO-8B-Instruct"

In [None]:
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id)
streamer = TextStreamer(tokenizer)

In [None]:
# Quantization
quantization_config = BitsAndBytesConfig(
    load_in_4bit=load_in_4bit,
    bnb_4bit_compute_dtype=bnb_4bit_compute_dtype,
    bnb_4bit_quant_type=bnb_4bit_quant_type,
    bnb_4bit_use_double_quant=bnb_4bit_use_double_quant
)

In [None]:
# Load model
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map=device,
    attn_implementation=attn_implementation,
    torch_dtype=torch_dtype,
    quantization_config=quantization_config
)

## Documents

In [None]:
# Load documents
documents = loader.load()
# Split documents
splitter = RecursiveCharacterTextSplitter(
    chunk_size=chunk_size,
    chunk_overlap=chunk_overlap,
    add_start_index=add_start_index,
    strip_whitespace=strip_whitespace,
    markdown_separators=MARKDOWN_SEPARATORS
)
documents = splitter.split(documents)

## Retriever

In [None]:
# Embedding model ID
embedding_model_id = "sentence-transformers/all-mpnet-base-v2"

In [None]:
# Embeddings
embeddings = HuggingFaceHubEmbeddings(
    model_name=embedding_model_id
)

In [None]:
# Vector database
vector_database = FAISS.from_documents(
    documents,
    embeddings=embeddings
)

In [None]:
# Retriever
retriever = vector_database.retriever(
    search_type="similarity",
    search_kwargs={"k": top_k}
)

## Inference

In [None]:
def prompt_template(context, question):
    return (
        "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n"
        "You are Korean. Use Korean only. 한국어만 사용하세요.\n"
        "Answer the question based on your knowledge. Use the following context to help:\n"
        f"{context}<|eot_id|>"
        
        "<|start_header_id|>user<|end_header_id|>\n\n"
        f"{question}<|eot_id|>"
        
        "<|start_header_id|>assistant<|end_header_id|>\n\n"
    )

In [None]:
def generate_response(context, question):
    prompt = prompt_template(context, question)
    
    input_ids = tokenizer.encode(
        prompt,
        max_length=max_length,
        padding=padding,
        truncation=truncation,
        add_special_tokens=True,
        return_tensors="pt"
    ).to(device)
    
    outputs = model.generate(
        input_ids=input_ids,
        pad_token_id=tokenizer.eos_token_id,
        num_return_sequences=num_return_sequences,
        max_new_tokens=max_new_tokens,
        do_sample=do_sample,
        temperature=temperature,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        streamer=streamer
    )

    return tokenizer.decode(outputs[0], skip_special_tokens=False)