<a href="https://colab.research.google.com/github/Siddharth-R512/Knowledge-Based-QA-using-DistilBERT/blob/main/Knowledge_Based_QA_using_DistilBERT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
from transformers import DistilBertTokenizer, DistilBertForQuestionAnswering
import nltk
nltk.download('punkt')
nltk.download('stopwords')
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize

class KnowledgeBasedQASystem:

  def __init__(self, model_name='distilbert-base-uncased-distilled-squad'):
    self.tokenizer = DistilBertTokenizer.from_pretrained(model_name)
    self.model = DistilBertForQuestionAnswering.from_pretrained(model_name)
    self.stop_words = set(stopwords.words('english'))
  def preprocess_text(self, text):
    tokens = word_tokenize(text)
    tokens = [word.lower() for word in tokens if word.isalnum() and word.lower() not in self.stop_words]
    return ' '.join(tokens)
  def answer_question(self, question, context):
    preprocessed_question = self.preprocess_text(question)
    preprocessed_context = self.preprocess_text(context)
    inputs = self.tokenizer(preprocessed_question, preprocessed_context, return_tensors='pt', max_length=512, truncation=True)
    outputs = self.model(**inputs)
    start_logits = outputs.start_logits
    end_logits = outputs.end_logits
    start_index = torch.argmax(start_logits)
    end_index = torch.argmax(end_logits)
    answer_tokens = inputs['input_ids'][0][start_index:end_index+1]
    answer = self.tokenizer.decode(answer_tokens)
    return answer


if __name__ == "__main__":
  qa_system = KnowledgeBasedQASystem()
  while True:
    question = input("Enter your question: ")
    context = input("Enter the context: ")
    answer = qa_system.answer_question(question, context)
    print("Answer:", answer)

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


Answer: paris
Answer: william shakespeare
Answer: h2o
Answer: china
