In [1]:
# Import Necessary Libraries

from dotenv import load_dotenv, dotenv_values
import google.generativeai as genai
from IPython.display import Markdown, display
import os

# Load .env file

load_dotenv()

# Set the TESSDATA_PREFIX environment variable

tessdata = os.getenv("TESSDATA_PREFIX")
os.environ['TESSDATA_PREFIX'] = tessdata

# Set the Google api key

my_api_key = os.getenv("GOOGLE_API_KEY")
genai.configure(api_key=my_api_key)

In [2]:

from langchain_google_genai import ChatGoogleGenerativeAI, GoogleGenerativeAIEmbeddings
import warnings 
  
chat = ChatGoogleGenerativeAI(model= "gemini-1.5-flash", temperature = 0)
# Settings the warnings to be ignored 
warnings.filterwarnings('ignore') 

In [3]:
import base64
import os

from langchain_core.messages import HumanMessage


def encode_image(image_path):
    """Getting the base64 string"""
    with open(image_path, "rb") as image_file:
        return base64.b64encode(image_file.read()).decode("utf-8")


def image_summarize(img_base64, prompt,chat):
    """Make image summary"""
    
    msg = chat.invoke(
        [
            HumanMessage(
                content=[
                    {"type": "text", "text": prompt},
                    {
                        "type": "image_url",
                        "image_url": {"url": f"data:image/png;base64,{img_base64}"},
                    },
                ]
            )
        ]
    )
    return msg.content


def generate_img_summaries(path):
    """
    Generate summaries and base64 encoded strings for images
    path: Path to list of .png files extracted by Unstructured
    """

    # Store base64 encoded images
    img_base64_list = []

    # Store image summaries
    image_summaries = []

    # Prompt
    prompt = """You are an assistant tasked with summarizing database schemas for retrieval. 
    These summaries will be embedded and used to retrieve the raw image. \ 
    Start with the Name of the Schema and provide     a concise summary of the image that is well optimized for retrieval."""

    # Apply to images
    for img_file in sorted(os.listdir(path)):
        if img_file.endswith(".png"):
            img_path = os.path.join(path, img_file)
            base64_image = encode_image(img_path)
            img_base64_list.append(base64_image)
            image_summaries.append(image_summarize(base64_image, prompt,chat))

    return img_base64_list, image_summaries


# Image summaries


fpath = "schemas/"
img_base64_list, image_summaries = generate_img_summaries(fpath)

In [4]:
for index, image_summary in enumerate(image_summaries):
    print(index,"-", image_summary)


0 - ## Database Schema: Movie Rental Database

This database schema represents a movie rental system. It includes information about movies, actors, categories, languages, stores, customers, rentals, payments, and addresses. The schema is designed to track movie inventory, customer information, rental transactions, and payment details. It also includes information about staff members, store locations, and addresses. 

1 - ## regions
This schema contains information about regions, including their ID and name.

## countries
This schema contains information about countries, including their ID, name, and region ID.

## locations
This schema contains information about locations, including their ID, street address, postal code, city, state/province, and country ID.

## departments
This schema contains information about departments, including their ID, name, manager ID, and location ID.

## employees
This schema contains information about employees, including their ID, first name, last name, e

In [5]:
import uuid

from langchain.retrievers.multi_vector import MultiVectorRetriever
from langchain.storage import InMemoryStore
from langchain_community.vectorstores import Chroma
from langchain_core.documents import Document
from langchain_core.output_parsers import StrOutputParser


def create_multi_vector_retriever(
    vectorstore, image_summaries, images
):
    """
    Create retriever that indexes summaries, but returns raw images or texts
    """

    # Initialize the storage layer
    store = InMemoryStore()
    id_key = "doc_id"

    # Create the multi-vector retriever
    retriever = MultiVectorRetriever(
        vectorstore=vectorstore,
        docstore=store,
        id_key=id_key,
    )

    # Helper function to add documents to the vectorstore and docstore
    def add_documents(retriever, doc_summaries, doc_contents):
        doc_ids = [str(uuid.uuid4()) for _ in doc_contents]
        summary_docs = [
            Document(page_content=s, metadata={id_key: doc_ids[i]})
            for i, s in enumerate(doc_summaries)
        ]
        retriever.vectorstore.add_documents(summary_docs)
        retriever.docstore.mset(list(zip(doc_ids, doc_contents)))

    if image_summaries:
        add_documents(retriever, image_summaries, images)

    return retriever

embedding = GoogleGenerativeAIEmbeddings(model="models/text-embedding-004")
# The vectorstore to use to index the summaries
vectorstore = Chroma(
    collection_name="mm_rag_cj_blog", embedding_function=embedding
)

# Create retriever
retriever_multi_vector_img = create_multi_vector_retriever(
    vectorstore,
    image_summaries,
    img_base64_list,
)

In [9]:
import io
import re

from IPython.display import HTML, display
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
from PIL import Image


def plt_img_base64(img_base64):
    """Disply base64 encoded string as image"""
    # Create an HTML img tag with the base64 string as the source
    image_html = f'<img src="data:image/jpeg;base64,{img_base64}" />'
    # Display the image by rendering the HTML
    display(HTML(image_html))


def looks_like_base64(sb):
    """Check if the string looks like base64"""
    return re.match("^[A-Za-z0-9+/]+[=]{0,2}$", sb) is not None


def is_image_data(b64data):
    """
    Check if the base64 data is an image by looking at the start of the data
    """
    image_signatures = {
        b"\xff\xd8\xff": "jpg",
        b"\x89\x50\x4e\x47\x0d\x0a\x1a\x0a": "png",
        b"\x47\x49\x46\x38": "gif",
        b"\x52\x49\x46\x46": "webp",
    }
    try:
        header = base64.b64decode(b64data)[:8]  # Decode and get the first 8 bytes
        for sig, format in image_signatures.items():
            if header.startswith(sig):
                return True
        return False
    except Exception:
        return False


def resize_base64_image(base64_string, size=(128, 128)):
    """
    Resize an image encoded as a Base64 string
    """
    # Decode the Base64 string
    img_data = base64.b64decode(base64_string)
    img = Image.open(io.BytesIO(img_data))

    # Resize the image
    resized_img = img.resize(size, Image.LANCZOS)

    # Save the resized image to a bytes buffer
    buffered = io.BytesIO()
    resized_img.save(buffered, format=img.format)

    # Encode the resized image to Base64
    return base64.b64encode(buffered.getvalue()).decode("utf-8")


def split_image_text_types(docs):
    """
    Split base64-encoded images and texts
    """
    b64_images = []
    texts = []
    for doc in docs:
        # Check if the document is of type Document and extract page_content if so
        if isinstance(doc, Document):
            doc = doc.page_content
        if looks_like_base64(doc) and is_image_data(doc):
            doc = resize_base64_image(doc, size=(1300, 600))
            b64_images.append(doc)
        else:
            texts.append(doc)
    return {"images": b64_images, "texts": texts}


def img_prompt_func(data_dict):
    """
    Join the context into a single string
    """
    formatted_texts = "\n".join(data_dict["context"]["texts"])
    messages = []

    # Adding image(s) to the messages if present
    if data_dict["context"]["images"]:
        for image in data_dict["context"]["images"]:
            image_message = {
                "type": "image_url",
                "image_url": {"url": f"data:image/jpeg;base64,{image}"},
            }
            messages.append(image_message)

    # Adding the text for analysis
    text_message = {
        "type": "text",
        "text": (
            """You are a technical database schema and sql query generation expert. 
            You will be provided with various types of schemas in various file and image formats. 
            Your task is to generate concise, accurate answers without adding any information you are not confident about. 
            Do not be verbose - answer only what is asked in the question. 
            While generating queries follow the standard convention of each database type like Oracle, hive, mysql etc.


                    Important Guidelines:
                    * Prioritize accuracy:  If you are uncertain about any detail, state "Unknown" or "Not visible" instead of guessing.
                    * Avoid hallucinations: Do not add information that is not directly supported by the image.
                    * Be specific: Use precise language to describe shapes, colors, textures, and any interactions depicted.
                    * Consider context: If the image is a screenshot or contains text, incorporate that information into your description.
                    """
            f"User-provided question: {data_dict['question']}\n\n"
            "Text and / or tables:\n"
            f"{formatted_texts}"
        ),
    }
    messages.append(text_message)
    return [HumanMessage(content=messages)]


def multi_modal_rag_chain(retriever,chat):
    """
    Multi-modal RAG chain
    """

 
        # RAG pipeline
    chain = (
        {
            "context": retriever | RunnableLambda(split_image_text_types),
            "question": RunnablePassthrough(),
        }
        | RunnableLambda(img_prompt_func)
        | chat
        | StrOutputParser()
    )

    return chain

# Create RAG chain
query_generation_chain= multi_modal_rag_chain(retriever_multi_vector_img,chat)


### MetaData Test

In [10]:
prompt = "data types and data lengths of all the columns in jobs table "
result = query_generation_chain.invoke(prompt)
print(result)

Number of requested results 4 is greater than number of elements in index 2, updating n_results = 2


| Column Name | Data Type | Length |
|---|---|---|
| JOB_ID | VARCHAR | 10 |
| JOB_TITLE | VARCHAR | 35 |
| MIN_SALARY | NUMBER |  |
| MAX_SALARY | NUMBER |  |


In [11]:
prompt = """Enumerate all the tables in bulleted points mentioned in dvd rental schema"""
result = query_generation_chain.invoke(prompt)
print(result)

Number of requested results 4 is greater than number of elements in index 2, updating n_results = 2


- film_category
- category
- film
- language
- inventory
- rental
- customer
- store
- staff
- address
- city
- country
- payment
- actor
- film_actor
- film_text
- film_list


In [66]:
prompt = """Generate a query to update the datatype of the column COMMISSION_PCT to decimal in oracle"""
result = query_generation_chain.invoke(prompt)
print(result)

Number of requested results 4 is greater than number of elements in index 2, updating n_results = 2


```sql
ALTER TABLE employees MODIFY (COMMISSION_PCT DECIMAL);
```


### Fake schema test

In [12]:

## Fake Prompt
prompt = "Write sql queries to join langchain and haystack tables"
result = query_generation_chain.invoke(prompt)
print(result)

Number of requested results 4 is greater than number of elements in index 2, updating n_results = 2


Unknown. The provided image does not contain tables named 'langchain' and 'haystack'.


### DDL Queries

In [63]:
prompt = "Generate a query to create and delete employees schema"
result = query_generation_chain.invoke(prompt)
print(result)

Number of requested results 4 is greater than number of elements in index 2, updating n_results = 2


```sql
CREATE TABLE employees (
    EMPLOYEE_ID NUMBER,
    FIRST_NAME VARCHAR2(20),
    LAST_NAME VARCHAR2(25),
    EMAIL VARCHAR2(25),
    PHONE_NUMBER VARCHAR2(20),
    HIRE_DATE DATE,
    JOB_ID VARCHAR2(10),
    SALARY NUMBER,
    COMMISSION_PCT NUMBER,
    MANAGER_ID NUMBER,
    DEPARTMENT_ID NUMBER
);

DROP TABLE employees;
```


In [64]:
prompt = "Generate a query to rename the salary column to Total_salary"
result = query_generation_chain.invoke(prompt)
print(result)

Number of requested results 4 is greater than number of elements in index 2, updating n_results = 2


```sql
ALTER TABLE employees
RENAME COLUMN salary TO Total_salary;
```


### DML

In [67]:
prompt = "Generate a query each to delete and insert a row in customer table"
result = query_generation_chain.invoke(prompt)
print(result)

Number of requested results 4 is greater than number of elements in index 2, updating n_results = 2


```sql
-- Delete a row from the customer table
DELETE FROM customer WHERE customer_id = 1;

-- Insert a row into the customer table
INSERT INTO customer (customer_id, store_id, first_name, last_name, email, address_id, active, last_update_timestamp) 
VALUES (1, 1, 'John', 'Doe', 'john.doe@example.com', 1, 1, NOW());
```


### Joins

In [13]:

prompt = "Write a join query to join departments and jobs table - if they do not have relationship - frame the query using intermediate tables"
# Run RAG chain
result = query_generation_chain.invoke(prompt)
print(result)

Number of requested results 4 is greater than number of elements in index 2, updating n_results = 2


```sql
SELECT *
FROM departments d
JOIN employees e ON d.department_id = e.department_id
JOIN jobs j ON e.job_id = j.job_id;
```


In [14]:
prompt = "Write a full outer join query to join employees and departments table?"
result = query_generation_chain.invoke(prompt)
print(result)

Number of requested results 4 is greater than number of elements in index 2, updating n_results = 2


```sql
SELECT *
FROM employees e
FULL OUTER JOIN departments d ON e.department_id = d.department_id;
```


### Case Statement

In [15]:
# Case statements

prompt = """Write a sql query for the following situation : 
 if the employees have managers then create a temporary column called 'M_YES' 
 and if the employess do not have mangers then 'M_NO'"""

# Run RAG chain
result = query_generation_chain.invoke(prompt)
print(result)

Number of requested results 4 is greater than number of elements in index 2, updating n_results = 2


```sql
SELECT 
    *,
    CASE
        WHEN MANAGER_ID IS NOT NULL THEN 'M_YES'
        ELSE 'M_NO'
    END AS M_YES_NO
FROM employees;
```


### Analytical Query 

In [16]:


# Rank
prompt = prompt = """ 1. Order the employees in the order of the salary they earn - 
           2. start with the one getting the highest 
           3. rank them in numbers 
           """

# Run RAG chain
result = query_generation_chain.invoke(prompt)
print(result)

Number of requested results 4 is greater than number of elements in index 2, updating n_results = 2


```sql
SELECT EMPLOYEE_ID, FIRST_NAME, LAST_NAME, SALARY, DENSE_RANK() OVER (ORDER BY SALARY DESC) AS SALARY_RANK
FROM employees
ORDER BY SALARY_RANK;
```


In [18]:
# Dense Rank
prompt = """ 1. Order the employees in the order of the salary they earn - 
           2. start with the one getting the highest 
           3. rank them in numbers 
           and 
           4.if two employees have same salary provide them the same number"""
result = query_generation_chain.invoke(prompt)
print(result)

Number of requested results 4 is greater than number of elements in index 2, updating n_results = 2


```sql
SELECT EMPLOYEE_ID,
       FIRST_NAME,
       LAST_NAME,
       SALARY,
       DENSE_RANK() OVER (ORDER BY SALARY DESC) AS SALARY_RANK
FROM employees
ORDER BY SALARY_RANK;
```


In [19]:

prompt = """Calculate the lowest number of rentals for each store. Use analytical functions if feasible"""

# Run RAG chain
result = query_generation_chain.invoke(prompt)
print(result)

Number of requested results 4 is greater than number of elements in index 2, updating n_results = 2


```sql
SELECT store_id, MIN(rental_id) AS min_rental_id
FROM rental
GROUP BY store_id;
```


In [20]:
## Row Number
prompt = """Assign a row number to the output set when the output set is organized by store in rental table and ordered in descending order."""

result = query_generation_chain.invoke(prompt)
print(result)

Number of requested results 4 is greater than number of elements in index 2, updating n_results = 2


```sql
WITH RentalStore AS (
  SELECT
    r.store_id,
    ROW_NUMBER() OVER (PARTITION BY r.store_id ORDER BY r.rental_date DESC) AS row_num
  FROM rental AS r
)
SELECT
  rs.store_id,
  rs.row_num
FROM RentalStore AS rs;
```


### Functions and Procedures

In [21]:
prompt = """Write a sql function to add a new employee in the employees table in oracle database"""

# Run RAG chain
result = query_generation_chain.invoke(prompt)
print(result)

Number of requested results 4 is greater than number of elements in index 2, updating n_results = 2


```sql
CREATE OR REPLACE FUNCTION add_employee (
    p_first_name VARCHAR2,
    p_last_name VARCHAR2,
    p_email VARCHAR2,
    p_phone_number VARCHAR2,
    p_hire_date DATE,
    p_job_id NUMBER,
    p_salary NUMBER,
    p_commission_pct NUMBER,
    p_manager_id NUMBER,
    p_department_id NUMBER
)
RETURN NUMBER IS
    v_employee_id NUMBER;
BEGIN
    INSERT INTO employees (
        first_name,
        last_name,
        email,
        phone_number,
        hire_date,
        job_id,
        salary,
        commission_pct,
        manager_id,
        department_id
    ) VALUES (
        p_first_name,
        p_last_name,
        p_email,
        p_phone_number,
        p_hire_date,
        p_job_id,
        p_salary,
        p_commission_pct,
        p_manager_id,
        p_department_id
    );
    SELECT employee_id INTO v_employee_id FROM employees WHERE first_name = p_first_name AND last_name = p_last_name;
    RETURN v_employee_id;
END;
/
```


In [22]:
prompt = """Write a stored procedure to add a new rental customer information in the Mysql database"""

# Run RAG chain
result = query_generation_chain.invoke(prompt)
print(result)

Number of requested results 4 is greater than number of elements in index 2, updating n_results = 2


```sql
-- Stored procedure to add a new rental customer
DELIMITER //
CREATE PROCEDURE add_rental_customer(
    IN customer_name VARCHAR(45),
    IN first_name VARCHAR(45),
    IN last_name VARCHAR(45),
    IN email VARCHAR(50),
    IN address VARCHAR(50),
    IN active TINYINT(1),
    IN active_date DATE
)
BEGIN
    INSERT INTO customer (
        cust_name,
        first_name,
        last_name,
        email,
        address,
        active,
        active_date
    ) VALUES (
        customer_name,
        first_name,
        last_name,
        email,
        address,
        active,
        active_date
    );
END //
DELIMITER ;
```


### Specifying Datasources

In [56]:
prompt = """Write queries to retrieve 100 records from employees database for oracle, mysql, mongodb, hive, databases  """

# Run RAG chain
result = query_generation_chain.invoke(prompt)
print(result)

Number of requested results 4 is greater than number of elements in index 2, updating n_results = 2


```sql
-- Oracle
SELECT * FROM employees WHERE ROWNUM <= 100;

-- MySQL
SELECT * FROM employees LIMIT 100;

-- MongoDB
db.employees.find().limit(100);

-- Hive
SELECT * FROM employees LIMIT 100;
```


### Generating Codes

In [61]:
prompt = """Write queries to retrieve 100 records from employees database using pyspark for oracle database: 
          1. Connect to the Database 
        2. Execute query 
        3. Retrieve results  """

# Run RAG chain
result = query_generation_chain.invoke(prompt)
print(result)

Number of requested results 4 is greater than number of elements in index 2, updating n_results = 2


```python
from pyspark.sql import SparkSession
from pyspark.sql.types import *
from pyspark.sql.functions import *

# Create a SparkSession
spark = SparkSession.builder.appName("OracleQuery").getOrCreate()

# Define the connection parameters
url = "jdbc:oracle:thin:@<hostname>:<port>/<database_name>"
driver = "oracle.jdbc.driver.OracleDriver"
user = "<username>"
password = "<password>"

# Create a DataFrame from the Oracle table
df = spark.read.format("jdbc") \
    .option("url", url) \
    .option("driver", driver) \
    .option("user", user) \
    .option("password", password) \
    .option("dbtable", "employees") \
    .load()

# Retrieve the first 100 records
df.limit(100).show()

# Close the SparkSession
spark.stop()
```


In [60]:
prompt = """Write queries to retrieve 100 records from employees database using pandas python for oracle database: 
          1. Connect to the Database 
        2. Execute query 
        3. Retrieve results  """

# Run RAG chain
result = query_generation_chain.invoke(prompt)
print(result)

Number of requested results 4 is greater than number of elements in index 2, updating n_results = 2


```python
import cx_Oracle
import pandas as pd

# Replace with your Oracle connection details
conn_string = 'username/password@host:port/service_name'

# Connect to the database
conn = cx_Oracle.connect(conn_string)

# Create a cursor object
cursor = conn.cursor()

# Execute the query
cursor.execute("SELECT * FROM employees WHERE ROWNUM <= 100")

# Fetch the results
results = cursor.fetchall()

# Create a Pandas DataFrame from the results
df = pd.DataFrame(results, columns=[desc[0] for desc in cursor.description])

# Print the DataFrame
print(df)

# Close the cursor and connection
cursor.close()
conn.close()
```


### Test Data Generation

In [69]:
prompt = """Generate 50 rows of data in .csv format for the employees table  """

# Run RAG chain
result = query_generation_chain.invoke(prompt)
print(result)

Number of requested results 4 is greater than number of elements in index 2, updating n_results = 2


```csv
EMPLOYEE_ID,FIRST_NAME,LAST_NAME,EMAIL,PHONE_NUMBER,HIRE_DATE,JOB_ID,SALARY,COMMISSION_PCT,MANAGER_ID,DEPARTMENT_ID
1,Steven,King,SKING,515.123.4567,1987-06-17,AD_PRES,24000,0,NULL,90
2,Neena,Kochhar,NKOCHHAR,515.123.4568,1989-09-21,AD_VP,17000,0,1,90
3,Lex,De Haan,LDEHAAN,515.123.4569,1993-01-13,AD_VP,17000,0,1,90
4,Alexander,Hunold,AHUNOLD,590.423.4567,1990-03-03,IT_PROG,9000,0,2,60
5,Bruce,Ernst,BERNST,590.423.4568,1991-05-21,IT_PROG,6000,0,2,60
6,David,Austin,DAUSTIN,590.423.4569,1997-06-25,IT_PROG,6000,0,2,60
7,Valli,Pataballa,VPATABAL,590.423.4570,1998-02-05,IT_PROG,4800,0,2,60
8,Diana,Lorentz,DLORENTZ,590.423.5567,1999-02-07,IT_PROG,4200,0,2,60
9,Nancy,Greenberg,NGREENBE,515.124.4569,1989-12-08,FI_MGR,12008,0,3,100
10,Daniel,Faviet,DFAVIET,515.124.4169,1994-08-16,FI_ACCOUNT,9000,0,9,100
11,John,Chen,JCHEN,515.124.4269,1997-09-20,FI_ACCOUNT,8200,0,9,100
12,Ismael,Sciarra,ISCIARRA,515.124.4369,1997-07-30,FI_ACCOUNT,7700,0,9,100
13,Jose,Manuel,JMANU,515.124.4469,1999-09-07,F