In [2]:
#!pip install lxml
#!pip install PyMuPDF
# Uncomment and run this line once to download the model
#!python -m spacy download en_core_web_sm
#!pip install bitsandbytes
#!pip install flash_attn
#!pip install xformers


In [3]:
# --- 0. Environment Setup & Offline Preparation ---

# Standard Imports
import os
import re
import pandas as pd
import fitz # PyMuPDF is imported as 'fitz'
import lxml.etree as etree
from lxml.etree import _Element as Element # Type hinting for lxml.etree.Element
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForTokenClassification
from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils import PreTrainedTokenizer
from transformers.utils.quantization_config import BitsAndBytesConfig
from transformers.training_args import TrainingArguments
import torch
import kagglehub
import spacy
import json
from dataclasses import dataclass, field, asdict
from typing import List, Dict, Union
from tqdm.auto import tqdm

os.environ["TRANSFORMERS_VERBOSITY"] = "error"

# Set device for PyTorch
device = "cuda" if torch and torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")


Using device: cuda


In [4]:
# Define constants for file paths and model configurations
BASE_INPUT_DIR = './kaggle/input/make-data-count-finding-data-references'
ARTICLE_TRAIN_DIR = os.path.join(BASE_INPUT_DIR, 'train')
ARTICLE_TEST_DIR = os.path.join(BASE_INPUT_DIR, 'test')

# Define directories for articles in train and test sets
LABELED_TRAINING_DATA_CSV_PATH = os.path.join(BASE_INPUT_DIR, 'train_labels.csv')

# Define the base model path
#MODEL_PATH = kagglehub.model_download("qwen-lm/qwen-3/transformers/0.6b")
MODEL_PATH = kagglehub.model_download("qwen-lm/qwen-3/transformers/1.7b")
#MODEL_PATH = kagglehub.model_download("google/gemma-3n/transformers/gemma-3n-e2b-it")
#T5GEMMA_PATH = kagglehub.model_download("google/t5gemma/transformers/t5gemma-2b-2b-ul2-it")

# Output directory for the fine-tuned model and results
BASE_OUTPUT_DIR = "./kaggle/working"
SUBMISSION_CSV_PATH = os.path.join(BASE_OUTPUT_DIR, "submission.csv")

# Load a spaCy model (e.g., 'en_core_web_sm')
# python -m spacy download en_core_web_sm 
try:
    NLP_SPACY = spacy.load("en_core_web_sm")
except OSError:
    !python -m spacy download en_core_web_sm 

## Data Extraction

In [5]:

@dataclass
class DatasetCitation:
    dataset_ids: List[str]
    citation_context: str


@dataclass
class ArticleData:
    article_id: str = ""
    article_doi: str = ""
    title: str = ""
    author: str = ""
    abstract: str = ""
    # datasets: List[str] = field(default_factory=list)
    data_availability: str = ""
    other_dataset_citations: List[DatasetCitation] = field(default_factory=list)

    def __post_init__(self):
        # Custom initialization
        if self.article_id and not self.article_doi:
            # If article_id is provided but not article_doi, set article_doi
            self.article_doi = self.article_id.replace("_", "/").lower()

    def add_dataset_citation(self, dataset_citation: DatasetCitation):
        """Adds a dataset citation to the article."""
        self.other_dataset_citations.append(dataset_citation)
        
    def to_dict(self):
        return asdict(self)

    def to_json(self):
        return json.dumps(self.to_dict(), separators=(',', ':'))

    def has_data(self) -> bool:
        """Returns True if there is any data availability or dataset citation."""
        return bool(self.data_availability or self.other_dataset_citations)

In [6]:
article = ArticleData(article_id="10.1234_example_article")
article.article_doi

'10.1234/example/article'

### Common dataset identifiers

In [7]:
# --- 2. Information Extraction (IE) - Dataset Identification ---
NON_STD_UNICODE_DASHES = re.compile(r'[\u2010\u2011\u2012\u2013\u2014]')
NON_STD_UNICODE_TICKS = re.compile(r'[\u201c\u201d]')

def clean_text(text: str) -> str:
    """
    Clean the input text by removing non-standard unicode dashes and extra whitespace.
    
    Args:
        text (str): The text to clean.
        
    Returns:
        str: The cleaned text.
    """
    if not text:
        return ""
    # Replace all non-standard unicode dashes with '-'
    text = text.replace('\u200b', '').replace('-\n', '-').replace('_\n', '_').replace('/\n', '/')
    text = NON_STD_UNICODE_DASHES.sub('-', text)
    text = NON_STD_UNICODE_TICKS.sub("'", text)
    # Remove extra whitespace
    return re.sub(r'\s+', ' ', text).strip()

# Regex patterns for common dataset identifiers
# DOI_PATTERN = r'10\.\d{4,5}/[-._;()/:A-Za-z0-9\u002D\u2010\u2011\u2012\u2013\u2014\u2015]+'	DOI_PATTERN
# DOI_PATTERN = r'10\.\s?\d{4,5}\/[-._()<>;\/:A-Za-z0-9]+\s?(?:(?![A-Z]+)(?!\d{1,3}\.))+[-._()<>;\/:A-Za-z0-9]+'
#DOI_PATTERN = r'\bhttps://doi.org/10\.\d{4,5}\/[-._\/:A-Za-z0-9]+'
DOI_PATTERN = r'\b10\.\d{4,5}\/[-._\/:A-Za-z0-9]+'
EPI_PATTERN = r'\bEPI[-_A-Z0-9]{2,}'
SAM_PATTERN = r'\bSAMN[0-9]{2,}'          # SAMN07159041
IPR_PATTERN = r'\bIPR[0-9]{2,}'
CHE_PATTERN = r'\bCHEMBL[0-9]{2,}'
PRJ_PATTERN = r'\bPRJ[A-Z0-9]{2,}'
E_G_PATTERN = r'\bE-[A-Z]{4}-[0-9]{2,}'   # E-GEOD-19722 or E-PROT-100
ENS_PATTERN = r'\bENS[A-Z]{4}[0-9]{2,}'
CVC_PATTERN = r'\bCVCL_[A-Z0-9]{2,}'
EMP_PATTERN = r'\bEMPIAR-[0-9]{2,}'
PXD_PATTERN = r'\bPXD[0-9]{2,}'
HPA_PATTERN = r'\bHPA[0-9]{2,}'
SRR_PATTERN = r'\bSRR[0-9]{2,}'
GSE_PATTERN = r'\b(GSE|GSM|GDS|GPL)\d{4,6}\b' # Example for GEO accession numbers (e.g., GSE12345, GSM12345)
GNB_PATTERN = r'\b[A-Z]{1,2}\d{5,6}\b' # GenBank accession numbers (e.g., AB123456, AF000001)
CAB_PATTERN = r'\bCAB[0-9]{2,}'

# Combine all patterns into a list
DATASET_ID_PATTERNS = [
    DOI_PATTERN,
    EPI_PATTERN,
    SAM_PATTERN,
    IPR_PATTERN,
    CHE_PATTERN,
    PRJ_PATTERN,
    E_G_PATTERN,
    ENS_PATTERN,
    CVC_PATTERN,
    EMP_PATTERN,
    PXD_PATTERN,
    HPA_PATTERN,
    SRR_PATTERN,
    GSE_PATTERN,
    GNB_PATTERN,
    CAB_PATTERN,
]

# Compile all patterns for efficiency
COMPILED_DATASET_ID_REGEXES = [re.compile(p) for p in DATASET_ID_PATTERNS]

# Data related keywords to look for in the text
# These keywords help to ensure that the text is relevant to datasets
DATA_RELATED_KEYWORDS = ['data release', 'data associated', 'data availability', 'data access', 'download', 'program data', 'the data', 'dataset', 'database', 'repository', 'data source', 'data access', 'archive', 'arch.', 'digital']

def is_text_data_related(text: str) -> bool:
    if not text:
        return False
    
    text_lower = text.lower()
    return any(keyword in text_lower for keyword in DATA_RELATED_KEYWORDS)

def text_has_dataset_id(text: str) -> bool:
    """
    Check if the given text contains any dataset identifier.
    
    Args:
        text (str): The text to check for dataset identifiers.
        
    Returns:
        bool: True if any dataset identifier is found, False otherwise.
    """

    occurrences_with_context: list[str] = []
    for regex in COMPILED_DATASET_ID_REGEXES:
        if regex.search(text):
            text_lower = text.lower()
            # Check for specific keywords in the text
            if any(keyword in text_lower for keyword in DATA_RELATED_KEYWORDS):
                return True
    return False

def extract_dataset_ids(text: str, context_chars: int = 250) -> list[dict[str, list[str] | str]]:
    """
    Extract dataset identifiers with context from the given text.
    
    Args:
        text (str): The text to search for dataset identifiers.
        context_chars (int): Number of characters to include before and after the match for context.
        
    Returns:
        list[str]: A list of extracted dataset identifiers with context.
    """
    text = clean_text(text)
    is_small_context = len(text) < context_chars * 2
    dataset_ids: list[str] = []
    occurrences_with_context: list[dict[str, list[str] | str]] = []
    if is_text_data_related(text):
        for regex in COMPILED_DATASET_ID_REGEXES:
            matches = regex.finditer(text, re.IGNORECASE)
            for match in matches:
                dataset_id = text[match.start() : match.end()]
                if is_small_context:
                    dataset_ids.append(dataset_id)
                else:
                    citation_context = text[max(0, match.start() - context_chars): match.end() + context_chars ]
                    citation_context = citation_context.replace('\n', '').replace('[', '').replace(']', '')
                    citation_context = re.sub(r'\s+', ' ', citation_context).strip()
                    if is_text_data_related(citation_context):
                        occurrences_with_context.append({"dataset_ids": [dataset_id], "citation_context": citation_context})
                        #occurrences_with_context.append("{" + f'"dataset_ids": {[dataset_id]}, citation_context: "{citation_context}"' + "}")
        if dataset_ids:
            occurrences_with_context.append({"dataset_ids": dataset_ids, "citation_context": text})
    
    # # If no occurrences found, return an empty string
    # # Otherwise, join the occurrences with a specific separator
    # if not occurrences_with_context:
    #     return ""
    # return json.dumps(occurrences_with_context, separators=(',', ':'))
    return occurrences_with_context

# Use NLP to get sentences from the given text

def get_sentences_from_text(text: str, nlp=NLP_SPACY) -> str:
    if not text:
        return ""
    
    # Replace all non-standard unicode dashes with '-'
    text = clean_text(text)
    text = text.replace('\n', ' ').strip()
    doc_spacy = nlp(text)
    return " ".join([sent.text for sent in doc_spacy.sents])

In [8]:
s = "Seventy\u2010eight pleural effusions. CT scans acquired from The Cancer \nImaging Archive \u201cNSCLC Radiomics\u201d data collection. All expert\u2010vetted segmentations are publicly available in NIfTI format through The Cancer Imaging Archive at  https://doi.org/10.7937/tcia.2020.6c7y\u2010gq39"
print(f"Cleaned text: {get_sentences_from_text(s)}")

Cleaned text: Seventy-eight pleural effusions. CT scans acquired from The Cancer Imaging Archive 'NSCLC Radiomics' data collection. All expert-vetted segmentations are publicly available in NIfTI format through The Cancer Imaging Archive at https://doi.org/10.7937/tcia.2020.6c7y-gq39


### XML Element Extraction

In [9]:

def extract_element_text(element: Element | None) -> str:
    if element is not None:
        # Use itertext() to get all text content from the <p> tag and its descendants
        # and join them into a single string.
        all_text = " ".join(element.itertext(tag=None)).replace('\u200b', '').strip()
        return all_text[:2000]
    else:
        return ""
    
def extract_next_sibling_text(elements: list[Element] | None, sibling_xpath: str) -> str:
    """
    Extracts text from the next sibling of the given XML element.
    
    Args:
        element (Element | None): The XML element whose next sibling's text is to be extracted.
        sibling_xpath (str): The XPath expression to find the next sibling element. (eg. "following-sibling::passage[1]")
        
    Returns:
        str: A string containing the text from the next sibling element, or an empty string if no sibling exists.

    """
    # Check if the provided elements list is None or empty
    if not elements:
        return ""
    
    # Assuming there's only one such element, take the first one found
    # and find the element immediately following based on the given sibling_xpath.
    first_element = elements[0]
    sibling_elements = first_element.xpath(sibling_xpath)

    if not sibling_elements:
        # print("DEBUG: No following <passage> element found.") # Uncomment for debugging
        return ""
    
    next_sibling = sibling_elements[0]
    if next_sibling is None:
        return ""
    
    return extract_element_text(next_sibling)

def extract_elements_text(elements: list[Element] | None, sep: str = " ") -> str:
    elements_text = []
    if elements is None:
        return ""
    
    for element in elements:
        text = extract_element_text(element)
        if text:
            elements_text.append(text)

    return sep.join(elements_text).strip()

def extract_elements_text_from_xpath_list(root: Element | None, xpath_list: list[str], ns: dict[str, str] | None = None) -> str:
    elements_text = ""
    if root is None or not xpath_list:
        return ""
    
    for xpath in xpath_list:
        element = root.find(xpath, namespaces=ns)
        elements_text += extract_element_text(element)
    return elements_text

def extract_text_from_elements_within_element(element: Element | None, child_xpaths: list[str] = [], ns: dict[str, str] | None = None) -> str:
    """
    Extracts text from elements within a given XML element that match the specified tag names.
    
    Args:
        element (Element | None): The XML element to search within.
        tag_names (list[str]): A list of tag names to search for.
        
    Returns:
        str: A string containing the extracted text from the matching elements.
    """
    if element is None:
        return ""
    
    if not child_xpaths:
        # If no child tag names are provided, return the text of the element itself
        return extract_element_text(element)
    
    extracted_text = []
    for xpath in child_xpaths:
        for child in element.findall(xpath, namespaces=ns):
            text = extract_element_text(child)
            if text:
                extracted_text.append(text)
    
    return "|".join(extracted_text)

def extract_data_related_elements_text(elements: list[Element] | None, child_xpaths: list[str] = [], ns: dict[str, str] | None = None) -> list[dict[str, list[str] | str]]:
    elements_text = []
    if elements is None:
        return elements_text
    
    for element in elements:
        text = extract_dataset_ids(extract_text_from_elements_within_element(element, child_xpaths, ns))
        if text:
            elements_text.extend(text)

    return elements_text

def extract_data_related_elements_text_from_xpath_list(root: Element | None, xpath_list: list[str], ns: dict[str, str] | None = None) -> list[dict[str, list[str] | str]]:
    """
    Extracts text from elements in the XML tree that match the provided XPath expressions.
    
    Args:
        root (Element | None): The root element of the XML tree.
        xpath_list (list[str]): A list of XPath expressions to search for elements.
        
    Returns:
        list[str]: A list of extracted text from the matching elements.
    """
    elements_text = []
    if root is None or not xpath_list:
        return elements_text
    
    for xpath in xpath_list:
        primary_xpath, *child_xpath_text = xpath.split('||')
        child_xpaths = child_xpath_text[0].split(',') if child_xpath_text else []
        elements = root.findall(primary_xpath, namespaces=ns)
        if elements:
            elements_text.extend(extract_data_related_elements_text(elements, child_xpaths, ns))
    return elements_text




### PDF File Extraction

In [10]:

def extract_author_names(full_text: str, nlp=NLP_SPACY) -> str:
    """
    Extracts potential author names from the beginning of a research article's text
    using spaCy's Named Entity Recognition. It attempts to isolate the author section
    and applies heuristics to filter out non-author entities.

    Args:
        full_text (str): The complete text content of the research article,
                         typically extracted from a PDF.

    Returns:
        List[str]: A list of unique strings, each representing a potential author name,
                   sorted alphabetically. Returns an empty list if no authors are found.
    """
    if not full_text or not full_text.strip():
        return ""

    full_text = full_text.replace('1\n,', ',').replace('1,', ',').replace('\u2019', "'")

    # 1. Isolate the potential author section
    # Authors are typically at the very beginning, before the abstract or introduction.
    # We'll search for common section headers to define the end of the author block.
    # Using regex for case-insensitive search and handling various newline/spacing.
    header_patterns = [
        r"\n\s*Abstract\s*\n",
        r"\n\s*Introduction\s*\n",
        r"\n\s*Summary\s*\n",
        r"\n\s*Keywords\s*\n",
        r"\n\s*Graphical Abstract\s*\n",
        r"\n\s*1\.\s*Introduction\s*\n", # Common for numbered sections
        r"\n\s*DOI:\s*\n" # Sometimes DOI appears before abstract
    ]

    author_section_end_index = len(full_text)
    for pattern in header_patterns:
        match = re.search(pattern, full_text, re.IGNORECASE)
        if match:
            # Take text up to the start of the found header
            author_section_end_index = min(author_section_end_index, match.start())
            break
    
    # As a fallback or if no header is found early, limit the search to the first
    # 2500 characters. This prevents processing the entire document for authors.
    author_section_text = full_text[:min(author_section_end_index, 2500)]

    if not author_section_text.strip():
        return ""

    # 2. Process the isolated author section with spaCy
    doc = nlp(author_section_text)

    # 3. Extract PERSON entities and apply initial filtering
    potential_authors: list[str] = []
    for ent in doc.ents:
        if ent.label_ == "PERSON":
            name = ent.text.strip()
            # Basic filtering to reduce false positives:
            # - Exclude very short strings (e.g., single letters, common conjunctions)
            # - Exclude common stop words (e.g., "The", "And")
            # - Exclude all-uppercase strings that might be acronyms (e.g., "WHO", "NASA")
            # - Ensure it contains at least one space (e.g., "John Doe") or is a capitalized
            #   single word that's longer than 2 characters (e.g., "Smith").
            if (len(name) > 1 and
                name.lower() not in nlp.Defaults.stop_words and
                not name.isupper() and
                (' ' in name or (name[0].isupper() and len(name) > 2))):
                
                potential_authors.append(name)

    # 4. Apply more advanced heuristics to filter out non-author names
    # This step is crucial for accuracy and often requires tuning.
    filtered_authors = []
    for author in potential_authors:
        # Heuristic 1: Filter out names that contain common affiliation keywords.
        # This is a simple check; more robust solutions might use spaCy's dependency
        # parsing to check if a PERSON entity is part of an ORG entity.
        affiliation_keywords = ["univ", "observ", "institute", "department", "center", "lab",
                                "hospital", "college", "school", "inc.", "ltd.", "company",
                                "corp.", "group", "foundation", "research"]
        if any(keyword in author.lower() for keyword in affiliation_keywords):
            continue # Skip if it looks like an affiliation

        # Heuristic 2: Filter out names that contain email patterns or ORCID patterns.
        if '@' in author or re.search(r'\b\d{4}-\d{4}-\d{4}-\d{3}[\dX]\b', author):
            continue # Skip if it contains an email or ORCID

        # Heuristic 3: Filter out names that are likely just initials or very short.
        # This is partially covered by initial filtering, but can be refined.
        # E.g., "J. D." might be an author, but "J." alone is unlikely.
        if len(author.split()) == 1 and len(author) <= 2 and author.isupper():
            continue # Skip single-letter or two-letter uppercase (e.g., "JD")

        filtered_authors.append(author)

    # Convert to list and sort for consistent output
    return filtered_authors[0] if filtered_authors else ""

def extract_pdf_doc_text(pdf_doc: fitz.Document)  -> dict[str, str | list[dict[str, list[str] | str]]]:
    """
    Extracts all text from a PDF document using PyMuPDF.
    
    Args:
        pdf_doc (fitz.Document): The PDF document to extract text from.
        
    Returns:
        str: A JSON string of the article_dict containing specific elements extracted from the PDF.
    """

    # Initialize the article dictionary with empty strings
    article_dict = {
        'title': '',
        'author': '',
        'abstract': '',
        'data_availability': '',
        'other_dataset_citations': []
    }

    # Initialize variables for text extraction
    p1 = None  # Placeholder for the first page text
    other_dataset_citations = []
    for page in pdf_doc:
        # Extract text from the page
        textpage = page.get_textpage()
        if page.number == 0:
            p1_txt = textpage.extractTEXT()
            p1 = get_sentences_from_text(p1_txt)
            p1 = p1[:int(len(p1)/2)]
            article_dict['author'] = extract_author_names(p1_txt, nlp=NLP_SPACY)

        # Extract text from all blocks that have an abstract or dataset id's
        blocks = textpage.extractBLOCKS()
        for block in blocks:
            block_text = get_sentences_from_text(block[4])
            block_text_lower = block_text.lower()
            if page.number == 0 and len(block_text) > 100 and "abstract" in block_text_lower:
                # Add the abstract block text to the article dictionary
                article_dict['abstract'] = block_text
            elif "data availability" in block_text_lower or "data accessibility" in block_text_lower or "acknowledgments" in block_text_lower:
                # Add the data availability block text to the article dictionary
                article_dict['data_availability'] = block_text
            else:
                context_chars = min(250, len(block_text))  # Use a minimum
                dataset_ids_found = extract_dataset_ids(block_text, context_chars)  # Extract dataset IDs from the block text
                if dataset_ids_found:
                    # print(f"DEBUG: Found dataset IDs in block: {dataset_ids_found}")  # Debugging output
                    # print(f"DEBUG: block_text: {block_text}")  # Debugging output
                    if len(article_dict['data_availability']) > 0 and len(article_dict['data_availability']) < 25:
                        # If data availability text is only a few characters, append the next block text to it
                        # This is a heuristic to ensure that we capture relevant dataset IDs
                        article_dict['data_availability'] = block_text
                    else:
                        # Append the dataset IDs found in the block to the other_dataset_citations
                        other_dataset_citations.extend(dataset_ids_found)

    article_dict['other_dataset_citations'] = other_dataset_citations if other_dataset_citations else []
    
    # If an abstract was not found, use the first page text as the abstract
    if not article_dict['abstract'] and p1:
        article_dict['abstract'] = p1

    # Return the article dictionary as a JSON string
    return article_dict


### XML File Extraction

In [11]:

def extract_xml_text_jats(root: Element) -> dict[str, str | list[dict[str, list[str] | str]]]:
    # Find the title, abstract, and data availablity info for Journal Archiving and Interchange DTD (JATS)
    # The ".//" ensures it searches anywhere in the document, not just direct children of root.
    ns = None  # No namespaces for JATS

    xpath_title = ".//article-title"
    xpath_authors_1 = ".//contrib-group/contrib[@contrib-type='author']/name"
    xpath_authors_2 = ".//biblstruct/analytic/author[@role='corresp']/persname"
    author = extract_element_text(root.find(xpath_authors_1, namespaces=ns))
    if not author:
        author = extract_element_text(root.find(xpath_authors_2, namespaces=ns))
    xpath_abstract = ".//abstract"
    xpath_data_avails = [".//notes[@notes-type='data-availability']", ".//sec[@sec-type='data-availability']"]
    xpath_citations = [".//element-citation||.article-title,.source,.pub-id", ".//mixed-citation"]  # List of XPath expressions for citations

    return {
        'title': extract_element_text(root.find(xpath_title, ns)),
        'author': author,
        'abstract': get_sentences_from_text(extract_element_text(root.find(xpath_abstract, ns))),
        'data_availability': extract_elements_text_from_xpath_list(root, xpath_data_avails, ns=ns),
        'other_dataset_citations': extract_data_related_elements_text_from_xpath_list(root, xpath_citations, ns=ns),
    }

def extract_xml_text_tei(root: Element) -> dict[str, str | list[dict[str, list[str] | str]]]:
    # Find the title, abstract, and data availability info for Text Encoding Initiative (TEI)
    # Set the namespace for TEI
    ns = {'tei': 'http://www.tei-c.org/ns/1.0'}

    xpath_title = ".//tei:title"
    xpath_authors = ".//tei:sourcedesc/tei:biblstruct/tei:analytic/tei:author/tei:persname"
    xpath_abstract = ".//tei:abstract"
    xpath_data_avail = "" #".//tei:biblstruct"
    xpath_citations = [".//tei:biblstruct||.//tei:title,.//tei:idno,.//tei:notes"]  # List of XPath expressions for citations
        
    return {
        'title': extract_element_text(root.find(xpath_title, namespaces=ns)),
        'author': extract_element_text(root.find(xpath_authors, namespaces=ns)),
        'abstract': get_sentences_from_text(extract_element_text(root.find(xpath_abstract, namespaces=ns))),
        'data_availability': xpath_data_avail,  # No direct extraction for TEI data_availability
        'other_dataset_citations': extract_data_related_elements_text_from_xpath_list(root, xpath_citations, ns=ns),
    }

def extract_xml_text_wiley(root: Element) -> dict[str, str | list[dict[str, list[str] | str]]]:
    # Find the title, abstract, and data availability info for Wiley XML format
    # Set the namespace for Wiley
    ns = {'ns': 'http://www.wiley.com/namespaces/wiley'}

    xpath_title = ".//ns:publicationMeta[@level='part']/ns:titleGroup"    #<publicationMeta level="part"><titleGroup><title type="main">
    xpath_authors = ".//selfCitationGroup/citation[@type='self']/author"
    xpath_abstract = ".//ns:abstract[@type='main']"  #<abstract type="main"
    xpath_data_avail = ".//ns:section[@type='dataAvailability']"  #<section numbered="no" type="dataAvailability"
    xpath_citations = [".//ns:citation||.//ns:articleTitle,.//ns:journalTitle,.//ns:url"]  # List of XPath expressions for citations
        
    return {
        'title': extract_elements_text(root.findall(xpath_title, namespaces=ns)),
        'author': extract_element_text(root.find(xpath_authors, namespaces=ns)),
        'abstract': get_sentences_from_text(extract_element_text(root.find(xpath_abstract, namespaces=ns))),
        'data_availability': extract_element_text(root.find(xpath_data_avail, namespaces=ns)),
        'other_dataset_citations': extract_data_related_elements_text_from_xpath_list(root, xpath_citations, ns=ns),
    }

def extract_xml_text_biorxiv(root: Element) -> dict[str, str | list[dict[str, list[str] | str]]]:
    # Find the title, abstract, and data availability info for BioRxiv XML format
    # Set the namespace for BioRxiv
    ns = {'biorxiv': 'http://www.biorxiv.org'}

    xpath_title = ".//biorxiv:title"
    xpath_authors = ".//biorxiv:contrib[@contrib-type='author']/biorxiv:name"
    xpath_abstract = ".//biorxiv:abstract"
    xpath_data_avail = ".//biorxiv:sec[@sec-type='data-availability']"  #<sec sec-type="data-availability"
    xpath_citations = [".//biorxiv:biblio||.//biorxiv:title,.//biorxiv:source,.//biorxiv:pub-id"]  # List of XPath expressions for citations
        
    return {
        'title': extract_element_text(root.find(xpath_title, namespaces=ns)),
        'author': extract_element_text(root.find(xpath_authors, namespaces=ns)),
        'abstract': get_sentences_from_text(extract_element_text(root.find(xpath_abstract, namespaces=ns))),
        'data_availability': extract_element_text(root.find(xpath_data_avail, namespaces=ns)),
        'other_dataset_citations': extract_data_related_elements_text_from_xpath_list(root, xpath_citations, ns=ns),
    }

def extract_xml_text_bioc(root: Element) -> dict[str, str | list[dict[str, list[str] | str]]]:
    # Find the title, abstract, and data availability info for BioC-API XML format
    ns = None  # No namespaces for BioC

    xpath_title = "string(.//passage[infon[@key='section_type' and text()='TITLE']]/text)"
    xpath_authors = "string(.//infon[@key='name_0'] | .//infon[@key='name_1'])"
    xpath_abstract = "string(.//passage[infon[@key='section_type' and text()='ABSTRACT']]/text)"
    xpath_data_avail = ".//passage[text[text()='DATA ACCESSIBILITY:']]"
    xpath_data_avail_sibling = "following-sibling::passage[1]"
    xpath_citations = []
        
    return {
        'title': root.xpath(xpath_title, namespaces=ns),
        'author': root.xpath(xpath_authors, namespaces=ns).strip().replace('surname:', '').replace(';given-names:', ' '),
        'abstract': get_sentences_from_text(root.xpath(xpath_abstract, namespaces=ns)[:2000]),  # Limit to 2000 characters
        'data_availability': extract_next_sibling_text(root.xpath(xpath_data_avail, namespaces=ns), xpath_data_avail_sibling),
        'other_dataset_citations': xpath_citations,
    }

def extract_xml_text_taxonx(root: Element) -> dict[str, str | list[dict[str, list[str] | str]]]:
    # Find the title, abstract, and data availability info for TaxonX format
    ns = None  # No namespaces for Taxonomic Treatment Publishing DTD

    xpath_title = "string(.//article-meta/title-group/article-title)"
    xpath_authors = ""
    xpath_abstract = "string(.//article-meta/abstract)"
    xpath_data_avail = ""
    xpath_citations = []
        
    return {
        'title': root.xpath(xpath_title, namespaces=ns),
        'author': xpath_authors,
        'abstract': get_sentences_from_text(root.xpath(xpath_abstract, namespaces=ns)[:2000]),  # Limit to 2000 characters
        'data_availability': xpath_data_avail,  # No direct extraction for TaxonX data_availability
        'other_dataset_citations': xpath_citations,
    }


## File Processing

In [12]:

# Dictionary mapping XML types to their respective extraction functions
XML_TYPE_EXTRACTORS = {
    'jats': extract_xml_text_jats,
    'tei': extract_xml_text_tei,
    'wiley': extract_xml_text_wiley,
    'bioc': extract_xml_text_bioc,
    'taxonx': extract_xml_text_taxonx,
}

# --- Data Loading ---
def get_file_extension(file_path: str) -> str:
    """
    Returns the file extension of the given file path.
    
    Args:
        file_path (str): The path to the file.
        
    Returns:
        str: The file extension, or an empty string if no extension is found.
    """
    _, ext = os.path.splitext(file_path)
    return ext.lower() if ext else ""

def read_first_line_of_xml(file_path: str) -> str | None:
    """
    Reads and returns the first line of an XML file.

    Args:
        file_path (str): The path to the XML file.

    Returns:
        str | None: The first line of the file, stripped of leading/trailing whitespace,
                    or None if the file cannot be read or is empty.
    """
    if not file_path and not os.path.exists(file_path):
        return None
    
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            first_line = f.readline().replace('<?xml version="1.0" encoding="UTF-8"?>', '').replace('<?xml version="1.0" encoding="UTF-8" standalone="yes"?>', '').strip()
            # If the first line is empty, read the next line
            if not first_line:
                first_line = f.readline()
            return first_line.strip()[:90] if first_line else None
    except UnicodeDecodeError:
        try:
            with open(file_path, 'r', encoding='iso-8859-1') as f:
                first_line = f.readline().replace('<?xml version="1.0" encoding="UTF-8"?>', '').replace('<?xml version="1.0" encoding="UTF-8" standalone="yes"?>', '').strip()
                if not first_line:
                    first_line = f.readline()
                return first_line.strip()[:90] if first_line else None
        except Exception as e:
            return None
    except Exception as e:
        print(f"Error reading file '{file_path}': {e}")
        return None
    
def identify_xml_type(first_line: str) -> str:
    """
    Identifies the XML type based on the first line of the XML file.
    
    Args:
        file_path (str): The path to the XML file.
        
    Returns:
        str: The identified XML type ('jats', 'tei', 'wiley', 'bioc', or 'unknown').
    """
    if not first_line:
        return "unknown"
    first_line_lower = first_line.lower()
    # Check for specific patterns in the first line
    if 'journal archiving and interchange dtd' in first_line_lower:
        return "jats"
    elif 'xmlns="http://www.tei-c.org/ns/1.0"' in first_line_lower:
        return "tei"
    elif 'xmlns="http://www.wiley.com/namespaces/wiley"' in first_line_lower:
        return "wiley"
    elif 'bioc.dtd' in first_line_lower or 'bioc-api' in first_line_lower:
        return "bioc"
    elif 'taxonomic treatment publishing dtd' in first_line_lower:
        return "taxonx"
    
    return "unknown"    

def get_xml_type(file_path: str) -> str:
    """
    Determines the XML type of a file based on its first line.
    
    Args:
        file_path (str): The path to the XML file.
        
    Returns:
        str: The identified XML type ('jats', 'tei', 'wiley', 'bioc', 'taxonx', or 'unknown').
    """
    first_line = ""
    if ".xml" == get_file_extension(file_path):
        # If the file is an XML file, read the first line and identify the type
        first_line = read_first_line_of_xml(file_path)
    return identify_xml_type(first_line) if first_line else "unknown"

def load_file_paths(dataset_type_dir: str) -> pd.DataFrame: 
    pdf_path = os.path.join(dataset_type_dir, 'PDF')
    xml_path = os.path.join(dataset_type_dir, 'XML')
    dataset_type = os.path.basename(dataset_type_dir)
    pdf_files = [f for f in os.listdir(pdf_path) if f.endswith('.pdf')]
    xml_files = [f for f in os.listdir(xml_path) if f.endswith('.xml')]
    df_pdf = pd.DataFrame({
        'article_id': [f.replace('.pdf', '') for f in pdf_files],
        'pdf_file_path': [os.path.join(pdf_path, f) for f in pdf_files]
    })
    df_xml = pd.DataFrame({
        'article_id': [f.replace('.xml', '') for f in xml_files],
        'xml_file_path': [os.path.join(xml_path, f) for f in xml_files]
    })
    merge_df = pd.merge(df_pdf, df_xml, on='article_id', how='outer', suffixes=('_pdf', '_xml'), validate="one_to_many")
    merge_df['dataset_type'] = dataset_type
    return merge_df

def extract_pdf_text(file_path: str, xml_type: str | None = None) -> dict[str, str | list[dict[str, list[str] | str]]]:
    """Extracts all text from a PDF file using PyMuPDF."""
    article_dict = {}
    if file_path and os.path.exists(file_path):
        try:
            with fitz.open(file_path) as doc:
                article_dict = extract_pdf_doc_text(doc)  # Extract text from the PDF document
        except Exception as e:
            print(f"Error reading PDF {file_path}: {e}")
    else:
        print(f"PDF file not found: {file_path}")
    
    return article_dict

def extract_xml_text(file_path: str, xml_type: str) -> dict[str, str | list[dict[str, list[str] | str]]]:
    """Reads and extracts text from an XML file based on the specified XML type.
    Args:
        file_path (str): The path to the XML file.
        xml_type (str): The type of XML format (e.g., 'jats', 'tei', 'wiley', 'bioc', 'taxonx').
    Returns:
        dict: A dictionary containing the extracted text from the XML file.
    """
    # Initialize the article dictionary
    article_dict = {}
    if file_path and os.path.exists(file_path):
        # Disable external entity resolution for security
        parser = etree.XMLParser(resolve_entities=False, no_network=True)
        try:
            tree = etree.parse(file_path, parser)
            root = tree.getroot()
            # Use the appropriate extraction function based on the xml_type
            extract_function = XML_TYPE_EXTRACTORS.get(xml_type, extract_xml_text_jats)  
            article_dict = extract_function(root)
        except Exception as e:
            print(f"Error reading XML {file_path}: {e}")
    else:
        print(f"XML file not found: {file_path}")    
    return article_dict

def dedupe_other_dataset_citations(article_dict: dict[str, str | list[dict[str, list[str] | str]]]) -> dict[str, str | list[dict[str, list[str] | str]]]:
    """
    Deduplicates dataset IDs in the article dictionary.
    
    Args:
        article_dict (dict): The article dictionary containing dataset IDs.
        
    Returns:
        dict: The updated article dictionary with deduplicated dataset IDs.
    """
    unique_dataset_ids = set()
    unique_dataset_id_citations = []
    if 'other_dataset_citations' in article_dict and isinstance(article_dict['other_dataset_citations'], list):
        dataset_citations = article_dict['other_dataset_citations']
        for citation_dict in dataset_citations:
            dataset_ids = citation_dict['dataset_ids']
            if isinstance(dataset_ids, list):
                if len(dataset_ids) == 1:
                    dataset_id = dataset_ids[0]
                    if dataset_id not in article_dict['data_availability'] and dataset_id not in unique_dataset_ids:
                        unique_dataset_ids.add(dataset_id)
                        unique_dataset_id_citations.append(citation_dict)
                else:
                    unique_dataset_id_citations.append(citation_dict)
        article_dict['other_dataset_citations'] = unique_dataset_id_citations
    
    return article_dict 

def process_unsupported_file(file_path: str, xml_type: str | None = None) -> dict:
    return {
        'title': f"Unsupported file type for: {file_path}",
        'data_availability': "",
        'other_dataset_citations': [],
    }

# Dictionary mapping file extensions to loading functions
FILE_EXTRACTORS = {
    '.xml': extract_xml_text,
    '.pdf': extract_pdf_text,
}

def extract_article_dict(file_path: str) -> dict[str, str | list[dict[str, list[str] | str]]]:
    # Get the file extension (e.g., '.xml', '.pdf')
    file_extension = get_file_extension(file_path)

    # Get the XML type if the file is an XML file
    xml_type = get_xml_type(file_path)

    # Get the appropriate function from the dictionary,
    # or fall back to a default 'unsupported' function if not found.
    extract_function = FILE_EXTRACTORS.get(file_extension, process_unsupported_file)

    # Call the selected function
    article_dict = extract_function(file_path, xml_type=xml_type)
    article_dict = dedupe_other_dataset_citations(article_dict)
    
    return article_dict

def extract_article_text(arg: str| dict) -> str:
    """
    Overloaded function: Accepts either a file_path (str) or an article_dict (dict).
    Returns the article text as a JSON string.
    """
    if isinstance(arg, dict):
        # If it's already a dict, just serialize it
        return json.dumps(arg, separators=(',', ':'))
    elif isinstance(arg, str):
        # If it's a file path, process as before
        file_path = arg
        file_extension = get_file_extension(file_path)
        xml_type = get_xml_type(file_path)
        extract_function = FILE_EXTRACTORS.get(file_extension, process_unsupported_file)
        article_dict = extract_function(file_path, xml_type=xml_type)
        article_dict = dedupe_other_dataset_citations(article_dict)
        text_content = json.dumps(article_dict, separators=(',', ':'))
        print(f"Extracted text from {file_path}. Length: {len(text_content)} characters, xml_type: {xml_type}")
        return text_content
    else:
        raise TypeError("extract_article_text expects a file path (str) or article_dict (dict)")


In [13]:
# Test extracting text from various PDF and XML files
# pdf_file_path = os.path.join(ARTICLE_TRAIN_DIR, 'PDF', '10.1002_2017jc013030.pdf')
pdf_file_path = os.path.join(ARTICLE_TRAIN_DIR, 'XML', '10.1002_mp.14424.xml')
# pdf_file_path = os.path.join(ARTICLE_TRAIN_DIR, 'XML', '10.1007_s00259-022-06053-8.xml')    # jats
# pdf_file_path = os.path.join(ARTICLE_TRAIN_DIR, 'XML', '10.1007_s00382-022-06361-7.xml')    # tei
# pdf_file_path = os.path.join(ARTICLE_TRAIN_DIR, 'XML', '10.1111_1365-2435.13431.xml')       # wiley
# pdf_file_path = os.path.join(ARTICLE_TRAIN_DIR, 'XML', '10.1111_mec.16977.xml')             # bioc
# pdf_file_path = os.path.join(ARTICLE_TRAIN_DIR, 'XML', '10.3897_zoologia.35.e23481.xml')    # taxonx
article_text = extract_article_text(extract_article_dict(pdf_file_path))
display(len(article_text))
article_text

2704

'{"title":"PleThora: Pleural effusion and thoracic cavity segmentations in diseased lungs for benchmarking chest CT processing pipelines","author":"Kiser Kendall J.","abstract":"This manuscript describes a dataset of thoracic cavity segmentations and discrete pleural effusion segmentations we have annotated on 402 computed tomography (CT) scans acquired from patients with non-small cell lung cancer. The segmentation of these anatomic regions precedes fundamental tasks in image analysis pipelines such as lung structure segmentation, lesion detection, and radiomics feature extraction. Bilateral thoracic cavity volumes and pleural effusion volumes were manually segmented on CT scans acquired from The Cancer Imaging Archive \'NSCLC Radiomics\' data collection. Four hundred and two thoracic segmentations were first generated automatically by a U-Net based algorithm trained on chest CTs without cancer, manually corrected by a medical student to include the complete thoracic cavity (normal, p

## Pre-processing

In [14]:
# Set the base file dir for the articles to be processed
# base_file_dir = ARTICLE_TEST_DIR \
#     if os.getenv('KAGGLE_IS_COMPETITION_RERUN') \
#     else ARTICLE_TRAIN_DIR

base_file_dir = ARTICLE_TEST_DIR

display(f"Base dir: {base_file_dir}")

# Load the PDF and XML file paths from the base_file_dir
file_paths_df = load_file_paths(base_file_dir)
file_paths_df['xml_file_path'] = file_paths_df['xml_file_path'].fillna('')

# Get the xml type for each file based on the first line of the XML file
file_paths_df['xml_type'] = file_paths_df['xml_file_path'].apply(get_xml_type)

file_paths_df.to_csv(os.path.join(BASE_OUTPUT_DIR, 'file_paths.csv'), index=False)

print(f"Files paths shape: {file_paths_df.shape}")
display(file_paths_df.sample(3))


'Base dir: ./kaggle/input/make-data-count-finding-data-references\\test'

Files paths shape: (30, 5)


Unnamed: 0,article_id,pdf_file_path,xml_file_path,dataset_type,xml_type
24,10.1002_ejoc.202000916,./kaggle/input/make-data-count-finding-data-re...,./kaggle/input/make-data-count-finding-data-re...,test,jats
18,10.1002_ece3.961,./kaggle/input/make-data-count-finding-data-re...,./kaggle/input/make-data-count-finding-data-re...,test,jats
13,10.1002_ece3.5260,./kaggle/input/make-data-count-finding-data-re...,./kaggle/input/make-data-count-finding-data-re...,test,jats


## Define the Qwen evaluation model class

In [32]:
# --- QwenModelEval Class ---
#max_new_tokens=32768
# class QwenModelEval:
#     def __init__(self, model_name, sys_prompt, enable_thinking=False, max_new_tokens=100, max_input_length=8200): # <--- Increased max_new_tokens slightly for safety with greedy
#         print(f"Loading Qwen model and tokenizer from: {model_name}")
#         self.model_name = model_name
#         self.sys_prompt = sys_prompt
#         self.enable_thinking = enable_thinking  # Enable or disable thinking mode
#         self.max_new_tokens = max_new_tokens  # Set the maximum number of new tokens to generate
#         self.max_input_length = max_input_length  # Set the maximum input length for the model
#         self.bnb_config = BitsAndBytesConfig(
#             load_in_4bit=True,
#             bnb_4bit_quant_type="nf4",
#             bnb_4bit_compute_dtype=torch.float16,
#             bnb_4bit_use_double_quant=True,
#             llm_int8_enable_fp32_cpu_offload=True
#         )
#         self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
#         if self.tokenizer.pad_token is None:
#             self.tokenizer.pad_token = self.tokenizer.eos_token        
#         self.model = AutoModelForCausalLM.from_pretrained(
#             self.model_name,
#             quantization_config=self.bnb_config,
#             device_map="auto",
#             torch_dtype=torch.float16,
#             attn_implementation="sdpa",
#             trust_remote_code=True
#         )
#         self.model.eval()
#         print(self.model.hf_device_map)

#     def generate_response(self, user_input):  
#         inputs = self._get_inputs(user_input)
        
#         with torch.no_grad(): 
#             generation_args = {
#                 "max_new_tokens": self.max_new_tokens,
#                 "pad_token_id": self.tokenizer.eos_token_id,
#                 "eos_token_id": self.tokenizer.convert_tokens_to_ids("<|im_end|>"),
#                 # "num_beams": 1, # num_beams is only relevant for beam search, not greedy or simple sampling
#             }

#             if self.enable_thinking:
#                 # Model card's recommended "thinking mode" parameters
#                 generation_args["do_sample"] = True
#                 generation_args["temperature"] = 0.6
#                 generation_args["top_p"] = 0.95
#                 generation_args["top_k"] = 20
#                 generation_args["min_p"] = 0
#             else:
#                 # Forcing greedy decoding for concise, classification-focused output
#                 generation_args["do_sample"] = False # <--- CRITICAL CHANGE: Force Greedy Decoding
#                 # Remove sampling parameters as they are not used with do_sample=False
#                 # generation_args["temperature"] = 0.1
#                 # generation_args["top_p"] = 0.7
#                 # generation_args["top_k"] = 5
#                 # generation_args["min_p"] = 0

#             generated_ids = self.model.generate(
#                 **inputs,
#                 **generation_args,
#             )

#         input_len = inputs['input_ids'].shape[1]
#         generated_only_ids = generated_ids[0][input_len:]
#         decoded_generated_text = self.tokenizer.decode(generated_only_ids, skip_special_tokens=False)
        
#         print("\n--- RAW DECODED OUTPUT (GENERATED ONLY) ---")
#         print(decoded_generated_text)
#         print("--------------------------------------------\n")

#         return self._parse_response(decoded_generated_text)
    
#     def _get_inputs(self, user_input):
#         user_input = user_input[:self.max_input_length]
#         messages = [
#             {"role": "system", "content": self.sys_prompt},
#             {"role": "user", "content": user_input}
#         ]
#         text = self.tokenizer.apply_chat_template(
#             messages,
#             tokenize=False,
#             add_generation_prompt=True,
#         )
#         return self.tokenizer(text, return_tensors="pt").to(self.model.device)
    
#     def _parse_response(self, generated_text: str):
#         thinking_content = ""
#         raw_response = generated_text

#         think_match = re.search(r'<think>(.*?)(?=\[|\<\|im_end\|\>|$)', generated_text, re.DOTALL)
#         if think_match:
#             thinking_content = think_match.group(1).strip()
#             raw_response = generated_text[generated_text.find(think_match.group(0)) + len(think_match.group(0)):]
#             raw_response = raw_response.strip()

#         response = self._parse_json(raw_response)
        
#         print(f"Parsed response: {response}")
#         print(f"Extracted thinking content: {thinking_content}")
#         return response, thinking_content
    
#     def _parse_json(self, raw_response: str) -> list[dict[str,str]]:
#         cleaned = raw_response.strip()
#         if cleaned.startswith("```json"):
#             cleaned = cleaned[len("```json"):].strip()
#         if cleaned.endswith("```"):
#             cleaned = cleaned[:-3].strip()

#         json_match = re.search(r'\[.*?\]', cleaned, re.DOTALL)
#         if json_match:
#             json_string = json_match.group(0)
#             try:
#                 parsed_json = json.loads(json_string)
#                 if isinstance(parsed_json, list) and all(isinstance(item, dict) for item in parsed_json):
#                     return parsed_json
#                 else:
#                     print(f"Warning: Parsed JSON is not in expected list of dicts format: {parsed_json}")
#                     return []
#             except json.JSONDecodeError as e:
#                 print(f"JSON decoding error: {e}")
#                 print(f"Problematic JSON string: {json_string}")
#                 return []
#         else:
#             print("No JSON array found in generated output.")
#             return []


class QwenModelEval:
    def __init__(self, model_name, sys_prompt, enable_thinking=True, max_new_tokens=256, max_input_length=8200):
        print(f"Loading Qwen model and tokenizer from: {model_name}")
        self.model_name = model_name
        self.sys_prompt = sys_prompt
        self.enable_thinking = enable_thinking  # Enable or disable thinking mode
        self.max_new_tokens = max_new_tokens  # Set the maximum number of new tokens to generate
        self.max_input_length = max_input_length  # Set the maximum input length for the model
        self.bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4", # "fp4" or "nf4"
            bnb_4bit_compute_dtype=torch.float16,
            bnb_4bit_use_double_quant=True,
            llm_int8_enable_fp32_cpu_offload=True # <--- ADD THIS LINE
        )
        # Load the tokenizer and ensure pad_token_id is set for generation if not already in tokenizer config
        self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token        
        # Load model with quantization
        self.model = AutoModelForCausalLM.from_pretrained(
            self.model_name,
            quantization_config=self.bnb_config,
            device_map="auto", # Automatically maps model layers to available devices
            torch_dtype=torch.float16, # Match compute_dtype if using 4-bit
            attn_implementation="sdpa", # Use SDPA for better performance
            trust_remote_code=True
        )
        self.model.eval() # Set the model to evaluation mode here.
        print(self.model.hf_device_map)

    def generate_response(self, user_input):  
        inputs = self._get_inputs(user_input)
        # Disable gradient calculation during inference
        # Generate the response using the model
        with torch.no_grad(): 
            generated_ids = self.model.generate(
                **inputs,
                max_new_tokens=self.max_new_tokens,
                pad_token_id=self.tokenizer.eos_token_id,
                eos_token_id=self.tokenizer.convert_tokens_to_ids("<|im_end|>"),
                do_sample=False,  # CRUCIAL: Always enable sampling as per model card
                # num_beams=1,     # Keep at 1 for standard sampling (not beam search)
                # temperature=0.6 if self.enable_thinking else 0.7,
                # top_p=0.95 if self.enable_thinking else 0.8,
                # top_k=20,
                # min_p=0,
            )

        # Get the length of the input tokens
        input_len = inputs['input_ids'].shape[1]
        
        # Slice generated_ids to get only the newly generated tokens
        # generated_ids[0] is the first (and only) sequence in the batch
        # [input_len:] slices from the end of the input tokens onwards
        generated_only_ids = generated_ids[0][input_len:]
        
        # Decode only the generated part (including special tokens like <think>)
        decoded_generated_text = self.tokenizer.decode(generated_only_ids, skip_special_tokens=False)
        
        print("\n--- RAW DECODED OUTPUT (GENERATED ONLY) ---")
        print(decoded_generated_text)
        print("--------------------------------------------\n")

        # Pass only the generated text to the parsing method
        return self._parse_response(inputs, generated_ids)
    
    def _get_inputs(self, user_input):
        """Prepare the input for the model based on user input."""
        # Trim the user input to a maximum length for better performance
        user_input = user_input[:self.max_input_length]  # Limit input length to 4096 characters
        # Create the messages for the chat template
        messages = [
            {"role": "system", "content": self.sys_prompt},
            {"role": "user", "content": user_input}
        ]
        text = self.tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True,
            enable_thinking=self.enable_thinking
        )
        return self.tokenizer(text, return_tensors="pt").to(self.model.device)
    
    def _parse_response(self, inputs, generated_ids):
        # Extract the output IDs from the generated IDs
        output_ids = generated_ids[0][len(inputs.input_ids[0]):].tolist()
        try:
            index = len(output_ids) - output_ids[::-1].index(151668)
        except ValueError:
            index = 0

        thinking_content = self.tokenizer.decode(output_ids[:index], skip_special_tokens=True).strip("\n")
        raw_response = self.tokenizer.decode(output_ids[index:], skip_special_tokens=True).strip("\n")
        response = self._parse_json(raw_response)
        print(f"Parsed response: {response}")
        return response, thinking_content
    
    def _parse_json(self, raw_response: str) -> list[dict[str,str]]:
        # Remove code block markers and leading/trailing whitespace
        cleaned = raw_response.strip()
        if cleaned.startswith("```json"):
            cleaned = cleaned[len("```json"):].strip()
        if cleaned.endswith("```"):
            cleaned = cleaned[:-3].strip()

        # Now parse as JSON
        try:
            return json.loads(cleaned)
        except json.JSONDecodeError as e:
            return []        
        


## Define the System prompt

In [33]:
# Define the one-shot reasoning and task prompt
# This prompt is designed to guide the model through a structured reasoning process
SYS_PROMPT = """
You are an advanced AI research assistant specialized in identifying and classifying datasets used within academic research papers.
Your primary goal is to accurately extract and categorize dataset identifiers (dataset_ids) from provided paper sections.

---

### Input Data Structure

You will receive a JSON string representing key sections of an academic paper, structured as follows:

```json
{
    "title": "Title of the paper",
    "author": "The primary author of the paper, e.g., 'Author A'",
    "abstract": "Abstract of the paper",
    "data_availability": "Data availability information",
    "other_dataset_citations": [
        {"dataset_ids": ["10.12345/12345"], "citation_context": "Dataset citation context 1"},
        {"dataset_ids": ["10.1234/xxxx.1x1x-xx11", "EPI_ISL_12345678"], "citation_context": "Dataset citation context 2"},
        ...
    ]
}
```

**Guidance on Input Sections:**
*   **`title`**: Provides general context for the paper's topic.
*   **`author`**: The primary author of the paper. This is crucial for determining if the dataset's *raw data* was *originally generated by this author*.
*   **`abstract`**: **CRITICAL** for understanding the research scope and, most importantly, for determining if a dataset's *raw data* was *originally generated by the authors of *this* paper*.
*   **`data_availability`**: This section provides information on datasets. Its content must be evaluated to determine if the data was *originally generated by the authors of this paper* (Primary) or *acquired from an existing source* (Secondary).
*   **`other_dataset_citations`**: A list of potential dataset citations. The `citation_context` is vital to confirm if a `dataset_id` truly refers to a dataset and to aid in classifying its origin (Primary or Secondary).

---

### Core Objective & Critical Exclusion

Your overarching objective is to identify and classify **only valid, data-related `dataset_id`s**.

**CRITICAL EXCLUSION**: You **MUST NOT** extract any `dataset_id`s that refer to other academic papers, articles, or the paper itself. Focus strictly on identifiers for *datasets* only found within the `abstract`, `data_availability` and `other_dataset_citations` sections. DO NOT make up any dataset_ids.

---

### Key Definitions

*   **`dataset_id`**: A unique, persistent identifier for a dataset. There are two main types:

    1.  **Digital Object Identifier (DOI)**:
        *   DOI's are used to identify both academic papers/articles and datasets. Your goal is to find DOI's that are related to data/datasets NOT to papers/articles.
        *   **Format**: `[prefix]/[suffix]`. The prefix always starts with "10." and is followed by a 4 or 5 digit number. The suffix can contain letters, numbers, and special characters.
        *   DO NOT look for Accession IDs within the suffix of a DOI.
        *   May or may not start with "https://doi.org/" or "doi:".
        *   **NOTE ON EXTRACTION**: A DOI may appear as a standalone string (e.g., `10.1234/abc`), or embedded within a URL (e.g., `https://doi.org/10.1234/abc`), or within a sentence. In all cases, **extract the `10.xxxx/yyyy` string as a potential `dataset_id`** and proceed with validation rules.
        *   **IMPORTANT DOI VALIDATION RULE**:
            *   A DOI is a `dataset_id` ONLY if the surrounding `citation_context` or `data_availability` section clearly indicates it refers to a dataset, data repository, data archive, or similar data-specific entity.
            *   Only identify DOIs that are explicitly used as `dataset_id`s.
            *   **DO NOT extract DOIs for academic papers/articles.**
            *   **If a DOI is presented as a reference to a paper, article, or publication (e.g., "as described in [DOI]", "cited in [DOI]", "see [DOI] for details on the method"), it is NOT a dataset_id.**

    2.  **Accession ID**:
        *   DO NOT look for Accession IDs within a DOI. If a dataset_id has a DOI format, no portion of the DOI should be identified as an Accession ID.
        *   They always start with two or more alpha characters, including underscores ("_"), followed by three or more digits.
        *   These identifiers are often used in biological databases, chemical databases, or other scientific data repositories.
        *   *Examples*: `"EPI_ISL_12345678"` (e.g., a genomic sequence dataset), `"IPR000264"` (e.g., a protein family identifier), `"SAMN07159041"` (e.g., a biological sample record), `"CHEMBL1782574"` (e.g., a chemical compound entry)

*   **Distinction: Data Source vs. Specific Data Product (dataset_id):**
    *   A **Data Source** is a repository or collection where data is stored (e.g., The Cancer Imaging Archive).
    *   A **Specific Data Product** is the actual dataset or data item being referenced by a `dataset_id` (e.g., "images", "scans", "segmentations", "raw data", etc.).
    *   **CRITICAL**: The classification (Primary/Secondary) applies to the *Specific Data Product* associated with the `dataset_id`, not the general data source it might have come from.

*   **Dataset Type Classification**: **This is the MOST CRITICAL distinction. For each `dataset_id`, ask: Was the *specific data product* it refers to *CREATED* by the authors of *this paper*?**

    *   **Primary**: The `dataset_id` refers to **NEW DATA** (e.g., measurements, annotations, segmentations, etc.) that was **ORIGINALLY GENERATED, COLLECTED, PROCESSED, or CREATED by the AUTHORS OF *THIS SPECIFIC PAPER*** as a direct output of their research. This is the *novel contribution* of the paper.
        *   *Keywords for Primary*: "generated", "sequenced", "collected", "created", "produced", "developed", "annotated", "our data".
        *   **CRITICAL CLARIFICATION**: If a `dataset_id` refers to a novel data product (like annotations or segmentations) that was *created by the authors* (Primary) but was *derived from* or *applied to* existing, external data (Secondary input), the `dataset_id` for the *novel data product* is still **Primary**. The act of making this novel data product publicly available does not change its origin.
        *   *Example*: For an example DOI of `10.7937/tcia.2020.6c7y-gq39`, if the context states "a dataset of thoracic cavity segmentations and discrete pleural effusion segmentations **we have annotated**... All expert-vetted segmentations are publicly available... at https://doi.org/10.7937/tcia.2020.6c7y-gq39". The "we have annotated" indicates original creation of the *segmentations* by the authors, making this **Primary**.

    *   **Secondary**: The `dataset_id` refers to **EXISTING DATA** that was **ACQUIRED, DERIVED, USED, REUSED, or RE-ANALYZED from EXISTING RECORDS or PREVIOUSLY PUBLISHED DATASETS** and that was *not originally generated by the authors of this specific paper*. This is *input data* that the authors *used*, but did not create.
        *   *Keywords for Secondary*: "previously published", "existing", "external", "re-analyzed", "obtained from", "acquired from", "derived from", "sourced from", "data from [external source]".
        *   **IMPORTANT**: If the data was *not created by the authors of this paper*, it is **Secondary**.
        *   *Example*: For an example DOI of `10.7937/K9/TCIA.2015.PF0M9REI`, if the context states "CT scans **acquired from The Cancer Imaging Archive 'NSCLC Radiomics' data collection**". This indicates the raw CT scans were acquired and used by the authors, but not created by them, making this **Secondary**.


---

### Classification Logic Flow (for each identified `dataset_id`):

To classify a `dataset_id` as Primary or Secondary, follow these steps strictly:

1.  **For the specific `dataset_id` being evaluated, identify the *data product* it refers to** by examining the `abstract` and associated `citation_context`. (e.g., "images", "scans", "segmentations", "genomic sequences", "raw data", etc.).
2.  **STEP 1: Is this *specific data product* a NEW CREATION by the authors of *this paper*?**
    *   Look for phrases like "we have annotated", "generated by us", "created by the authors", "our data", "produced in this study", or descriptions of *original data collection/creation* by the authors.
        *   **IF YES**: Classify as **Primary**.
        *   *Example*: If the abstract states things like "a dataset of thoracic cavity segmentations and discrete pleural effusion segmentations **we have annotated**... All expert-vetted segmentations are publicly available at ...". The "we have annotated" indicates original creation of the *segmentations* by the authors, making this **Primary**.
        *   **REMEMBER**: Even if this *Primary* data product was derived from *Secondary* input data, it is still **Primary** because the *data product itself* is a novel creation of these authors.
3.  **STEP 2: Is this *specific data product* EXISTING DATA ACQUIRED from an *external source*?**
    *   If the `dataset_id` was *not* classified as Primary in Step 1, then look for phrases like "acquired from", "obtained from", "derived from", or "previously published".
        *   **IF YES**: Classify as **Secondary**.
        *   *Example*: If the abstract states things like "scans **acquired from** The Cancer Imaging Archive 'NSCLC Radiomics' data collection". The **acquired from** indicates the raw CT scans were acquired and used by the authors, but not created by them, making this **Secondary**.
4.  **STEP 3: Fallback Rule:**
    *   If, after applying Step 1 and Step 2, the origin of the *specific data product* remains truly ambiguous, then default to "Primary".

---

### Tasks: Step-by-Step Instructions

Follow these three tasks in order:

**SHORT-CIRCUIT RULE:**
**IF** the `data_availability` section is an empty string (`""`) **AND** the `other_dataset_citations` section is an empty list (`[]`), **THEN** immediately **skip all other tasks** and proceed directly to **Task 3** to output the "Missing" JSON structure. Do not perform any further analysis or reasoning.

**Task 1: Identify Valid Dataset IDs**

1.  **Search Priority**:
    *   **IF** `data_availability` is NOT an empty string (`""`), search its text first.
    *   **THEN**, **IF** `other_dataset_citations` is NOT an empty list (`[]`), search its text.
    *   **IMPORTANT**: If `data_availability` is empty, proceed directly to search `other_dataset_citations`.
2.  **Validation and Extraction**: For each potential `dataset_id` (DOI or Accession ID) found *within the text* of `data_availability` or `citation_context`, confirm it is truly data-related and **extract the identifier string**.
    *   **For DOIs**: Strictly apply the **IMPORTANT DOI VALIDATION RULE** defined above. If it refers to a publication, **DO NOT** extract it.
    *   **For all IDs**: Look for surrounding terms like "data release", "data availability", "dataset", "database", "repository", "data source", "data access", or "data archive" within the `data_availability` section or the `citation_context`.
    *   *Example of Extraction from `data_availability`*: If `data_availability` contains "Data are available at Dryad Digital Repository at: https://doi.org/10.5061/dryad.zw3r22854 . ...", then `10.5061/dryad.zw3r22854` is a valid `dataset_id` to extract.
3.  **Deduplication**: If the same `dataset_id` is found multiple times, **only process the first instance encountered**.
4.  **Conditional Proceeding**:
    *   If **no valid `dataset_id`s are found** after searching both sections, **skip directly to Task 3** and output the "Missing" JSON structure.
    *   If one or more valid `dataset_id`s are found, proceed to Task 2.

**Task 2: Classify Dataset Types**

1.  For each valid `dataset_id` identified in Task 1, classify its type as either "Primary" or "Secondary".
2.  **STRICTLY APPLY THE "CLASSIFICATION LOGIC FLOW" ABOVE for each `dataset_id`.** Use the `abstract` section and the `citation_context` to determine if the *specific data product* associated with the `dataset_id` was *originally generated by the authors of *this* paper* (Primary) or *acquired/reused from an existing source* (Secondary).
3.  Apply the "Key Definitions" for Primary and Secondary types, paying close attention to the associated keywords and the provided examples.
4.  Remember the "Fallback Rule": Default to "Primary" if the classification remains truly ambiguous regarding the *original generation* of the raw data.

**Task 3: Format and Return Results**

**CRITICAL: Your entire response MUST ONLY be the JSON array. Do NOT include any conversational text, explanations, reasoning steps, or internal thoughts (like <think> tags).**

Return your final results as a JSON array of objects.

1.  **Scenario A: No Valid Datasets Found**
    If Task 1 resulted in no valid `dataset_id`s, return a single JSON object with the following structure:
    ```json
    [
        {
            "dataset_id": "Missing",
            "type": "Missing"
        }
    ]
    ```
2.  **Scenario B: One or More Valid Datasets Found**
    If Task 1 identified one or more valid `dataset_id`s, return every valid dataset found in a JSON array of objects, where each object has the following structure:
    ```json
    [
        {
            "dataset_id": "example_id_1",
            "type": "Primary"
        },
        {
            "dataset_id": "example_id_2",
            "type": "Secondary"
        },
        ...
    ]
    ```
"""

## Instantiate the Qwen Model

In [34]:
# Instantiate the QwenModelEval class with the model path and system prompt
inference_model = QwenModelEval(MODEL_PATH, sys_prompt=SYS_PROMPT, enable_thinking=False, max_new_tokens=100, max_input_length=8200)

Loading Qwen model and tokenizer from: C:\Users\jim\.cache\kagglehub\models\qwen-lm\qwen-3\transformers\1.7b\1


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

RuntimeError: Tensor.item() cannot be called on meta tensors

## Evaluate the articles

In [26]:
def evaluate_articles(file_paths_df: pd.DataFrame, model) -> pd.DataFrame:
    results = []
    # tqdm(os.listdir(pdf_directory), total=len(os.listdir(pdf_directory)))
    for i, row in tqdm(file_paths_df.iterrows(), total=len(file_paths_df)):
        article_id = row['article_id']
        text_type = "XML"
        article_dict = extract_article_dict(row['xml_file_path'])
        if not article_dict['data_availability'] and not article_dict['other_dataset_citations']:
            text_type = "PDF"
            article_dict = extract_article_dict(row['pdf_file_path'])

        user_input = ""
        response = [{'dataset_id': 'Missing', 'type': 'Missing'}]
        thinking_content = ""
        
        # Only process articles that have data_availability and/or other_dataset_citations
        if article_dict['data_availability'] or article_dict['other_dataset_citations']:
            # Prepare the user input for the model
            user_input = f"Text Content: {extract_article_text(article_dict)}\n"
            print(f"Processing article {i}/{len(file_paths_df)}: {article_id}, type: {text_type}, input: {len(user_input)}")
            # Generate response from the model
            response, thinking_content = model.generate_response(user_input)

        results.append({
            'article_id': article_id,
            'text_type': text_type,
            'llm_input': user_input,
            'llm_response': response,
            'llm_thinking_content': thinking_content
        })

    return pd.DataFrame(results).sort_values(by=["article_id"]).reset_index(drop=True)


In [19]:
# # Load the file paths DataFrame from the CSV file
# file_paths_df = pd.read_csv(os.path.join(BASE_OUTPUT_DIR, 'file_paths.csv'))
# # Fill NaN values in the 'xml_type' and 'xml_text' columns with empty strings
# file_paths_df['xml_file_path'] = file_paths_df['xml_file_path'].fillna('')
# file_paths_df['xml_text'] = file_paths_df['xml_text'].fillna('')
# Display the first few rows of the file paths DataFrame
print(f"File paths DataFrame shape: {file_paths_df.shape}")
display(file_paths_df.head(3))


File paths DataFrame shape: (30, 5)


Unnamed: 0,article_id,pdf_file_path,xml_file_path,dataset_type,xml_type
0,10.1002_2017jc013030,./kaggle/input/make-data-count-finding-data-re...,./kaggle/input/make-data-count-finding-data-re...,test,tei
1,10.1002_anie.201916483,./kaggle/input/make-data-count-finding-data-re...,./kaggle/input/make-data-count-finding-data-re...,test,jats
2,10.1002_anie.202005531,./kaggle/input/make-data-count-finding-data-re...,./kaggle/input/make-data-count-finding-data-re...,test,bioc


In [None]:
# sample_file_paths_df = file_paths_df.copy().sample(3, random_state=42).reset_index(drop=True)
# sample_file_paths_df

In [31]:
processed_articles_df = evaluate_articles(file_paths_df, inference_model)
# Save processed_articles_df to CSV
processed_articles_df.to_csv(os.path.join(BASE_OUTPUT_DIR, 'sample_evaluated_articles.csv'), index=False)
print(f"Processed articles DataFrame shape: {processed_articles_df.shape}")
processed_articles_df

  0%|          | 0/30 [00:00<?, ?it/s]

The following generation flags are not valid and may be ignored: ['temperature', 'top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Processing article 0/30: 10.1002_2017jc013030, type: XML, input: 1441

--- RAW DECODED OUTPUT (GENERATED ONLY) ---
<think>
Okay, let's tackle this problem step by step. The user provided a JSON structure with a paper's details, and I need to extract and classify dataset IDs based on the given instructions. 

First, I'll check the input data. The title is about assessing variability in the relationship between particulate backscattering coefficient and chlorophyll a concentration using a global Biogeochemical-Argo database. The author is Marie Barbieux. The abstract mentions the BGC-Argo
--------------------------------------------

No JSON array found in generated output.
Parsed response: []
Extracted thinking content: Okay, let's tackle this problem step by step. The user provided a JSON structure with a paper's details, and I need to extract and classify dataset IDs based on the given instructions. 

First, I'll check the input data. The title is about assessing variability in the re

The following generation flags are not valid and may be ignored: ['temperature', 'top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Processing article 1/30: 10.1002_anie.201916483, type: PDF, input: 1340


KeyboardInterrupt: 

In [30]:
def format_dataset_id(dataset_id: str) -> str:
    """
    Formats the dataset_id by removing any leading/trailing whitespace and ensuring it is a string.
    
    Args:
        dataset_id (str): The dataset identifier to format.
        
    Returns:
        str: The formatted dataset identifier.
    """
    if dataset_id and dataset_id.startswith("10.") and len(dataset_id) > 10:
        # If the dataset_id starts with "10." and is longer than 10 characters, it's likely a DOI
        dataset_id = "https://doi.org/" + dataset_id.lower().strip()
    return dataset_id

# Create a DataFrame to hold the evaluation results by expaning the 'llm_response' column
def expand_evaluation_results(df: pd.DataFrame) -> pd.DataFrame:
    """
    Expands the evaluation results DataFrame by extracting dataset_id and type from the 'llm_response' column.
    
    Args:
        df (pd.DataFrame): The DataFrame containing evaluation results.
        
    Returns:
        pd.DataFrame: A new DataFrame with expanded dataset_id and type columns.
    """
    expanded_rows = []
    for _, row in df.iterrows():
        article_id = row['article_id']
        article_doi = article_id.replace('_', '/')
        datasets = row['llm_response']
        missing_dataset = {
            'article_id': article_id,
            'dataset_id': 'Missing',
            'type': 'Missing',
        }

        if datasets:
            for dataset in datasets:
                dataset_id = dataset.get('dataset_id', 'Missing')
                # Skip if the dataset_id is the same as the article DOI
                if dataset_id == article_doi:
                    # If the dataset_id is the same as the article DOI add it as Missing
                    expanded_rows.append(missing_dataset)
                else:
                    expanded_rows.append({
                        'article_id': article_id,
                        'dataset_id': dataset.get('dataset_id', 'Missing'),
                        'type': dataset.get('type', 'Missing'),
                    })
        else:
            # If no datasets were found, add a row with 'Missing' values
            expanded_rows.append(missing_dataset)
    
    # Create a DataFrame from the expanded rows
    expanded_df = pd.DataFrame(expanded_rows)
    expanded_df['dataset_id'] = expanded_df['dataset_id'].apply(format_dataset_id)  # Format dataset_id
    expanded_df['type'] = expanded_df['type'].str.strip().str.capitalize()  # Ensure type is capitalized and stripped of whitespace
    expanded_df = expanded_df.sort_values(by=["article_id", "dataset_id", "type"], ascending=True).drop_duplicates(subset=['article_id', 'dataset_id'], keep="first").reset_index(drop=True)
    
    return expanded_df

def prepare_for_submission(expanded_df: pd.DataFrame) -> pd.DataFrame:
    """
    Prepares the expanded DataFrame for submission by ensuring the correct columns and formatting.
    
    Args:
        expanded_df (pd.DataFrame): The DataFrame containing expanded dataset information.
        
    Returns:
        pd.DataFrame: A DataFrame ready for submission with 'article_id', 'dataset_id', and 'type' columns.
    """
    # Ensure the DataFrame has the correct columns
    submission_df = expanded_df[['article_id', 'dataset_id', 'type']].copy()
    # Rename columns to match the expected format
    submission_df.columns = ['article_id', 'dataset_id', 'type']

    # Remove rows where type is 'Missing' and reset index
    submission_df = submission_df[submission_df['type'] != 'Missing'].reset_index(drop=True)
    submission_df['row_id'] = range(len(submission_df))

    # Reorder columns to match the submission format
    submission_df = submission_df[['row_id', 'article_id', 'dataset_id', 'type']]
    
    return submission_df


In [31]:
expanded_df = expand_evaluation_results(processed_articles_df)
expanded_df.to_csv(os.path.join(BASE_OUTPUT_DIR, 'expanded_eval_results.csv'), index=False)
print(f"Expanded Eval DataFrame shape: {expanded_df.shape}")
display(expanded_df.head(10))

Expanded Eval DataFrame shape: (40, 3)


Unnamed: 0,article_id,dataset_id,type
0,10.1002_2017jc013030,https://doi.org/10.17882/47142,Primary
1,10.1002_2017jc013030,https://doi.org/10.17882/49388,Primary
2,10.1002_2017jc013030,https://doi.org/10.5194/essd-2017-58,Primary
3,10.1002_anie.201916483,Missing,Missing
4,10.1002_anie.202005531,Missing,Missing
5,10.1002_anie.202007717,Missing,Missing
6,10.1002_chem.201902131,Missing,Missing
7,10.1002_chem.201903120,Missing,Missing
8,10.1002_chem.202000235,Missing,Missing
9,10.1002_chem.202001412,Missing,Missing


In [32]:

submission_df = prepare_for_submission(expanded_df)
submission_df.to_csv(SUBMISSION_CSV_PATH, index=False)

submission_df["type"].value_counts()

type
Secondary    17
Primary       8
Name: count, dtype: int64

In [34]:
def f1_score(tp, fp, fn):
    return 2 * tp / (2 * tp + fp + fn) if (2 * tp + fp + fn) != 0 else 0.0
    
    
# if not os.getenv('KAGGLE_IS_COMPETITION_RERUN'):
pred_df = submission_df.copy()
label_df = pd.read_csv("./kaggle/input/make-data-count-finding-data-references/sample_submission.csv")
label_df = label_df[label_df['type'] != 'Missing'].reset_index(drop=True)

hits_df = label_df.merge(pred_df, on=["article_id", "dataset_id", "type"])

tp = hits_df.shape[0]
fp = pred_df.shape[0] - tp
fn = label_df.shape[0] - tp

print("TP:", tp)
print("FP:", fp)
print("FN:", fn)
print("F1 Score:", round(f1_score(tp, fp, fn), 3))

TP: 5
FP: 20
FN: 9
F1 Score: 0.256
