# Retrieval Augmented Generation (RAG) with Langchain
*Using IBM Granite Models*

## In this notebook
This notebook contains instructions for performing Retrieval Augumented Generation (RAG). RAG is an architectural pattern that can be used to augment the performance of language models by recalling factual information from a knowledge base, and adding that information to the model query. The most common approach in RAG is to create dense vector representations of the knowledge base in order to retrieve text chunks that are semantically similar to a given user query.

RAG use cases include:
- Customer service: Answering questions about a product or service using facts from the product documentation.
- Domain knowledge: Exploring a specialized domain (e.g., finance) using facts from papers or articles in the knowledge base.
- News chat: Chatting about current events by calling up relevant recent news articles.

In its simplest form, RAG requires 3 steps:

- Initial setup:
  - Index knowledge-base passages for efficient retrieval. In this recipe, we take embeddings of the passages, and store them in a vector database.
- Upon each user query:
  - Retrieve relevant passages from the database. In this recipe, we use an embedding of the query to retrieve semantically similar passages.
  - Generate a response by feeding retrieved passage into a large language model, along with the user query.

## Setting up the environment

Ensure you are running python 3.10, 3.11, or 3.12 in a freshly-created virtual environment.

In [None]:
import sys
assert sys.version_info >= (3, 10) and sys.version_info < (3, 13), "Use Python 3.10, 3.11, or 3.12"

### Install dependencies

Granite utils provides some helpful functions for recipes.

In [None]:
%pip install git+https://github.com/ibm-granite-community/utils \
    transformers \
    langchain_community \
    'langchain_huggingface[full]' \
    langchain_milvus \
    replicate \
    wget \
    pypdf \
    tiktoken  # For better text splitting

Collecting git+https://github.com/ibm-granite-community/utils
  Cloning https://github.com/ibm-granite-community/utils to /tmp/pip-req-build-7z91tfir
  Running command git clone --filter=blob:none --quiet https://github.com/ibm-granite-community/utils /tmp/pip-req-build-7z91tfir
  Resolved https://github.com/ibm-granite-community/utils to commit fa8fc9d58f232f53149a43007a9008e2ceb80d3c
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting langchain_community
  Downloading langchain_community-0.3.27-py3-none-any.whl.metadata (2.9 kB)
Collecting langchain_milvus
  Downloading langchain_milvus-0.2.1-py3-none-any.whl.metadata (3.8 kB)
Collecting replicate
  Downloading replicate-1.0.7-py3-none-any.whl.metadata (29 kB)
Collecting wget
  Downloading wget-3.2.zip (10 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting pypdf
  Downloading pypdf-5.9.0-p

## Selecting System Components

### Choose your Embeddings Model

Specify the model to use for generating embedding vectors from text.

To use a model from a provider other than Huggingface, replace this code cell with one from [this Embeddings Model recipe](https://github.com/ibm-granite-community/granite-kitchen/blob/main/recipes/Components/Langchain_Embeddings_Models.ipynb).

In [None]:
# --- Component Setup ---
from langchain_huggingface import HuggingFaceEmbeddings
from transformers import AutoTokenizer
from langchain_milvus import Milvus
import tempfile
from langchain_community.llms import Replicate
from ibm_granite_community.notebook_utils import get_env_var
from ibm_granite_community.langchain import TokenizerChatPromptTemplate, create_stuff_documents_chain
from langchain.chains.retrieval import create_retrieval_chain
import os
import wget
from langchain.document_loaders import TextLoader, PyPDFLoader, WebBaseLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
import re



In [None]:
embeddings_model_path = "ibm-granite/granite-embedding-30m-english"
embeddings_model = HuggingFaceEmbeddings(model_name=embeddings_model_path)
embeddings_tokenizer = AutoTokenizer.from_pretrained(embeddings_model_path)

# 2. Vector Database
db_file = tempfile.NamedTemporaryFile(prefix="milvus_", suffix=".db", delete=False).name
print(f"Vector database: {db_file}")

vector_db = Milvus(
    embedding_function=embeddings_model,
    connection_args={"uri": db_file},
    auto_id=True,
    index_params={"index_type": "AUTOINDEX"},
)

# 3. LLM Setup
model_path = "ibm-granite/granite-3.3-8b-instruct"
model = Replicate(
    model=model_path,
    replicate_api_token=get_env_var('REPLICATE_API_TOKEN'),
    model_kwargs={
        "temperature": 0.3,
        "max_length": 2048,
        "top_p": 0.9
    }
)
tokenizer = AutoTokenizer.from_pretrained(model_path)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


modules.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/54.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/683 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/60.6M [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/191 [00:00<?, ?B/s]

Vector database: /tmp/milvus__n_79xe0.db
REPLICATE_API_TOKEN loaded from Google Colab secret.


tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

added_tokens.json:   0%|          | 0.00/207 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/801 [00:00<?, ?B/s]

### Choose your Vector Database

Specify the database to use for storing and retrieving embedding vectors.

To connect to a vector database other than Milvus substitute this code cell with one from [this Vector Store recipe](https://github.com/ibm-granite-community/granite-kitchen/blob/main/recipes/Components/Langchain_Vector_Stores.ipynb).

In [None]:
documents = []

# Startup knowledge sources (add more as needed)
sources = [
    "https://www.startupindia.gov.in/content/dam/invest-india/Templates/public/Startup_Definition_PRN.pdf",
    "https://www.dpiit.gov.in/sites/default/files/StartupIndiaActionPlan_16January2016.pdf",
    "https://www.rbi.org.in/commonperson/English/Scripts/Notification.aspx?Id=3310",
    "https://www.mca.gov.in/MinistryV2/companiesact.html",
    "https://www.investindia.gov.in/schemes"
]

for url in sources:
    try:
        if url.endswith('.pdf'):
            loader = PyPDFLoader(url)
        else:
            loader = WebBaseLoader(url)
        documents.extend(loader.load())
        print(f"Loaded: {url}")
    except Exception as e:
        print(f"Failed to load {url}: {str(e)}")

# Custom text splitter for business documents
business_text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=512,
    chunk_overlap=50,
    length_function=lambda x: len(embeddings_tokenizer.encode(x)),
    separators=["\n\n", "\n", ". ", "! ", "? ", ", ", " ", ""]
)

Failed to load https://www.startupindia.gov.in/content/dam/invest-india/Templates/public/Startup_Definition_PRN.pdf: Check the url of your file; returned status code 403
Failed to load https://www.dpiit.gov.in/sites/default/files/StartupIndiaActionPlan_16January2016.pdf: HTTPSConnectionPool(host='www.dpiit.gov.in', port=443): Max retries exceeded with url: /sites/default/files/StartupIndiaActionPlan_16January2016.pdf (Caused by SSLError(SSLCertVerificationError(1, "[SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: Hostname mismatch, certificate is not valid for 'www.dpiit.gov.in'. (_ssl.c:1016)")))
Loaded: https://www.rbi.org.in/commonperson/English/Scripts/Notification.aspx?Id=3310
Loaded: https://www.mca.gov.in/MinistryV2/companiesact.html
Loaded: https://www.investindia.gov.in/schemes


### Choose your LLM
The LLM will be used for answering the question, given the retrieved text.

Select a Granite Code model from the [`ibm-granite`](https://replicate.com/ibm-granite) org on Replicate. Here we use the Replicate Langchain client to connect to the model.

To connect to a model on a provider other than Replicate, substitute this code cell with one from the [LLM component recipe](https://github.com/ibm-granite-community/granite-kitchen/blob/main/recipes/Components/Langchain_LLMs.ipynb).

In [None]:
texts = []
for doc in documents:
    # Clean and preprocess text
    clean_content = re.sub(r'\s+', ' ', doc.page_content)  # Remove extra whitespace
    clean_content = re.sub(r'\[.*?\]', '', clean_content)  # Remove references/citations
    chunks = business_text_splitter.split_text(clean_content)
    texts.extend([doc.__class__(page_content=chunk, metadata=doc.metadata) for chunk in chunks])

print(f"{len(texts)} knowledge chunks created")

# Add to vector database
ids = vector_db.add_documents(texts)
print(f"{len(ids)} documents indexed")

6 knowledge chunks created
6 documents indexed


## Building the Vector Database

In this example, we take the State of the Union speech text, split it into chunks, derive embedding vectors using the embedding model, and load it into the vector database for querying.

### Download the document

Here we use President Biden's State of the Union address from March 1, 2022.

In [None]:
# --- RAG Pipeline ---
# Specialized prompt template for business blueprints
BUSINESS_PROMPT_TEMPLATE = """
You are StartupGPT, an expert AI business consultant. Generate a comprehensive business blueprint using the following context:

Startup Idea: {input}

Context:
{context}

Blueprint Structure:
1. Business Model Canvas:
   - Value Proposition: [Unique value]
   - Customer Segments: [Target customers]
   - Channels: [Distribution methods]
   - Revenue Streams: [Monetization strategy]
   - Cost Structure: [Key expenses]

2. Funding Strategy:
   - Estimated Initial Budget: [Amount with breakdown]
   - Recommended Funding Sources: [VCs, angels, grants]
   - Government Schemes: [Applicable programs]

3. Market Analysis:
   - Target Market Size: [Estimate]
   - Top Competitors: [3-5 competitors]
   - Differentiation Strategy: [Competitive advantage]

4. Go-to-Market Plan:
   - Launch Timeline: [3-6 month plan]
   - Customer Acquisition: [Marketing strategy]
   - Key Metrics: [KPIs to track]

5. Legal & Compliance:
   - Business Structure: [LLP, Pvt Ltd, etc]
   - Key Registrations: [GST, DPIIT, etc]
   - Intellectual Property: [Patents/trademarks needed]

6. Investor Connections:
   - Recommended Firms: [VC firms matching stage/domain]
   - Introduction Strategy: [How to approach]
"""

# Create prompt template
prompt_template = TokenizerChatPromptTemplate.from_template(
    BUSINESS_PROMPT_TEMPLATE,
    tokenizer=tokenizer
)


### Split the document into chunks

Split the document into text segments that can fit into the model's context window.

In [None]:
# --- Query Example ---
startup_idea = "Cake shop with drones as delivery"
blueprint = rag_chain.invoke({"input": startup_idea})

print("\n" + "="*50)

print("STARTUP BLUEPRINT GENERATED")
print("="*50)
print(blueprint['answer'])


STARTUP BLUEPRINT GENERATED
1. Business Model Canvas:
   - Value Proposition: A cake shop offering a wide variety of customizable, high-quality cakes, delivered swiftly via drones, ensuring convenience and timely delivery.
   - Customer Segments: Urban millennials, event organizers, corporate clients, and individuals seeking unique and high-quality cakes for special occasions.
   - Channels: Online ordering platform, social media, partnerships with event management companies, and corporate clients.
   - Revenue Streams: Online sales, catering services, and potential licensing of the drone delivery technology.
   - Cost Structure: Ingredients, labor, bakery equipment, drone maintenance, licensing, and regulatory compliance.

2. Funding Strategy:
   - Estimated Initial Budget: ₹50 lakhs (₹30 lakhs for bakery setup, ₹10 lakhs for drone technology, ₹5 lakhs for marketing, and ₹5 lakhs for miscellaneous expenses)
   - Recommended Funding Sources: Seed funding from angel investors, venture 

UI ***PART***


In [None]:
# After your RAG pipeline code, add this frontend section

from IPython.display import display, clear_output, HTML
import ipywidgets as widgets

# Assume the following are already defined from your RAG pipeline setup:
# model, prompt_template, vector_db
# from langchain.chains import create_stuff_documents_chain, create_retrieval_chain

# --- MOCK RAG PIPELINE FOR DEMONSTRATION ---
# This part is just to make the code runnable. You should use your actual pipeline.
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_community.llms.fake import FakeListLLM
from langchain_community.embeddings import FakeEmbeddings
from langchain_community.vectorstores import FAISS
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.chains import create_retrieval_chain

llm = FakeListLLM(responses=["""
**Startup Blueprint**

- **Problem:** Supermarkets discard tons of edible food due to minor blemishes or approaching sell-by dates.
- **Solution:** An AI platform that analyzes inventory data and customer behavior to dynamically price items, suggest promotions for near-expiry products, and redirect surplus food to charities.
- **Target Market:** Large to medium-sized supermarket chains.
- **Revenue Model:** Subscription-as-a-Service (SaaS) fee based on store size.
- **Key Features:**
    - Real-time inventory tracking integration.
    - Predictive pricing algorithm.
    - Automated donation and logistics coordination.
"""])
prompt_template = ChatPromptTemplate.from_template("Answer the user's question: {input}")
vector_db = FAISS.from_texts(["food waste is a problem", "ai can help"], embedding=FakeEmbeddings(size=1))
# --- END MOCK RAG PIPELINE ---


# Build RAG pipeline
combine_docs_chain = create_stuff_documents_chain(
    llm=llm,
    prompt=prompt_template,
)
rag_chain = create_retrieval_chain(
    retriever=vector_db.as_retriever(search_kwargs={"k": 1}), # Reduced k for mock
    combine_docs_chain=combine_docs_chain,
)

# Custom CSS for professional styling
css_style = """
<style>
.startup-card {
    background: white;
    border-radius: 10px;
    box-shadow: 0 4px 8px rgba(0,0,0,0.1);
    padding: 20px;
    margin: 15px 0;
    font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
}
.section-title {
    color: #1f3a93;
    border-bottom: 2px solid #1f3a93;
    padding-bottom: 5px;
    margin-top: 20px;
    font-weight: bold;
}
.idea-input {
    width: 100%;
    padding: 12px;
    border: 2px solid #1f3a93;
    border-radius: 5px;
    font-size: 16px;
}
.generate-btn {
    background: #1f3a93;
    color: white;
    border: none;
    padding: 12px 24px;
    font-size: 16px;
    border-radius: 5px;
    cursor: pointer;
    transition: background 0.3s;
}
.generate-btn:hover {
    background: #152b6b;
}
.output-area-custom { /* Renamed to avoid confusion with the widget name */
    background: #f8f9fa;
    border-left: 4px solid #1f3a93;
    padding: 15px;
    margin-top: 20px;
    white-space: pre-wrap; /* Ensures line breaks are respected */
}
</style>
"""

# Display custom CSS (still useful for widgets outside the Output area)
display(HTML(css_style))

# Create widgets
idea_input = widgets.Textarea(
    value='',
    placeholder='Describe your startup idea (e.g., "AI-powered food waste reduction for supermarkets")',
    description='',
    disabled=False,
    layout=widgets.Layout(width='100%', height='100px')
)
# Apply the CSS class to the input for better styling
idea_input.add_class("idea-input")


generate_btn = widgets.Button(
    description='Generate Blueprint',
    # We will use CSS classes instead of button_style for better control
    # button_style='success',
    layout=widgets.Layout(width='auto', margin='20px 0 20px 0')
)
# Apply our custom CSS class to the button
generate_btn.add_class("generate-btn")


output_area = widgets.Output() # No need for border here, CSS will handle it

# Progress indicator
progress = widgets.IntProgress(
    value=0,
    min=0,
    max=4,
    step=1,
    description='Processing:',
    bar_style='info',
    orientation='horizontal',
    layout=widgets.Layout(width='100%', visibility='hidden')
)

# Display widgets
display(widgets.VBox([
    widgets.HTML("<h2 style='color:#1f3a93; font-family: Segoe UI, sans-serif;'>Startup Blueprint Generator</h2>"),
    widgets.HTML("<p style='font-family: Segoe UI, sans-serif;'>Describe your business idea below to get a complete startup blueprint.</p>"),
    idea_input,
    generate_btn,
    progress,
    widgets.HTML("<div class='section-title'>Business Blueprint</div>"),
    output_area
]))

# Generation function
def generate_blueprint(btn):
    with output_area:
        clear_output()
        progress.value = 0
        progress.layout.visibility = 'visible'

        idea = idea_input.value.strip()
        if not idea:
            # Using HTML for a nicer-looking warning
            display(HTML("<p style='color:red;'>⚠️ Please enter a startup idea</p>"))
            progress.layout.visibility = 'hidden'
            return

        try:
            progress.description = "Analyzing idea..."
            progress.value = 1

            # Execute RAG pipeline
            blueprint = rag_chain.invoke({"input": idea})

            progress.value = 2
            progress.description = "Generating blueprint..."

            # Format the output from the LLM
            # This replacement is more robust for creating HTML lists
            result_html = blueprint['answer']
            result_html = result_html.replace('\n\n', '<br><br>')
            result_html = result_html.replace('\n- ', '<br><b>-</b> ')
            result_html = result_html.replace('**', '<b>') # Make bold text work

            progress.value = 3
            progress.description = "Formatting output..."

            # Display result with HTML formatting AND the required CSS
            display(HTML(f"""
            {css_style}
            <div class="startup-card">
                <h3 style="color:#1f3a93;">Your Idea: {idea}</h3>
                <div class="output-area-custom">
                    {result_html}
                </div>
            </div>
            """))

            progress.value = 4
            progress.description = "Complete!"

        except Exception as e:
            display(HTML(f"<p style='color:red;'>🚨 Error generating blueprint: {str(e)}</p>"))
        finally:
            # Hide progress bar after a short delay so user can see "Complete!"
            import time
            time.sleep(1)
            progress.layout.visibility = 'hidden'


# Attach event handler
generate_btn.on_click(generate_blueprint)

# Add sample ideas
display(widgets.HTML("""
<div class="startup-card">
    <h3 style="color:#1f3a93; font-weight: bold;">Sample Ideas to Try:</h3>
    <ul style="list-style-type: '💡'; padding-left: 20px;">
        <li>AI-powered tutoring platform for rural students</li>
        <li>Blockchain-based land registry system</li>
        <li>Eco-friendly packaging made from agricultural waste</li>
        <li>Telemedicine app for pet healthcare</li>
        <li>Vertical farming solutions for urban apartments</li>
    </ul>
</div>
"""))

ValueError: Prompt must accept context as an input variable. Received prompt with input variables: ['input']

In [None]:
%pip install faiss-cpu


Collecting faiss-cpu
  Downloading faiss_cpu-1.11.0.post1-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (5.0 kB)
Downloading faiss_cpu-1.11.0.post1-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (31.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m31.3/31.3 MB[0m [31m17.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: faiss-cpu
Successfully installed faiss-cpu-1.11.0.post1
