In [5]:
import os

from dotenv import load_dotenv
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_chroma import Chroma
from langchain_community.document_loaders import Docx2txtLoader
from langchain.prompts import PromptTemplate
from langchain.schema import Document
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough


In [6]:
load_dotenv()

True

In [13]:
# uncomment if facing SSL errors
# tiktoken_cache_dir = "tiktoken_cache"
# os.environ["TIKTOKEN_CACHE_DIR"] = tiktoken_cache_dir

In [8]:
openai_api_key = os.environ["OPENAI_API_KEY"]

### Setup LLM

In [9]:
llm = ChatOpenAI(
    api_key=openai_api_key,
    model="gpt-4.1-mini",
    temperature=0,
)

In [10]:
res = llm.invoke("hello")

In [11]:
print(res.content)

Hello! How can I assist you today?


### Setup embeddings model

In [14]:
embeddings = OpenAIEmbeddings(
    api_key=openai_api_key,
    model="text-embedding-3-small",
)

In [15]:
text = "Hello, this is a test for embeddings"

vector = embeddings.embed_query(text)

In [16]:
print(vector)

[0.015600232407450676, -0.015445519238710403, 0.017766214907169342, -0.033753231167793274, -0.019532522186636925, -0.049121394753456116, 0.022291572764515877, 0.03811098262667656, -0.008580127730965614, 0.00013154638872947544, 0.015677589923143387, -0.03058161400258541, -0.012486632913351059, -0.03867826238274574, -0.012280348688364029, 0.03481043502688408, -0.02074444107711315, 0.023322992026805878, -0.00726506719365716, 0.006124058272689581, 0.01429806463420391, -0.03042690083384514, 0.006394806317985058, 0.048631470650434494, 0.010765450075268745, -0.0049185859970748425, 0.02557922527194023, 0.05195780098438263, 0.04156624153256416, -0.04022539407014847, 0.01793382130563259, -0.033340662717819214, -0.005576116032898426, -0.028879769146442413, -0.014942701905965805, 0.031561464071273804, -0.013588963076472282, 0.030452685430645943, -0.03689906373620033, -0.002017715945839882, 0.014169136993587017, -0.0030523596797138453, -0.035403504967689514, 0.01864292286336422, -0.0121578676626086

### Load Data

In [17]:
def load_documents_with_docx2txt(folder_path):
    documents = []
    for filename in os.listdir(folder_path):
        file_path = os.path.join(folder_path, filename)
        loader = Docx2txtLoader(file_path)
        documents.extend(loader.load())
    
    return documents

In [19]:
folder_path = "data"
documents = load_documents_with_docx2txt(folder_path)

In [22]:
# print(documents[0].page_content)

### Store vectors in Chroma db

In [23]:
texts = [doc.page_content for doc in documents]

In [24]:
vectorstore = Chroma.from_texts(texts=texts, embedding=embeddings)

In [26]:
query = "How do I find out which artist sang a particular song?"
search_results = vectorstore.similarity_search(query, k = 2)
print(search_results)

[Document(id='f9b6a852-5ec8-4402-84e1-e7ef945e57d2', metadata={}, page_content='Table Description:\nStores detailed information about individual songs, including their titles, durations, and associations with artists and albums.\n\nAttribute Description:\n\nsong_id – Unique identifier for each song.\n\ntitle – Title of the song.\n\nartist_id – Foreign key referencing the artist who performed the song.\n\nalbum – Name of the album the song belongs to.\n\nduration – Length of the song in seconds.\n\nrelease_date – Official release date of the song.\n\ngenre – Musical genre of the song.\n\nTable Schema:\nsongs(song_id INT PRIMARY KEY, title VARCHAR(255), artist_id INT, album VARCHAR(255), duration INT, release_date DATE, genre VARCHAR(100), FOREIGN KEY (artist_id) REFERENCES artists(artist_id))'), Document(id='bb9ffc28-f218-4a7e-8a14-27160962b95d', metadata={}, page_content='Table Description:\nContains information about musical artists, including their names, genres, and related metadata

In [27]:
retriever = vectorstore.as_retriever(search_kwargs={"k": 2})

### Setup prompt template

In [28]:
template = """
You are an expert SQL assistant. Based on the following retrieved document information, generate the SQL query to answer the user's question.

Database Information:
{context}

User Query: 
{query}

Your response should include:
1. Explanation of which tables to join and why.
2. The SQL query.

Response:
"""

In [29]:
prompt_template = PromptTemplate.from_template(template)

In [31]:
res = prompt_template.invoke({"context": "this is some context", "query": "this is the user's query"})
print(res)

text="\nYou are an expert SQL assistant. Based on the following retrieved document information, generate the SQL query to answer the user's question.\n\nDatabase Information:\nthis is some context\n\nUser Query: \nthis is the user's query\n\nYour response should include:\n1. Explanation of which tables to join and why.\n2. The SQL query.\n\nResponse:\n"


### Create RAG chain

In [32]:
def process_docs(docs: list[Document]):
    return "\n\n".join(doc.page_content for doc in docs)

In [33]:
RAG_chain = (
    {
        "context": retriever | process_docs,
        "query": RunnablePassthrough()
    }
    | prompt_template
    | llm
    | StrOutputParser()
)

In [34]:
query = "How do I find out which artist sang a particular song?"

In [35]:
res = RAG_chain.invoke(query)

In [36]:
print(res)

To find out which artist sang a particular song, you need to join the **songs** table with the **artists** table. This is because the **songs** table contains the song details including the `artist_id`, and the **artists** table contains the artist's name and other related information. By joining these tables on the `artist_id`, you can retrieve the artist's name for a given song title.

### Explanation:
- Join **songs** and **artists** on `songs.artist_id = artists.artist_id`.
- Filter the results by the song title to find the specific song.
- Select the artist's name along with the song title for clarity.

### SQL Query:
```sql
SELECT s.title, a.artist_name
FROM songs s
JOIN artists a ON s.artist_id = a.artist_id
WHERE s.title = 'Your Song Title';
```

Replace `'Your Song Title'` with the actual title of the song you want to look up.
