<a href="https://colab.research.google.com/github/AdiV121003/RIG-Using-DataGemma/blob/main/RIG.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Pre-requisites:

A100 GPU

High-RAM runtime

Hugging Face Token


Step1: Login to your hugging face account and create a new token

Step2: Create DataCommons API key

Step3: Enable Data Commons NL API

Step4: Install and Import necessary libraries

In [None]:
#install the following required libraries
!pip install -q git+https://github.com/datacommonsorg/llm-tools
!pip install -q bitsandbytes accelerate

#load the finetuned Gemma2 27B model

import torch

import data_gemma as dg

from google.colab import userdata
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

# Initialize Data Commons API client
DC_API_KEY = userdata.get('DC_API_KEY')
dc = dg.DataCommons(api_key=DC_API_KEY)


# Get finetuned Gemma2 model from HuggingFace
HF_TOKEN = userdata.get('HF_TOKEN')

nf4_config = BitsAndBytesConfig(
   load_in_4bit=True,
   bnb_4bit_quant_type="nf4",
   bnb_4bit_compute_dtype=torch.bfloat16
)

model_name = 'google/datagemma-rig-27b-it'
tokenizer = AutoTokenizer.from_pretrained(model_name, token=HF_TOKEN)
datagemma_model = AutoModelForCausalLM.from_pretrained(model_name,
                                             device_map="auto",
                                             quantization_config=nf4_config,
                                             torch_dtype=torch.bfloat16,
                                             token=HF_TOKEN)

# Build the LLM Model stub to use in RIG flow
datagemma_model_wrapper = dg.HFBasic(datagemma_model, tokenizer)

Step5: Pick or Enter a Query

Step6: Run the RIG technique and Generate Output

In [None]:
from IPython.display import Markdown
import textwrap

def display_chat(prompt, text):
  formatted_prompt = "<font size='+1' color='brown'>🙋‍♂️<blockquote>" + prompt + "</blockquote></font>"
  text = text.replace('•', '  *')
  text = textwrap.indent(text, '> ', predicate=lambda _: True)
  formatted_text = "<font size='+1' color='teal'>🤖\n\n" + text + "\n</font>"
  return Markdown(formatted_prompt+formatted_text)

def to_markdown(text):
  text = text.replace('•', '  *')
  return Markdown(textwrap.indent(text, '> ', predicate=lambda _: True))


ans = dg.RIGFlow(llm=datagemma_model_wrapper, data_fetcher=dc, verbose=False).query(query=QUERY)
Markdown(textwrap.indent(ans.answer(), '> ', predicate=lambda _: True))


display_chat(QUERY, ans.answer())