<a href="https://colab.research.google.com/github/Uzeeeee/CH341A/blob/main/notebooks/datagemma_rig.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install -q git+https://github.com/datacommonsorg/llm-tools
!pip install -q bitsandbytes accelerate

## Step 1: Load the model

This section loads the finetuned Gemma2 27B model from Huggingface and creates a transformer model wrapper than can be used in the Retrieval Interleaved Generation (RIG) workflow. More technical details of this approach can be found in the [paper](https://datacommons.org/link/DataGemmaPaper).

In [None]:
import torch

import data_gemma as dg

from google.colab import userdata
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

# Initialize Data Commons API client
DC_API_KEY = userdata.get('DC_API_KEY')
dc = dg.DataCommons(api_key=DC_API_KEY)


# Get finetuned Gemma2 model from HuggingFace
HF_TOKEN = userdata.get('HF_TOKEN')

nf4_config = BitsAndBytesConfig(
   load_in_4bit=True,
   bnb_4bit_quant_type="nf4",
   bnb_4bit_compute_dtype=torch.bfloat16
)

model_name = 'google/datagemma-rig-27b-it'
tokenizer = AutoTokenizer.from_pretrained(model_name, token=HF_TOKEN)
datagemma_model = AutoModelForCausalLM.from_pretrained(model_name,
                                             device_map="auto",
                                             quantization_config=nf4_config,
                                             torch_dtype=torch.bfloat16,
                                             token=HF_TOKEN)

# Build the LLM Model stub to use in RIG flow
datagemma_model_wrapper = dg.HFBasic(datagemma_model, tokenizer)

## Step 2: Pick or enter a query for RIG

You can selected a query or enter your own query to test RIG.


In [None]:

#@title Pick a query from a sample list{ run: "auto" }
QUERY = "What progress has Pakistan made against health goals?" #@param ["In which countries are more women getting college degrees than men?","What is the percentage of the financial sector in GDP for different countries like the United States and China, based on the latest data?","Does India have more people living in the urban areas or rural areas?  How does that vary by states?  Are the districts with the most urban population also located in the states with the most urban population?","What are some interesting trends in Sunnyvale spanning gender, age, race, immigration, health conditions, economic conditions, crime and education?","Which US counties share a very similar demographic composition to the US overall in terms of gender, age and racial breakdown?","Compare Cambridge, MA and Palo Alto, CA in terms of demographics, education, and economy stats.","How does the size of household compare across counties in Utah vs. California?  Does it change between owned vs. rental properties?","Based on the distribution of foreign language speakers compare the diversity of people in: NYC, Seattle, Austin, Chicago and Tampa","Are there significant differences in the prevalence of various types of disabilities (such as vision, hearing, mobility, cognitive) between Dallas and Houston?","Are there countries in the world where the forest area has actually increased?","What progress has Pakistan made against health goals?","Does an increase in female participation in education result in a higher number of women holding political office?","Which countries have the highest life expectancy?","Has the average lifespan increased globally?"]


In [None]:
#@title Use your own query (Please see disclaimer at the top)
QUERY = 'What progress has Pakistan made against health goals?' #@param {type:"string"}

## Step 3: Run RIG and Print Output


In [None]:
from IPython.display import Markdown
import textwrap

def display_chat(prompt, text):
  formatted_prompt = "<font size='+1' color='brown'>üôã‚Äç‚ôÇÔ∏è<blockquote>" + prompt + "</blockquote></font>"
  text = text.replace('‚Ä¢', '  *')
  text = textwrap.indent(text, '> ', predicate=lambda _: True)
  formatted_text = "<font size='+1' color='teal'>ü§ñ\n\n" + text + "\n</font>"
  return Markdown(formatted_prompt+formatted_text)

def to_markdown(text):
  text = text.replace('‚Ä¢', '  *')
  return Markdown(textwrap.indent(text, '> ', predicate=lambda _: True))


ans = dg.RIGFlow(llm=datagemma_model_wrapper, data_fetcher=dc, verbose=False).query(query=QUERY)
Markdown(textwrap.indent(ans.answer(), '> ', predicate=lambda _: True))


display_chat(QUERY, ans.answer())

In [None]:
importimport kagglehub
path = kagglehub.model_download('google/gemma/tensorRtLlm/2b-it/2')

# More Information on Retrieval Interleaved Generation (RIG)

Retrieval Interleaved Generation (RIG): This novel approach fine-tunes Gemma 2 to recognize when it needs to replace a generated number with more accurate information from Data Commons. Think of it as the model double-checking its work against a trusted source.

Here's how RIG works:
1. User Query: A user submits a query to the LLM.
2. Initial Response & Data Commons Query: The DataGemma model (based on the Gemma 2 27 billion parameter (27B) model and fully fine-tuned for this RIG task) generates a response, which includes a natural language query for Data Commons' existing natural language interface,  specifically designed to retrieve relevant data.
3. Data Retrieval & Correction: Data Commons is queried, and the data are retrieved. These data, along with source information and a link, are then used to replace potentially inaccurate numbers in the initial response.
4. Final Response with Source Link: The final response is presented to the user, including a link to the source data and metadata in Data Commons for transparency and verification.

In the above example, notice the questions being asked of Data Commons (eg "what was the life expectancy in Pakistan in 2000?") which is being used to compare the initial LLM response in `[__DC__#1(62.102 yr [1] || 61.8 years)]`. `61.8 years` is the  value generated by Gemma2 27B. DataGemma is trained to query Data Commons with "what was the life expectancy in Pakistan in 2000?". Statistics from the World Bank along with citations to the initial source are returned by Data Commons and replaced in the final response.  