In [2]:
# Uncomment and run this cell 1st time to install dependencies
# %pip install -qU langchain langchain-openai langchain-chroma chromadb sentence-transformers

import os, getpass

# Clear any env OPENAI_API_KEY (for session safety)
os.environ.pop("OPENAI_API_KEY", None)
try:
    os.unsetenv("OPENAI_API_KEY")
except Exception:
    pass

# Securely input your OpenAI API key interactively
os.environ["OPENAI_API_KEY"] = getpass.getpass("Enter your OpenAI API key: ")

from langchain_openai import OpenAIEmbeddings
import chromadb

# Initialize Chroma persistent client and delete/create collection
persist_path = "./chroma_storage"
client = chromadb.PersistentClient(path=persist_path)

collection_name = "code_snippet_collection"

try:
    client.delete_collection(collection_name)
    print(f"Deleted collection '{collection_name}'")
except Exception:
    print(f"No existing collection named '{collection_name}' to delete")

collection = client.create_collection(
    name=collection_name,
    metadata={"hnsw:space": "cosine"}  # cosine distance metric
)
print(f"Created new collection '{collection_name}' at {persist_path}")

# Initialize OpenAI embeddings (model=text-embedding-3-small)
embeddings = OpenAIEmbeddings(
    model="text-embedding-3-small",
    request_timeout=60,
    max_retries=8,
)
print("Initialized OpenAI embeddings")

Enter your OpenAI API key:  ········


Deleted collection 'code_snippet_collection'
Created new collection 'code_snippet_collection' at ./chroma_storage
Initialized OpenAI embeddings


In [3]:
import hashlib
from typing import List

def load_file_chunks(filepath: str, chunk_size=500) -> List[str]:
    with open(filepath, "r", encoding="utf-8") as f:
        content = f.read()
    lines = content.splitlines()
    chunks = []
    cur_chunk = []
    cur_len = 0
    for line in lines:
        cur_chunk.append(line)
        cur_len += len(line)
        if cur_len > chunk_size:
            chunks.append("\n".join(cur_chunk))
            cur_chunk = []
            cur_len = 0
    if cur_chunk:
        chunks.append("\n".join(cur_chunk))
    return chunks

# Load your actual code and lineage output files here
pyspark_code_file = "./lineage_test.py"
lineage_output_file = "./lineage_output_clean.log"

code_chunks = load_file_chunks(pyspark_code_file)
lineage_chunks = load_file_chunks(lineage_output_file)

print(f"Loaded {len(code_chunks)} code chunks")
print(f"Loaded {len(lineage_chunks)} lineage chunks")

Loaded 6 code chunks
Loaded 2 lineage chunks


In [4]:
# Helper to create stable ids for documents
def make_id(text: str) -> str:
    return hashlib.sha256(text.encode("utf-8")).hexdigest()

def upsert_snippets(snippets: List[str]) -> int:
    ids = [make_id(s) for s in snippets]
    vectors = embeddings.embed_documents(snippets)
    collection.upsert(documents=snippets, embeddings=vectors, ids=ids)
    print(f"Upserted {len(ids)} snippets")
    return len(ids)

# Only index code snippets (lineage will be dynamically summarized later)
num_upserted = upsert_snippets(code_chunks)

Upserted 6 snippets


In [9]:
#!uv pip install langchain[community]

In [5]:
from langchain_openai import ChatOpenAI
from langchain.chains.summarize import load_summarize_chain
from langchain.docstore.document import Document

llm = ChatOpenAI(model_name="gpt-4", temperature=0)

# Use refine chain to handle longer lineage text progressively
summarization_chain = load_summarize_chain(llm, chain_type="refine")

# Join lineage chunks into one text block
lineage_text = "\n".join(lineage_chunks)
lineage_doc = Document(page_content=lineage_text)

# Define a clear, focused prompt for summarization
prompt_template = """
You are a data engineer assistant. Given the user's question and a detailed data lineage JSON, generate a concise natural language summary focused on the transformations relevant to the question.

User query: {user_query}

Lineage details:
{lineage_text}

Instructions:
- Explain key columns involved (e.g., computed columns like discounted_amount).
- Describe the transformations, joins, and filters affecting those columns.
- Mention any aggregations or final outputs.
- Avoid technical metadata or runtime info.
- Keep summary under 800 characters.
"""

# Wrap LangChain chain to include a dynamic prompt
from langchain.prompts import PromptTemplate

prompt = PromptTemplate(
    input_variables=["user_query", "lineage_text"],
    template=prompt_template
)

from langchain.chains import LLMChain

summary_chain = LLMChain(llm=llm, prompt=prompt, output_key="summary")

# Run chain on lineage text and current user question (replace with actual user query)
current_user_query = "Explain how discounted_amount is calculated"

summary_result = summary_chain.invoke(
    {"user_query": current_user_query, "lineage_text": lineage_text}
)

lineage_summary_text = summary_result.get("summary", "").strip()

print("=== Focused Lineage Summary ===")
print(lineage_summary_text)

  summary_chain = LLMChain(llm=llm, prompt=prompt, output_key="summary")


=== Focused Lineage Summary ===
The 'discounted_amount' is calculated through a series of transformations. Initially, two datasets are joined on the 'order_id' attribute. Then, a new attribute 'discounted_amount' is created using a 'CaseWhen' function. This function checks if the 'tier' attribute is equal to 'Premium'. If it is, it multiplies the 'amount' attribute by 0.9, effectively applying a 10% discount. If the 'tier' is not 'Premium', the original 'amount' is retained. This operation results in a new column 'discounted_amount'. The data is then grouped by 'product' and 'tier' attributes, and several aggregations are performed, including the sum of 'discounted_amount' (named 'tier_revenue'), the average of 'discounted_amount' (named 'avg_tier_revenue'), and the sum of 'quantity' (named 'tier_quantity'). The final output is sorted in ascending order by 'product' and 'tier'.


In [6]:
def query_codebase_with_focused_summary(user_query: str, lineage_summary: str, top_k: int = 5):
    combined_query = user_query + "\nLineage summary:\n" + lineage_summary
    query_vector = embeddings.embed_query(combined_query)
    
    results = collection.query(
        query_embeddings=[query_vector],
        n_results=top_k,
        include=["documents", "distances"]
    )
    
    print(f"\nCombined query (partial):\n{combined_query[:300]}...\n")
    print(f"Top {top_k} results:")
    for rank, (doc, dist) in enumerate(zip(results["documents"][0], results["distances"][0]), start=1):
        print(f"Rank {rank} - Distance: {dist:.4f}")
        print(doc)
        print("-" * 40)

# Example usage:
user_query = "Explain how discounted_amount is calculated"
query_codebase_with_focused_summary(user_query, lineage_summary_text)


Combined query (partial):
Explain how discounted_amount is calculated
Lineage summary:
The 'discounted_amount' is calculated through a series of transformations. Initially, two datasets are joined on the 'order_id' attribute. Then, a new attribute 'discounted_amount' is created using a 'CaseWhen' function. This function chec...

Top 5 results:
Rank 1 - Distance: 0.4485
    )

# Transformation 2: Customer Enrichment with Discounts
enriched_df = sales_df.join(customer_df, "order_id", "inner") \
    .withColumn(
        "discounted_amount",
        when(col("tier") == "Premium", col("amount") * 0.9)
        .otherwise(col("amount"))
    )

# Transformation 3: Final Analytics by Product and Tier
final_analytics_df = enriched_df.groupBy("product", "tier") \
    .agg(
        spark_sum("discounted_amount").alias("tier_revenue"),
        avg("discounted_amount").alias("avg_tier_revenue"),
----------------------------------------
Rank 2 - Distance: 0.4840
        spark_sum("quantity").alias("