## Importing Libraries

In [1]:
import os
from dotenv import load_dotenv
import locale
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import pacmap
import plotly.express as px

# PyTorch
import torch

# Hugging Face
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, TextStreamer
from datasets import load_dataset

# cohere
import cohere

In [2]:
# Set locale to UTF-8
locale.getpreferredencoding = lambda: 'UTF-8'
# Set pandas display options
pd.set_option("display.max_colwidth", None)
# # Set MKL_THREADING_LAYER to GNU
os.environ['MKL_THREADING_LAYER']='GNU'

## Login to cohere

In [None]:
load_dotenv()
api_key = os.getenv("COHERE_API_KEY")
co = cohere.Client(api_key)

## Device

In [3]:
# Device setup
device = (
    "cuda:0" if torch.cuda.is_available() else # Nvidia GPU
    "mps" if torch.backends.mps.is_available() else # Apple Silicon GPU
    "cpu"
)
print(f"Device = {device}")

Device = cuda:0


In [4]:
# Flash Attention Implementation
if device == "cuda:0":
    if torch.cuda.get_device_capability()[0] >= 8: # Ampere, Ada, or Hopper GPUs
        attn_implementation = "flash_attention_2"
        torch_dtype = torch.bfloat16
    else:
        attn_implementation = "eager"
        torch_dtype = torch.float16
else:
    attn_implementation = "eager"
    torch_dtype = torch.float32
print(f"Attention Implementation = {attn_implementation}")

Attention Implementation = flash_attention_2


## Hyperparameters

In [5]:
################################################################################
# Tokenizer parameters
################################################################################
max_length=8192
padding="do_not_pad"  # "max_length", "longest", "do_not_pad"
truncation=True

################################################################################
# Generation parameters
################################################################################
num_return_sequences=1
max_new_tokens=1024
do_sample=True  # True for sampling, False for greedy decoding
temperature=0.6
top_p=0.9
repetition_penalty=1.1

################################################################################
# bitsandbytes parameters
################################################################################
load_in_4bit=True
bnb_4bit_compute_dtype=torch_dtype
bnb_4bit_quant_type="nf4"  # "nf4", #fp4"
bnb_4bit_use_double_quant=True

################################################################################
# Retriever parameters
################################################################################
k=5

## Model

In [None]:
# Model ID
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"

In [None]:
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id)
streamer = TextStreamer(tokenizer)

In [None]:
# Quantization
quantization_config = BitsAndBytesConfig(
    load_in_4bit=load_in_4bit,
    bnb_4bit_compute_dtype=bnb_4bit_compute_dtype,
    bnb_4bit_quant_type=bnb_4bit_quant_type,
    bnb_4bit_use_double_quant=bnb_4bit_use_double_quant
)

In [None]:
# Load model
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map=device,
    attn_implementation=attn_implementation,
    torch_dtype=torch_dtype,
    quantization_config=quantization_config
)

In [6]:
# Embedding model ID
embedding_model_id = "Cohere/Cohere-embed-multilingual-light-v3.0"

In [None]:
# Reranking model ID
reranking_model_id = "rerank-multilingual-v3.0"

## Documents

In [9]:
# Document ID
document_id = "Cohere/wikipedia-22-12-ko-embeddings"

In [10]:
# Load documents
documents = load_dataset(document_id,  split="train")
documents = torch.tensor(documents)

  0%|          | 0/647897 [00:00<?, ?it/s]

## RAG

In [None]:
def prompt_template(context, question):
    return (
        "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n"
        "You are Korean. Use Korean only. 한국어만 사용하세요.\n"
        "Using the information contained in the context, give a comprehensive answer to the question. Respond only to the question asked. \n"
        "<|eot_id|>"
        
        "<|start_header_id|>user<|end_header_id|>\n\n"
        "###Context\n"
        f"{context}\n"
        f"###Question: {question}<|eot_id|>"
        
        "<|start_header_id|>assistant<|end_header_id|>\n\n"
    )

In [None]:
def generate_response(query):
    print("=> Retrieving documents...")
    query_embedding = co.embed(
        texts=query, 
        model=embedding_model_id,
        input_type='search_query',
        embedding_types=['float']
    ).embeddings
    query_embedding = torch.tensor(query_embedding)
    
    dot_scores = torch.mm(query_embedding, documents.transpose(0, 1))
    top_k = torch.topk(dot_scores, k=k)
    
    retrieved_docs = []
    for doc_id in top_k.indices[0].tolist():
        retrieved_docs.append(documents[doc_id])

    print("=> Generating response...")
    prompt = prompt_template(retrieved_docs, query)
    
    input_ids = tokenizer.encode(
        prompt,
        max_length=max_length,
        padding=padding,
        truncation=truncation,
        add_special_tokens=True,
        return_tensors="pt"
    ).to(device)
    
    outputs = model.generate(
        input_ids=input_ids,
        pad_token_id=tokenizer.eos_token_id,
        num_return_sequences=num_return_sequences,
        max_new_tokens=max_new_tokens,
        do_sample=do_sample,
        temperature=temperature,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        streamer=streamer
    )
    
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

In [None]:
user_prompt = "한국의 대통령은 누구인가요?"

In [None]:
print(f"Retrieval for {user_prompt}...")
print("\n==================================Top document==================================")
query_embedding = co.embed(
    texts=user_prompt, 
    model=embedding_model_id,
    input_type='search_query',
    embedding_types=['float']
).embeddings
query_embedding = torch.tensor(query_embedding)

dot_scores = torch.mm(query_embedding, documents.transpose(0, 1))
top_k = torch.topk(dot_scores, k=k)

docs = []
for doc_id in top_k.indices[0].tolist():
    docs.append(documents[doc_id])
    print("====================================================================")
print("==================================Metadata==================================")
for k in range(len(docs)):
    print(f"Document {k+1}:")
    print(docs[k].metadata)
print("====================================================================\n")

In [None]:
response = generate_response(user_prompt)