<a href="https://colab.research.google.com/github/Aman-Kothari7/RAGsystem-FastAPI/blob/main/RBI_LORA_FT_Gemma_Model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Finetuning Google Gemma's Model Using LORA - RBI QnA dataset


In [1]:
#Installing necessary packages
!pip3 install -q -U bitsandbytes==0.42.0
!pip3 install -q -U peft==0.8.2
!pip3 install -q -U trl==0.7.10
!pip3 install -q -U accelerate==0.27.1
!pip3 install -q -U datasets==2.17.0
!pip3 install -q -U transformers==4.38.0

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m105.0/105.0 MB[0m [31m7.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m183.4/183.4 kB[0m [31m5.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m280.0/280.0 kB[0m [31m12.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m150.9/150.9 kB[0m [31m4.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m536.7/536.7 kB[0m [31m21.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m79.8/79.8 kB[0m [31m11.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m15.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m19.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━

In [2]:
#Importing packages
import os
import transformers
import torch
from google.colab import userdata
from datasets import load_dataset
from trl import SFTTrainer
from peft import LoraConfig
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import BitsAndBytesConfig, GemmaTokenizer

In [3]:
#Setting hugging face token and permissions
os.environ["HF_TOKEN"] = userdata.get('HF_TOKEN')
os.environ["WANDB_DISABLED"] = "false"

In [None]:
#Creating a simple QnA dataset and using LORA to fine tune, using GPT-3.5 to generate QnA pairs

import pandas as pd

def count_words(text):

    text = str(text)
    return len(text.split())

df['WordCount'] = df['RBI Notification Text'].apply(count_words)

filtered_df = df[df['WordCount'] <= 250]

filtered_df = filtered_df.drop(columns=['WordCount'])


In [None]:
filtered_df.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 125 entries, 1 to 499
Data columns (total 2 columns):
 #   Column                  Non-Null Count  Dtype 
---  ------                  --------------  ----- 
 0   RBI Notification Title  125 non-null    object
 1   RBI Notification Text   116 non-null    object
dtypes: object(2)
memory usage: 2.9+ KB


In [None]:
import pandas as pd

def create_qa_prompts(df):
    prompts = []

    for index, row in df.iterrows():
        prompt_text = f"Based on the following RBI notification titled '{row['RBI Notification Title']}' and its content: {row['RBI Notification Text']}.\n Generate a single relevant general question and answer pair in this JSON format: Q:Question Text, A:Answer Text.Skip any introductory phrases, only give the questions and answer pair "

        prompts.append(prompt_text)

    return prompts



In [None]:
qa_prompts_list = create_qa_prompts(filtered_df)

In [None]:
qa_prompts_list[0]

"Based on the following RBI notification titled 'Marginal Standing Facility' and its content: RBI/2018-2019/161    FMOD.MAOG. No.131/01.18.001/2018-19 April 4, 2019 All Marginal Standing Facility (MSF) participants Madam/Sir, Marginal Standing Facility As announced in the First Bi-monthly Monetary Policy Statement, 2019-20, today, it has been decided by the Monetary Policy Committee (MPC) to reduce the policy Repo rate under the Liquidity Adjustment Facility (LAF) by 25 basis points from 6.25 per cent to 6.00 per cent with immediate effect. 2. Consequently, the Marginal Standing Facility (MSF) rate stands adjusted to 6.25 per cent with immediate effect. 3. All other terms and conditions of the extant MSF scheme will remain unchanged. Yours sincerely (Radha Shyam Ratho)    Chief General Manager.\n Generate a single relevant general question and answer pair in this JSON format: Q:Question Text, A:Answer Text.Skip any introductory phrases, only give the questions and answer pair "

In [None]:
len(qa_prompts_list)

125

In [None]:
qa_prompts_list_first_100 = qa_prompts_list[:100]
len(qa_prompts_list_first_100)

100

In [None]:
#Generating qna pairs
import asyncio
from openai import OpenAI
import pandas as pd
import itertools

client = OpenAI(
    api_key="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",
)

def send_prompt(prompt, model="gpt-3.5-turbo"):
    return client.chat.completions.create(
        messages=[
            {
                "role": "user",
                "content": prompt,
            }
        ],
        model=model,
    )
responses = []
for prompt in qa_prompts_list_first_100:
    print("Added", len(responses))
    response = send_prompt(prompt)
    responses.append(response)

In [None]:
for response in responses:
  print(response.choices[0].message.content)

{
  "Q": "What is the new Marginal Standing Facility (MSF) rate after the recent decision by the Monetary Policy Committee?",
  "A": "The new Marginal Standing Facility (MSF) rate is 6.25 per cent."
}
{
  "Q": "What is the purpose of the RBI notification regarding auction of Government of India Dated Securities?",
  "A": "The purpose is to inform all scheduled commercial banks, financial institutions, and primary dealers about the upcoming auctions of government securities and provide details on the auction process and terms."
}
{
  "Q": "What is the purpose of the RBI notification titled 'StCBs/RRBs – Increase in CRR'?",
  "A": "The purpose of the notification is to increase the Cash Reserve Ratio (CRR) of all Scheduled State Co-operative Banks (StCBs) and Regional Rural Banks (RRBs) in two stages due to the current macroeconomic and monetary conditions."
}
{
  "Q": "What was the revised Rupee value of the Special Currency Basket as per the RBI notification dated February 6, 2012?",
 

In [None]:
#Extracting QnA pairs and creating Dataframe
import pandas as pd
import json

qna_pairs = []

for response in responses:

    content = response.choices[0].message.content
    try:
        qna = json.loads(content)
        question = qna.get("Q")
        answer = qna.get("A")
        qna_pairs.append({"question": question, "answer": answer})
    except json.JSONDecodeError as e:
        print(f"Error decoding JSON: {e}")
        continue

rbi_qna_pairs = pd.DataFrame(qna_pairs)

print(rbi_qna_pairs.head())



Error decoding JSON: Expecting value: line 1 column 1 (char 0)
                                            question  \
0  What is the new Marginal Standing Facility (MS...   
1  What is the purpose of the RBI notification re...   
2  What is the purpose of the RBI notification ti...   
3  What was the revised Rupee value of the Specia...   
4  What is the purpose of the auction of Governme...   

                                              answer  
0  The new Marginal Standing Facility (MSF) rate ...  
1  The purpose is to inform all scheduled commerc...  
2  The purpose of the notification is to increase...  
3  The revised Rupee value of the Special Currenc...  
4  The purpose of the auction is for the Governme...  


In [None]:
rbi_qna_pairs.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 99 entries, 0 to 98
Data columns (total 2 columns):
 #   Column    Non-Null Count  Dtype 
---  ------    --------------  ----- 
 0   question  99 non-null     object
 1   answer    99 non-null     object
dtypes: object(2)
memory usage: 1.7+ KB


In [None]:
rbi_qna_pairs.to_csv('df_rbi_qna_pairs.csv', index=False)


In [4]:
#Defining model and quantization configuration
model_id = "google/gemma-2b"
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

In [5]:
#Initializing tokenizer and model details
tokenizer = AutoTokenizer.from_pretrained(model_id, token=os.environ['HF_TOKEN'])
model = AutoModelForCausalLM.from_pretrained(model_id,
                                             quantization_config=bnb_config,
                                             device_map={"":0},
                                             token=os.environ['HF_TOKEN'])

tokenizer_config.json:   0%|          | 0.00/1.11k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/555 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/627 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/13.5k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.95G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/67.1M [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/137 [00:00<?, ?B/s]

In [6]:
os.environ["WANDB_DISABLED"] = "false"

In [7]:
# Initializing LORA configuration and setting parameters
lora_config = LoraConfig(
    r = 8,
    target_modules = ["q_proj", "o_proj", "k_proj", "v_proj",
                      "gate_proj", "up_proj", "down_proj"],
    task_type = "CAUSAL_LM",
)

In [9]:
#Loading dataset
from datasets import load_dataset
data = load_dataset('csv', data_files='/content/df_rbi_qna_pairs.csv')
data = data.map(lambda samples: tokenizer(samples["question"]), batched=True)

Generating train split: 0 examples [00:00, ? examples/s]

Map:   0%|          | 0/99 [00:00<?, ? examples/s]

In [10]:
data['train']['question']

['What is the new Marginal Standing Facility (MSF) rate after the recent decision by the Monetary Policy Committee?',
 'What is the purpose of the RBI notification regarding auction of Government of India Dated Securities?',
 "What is the purpose of the RBI notification titled 'StCBs/RRBs – Increase in CRR'?",
 'What was the revised Rupee value of the Special Currency Basket as per the RBI notification dated February 6, 2012?',
 'What is the purpose of the auction of Government of India Dated Securities under Market Stabilisation Scheme (MSS)?',
 'When will the revised regulatory framework for Urban Co-operative Banks (UCBs) regarding Net Worth and Capital Adequacy come into effect?',
 'What is the quantum of Government securities that standalone Primary Dealers (PDs) can hold in the HTM category?',
 "What is the Government of India's decision regarding interest payments under the Agricultural Debt Waiver and Debt Relief Scheme, 2008?",
 'What disclosures are required to be made by Sec

In [11]:
# Formatting function to pass to SFTTrainer
def formatting_prompts_func(example):
    output_texts = []
    for i in range(len(example['question'])):
        text = f"### Question: {example['question'][i]}\n ### Answer: {example['answer'][i]}"
        output_texts.append(text)
    return output_texts

In [12]:
data['train'][4]

{'question': 'What is the purpose of the auction of Government of India Dated Securities under Market Stabilisation Scheme (MSS)?',
 'answer': "The purpose of the auction is for the Government of India to sell (re-issue) '6.18 percent Government Stock 2005' for a notified amount of Rs.5,000 crore through a price-based auction using multiple price auction method.",
 'input_ids': [2,
  1841,
  603,
  573,
  6187,
  576,
  573,
  27788,
  576,
  6632,
  576,
  5339,
  87640,
  54887,
  1362,
  11526,
  43775,
  136232,
  37288,
  591,
  141170,
  15939],
 'attention_mask': [1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1]}

In [13]:
# Setting training parameters
trainer = SFTTrainer(
    model=model,
    train_dataset=data["train"],
    args=transformers.TrainingArguments(
        per_device_train_batch_size=1,
        gradient_accumulation_steps=4,
        warmup_steps=2,
        max_steps=100,
        learning_rate=2e-4,
        fp16=True,
        logging_steps=1,
        output_dir="outputs",
        optim="paged_adamw_8bit"
    ),
    peft_config=lora_config,
    formatting_func=formatting_prompts_func,
)



Map:   0%|          | 0/99 [00:00<?, ? examples/s]



In [14]:
trainer.train()

Step,Training Loss
1,2.5851
2,3.1648
3,2.411
4,2.9464
5,2.5975
6,2.0607
7,2.486
8,2.3467
9,2.1481
10,2.1751


TrainOutput(global_step=100, training_loss=1.4028206366300582, metrics={'train_runtime': 133.5521, 'train_samples_per_second': 2.995, 'train_steps_per_second': 0.749, 'total_flos': 293877878292480.0, 'train_loss': 1.4028206366300582, 'epoch': 4.04})

In [54]:
#Generating output
text = "question:What is the purpose of the RBI notification titled 'StCBs/RRBs – Increase in CRR'?"
device = "cuda:0"
inputs = tokenizer(text, return_tensors="pt").to(device)

outputs = model.generate(**inputs, max_new_tokens=100)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

question:What is the purpose of the RBI notification titled 'StCBs/RRBs – Increase in CRR'?
answer:The purpose of the RBI notification is to inform all Scheduled Commercial Banks (SCBs) and Regional Rural Banks (RRBs) that the Cash Reserve Ratio (CRR) has been increased by 50 basis points from 5.50 per cent to 6.00 per cent with effect from April 1, 2016.


In [60]:

text = "question:What is the new Marginal Standing Facility (MSF) rate after the recent decision by the Monetary Policy Committee?"
device = "cuda:0"
inputs = tokenizer(text, return_tensors="pt").to(device)

outputs = model.generate(**inputs, max_new_tokens=40)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

question:What is the new Marginal Standing Facility (MSF) rate after the recent decision by the Monetary Policy Committee?
answer:The new Marginal Standing Facility (MSF) rate is 6.25 per cent.
question:What is the new Repo rate after the recent decision by the Monetary Policy Committee


In [49]:
generated_answer = tokenizer.decode(outputs[0], skip_special_tokens=True).split("answer:")[-1].strip()

In [50]:
print(generated_answer)

The purpose of the RBI notification is to inform all Scheduled Commercial Banks (SCBs) and Regional Rural Banks (RRBs) that the Cash Reserve Ratio (CRR) has been increased by 50 basis points from 5.50 per cent to 6.00 per cent with effect from April 1, 2016.


In [None]:
#Generating multiple outputs
generated_data = []
i = 0
for example in data["train"]:
  i = i + 1
  print(i)
  device = "cuda:0"
  question = example['question']
  answer = example['answer']
  formatted_question = "question:" + question
  inputs = tokenizer(formatted_question, return_tensors="pt").to(device)

  outputs = model.generate(**inputs, max_new_tokens=100)
  generated_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
  print("generated_response:", generated_response)

  generated_data.append({"question": question, "answer": answer, "generated_response": generated_response})

In [80]:
# Evaluting using BLUE score to compare generated responses and ground truth
from nltk.translate.bleu_score import corpus_bleu

reference_responses = ["The new Marginal Standing Facility (MSF) rate is 6.25 per cent.", "The purpose is to inform all scheduled commercial banks, financial institutions, and primary dealers about the upcoming auctions of government securities and provide details on the auction process and terms.", "The purpose of the notification is to increase the Cash Reserve Ratio (CRR) of all Scheduled State Co-operative Banks (StCBs) and Regional Rural Banks (RRBs) in two stages due to the current macroeconomic and monetary conditions.", "The revised Rupee value of the Special Currency Basket was fixed at Rs.68.838139 with effect from February 9, 2012.", "The revised regulatory framework for Urban Co-operative Banks (UCBs) regarding Net Worth and Capital Adequacy will come into effect from March 31, 2023.", "Standalone Primary Dealers (PDs) can hold Government securities in the HTM category up to the extent of their audited net owned funds (NOF) as at the end of March of the preceding financial year.", "The Government of India has decided to pay interest on the 2nd, 3rd, and 4th instalments, payable by July 2009, July 2010, and July 2011 respectively, at the prevailing Yield to Maturity Rate on 364-day Government of India Treasury Bills.","Representations received from banks, Federation/Association of urban co-operative banks, and the need to align with international practices and current risk management practices in India.","The FATF is calling upon jurisdictions to complete the implementation of their action plan within a timeframe.","The reverse repo rate under the Liquidity Adjustment Facility (LAF) was increased to 5.00 per cent from 4.75 per cent effective from April 29, 2005.", "The purpose of the SJSRY is to assist urban poor beneficiaries with self-employment opportunities.", "Primary co-operative banks should exercise due caution with regard to valuation while sanctioning loans against mortgage of house property."]
generated_responses = ["The new Marginal Standing Facility (MSF) rate is 6.25 per cent.", "The purpose of the RBI notification is to inform all concerned that the auction of Government of India Dated Securities will be conducted through the auction based bidding system on a single auction date. The auction will be conducted through the Reserve Bank's portal", "The purpose of the RBI notification is to inform all Scheduled Commercial Banks (SCBs) and Regional Rural Banks (RRBs) that the Cash Reserve Ratio (CRR) has been increased by 50 basis points from 5.50 per cent to 6.00 per cent with effect from April 1, 2016.", "The Rupee value of the Special Currency Basket was revised from 67.75 per cent of the value of the Special Currency to 67.50 per cent of the value of the Special Currency.", "The revised regulatory framework for Urban Co-operative Banks (UCBs) regarding Net Worth and Capital Adequacy will come into effect from April 1, 2007.", "The Government securities that standalone PDs can hold in the HTM category are limited to 25 per cent of their net owned funds (NOF) as per the RBI notification.", "Interest payments under the Agricultural Debt Waiver and Debt Relief Scheme, 2008 will be made by the Reserve Bank of India directly to the banks.", "The RBI decided to review the existing guidelines of classification of investments for urban co-operative banks in order to ensure that the investments made by urban co-operative banks are in accordance with the guidelines issued by the RBI.", "The FATF is calling upon jurisdictions to complete their action plans within a timeframe.", "The new reverse repo rate under the LAF is 5.25 per cent.", "The purpose of the SJSRY is to provide financial assistance to urban poor households for self-employment ventures.", "Primary co-operative banks should ensure that the loan amount is not more than the market value of the property. They should also verify the title of the property and ensure that the loan is secured by a mortgage over the property."]


bleu_score = corpus_bleu([[ref.split()] for ref in reference_responses], [gen.split() for gen in generated_responses])
print("BLEU Score:", bleu_score)


12
12
BLEU Score: 0.2588679828951621


In [87]:
import locale
locale.getpreferredencoding = lambda: "UTF-8"

In [91]:
!pip install rouge_score



In [92]:
# Evaluting using Rouge score to compare generated responses and ground truth
from datasets import load_metric

rouge = load_metric("rouge")

rouge_output = rouge.compute(predictions=generated_responses, references=reference_responses)
print("ROUGE Score:", rouge_output)

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


ROUGE Score: {'rouge1': AggregateScore(low=Score(precision=0.4142487829588807, recall=0.4531047311842984, fmeasure=0.4185759133076496), mid=Score(precision=0.5580127619497548, recall=0.5686101328512256, fmeasure=0.5475945988529005), high=Score(precision=0.7216834771430676, recall=0.6978882219323396, fmeasure=0.6875983841932014)), 'rouge2': AggregateScore(low=Score(precision=0.2274760488545012, recall=0.24968334461171351, fmeasure=0.22854595950795617), mid=Score(precision=0.3960222527022167, recall=0.4036174117009357, fmeasure=0.3891435508926452), high=Score(precision=0.5745409815766072, recall=0.5754780212885411, fmeasure=0.5667289352069257)), 'rougeL': AggregateScore(low=Score(precision=0.36412932566925976, recall=0.397937139768874, fmeasure=0.37322879396545194), mid=Score(precision=0.5237435002736727, recall=0.5277555333165687, fmeasure=0.5087383298056922), high=Score(precision=0.6903999964449381, recall=0.667790182076334, fmeasure=0.6635210319057566)), 'rougeLsum': AggregateScore(lo

##Conclusion

In [None]:
# The project achieved moderate performance with a BLEU score of 25.9% and ROUGE scores indicating moderate overlap between generated and reference responses.
# However, there is room for improvement, especially in handling numerical and temporal information. Notably, a BLEU score above 40% is generally considered excellent for most tasks.

# Improvements:

# Diversify and expand the training data to include a wider range of examples, particularly those involving numerical and temporal contexts.
# Fine-tune the model with domain-specific datasets or tailored pre-training steps.
# Experiment with different model architectures, hyperparameters, and optimization techniques.
# Implement post-processing techniques to refine generated responses.
# Continuously evaluate and iterate the model based on user feedback and performance metrics.

## Finetuning Google Gemma's Model Using LORA - original scraped dataset

In [None]:
!pip3 install -q -U bitsandbytes==0.42.0
!pip3 install -q -U peft==0.8.2
!pip3 install -q -U trl==0.7.10
!pip3 install -q -U accelerate==0.27.1
!pip3 install -q -U datasets==2.17.0
!pip3 install -q -U transformers==4.38.0

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m105.0/105.0 MB[0m [31m7.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m183.4/183.4 kB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m280.0/280.0 kB[0m [31m6.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m150.9/150.9 kB[0m [31m1.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m536.7/536.7 kB[0m [31m12.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m79.8/79.8 kB[0m [31m7.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m11.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m15.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━

In [None]:
pip install openai

Collecting openai
  Downloading openai-1.12.0-py3-none-any.whl (226 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m226.7/226.7 kB[0m [31m2.4 MB/s[0m eta [36m0:00:00[0m
Collecting httpx<1,>=0.23.0 (from openai)
  Downloading httpx-0.27.0-py3-none-any.whl (75 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m75.6/75.6 kB[0m [31m8.2 MB/s[0m eta [36m0:00:00[0m
Collecting httpcore==1.* (from httpx<1,>=0.23.0->openai)
  Downloading httpcore-1.0.4-py3-none-any.whl (77 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m77.8/77.8 kB[0m [31m10.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting h11<0.15,>=0.13 (from httpcore==1.*->httpx<1,>=0.23.0->openai)
  Downloading h11-0.14.0-py3-none-any.whl (58 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m58.3/58.3 kB[0m [31m7.2 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: h11, httpcore, httpx, openai
Successfully installed h11-0.14.0 httpcore-1.0.4 h

In [93]:
import os
import transformers
import torch
from google.colab import userdata
from datasets import load_dataset
from trl import SFTTrainer
from peft import LoraConfig
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import BitsAndBytesConfig, GemmaTokenizer

In [94]:
os.environ["HF_TOKEN"] = userdata.get('HF_TOKEN')

In [95]:
model_id = "google/gemma-2b"
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

In [96]:
tokenizer = AutoTokenizer.from_pretrained(model_id, token=os.environ['HF_TOKEN'])
model = AutoModelForCausalLM.from_pretrained(model_id,
                                             quantization_config=bnb_config,
                                             device_map={"":0},
                                             token=os.environ['HF_TOKEN'])

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [97]:
os.environ["WANDB_DISABLED"] = "false"

In [98]:
lora_config = LoraConfig(
    r = 8,
    target_modules = ["q_proj", "o_proj", "k_proj", "v_proj",
                      "gate_proj", "up_proj", "down_proj"],
    task_type = "CAUSAL_LM",
)

In [None]:
from datasets import load_dataset
data = load_dataset('csv', data_files='/content/notifications_data_small.csv',)
data = data.map(lambda samples: tokenizer(samples["RBI Notification Title"]), batched=True)

In [None]:
data['train']['RBI Notification Title']

['GOI Notification - 5.48 per cent',
 'Marginal Standing Facility',
 'Tender for "7.49 percent Government Stock, 2017" for an aggregate amount of Rs.5,000 crore : Auction to be held on June 23,2005',
 'Foreign Exchange Management (Transfer or Issue of Security by a Person Resident outside India) (Fifteenth Amendment) Regulations, 2013',
 'Auction of Government of India Dated Securities',
 'Income Tax Clearance Certificate/No Objection CertificateA.P. (DIR Series) Circular No.27 (September 28, 2002)',
 'StCBs/RRBs – Increase in CRR',
 'Master Circular – Detection and Impounding of Counterfeit Notes',
 'Tender for Non - Competitive Bids',
 'Year 2000 (Y2K) Issues – Information Sharing and Disclosure (Commercial Banks)',
 'Overseas Foreign Currency Borrowings by Authorised Dealer Banks',
 'Deferred Payment Protocols between Government of India and erstwhile USSR',
 'Auction of Government of India Dated Securities under Market Stabilisation Scheme (MSS)',
 'Auction for Sale (Re-issue ) of 

In [None]:
def formatting_prompts_func(example):
    output_texts = []
    for i in range(len(example['question'])):
        text = f"### Title: {example['RBI Notification Title'][i]}\n ### Explanation: {example['RBI Notification Text'][i]}"
        output_texts.append(text)
    return output_texts

In [None]:
data['train']

Dataset({
    features: ['RBI Notification Title', 'RBI Notification Text', 'input_ids', 'attention_mask'],
    num_rows: 500
})

In [None]:
trainer = SFTTrainer(
    model=model,
    train_dataset=data["train"],
    args=transformers.TrainingArguments(
        per_device_train_batch_size=1,
        gradient_accumulation_steps=4,
        warmup_steps=2,
        max_steps=100,
        learning_rate=2e-4,
        fp16=True,
        logging_steps=1,
        output_dir="outputs",
        optim="paged_adamw_8bit"
    ),
    peft_config=lora_config,
    formatting_func=formatting_prompts_func,
)

In [None]:
trainer.train()

In [None]:
text = "Marginal"
device = "cuda:0"
inputs = tokenizer(text, return_tensors="pt").to(device)

outputs = model.generate(**inputs, max_new_tokens=50)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

Marginal Standing Facility (MSF) rate was on Wednesday raised by the Monetary Policy Committee to 6.25 per cent.

The decision was taken at the two-day meeting of the Monetary Policy Committee, which concluded on Wednesday.

The decision
