In [1]:
from __future__ import annotations

import re
from dataclasses import dataclass
from typing import Tuple, List
import pandas as pd

import gradio as gr
import requests
import xmltodict
from transformers import AutoModelForQuestionAnswering, AutoTokenizer, pipeline
from transformers.pipelines.question_answering import QuestionAnsweringPipeline
from PyPDF2 import PdfReader

QA_MODEL_NAME = "ixa-ehu/SciBERT-SQuAD-QuAC"
TEMP_PDF_PATH = "/tmp/arxiv_paper.pdf"

@dataclass
class PaperMetaData:
    paper_id: str
    title: str
    abstract: str
    text: str

    @staticmethod
    def _clean_field(text: str) -> str:
        text = re.sub(r"\n", " ", text)
        text = re.sub(r"\s+", " ", text)
        return text

    @classmethod
    def from_dataset(cls, paper_id: str, title: str, abstract: str) -> PaperMetaData:
        return PaperMetaData(
            paper_id=paper_id,
            title=cls._clean_field(title),
            abstract=cls._clean_field(abstract),
            text="",  # Placeholder for text, will be filled later
        )

class ScrapedPaper:
    def __init__(self, paper_id: str, title: str, abstract: str, url: str):
        self.paper_id = paper_id
        self.title = title
        self.abstract = abstract
        self.url = url

    def _download_pdf(self, download_path: str) -> None:
        pdf_r = requests.get(self.url)
        pdf_r.raise_for_status()
        with open(download_path, "wb") as pdf_file:
            pdf_file.write(pdf_r.content)

    def extract_text(self) -> str:
        self._download_pdf(TEMP_PDF_PATH)
        reader = PdfReader(TEMP_PDF_PATH)
        pdf_text = " ".join([page.extract_text() for page in reader.pages])
        return pdf_text

    def get_paper_full_data(self) -> PaperMetaData:
        text = self.extract_text()
        return PaperMetaData.from_dataset(paper_id=self.paper_id, title=self.title, abstract=self.abstract, text=text)

def load_dataset(file_path: str) -> List[ScrapedPaper]:
    df = pd.read_csv(file_path)
    papers = []
    for _, row in df.iterrows():
        paper = ScrapedPaper(paper_id=row['paper_id'], title=row['title'], abstract=row['abstract'], url=row['url'])
        papers.append(paper)
    return papers

def get_paper_data(papers: List[ScrapedPaper], paper_id: str) -> Tuple[str, str, str]:
    for paper in papers:
        if paper.paper_id == paper_id:
            return paper.title, paper.abstract, paper.extract_text()
    return "", "", ""

def get_qa_pipeline(qa_model_name: str = QA_MODEL_NAME) -> QuestionAnsweringPipeline:
    tokenizer = AutoTokenizer.from_pretrained(qa_model_name)
    model = AutoModelForQuestionAnswering.from_pretrained(qa_model_name)
    qa_pipeline = pipeline("question-answering", model=model, tokenizer=tokenizer)
    return qa_pipeline

def get_answer(question: str, context: str) -> str:
    qa_pipeline = get_qa_pipeline()
    prediction = qa_pipeline(question=question, context=context)
    return prediction["answer"]

# Load the dataset
dataset_file_path = "your_dataset.csv"  # Replace with the actual file path
papers = load_dataset(dataset_file_path)

# Example usage
paper_id = "your_paper_id_here"
title, abstract, text = get_paper_data(papers, paper_id)


ModuleNotFoundError: No module named 'xmltodict'