In [32]:
!pip install transformers



In [11]:
import torch

#**Using Transformers Class to directly import model**

In [12]:
from transformers import AutoTokenizer, AutoModelForQuestionAnswering

In [13]:
class QASystem:
    def __init__(self):
        # Load the pre-trained QA model with a larger maximum sequence length

        # self.tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased-distilled-squad', truncation=True, max_length=512)
        # self.model = AutoModelForQuestionAnswering.from_pretrained('distilbert-base-uncased-distilled-squad', trust_remote_code=True)


        self.tokenizer = AutoTokenizer.from_pretrained('./distilbert', truncation=True, max_length=512)
        self.model = AutoModelForQuestionAnswering.from_pretrained('./distilbert', trust_remote_code=True)

    def extract_answer(self, question, context):
        # Tokenize the input with the larger maximum sequence length and truncation enabled
        inputs = self.tokenizer.encode_plus(question, context, add_special_tokens=True, max_length=512, truncation=True, return_tensors='pt')

        # Get the model predictions
        start_logits, end_logits = self.model(**inputs).values()

        # Find the start and end positions of the answer
        start_index = torch.argmax(start_logits, dim=1).item()
        end_index = torch.argmax(end_logits, dim=1).item()

        # Convert the token indices to actual answer
        tokens = self.tokenizer.convert_ids_to_tokens(inputs['input_ids'].tolist()[0])
        answer = ' '.join(tokens[start_index:end_index+1]).replace(' ##', '')

        return answer


In [14]:
qa_system = QASystem()

In [15]:
context = """
Transformers is backed by the three most popular deep learning libraries — Jax, PyTorch and TensorFlow — with a seamless integration
between them. It's straightforward to train your models with one before loading them for inference with the other.
"""
question = "Which deep learning libraries back Transformers?"

In [19]:
answer = qa_system.extract_answer(question, context)

In [20]:
print(answer)

jax , pytorch and tensorflow


#**Using Pipeline Method**

In [18]:
from transformers import pipeline

# Replace this with your own checkpoint

model_checkpoint = "deepset/roberta-base-squad2"
# model_checkpoint = "./distilbert"
question_answerer = pipeline("question-answering", model=model_checkpoint)


question_answerer(question=question, context=context)

{'score': 0.9649538993835449,
 'start': 76,
 'end': 103,
 'answer': 'Jax, PyTorch and TensorFlow'}