<a href="https://colab.research.google.com/github/Tanaya2012/QA-chatbot/blob/main/Fine_Tuned_BERT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Fine Tuning BERT for better results

This notebook demonstrates the process of fine-tuning the BERT model using a dataset generated from a pre-trained question generation model. The notebook utilizes the BERT model to train on the generated dataset, which consists of questions, answers, and corresponding context. 

The notebook includes steps to preprocess the dataset, such as tokenization and encoding. It then proceeds to fine-tune the BERT model on the preprocessed dataset, adjusting the model's weights and parameters to optimize performance for the specific question generation task.

Throughout the notebook, various evaluation metrics and techniques may be employed to assess the performance of the fine-tuned BERT model. The notebook provides code examples and explanations, allowing users to understand and reproduce the fine-tuning process for BERT in the context of question generation tasks.

In [1]:
import locale
def getpreferredencoding(do_setlocale = True):
    return "UTF-8"
locale.getpreferredencoding = getpreferredencoding
!pip install transformers

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [2]:
import torch
from transformers import AutoTokenizer, AutoModelForQuestionAnswering, Trainer, TrainingArguments, PreTrainedModel, PreTrainedTokenizerFast, AdamW
from pathlib import Path
import pandas as pd

## Importing the generated dataset

In [3]:
data = pd.read_csv("/content/drive/MyDrive/df_bert_train.csv")
pdf_text = Path('/content/drive/MyDrive/cleaned_sentences.txt').read_text()
cleaned_sentences_with_stopwords = pdf_text.split('\n')
clean_text = pdf_text.replace('\n', ' ')
data.head()

Unnamed: 0.1,Unnamed: 0,answer,question,text
0,0,<pad> emphysema and chronic bronchitis,What are the symptoms of chronic obstructive p...,chronic obstructive pulmonary disease copd is ...
1,1,airflow obstruction,What is the cause of chronic obstructive pulmo...,chronic obstructive pulmonary disease copd is ...
2,0,<pad> copd,What is an important contributor to mortality ...,copd is an important contributor to mortality ...
3,0,<pad> to reduce activity limitations among adu...,What is the goal of healthy people 2020?,healthy people 2020 has several copdrelated ob...
4,0,<pad> 2013,In what year did cdc analyze data from the beh...,to assess the statelevel prevalence of copd an...


In [21]:
user_questions = ['When did the GARDASIL 9 recommendations change?',
'What were the past 3 recommendation changes for GARDASIL 9?',
'Is GARDASIL 9 recommended for Adults?',
'Does the ACIP recommend one dose GARDASIL 9?']

## Fine Tuning

This code snippet performs several tasks related to preparing a dataset for training a question answering model using BERT.

First, it splits the questions, answers, and contexts into training and validation sets.

Next, it specifies the choice of a pre-trained model and tokenizer, specifically using the 'distilbert-base-uncased-distilled-squad' model.

Then, it prepares the dataset by encoding the training and validation contexts and questions using the tokenizer, incorporating truncation and padding.

The code proceeds to add the start and end positions of the answer within the input, by finding the answer tokens' occurrence in the encoded input.

Afterward, a custom dataset class called 'QADataset' is created, which handles the data in a suitable format for training. It converts the encodings into tensors and defines the length of the dataset.

Finally, the training and validation datasets are instantiated using the 'QADataset' class, using the respective encodings.

Overall, this code segment organizes the data into appropriate structures, prepares the inputs for BERT, and sets up custom dataset classes for training and validation datasets.

In [4]:
questions = data.question.tolist()
answers = data.answer.tolist()
contexts = data.text.tolist()

# Split the data into training and validation sets
split_index = int(0.8 * len(questions))
train_questions, val_questions = questions[:split_index], questions[split_index:]
train_answers, val_answers = answers[:split_index], answers[split_index:]
train_contexts, val_contexts = contexts[:split_index], contexts[split_index:]

# Choose a pre-trained model and tokenizer
model_name = 'distilbert-base-uncased-distilled-squad'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForQuestionAnswering.from_pretrained(model_name)

# Prepare the dataset using the tokenizer
train_encodings = tokenizer(train_contexts, train_questions, truncation=True, padding=True)
val_encodings = tokenizer(val_contexts, val_questions, truncation=True, padding=True)

# Add start_positions and end_positions to the dataset
def add_token_positions(encodings, answers):
    start_positions = []
    end_positions = []
    for i in range(len(answers)):
        # Get the answer tokens
        answer_tokens = tokenizer.encode(answers[i], add_special_tokens=False)
        
        # Find the first occurrence of the answer tokens in the encoded input
        for j in range(len(encodings.input_ids[i]) - len(answer_tokens) + 1):
            if encodings.input_ids[i][j:j + len(answer_tokens)] == answer_tokens:
                ans_start = j
                ans_end = j + len(answer_tokens) - 1
                break
        else:
            # If the answer is not found, set the start and end positions to 0
            ans_start = 0
            ans_end = 0

        start_positions.append(ans_start)
        end_positions.append(ans_end)

    encodings.update({'start_positions': start_positions, 'end_positions': end_positions})

add_token_positions(train_encodings, train_answers)
add_token_positions(val_encodings, val_answers)

# Create a custom dataset class
class QADataset(torch.utils.data.Dataset):
    def __init__(self, encodings):
        self.encodings = encodings

    def __getitem__(self, idx):
        return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}

    def __len__(self):
        return len(self.encodings.input_ids)

train_dataset = QADataset(train_encodings)
val_dataset = QADataset(val_encodings)

In [5]:
# !pip uninstall -y transformers accelerate
# !pip install transformers accelerate

This code snippet sets up training arguments and a Trainer object for fine-tuning the BERT model on the provided dataset. It defines parameters such as the output directory, number of training epochs, batch sizes, learning rate, and logging directory. The Trainer is instantiated with the model, training arguments, training, and evaluation datasets, along with a custom optimizer. The model is then fine-tuned using the trainer's `train()` method. Finally, the fine-tuned model and tokenizer are saved in the specified directory for future use.

In [6]:
# Set up training arguments and the Trainer
training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=11,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    logging_dir='./logs',
    learning_rate=1e-4, # Custom learning rate
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    optimizers=(AdamW(model.parameters(), lr=1e-4), None), # Custom optimizer with the learning rate
)

# Fine-tune the model
trainer.train()

# Save the fine-tuned model
model.save_pretrained('./FineTunedModel')
tokenizer.save_pretrained('./FineTunedModel')



Step,Training Loss


Step,Training Loss


('./FineTunedModel/tokenizer_config.json',
 './FineTunedModel/special_tokens_map.json',
 './FineTunedModel/vocab.txt',
 './FineTunedModel/added_tokens.json',
 './FineTunedModel/tokenizer.json')

## Testing

This code snippet loads a fine-tuned question answering model and tokenizer from a specified directory. It defines a function called `get_answer` that takes the loaded model, tokenizer, question, and context as input. The function tokenizes the question and context, processes the inputs with the model, and obtains the start and end scores. It identifies the answer tokens and converts them back into a string format. If the answer is empty, the score is set to 0.0. The function returns the answer and corresponding score.

In [27]:
# Load the fine-tuned model and tokenizer
model = AutoModelForQuestionAnswering.from_pretrained('./FineTunedModel')
tokenizer = AutoTokenizer.from_pretrained('./FineTunedModel')

def get_answer(model: PreTrainedModel, tokenizer: PreTrainedTokenizerFast, question: str, context: str = None) -> str:
    
    inputs = tokenizer(question, context, return_tensors='pt', max_length=512, truncation=True)
    outputs = model(**inputs)

    start_scores, end_scores = outputs.start_logits, outputs.end_logits
    score = float(torch.max(start_scores))
    start_index = torch.argmax(start_scores)
    end_index = torch.argmax(end_scores)

    tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
    answer_tokens = tokens[start_index:end_index + 1]
    answer_tokens = [token for token in answer_tokens if token not in tokenizer.all_special_tokens]
    answer = tokenizer.convert_tokens_to_string(answer_tokens)

    if answer == '':
      score = 0.0

    return answer, score

question = "When did the GARDASIL 9 recommendations change?"
answer = get_answer(model, tokenizer, question)
print("Answer:", answer)

Answer: ('', 0.0)


## Final Answer

In [30]:
for question in user_questions:
  max_score = 0.0
  final_ans = ''
  for text in cleaned_sentences_with_stopwords:
    ans, score = get_answer(model, tokenizer, question, text)
    if score > max_score:
      final_ans = ans
      max_score = score 
  print('Question:', question)
  print("final answer: ", final_ans)

Question: When did the GARDASIL 9 recommendations change?
final answer:  july 1
Question: What were the past 3 recommendation changes for GARDASIL 9?
final answer:  costeffectiveness modeling
Question: Is GARDASIL 9 recommended for Adults?
final answer:  
Question: Does the ACIP recommend one dose GARDASIL 9?
final answer:  81
