<a href="https://colab.research.google.com/github/OlaleyeFaithfulness/Unicorn_Companies_Rag/blob/main/unicorn_rag_pipeline.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import torch

# -------------------------------------------------
# MODEL CONFIG
# -------------------------------------------------

MODEL_NAME = "unsloth/Qwen2.5-7B-Instruct-bnb-4bit"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
)

print("Loading model...")

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    quantization_config=bnb_config,
    device_map="auto"
)

model.eval()

print("Model loaded successfully.")


# **Embedding model**

In [None]:
# Setting Up Embedding Model & Imports

# Imports
from langchain_community.embeddings import SentenceTransformerEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_core.documents import Document  # LangChain core Document class
from langchain_community.vectorstores.faiss import FAISS as LCFAISS
from langchain_core.prompts import PromptTemplate

import os
import torch
import faiss
import pandas as pd
import numpy as np
# ----------------------------
# 1. Load BGE embedding model
# ----------------------------
embedding_model = SentenceTransformerEmbeddings(
    model_name="BAAI/bge-large-en",
    model_kwargs={"device": "cuda" if torch.cuda.is_available() else "cpu"}
)

print("BAAI/bge-large-en embedding model loaded. Device:", "cuda" if torch.cuda.is_available() else "cpu")

# **Load and clean the csv**

In [None]:
# ----------------------------
# Load Dataset
# ----------------------------

DATA_PATH = "datasets/unicorn_companies.csv"

df = pd.read_excel(DATA_PATH)

print(f"Dataset loaded successfully.")
print(f"DF Shape: {df.shape}")
print(f"DF Columns: {list(df.columns)}")
print(df.head(5))


In [None]:
# ----------------------------
# Cleaning Dataset
# ----------------------------

# Standardize column names
df.columns = (
    df.columns
      .str.strip()                 # remove whitespace
      .str.lower()                 # lowercase
      .str.replace(" ", "_", regex=True)  # spaces -> underscores
      .str.replace("(", "", regex=True)
      .str.replace(")", "", regex=True)
      .str.replace("$", "", regex=True)
)

# Rename specific columns for clarity
df.rename(columns={"valuation_b": "valuation"}, inplace=True)

# Quick sanity checks
print("Columns after cleaning:")
print(df.columns)

print("\nData types:")
print(df.dtypes)

print("\nDataFrame info:")
df.info()

print("\nRows with missing 'select_investors':")
print(df[df["select_investors"].isna()])


# **Creating langchain docs and metadata**

In [None]:
# ----------------------------
# Create LangChain Documents with Metadata
# ----------------------------

from langchain_core.documents import Document

lc_docs = []

for i, row in df.iterrows():
    # Handle missing investors
    investors = row["select_investors"]
    investors_text = (
        f"Select investors: {investors}" if pd.notna(investors)
        else "Select investors not publicly listed"
    )

    # Format date safely
    try:
        pretty_date = row["date_added"].strftime("%B %d, %Y")
    except Exception:
        pretty_date = "Date not available"

    # Build the text to embed
    page_content = (
        f"Company: {row['company']}\n"
        f"Category: {row['category']}\n"
        f"Country: {row['country']}\n"
        f"Valuation: ${row['valuation']}B\n"
        f"Date added to unicorn list: {pretty_date}\n"
        f"{investors_text}"
    )

    # Metadata for retrieval
    metadata = {
        "row_id": int(i),
        "company": row["company"],
        "country": row["country"],
        "category": row["category"],
        "valuation": row["valuation"],
        "date_added": row["date_added"]  # raw datetime
    }

    lc_docs.append(Document(page_content=page_content, metadata=metadata))

print(f"Created {len(lc_docs)} Document objects.")

# Optional: inspect some entries
print("\nSample page content:")
print(lc_docs[0].page_content)

print("\nSample metadata:")
print(lc_docs[0].metadata)


# **Building vectorestore**

In [None]:
# ----------------------------
# Embed Documents into FAISS Vector Store
# ----------------------------

# Build the vector store
vectorstore = LCFAISS.from_documents(
    documents=lc_docs,
    embedding=embedding_model,
    normalize_L2=True
)

print(f"Vector store built successfully. Documents indexed: {len(lc_docs)}")


# **Build retriever**

In [None]:
# ----------------------------
# Retriever Setup
# ----------------------------

# Standard retriever (top 5 similar docs)
retriever = vectorstore.as_retriever(
    search_type="similarity",
    search_kwargs={"k": 5}  # tune number of results
)

print("Retriever ready. Returns top 5 similar documents.")

# Optional: metadata-filtered retriever example
# Uncomment to use
# category_retriever = vectorstore.as_retriever(
#     search_type="similarity",
#     search_kwargs={"k": 4},
#     filter={"category": "Artificial intelligence"}  # filters docs BEFORE search
# )
# print("Category-filtered retriever ready for AI category.")


# **Check similarity scores**

In [None]:
# ----------------------------
# Query Vector Store with Similarity Scores
# ----------------------------

query = "tell me about companies that achieved unicorn status in 2021"

# Perform similarity search
results_with_scores = vectorstore.similarity_search_with_score(query, k=4)

print(f"Top {len(results_with_scores)} results for query: '{query}'\n")

for i, (doc, score) in enumerate(results_with_scores, start=1):
    print(f"Result {i}")
    print(f"Similarity score: {score:.4f}")  # FAISS score (depends on metric)
    print(f"Metadata: {doc.metadata}")
    print(f"Text preview:\n{doc.page_content[:300]}{'...' if len(doc.page_content) > 300 else ''}")
    print("-" * 60)


# **Build the prompt**

In [None]:
# ----------------------------
# Build Prompt Template
# ----------------------------

from langchain_core.prompts import PromptTemplate

PROMPT = PromptTemplate.from_template("""
You are CHAD, CHAD stands for (Computational Hyper-Advanced Decoder), a personal AI created by Olaleye Faithfulness Ibukun.

You know about unicorn companies and their info ONLY from the context provided below.

------------------
CONTEXT:
{context}
------------------

QUESTION:
{question}

ANSWER:
""")

print("Prompt template created successfully.")


In [None]:
# ----------------------------
# Helper Function: Format Docs into Context
# ----------------------------
def format_docs(docs):
    """Combine retrieved Document objects into a single string context."""
    return "\n\n".join(doc.page_content for doc in docs)


# ----------------------------
# Ask CHAD: RAG + LLM Chain
# ----------------------------
def ask_chad(question):
    """
    Retrieve relevant documents, build context, and generate an answer using CHAD.

    Args:
        question (str): User query

    Returns:
        str: Answer generated by CHAD
    """
    # Step A — Retrieve documents
    docs = retriever.invoke(question)

    # Step B — Build context string
    context = format_docs(docs) if docs else "NO RELEVANT CONTEXT FOUND."

    # Step C — Fill prompt template
    prompt = PROMPT.format(
        context=context,
        question=question
    )

    # Step D — Tokenize for Qwen2.5
    inputs = tokenizer(
        prompt,
        return_tensors="pt"
    ).to(model.device)

    # Step E — Generate response
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=300,
            do_sample=True,
            temperature=0.4,
            top_p=0.6,
            eos_token_id=tokenizer.eos_token_id  # stop at end-of-sequence
        )

    # Step F — Decode output
    answer = tokenizer.decode(
        outputs[0],
        skip_special_tokens=True
    )

    # Step G — Remove prompt echo if present
    answer_only = answer.split("ANSWER:")[-1].strip()

    return answer_only


# **Testing the AI**

In [None]:
question = "Which companies reached unicorn status in 2021?"
answer = ask_chad(question)
print("CHAD says:\n", answer)
