In [1]:
from dotenv import load_dotenv
load_dotenv()

True

# Development

## Ingestion

In [2]:
import os
import json
import fitz
from llm_utils import to_markdown
from llama_index.core import Document
from llama_index.core.node_parser import MarkdownNodeParser

  from .autonotebook import tqdm as notebook_tqdm


### Load PDF

In [3]:
pdf_path = "pdfs/2403.19887v1.Jamba__A_Hybrid_Transformer_Mamba_Language_Model.pdf"
paper_pdf = fitz.open(pdf_path)
md_text = to_markdown(paper_pdf)

In [4]:
# Create document object from the markdown text
document_paper = Document(text=md_text)

# Breakdown the document into nodes using the MarkdownNodeParser
parser = MarkdownNodeParser()
nodes = parser.get_nodes_from_documents([document_paper])  # Returns a list of nodes

# Remove empty nodes or nodes with less than 10 characters
nodes = [node for node in nodes if len(node.text.strip()) > 10]

len(nodes)

11

### Embed PDF Nodes

In [5]:
from openai import OpenAI

oai_client = OpenAI()

In [None]:
# Sample Embedding
response = oai_client.embeddings.create(
    input="This is a sample text",
    model="text-embedding-3-small"
)

response.data[0].embedding

In [20]:
for i, node in enumerate(nodes):
    response = oai_client.embeddings.create(
        input=node.get_content(metadata_mode="all"),
        model="text-embedding-3-small"
    )

    node.embedding = response.data[0].embedding

    # Add original text to metadata as well for reference
    node.metadata['text'] = node.text
    node.id_ = i  # Assign straightforward index as ids for now

# Note:
# This is a simple example of how to index the nodes using openai embeddings one by one.
# In practice, you should batch the embeddings and index them in bulk and not use a for loop.

In [21]:
nodes[0].metadata

{'Header_2': 'A Hybrid Transformer-Mamba Language Model',
 'text': 'A Hybrid Transformer-Mamba Language Model\n\n**Opher Lieber** _∗_ **Barak Lenz** _∗_ **Hofit Bata** **Gal Cohen** **Jhonathan Osin**\n**Itay Dalmedigos** **Erez Safahi** **Shaked Meirom** **Yonatan Belinkov**\n**Shai Shalev-Shwartz** **Omri Abend** **Raz Alon** **Tomer Asida**\n**Amir Bergman** **Roman Glozman** **Michael Gokhman** **Avashalom Manevich**\n**Nir Ratner** **Noam Rozen** **Erez Shwartz** **Mor Zusman** **Yoav Shoham**'}

In [None]:
nodes[0].embedding

In [30]:
[json.dumps(node.metadata) for node in nodes]

['{"Header_2": "A Hybrid Transformer-Mamba Language Model", "text": "A Hybrid Transformer-Mamba Language Model\\n\\n**Opher Lieber** _\\u2217_ **Barak Lenz** _\\u2217_ **Hofit Bata** **Gal Cohen** **Jhonathan Osin**\\n**Itay Dalmedigos** **Erez Safahi** **Shaked Meirom** **Yonatan Belinkov**\\n**Shai Shalev-Shwartz** **Omri Abend** **Raz Alon** **Tomer Asida**\\n**Amir Bergman** **Roman Glozman** **Michael Gokhman** **Avashalom Manevich**\\n**Nir Ratner** **Noam Rozen** **Erez Shwartz** **Mor Zusman** **Yoav Shoham**"}',
 '{"Header_2": "A Hybrid Transformer-Mamba Language Model", "Header_3": "Abstract", "text": "Abstract\\n\\nWe present Jamba, a new base large language model based on a novel hybrid\\nTransformer-Mamba mixture-of-experts (MoE) architecture. Specifically, Jamba\\ninterleaves blocks of Transformer and Mamba layers, enjoying the benefits of both\\nmodel families. MoE is added in some of these layers to increase model capacity\\nwhile keeping active parameter usage manageab

### Index the Embeddings to a Vector Index

In [22]:
from upstash_vector import Index, Vector

index = Index.from_env()

In [23]:
# Convert the nodes to vector objects
vectors = []

for node in nodes:
    vector = Vector(
        id=node.node_id,
        vector=node.embedding,
        metadata=node.metadata
    )
    vectors.append(vector)

In [25]:
# Upsert the vectors to the index
index.upsert(vectors)

'Success'

## Retrieval

In [52]:
query = "What is the Jamba language model?"
TOP_K = 3

In [53]:
# Get the embeddings for the query
query_vector = oai_client.embeddings.create(
    input=query,
    model="text-embedding-3-small"
).data[0].embedding

In [54]:
# Execute the query
query_result = index.query(
    vector=query_vector,
    include_metadata=True,
    include_vectors=False,
    top_k=TOP_K
)

In [55]:
for result in query_result:
    print("Score:", result.score)
    print("ID:", result.id)
    print("Metadata:", result.metadata)
    print()

Score: 0.8411506
ID: 1
Metadata: {'Header_2': 'A Hybrid Transformer-Mamba Language Model', 'Header_3': 'Abstract', 'text': 'Abstract\n\nWe present Jamba, a new base large language model based on a novel hybrid\nTransformer-Mamba mixture-of-experts (MoE) architecture. Specifically, Jamba\ninterleaves blocks of Transformer and Mamba layers, enjoying the benefits of both\nmodel families. MoE is added in some of these layers to increase model capacity\nwhile keeping active parameter usage manageable. This flexible architecture allows\nresource- and objective-specific configurations. In the particular configuration we\nhave implemented, we end up with a powerful model that fits in a single 80GB\nGPU. Built at large scale, Jamba provides high throughput and small memory\nfootprint compared to vanilla Transformers, and at the same time state-of-the-art\nperformance on standard language model benchmarks and long-context evaluations.\nRemarkably, the model presents strong results for up to 256K

In [63]:
def build_result_str(metadata):
    '''
    Build a string from the metadata dictionary of a query result for adding to the context of the LLM.
    '''
    text = metadata['text']
    _meta = {
        k: v for k, v in metadata.items() if k != 'text'
    }
    
    meta_str = "\n".join([f"{k}: {v}" for k, v in _meta.items()])
    return f"{meta_str}\n\n{text}"

In [64]:
print(build_result_str(query_result[1].metadata))

Header_2: A Hybrid Transformer-Mamba Language Model
Header_3: 4 ### Training Infrastructure and Dataset

4 ### Training Infrastructure and Dataset

The model was trained on NVIDIA H100 GPUs. We used an in-house proprietary framework
allowing efficient large-scale training including FSDP, tensor parallelism, sequence parallelism, and
expert parallelism.

Jamba is trained on an in-house dataset that contains text data from the Web, books, and code, with
the last update in March 2024. Our data processing pipeline includes quality filters and deduplication.


In [66]:
context_prompt = """Retrieved context to answer the query is as follows:
{context_str}
"""

def build_context_prompt(retrieval_results):
    context_str = "\n\n---------------------\n\n".join([build_result_str(r.metadata) for r in retrieval_results])
    return context_prompt.format(
        context_str=context_str
    )

In [68]:
print(build_context_prompt(query_result))

Retrieved context to answer the query is as follows:
Header_2: A Hybrid Transformer-Mamba Language Model
Header_3: Abstract

Abstract

We present Jamba, a new base large language model based on a novel hybrid
Transformer-Mamba mixture-of-experts (MoE) architecture. Specifically, Jamba
interleaves blocks of Transformer and Mamba layers, enjoying the benefits of both
model families. MoE is added in some of these layers to increase model capacity
while keeping active parameter usage manageable. This flexible architecture allows
resource- and objective-specific configurations. In the particular configuration we
have implemented, we end up with a powerful model that fits in a single 80GB
GPU. Built at large scale, Jamba provides high throughput and small memory
footprint compared to vanilla Transformers, and at the same time state-of-the-art
performance on standard language model benchmarks and long-context evaluations.
Remarkably, the model presents strong results for up to 256K tokens con

In [77]:
def context_retrieval(search_query: str) -> str:
    '''
    This function let's you semantically retrieve relevant context chunks from a given document based on a query.

    Arguments:
        query (str): The query to search for in the document. Based on the original user query, write a good search query
                     which is more logically sound to retrieve the relevant information from the document.

    Returns:
        str: The retrieved context chunks from the document based on the search query formatted as a string.
    '''
    # Get the embeddings for the search query
    query_vector = oai_client.embeddings.create(
        input=search_query,
        model="text-embedding-3-small"
    ).data[0].embedding

    # Execute the query
    query_result = index.query(
        vector=query_vector,
        include_metadata=True,
        include_vectors=False,
        top_k=3
    )

    return build_context_prompt(query_result)

In [78]:
print(context_retrieval("on what hardware was jamba trained on?"))

Retrieved context to answer the query is as follows:
Header_2: A Hybrid Transformer-Mamba Language Model
Header_3: 4 ### Training Infrastructure and Dataset

4 ### Training Infrastructure and Dataset

The model was trained on NVIDIA H100 GPUs. We used an in-house proprietary framework
allowing efficient large-scale training including FSDP, tensor parallelism, sequence parallelism, and
expert parallelism.

Jamba is trained on an in-house dataset that contains text data from the Web, books, and code, with
the last update in March 2024. Our data processing pipeline includes quality filters and deduplication.

---------------------

Header_2: A Hybrid Transformer-Mamba Language Model
Header_3: 7 ### Conclusion

7 ### Conclusion

We presented Jamba, a novel architecture which combines Attention and Mamba layers, with MoE
modules, and an open implementation of it, reaching state-of-the-art performance and supporting
long contexts. We showed how Jamba provides flexibility for balancing perfor

## Response Generation

In [79]:
system_prompt = """You are a Q&A bot. You are here to answer questions based on the context retrieved
from a vector index of the chunks of a document. You are prohibited from using prior knowledge and you
can only use the context given. If you need more information, please ask the user. If you cannot answer 
the question from the context, you can tell the user that you cannot answer the question. You can also 
ask for more information from the user."""

system_message = {
    'role': 'system',
    'content': system_prompt
}

In [132]:
# A mapping of the tool name to the function that should be called
available_functions = {
    "context_retrieval": context_retrieval,
}

# Here we have only one function, but you can have multiple as well

In [140]:
# Define a JSON schema for the tools that the LLM can use
# Here we define a schema for the context_retrieval function
tools_schema = [
    {
        "type": "function",
        "function": {
            "name": "context_retrieval",
            "description": "This function let's you semantically retrieve relevant context chunks from a given document based on a query. Based on the original user query, write a good search query which is more logically sound to retrieve the relevant information from the document. You might even have to break down the user query into multiple search queries and call this function multiple times separately. This function finally returns the retrieved context chunks from the document based on the search query formatted as a string.",
            "parameters": {
                "type": "object",
                "properties": {
                    "search_query": {
                        "type": "string",
                        "description": "The sub-query to search for in the document."
                    }
                },
                "required": ["search_query"],
            },
        },
    }
]

In [151]:
def conversation_turn(user_message, messages, tools, model='gpt-3.5-turbo', temperature=0.2, max_tokens=512, verbose=True, **kwargs):

    # Add user message to messages list
    messages.append({
        'role': 'user',
        'content': user_message
    })

    if verbose:
        print("\n<< User Message >>")
        print(user_message)
    
    # Send the conversation and available tools/functions to the model
    response = oai_client.chat.completions.create(
        model=model,
        messages=messages,
        tools=tools,
        tool_choice="auto",  # auto is default, but we'll be explicit
        **kwargs
    )
    response_message = response.choices[0].message
    tool_calls = response_message.tool_calls

    # Add the response to the messages list
    messages.append(response_message)

    # Check if the model wanted to call a function
    if tool_calls:

        # Call each of the functions
        for tool_call in tool_calls:
            function_name = tool_call.function.name
            function_to_call = available_functions[function_name]
            function_args = json.loads(tool_call.function.arguments)

            if verbose:
                print(f"\n<< Calling Function `{function_name}` with Args: {function_args} >>")

            # Call the function
            function_response = function_to_call(**function_args)

            # if verbose:
            #     print("<< Function Response >>")
            #     print(function_response)

            # Add the function response to the messages list
            messages.append(
                {
                    "tool_call_id": tool_call.id,
                    "role": "tool",
                    "name": function_name,
                    "content": function_response,
                }
            )

        # Get a new response from the model based on the function response
        second_response = oai_client.chat.completions.create(
            model=model,
            messages=messages,
            **kwargs
        )
        second_response_message = second_response.choices[0].message
        messages.append(second_response_message)

        if verbose:
            print("\n<< Response >>")
            print(second_response_message.content)

        return second_response_message, messages

    if verbose:
        print("\n<< Response >>")
        print(response_message.content)

    return response_message, messages

In [152]:
messages = [
    system_message
]

In [154]:
response, messages = conversation_turn(
    "hardware of jamba and what is the conclusion?",
    messages,
    tools_schema
)


<< User Message >>
hardware of jamba and what is the conclusion?

<< Calling Function `context_retrieval` with Args: {'search_query': 'hardware of jamba'} >>

<< Calling Function `context_retrieval` with Args: {'search_query': 'conclusion of jamba'} >>

<< Response >>
The hardware used for training Jamba includes NVIDIA H100 GPUs, and a proprietary in-house framework was utilized to enable large-scale training with features like FSDP, tensor parallelism, sequence parallelism, and expert parallelism.

The conclusion drawn about Jamba is that it is a novel architecture that combines Attention and Mamba layers with MoE modules, reaching state-of-the-art performance and supporting long contexts. Jamba provides flexibility for balancing performance and memory requirements while maintaining a high throughput. The design choices, such as the ratio of Attention-to-Mamba layers, were experimented with, leading to insights that will guide future work on hybrid attention-state-space models.

If 