In [4]:
import json

# Load your preprocessed data
with open('processed_jax_docs/jax_knowledge_base.json', 'r') as f:
    knowledge_base = json.load(f)

# Check a sample entry
sample_doc = knowledge_base[0]
print(f"Title: {sample_doc['title']}")
print(f"Content snippet: {sample_doc['content']}...")
print(f"Code blocks: {len(sample_doc['code_blocks'])}")

Title: Building from source — JAX  documentation
Content snippet: Building from source — JAX documentation Skip to main content .md Building from source Contents Building from source # First, obtain the JAX source code: 
```python
git clone https://github.com/jax-ml/jax
cd jax
```
```python
python build/build.py build --wheels=jaxlib --verbose
pip install dist/*.whl  # installs jaxlib (includes XLA)
```
 To build a wheel for a version of Python different from your current system installation pass --python_version flag to the build command: 
```python
python build/build.py build --wheels=jaxlib --python_version=3.12 --verbose
```
 The rest of this document assumes that you are building for Python version matching your current system installation. If you need to build for a different version, simply append --python_version=<py version> flag every time you call python build/build.py . Note, the Bazel build will always use a hermetic Python installation regardless of whether the --python_v

In [2]:
!pip install pinecone langchain langchain_community

Collecting langchain
  Downloading langchain-0.3.22-py3-none-any.whl.metadata (7.8 kB)
Collecting langchain_community
  Downloading langchain_community-0.3.20-py3-none-any.whl.metadata (2.4 kB)
Collecting langchain-core<1.0.0,>=0.3.49 (from langchain)
  Downloading langchain_core-0.3.49-py3-none-any.whl.metadata (5.9 kB)
Collecting langchain-text-splitters<1.0.0,>=0.3.7 (from langchain)
  Downloading langchain_text_splitters-0.3.7-py3-none-any.whl.metadata (1.9 kB)
Collecting langsmith<0.4,>=0.1.17 (from langchain)
  Downloading langsmith-0.3.22-py3-none-any.whl.metadata (15 kB)
Collecting SQLAlchemy<3,>=1.4 (from langchain)
  Downloading sqlalchemy-2.0.40-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (9.6 kB)
Collecting tenacity!=8.4.0,<10,>=8.1.0 (from langchain_community)
  Downloading tenacity-9.1.2-py3-none-any.whl.metadata (1.2 kB)
Collecting dataclasses-json<0.7,>=0.5.7 (from langchain_community)
  Downloading dataclasses_json-0.6.7-py3-none-any.whl.metadat

In [7]:
!pip install sentence-transformers

Collecting sentence-transformers
  Downloading sentence_transformers-4.0.1-py3-none-any.whl.metadata (13 kB)
Collecting transformers<5.0.0,>=4.41.0 (from sentence-transformers)
  Downloading transformers-4.50.3-py3-none-any.whl.metadata (39 kB)
Collecting safetensors>=0.4.3 (from transformers<5.0.0,>=4.41.0->sentence-transformers)
  Downloading safetensors-0.5.3-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.8 kB)
Downloading sentence_transformers-4.0.1-py3-none-any.whl (340 kB)
Downloading transformers-4.50.3-py3-none-any.whl (10.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.2/10.2 MB[0m [31m28.5 MB/s[0m eta [36m0:00:00[0m [36m0:00:01[0m
[?25hDownloading safetensors-0.5.3-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (471 kB)
Installing collected packages: safetensors, transformers, sentence-transformers
Successfully installed safetensors-0.5.3 sentence-transformers-4.0.1 transformers-4.50.3


In [67]:
!pip install faiss-cpu



In [None]:
import os
from dotenv import load_dotenv
load_dotenv()


In [None]:
import os
import json
from typing import List, Dict
from langchain.vectorstores import FAISS
from langchain.embeddings import OpenAIEmbeddings
from langchain.schema import Document
from langchain.storage import InMemoryStore
from langchain.retrievers import MultiVectorRetriever
from langchain.text_splitter import RecursiveCharacterTextSplitter

class JAXFAISSRetriever:
    def __init__(self, knowledge_base_path: str):
        # Initialize embeddings
        self.embeddings = OpenAIEmbeddings(model="text-embedding-3-small")  
        
        # Load knowledge base
        self.knowledge_base = self._load_knowledge_base(knowledge_base_path)
        
        # Initialize vector stores
        self.vectorstore, self.docstore = self._create_vector_stores()
        
        # Create retriever
        self.retriever = MultiVectorRetriever(
            vectorstore=self.vectorstore,
            docstore=self.docstore,
            id_key="doc_id"
        )
    
    def _load_knowledge_base(self, path: str) -> List[Dict]:
        """Load preprocessed knowledge base from JSON file"""
        with open(path, 'r', encoding='utf-8') as f:
            return json.load(f)
    
    def _create_vector_stores(self):
        """Create FAISS vector store and document store"""
        # Prepare documents
        documents = []
        metadatas = []
        ids = []
        
        text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=1000,
            chunk_overlap=200
        )
        
        for doc in self.knowledge_base:
            # Split document content
            splits = text_splitter.split_text(doc['content'])
            for i, split in enumerate(splits):
                doc_id = f"{doc['id']}_{i}"
                documents.append(split)
                metadatas.append({
                    'type': 'documentation',
                    'title': doc['title'],
                    'source': doc['path'],
                    'doc_id': doc['id'],
                    'chunk_id': i
                })
                ids.append(doc_id)
            
            # Add code blocks
            for i, code_block in enumerate(doc['code_blocks']):
                code_id = f"code_{doc['id']}_{i}"
                documents.append(code_block)
                metadatas.append({
                    'type': 'code',
                    'title': doc['title'],
                    'source': doc['path'],
                    'doc_id': doc['id'],
                    'code_block_id': i
                })
                ids.append(code_id)
        
        # Create FAISS vector store
        vectorstore = FAISS.from_texts(
            texts=documents,
            embedding=self.embeddings,
            metadatas=metadatas
        )
        
        # Create document store - CORRECTED IMPLEMENTATION
        docstore = InMemoryStore()
        # Convert to list of tuples as required by mset
        docstore.mset([(doc['id'], doc) for doc in self.knowledge_base])
        
        return vectorstore, docstore
    
    def query(self, question: str, include_code: bool = True, top_k: int = 3):
        """Execute a query with optional code filtering"""
        if include_code:
            docs = self.vectorstore.similarity_search(
                question,
                k=top_k,
                filter=lambda meta: meta.get('type') == 'code'
            )
        else:
            docs = self.vectorstore.similarity_search(
                question,
                k=top_k,
                filter=lambda meta: meta.get('type') == 'documentation'
            )
        
        # Convert to LangChain Document objects
        lc_docs = [
            Document(
                page_content=doc.page_content,
                metadata=doc.metadata
            ) for doc in docs
        ]
        
        # Get full documents for context
        doc_ids = list(set([doc.metadata['doc_id'] for doc in docs]))
        full_docs = [doc for doc in self.docstore.mget(doc_ids) if doc is not None]
        
        return {
            "relevant_chunks": lc_docs,
            "source_documents": full_docs
        }
    
    def save_index(self, path: str):
        """Save FAISS index to disk"""
        self.vectorstore.save_local(path)
    
    @classmethod
    def load_index(cls, path: str, knowledge_base_path: str):
        """Load FAISS index from disk"""
        retriever = cls(knowledge_base_path)
        retriever.vectorstore = FAISS.load_local(
            path,
            retriever.embeddings,
            allow_dangerous_deserialization=True
        )
        return retriever

In [18]:
# Initialize retriever
retriever = JAXFAISSRetriever('processed_jax_docs/jax_knowledge_base.json')

# Save index for later use
retriever.save_index("jax_faiss_index")

# Query examples
results = retriever.query(
    "How to use jax.jit with a neural network?",
    include_code=True
)

# Display results
for chunk in results["relevant_chunks"]:
    print(f"\nFrom {chunk.metadata['title']} ({chunk.metadata['type']}):")
    print(chunk.page_content[:200] + "...")


From Distributed arrays and automatic parallelization — JAX  documentation (code):
import jax
import jax.numpy as jnp...

From Control autodiff’s saved values with jax.checkpoint (aka jax.remat) — JAX  documentation (code):
import jax
import jax.numpy as jnp...

From jax.named_scope — JAX  documentation (code):
>>> import jax
>>>
>>> @jax.jit
... def layer(w, x):
...   with jax.named_scope("dot_product"):
...     logits = w.dot(x)
...   with jax.named_scope("activation"):
...     return jax.nn.relu(logits)...


In [19]:
results = retriever.query(
    "How to use jax.vmap with multiple arguments?",
    include_code=True
)

# Display results
for chunk in results["relevant_chunks"]:
    print(f"\nFrom {chunk.metadata['title']} ({chunk.metadata['type']}):")
    print(chunk.page_content[:200] + "...")


From External callbacks — JAX  documentation (code):
x = jnp.arange(5.0)
jax.vmap(f)(x);...

From Introduction to debugging — JAX  documentation (code):
x = jnp.arange(5.0)
jax.vmap(f)(x);...

From Pseudorandom numbers — JAX  documentation (code):
import jax
print("vectorized:", jax.vmap(random.normal)(subkeys))...


In [20]:
from langchain.prompts import PromptTemplate
from langchain.chains import RetrievalQA
from langchain.chat_models import ChatOpenAI


llm = ChatOpenAI(model="gpt-4", temperature=0)

# Create a comprehensive prompt template that encourages code examples
prompt_template = """You are an expert in JAX and machine learning. Use the following pieces of context to answer the question at the end. 
If the question involves code or implementation details, always provide a complete, executable code example using JAX.

Context information:
-------------------
{context}

Question: {question}

When providing code examples:
1. Use proper JAX imports (jax.numpy as jnp, jax, etc.)
2. Include type annotations where appropriate
3. Add brief comments explaining key parts
4. Ensure the code is syntactically correct

Answer in the following format:
[Explanation] Provide a clear explanation of the concept or solution
[Code Example] (if applicable):
```python
# Your code here
[Additional Notes] Any caveats or important considerations"""

QA_PROMPT = PromptTemplate(
    template=prompt_template,
    input_variables=["context", "question"]
)


qa_chain = RetrievalQA.from_chain_type(
    llm=llm,
    chain_type="stuff",
    retriever=retriever.vectorstore.as_retriever(),  
    chain_type_kwargs={"prompt": QA_PROMPT},
    return_source_documents=True
)


question = "How to use jax.vmap with multiple arguments?"
result = qa_chain({"query": question})

print("Answer:")
print(result["result"])
print("\nSources:")
for doc in result["source_documents"]:
    print(f"- {doc.metadata['title']} (chunk {doc.metadata.get('chunk_id', 'code')})")

Answer:
[Explanation]
The `jax.vmap` function in JAX is used to vectorize or batch computations over one or more arguments of a function. It is a powerful tool for parallelizing computations over a batch dimension. 

When using `jax.vmap` with multiple arguments, you need to specify the `in_axes` parameter to indicate which axes of the input arguments should be mapped over. The `in_axes` parameter can be a tuple, list, or dictionary, depending on the structure of the input arguments. 

If the function you want to vectorize takes multiple arguments, you can pass them as a tuple to `jax.vmap`. The `in_axes` parameter should also be a tuple of the same length, specifying the axis to map for each argument. If an argument should not be mapped over, you can specify its axis as `None`.

[Code Example]
```python
import jax
import jax.numpy as jnp

# Define a function that takes two arguments
def f(x, y):
  return x + y

# Create some data
x = jnp.arange(5.0)
y = jnp.arange(5.0, 10.0)

# Use vm

In [33]:
question = "How does JAX's grad function work for automatic differentiation?"
result = qa_chain({"query": question})

print("Answer:")
print(result["result"])

Answer:
[Explanation]
JAX's `grad` function is used for automatic differentiation. It takes a function as an input and returns a new function that computes the gradient of the input function. The gradient of a function at a certain point is a vector that points in the direction of the greatest rate of increase of the function at that point, and its magnitude is the rate of increase in that direction.

In the context of machine learning, gradients are used to update the parameters of models during training in order to minimize a loss function. The `grad` function in JAX makes it easy to compute these gradients.

One of the powerful features of JAX's `grad` function is that it can be applied to its own output to compute higher-order derivatives. This is because the functions that compute derivatives are themselves differentiable in JAX.

[Code Example]
```python
import jax
import jax.numpy as jnp
from jax import grad

# Define a function
f = lambda x: x**3 + 2*x**2 - 3*x + 1

# Compute t

In [25]:
!pip install streamlit streamlit_chat

Collecting streamlit_chat
  Downloading streamlit_chat-0.1.1-py3-none-any.whl.metadata (4.2 kB)
Downloading streamlit_chat-0.1.1-py3-none-any.whl (1.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.2/1.2 MB[0m [31m8.4 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hInstalling collected packages: streamlit_chat
Successfully installed streamlit_chat-0.1.1


In [34]:
%%writefile temp_app.py
import os
import json
from dotenv import load_dotenv
import streamlit as st
from streamlit_chat import message

# Load environment variables
load_dotenv()

# Set Streamlit page configuration
st.set_page_config(
    page_title="JAX Helper Bot",
    page_icon="🦜",
    layout="wide",
    initial_sidebar_state="expanded"
)

# Initialize retriever
knowledge_base_path = "jax_faiss_index/index.faiss"  # Change path if needed
retriever = JAXFAISSRetriever(knowledge_base_path)

# Function to create source links
def create_sources_string(sources):
    if not sources:
        return ""
    sources_string = "Sources:\n"
    for i, source in enumerate(sorted(sources)):
        sources_string += f"{i+1}. {source}\n"
    return sources_string

# Sidebar
with st.sidebar:
    st.title("User Profile")
    user_name = "Manish Sharma"
    user_email = "manishsharma@gmail.com"
    st.write(f"**Name:** {user_name}")
    st.write(f"**Email:** {user_email}")

st.header("JAX Helper Bot 🦜🔗")

# Initialize session state
if "chat_answers_history" not in st.session_state:
    st.session_state["chat_answers_history"] = []
    st.session_state["user_prompt_history"] = []

# User input
prompt = st.text_input("Ask a question about JAX:", placeholder="How to use jax.vmap with multiple arguments?")

if st.button("Submit") and prompt:
    with st.spinner("Fetching relevant documents and generating response..."):
        retrieval_result = retriever.query(prompt)
        qa_result = qa_chain({"query": prompt, "context": retrieval_result["relevant_chunks"]})
        
        sources = set(doc.metadata["source"] for doc in retrieval_result["source_documents"])
        formatted_response = f"{qa_result['result']}\n\n{create_sources_string(sources)}"
        
        # Store chat history
        st.session_state["user_prompt_history"].append(prompt)
        st.session_state["chat_answers_history"].append(formatted_response)

# Display chat history
if st.session_state["chat_answers_history"]:
    for user_query, response in zip(st.session_state["user_prompt_history"], st.session_state["chat_answers_history"]):
        message(user_query, is_user=True, key=f"user_{user_query}")
        message(response, key=f"bot_{response}")

st.markdown("---")
st.markdown("Powered by LangChain and Streamlit")


Overwriting temp_app.py
