In [126]:
#!pip install lxml

In [127]:
# --- 0. Environment Setup & Offline Preparation ---

# Standard Imports
import os
import glob
import re
import pandas as pd
import lxml.etree as etree
from lxml.etree import _Element as Element # Type hinting for lxml.etree.Element
import collections # For deque in parenthesis removal
import fitz # PyMuPDF for PDF processing
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils import PreTrainedTokenizer
from transformers.utils.quantization_config import BitsAndBytesConfig
from transformers.training_args import TrainingArguments
from trl import SFTTrainer
import torch
from datasets import Dataset # Hugging Face datasets library
import kagglehub
import spacy
import json

# Set device for PyTorch
device = "cuda" if torch and torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")


Using device: cpu


In [128]:
# Uncomment and run this line once to download the model
#!python -m spacy download en_core_web_sm

In [129]:
# Define constants for file paths and model configurations
BASE_INPUT_DIR = './kaggle/input/make-data-count-finding-data-references'
ARTICLE_TRAIN_DIR = os.path.join(BASE_INPUT_DIR, 'train')
ARTICLE_TEST_DIR = os.path.join(BASE_INPUT_DIR, 'test')

# Define directories for articles in train and test sets
LABELED_TRAINING_DATA_CSV_PATH = os.path.join(BASE_INPUT_DIR, 'train_labels.csv')

# Define the base model path
QWEN_BASE_MODEL_PATH = kagglehub.model_download("qwen-lm/qwen-3/transformers/0.6b")

# Output directory for the fine-tuned model and results
BASE_OUTPUT_DIR = "./kaggle/working"
FINE_TUNED_MODEL_OUTPUT_DIR = os.path.join(BASE_OUTPUT_DIR, "qwen_finetuned_dataset_classifier")
FINAL_RESULTS_CSV_PATH = os.path.join(BASE_OUTPUT_DIR, "article_dataset_classification.csv")

# Load a spaCy model (e.g., 'en_core_web_sm')
# python -m spacy download en_core_web_sm 
NLP_SPACY = spacy.load("en_core_web_sm")


In [130]:
# --- 2. Information Extraction (IE) - Dataset Identification ---

# Regex patterns for common dataset identifiers
# DOI_PATTERN = r'10\.\d{4,5}/[-._;()/:A-Za-z0-9\u002D\u2010\u2011\u2012\u2013\u2014\u2015]+'	DOI_PATTERN
# DOI_PATTERN = r'10\.\s?\d{4,5}\/[-._()<>;\/:A-Za-z0-9]+\s?(?:(?![A-Z]+)(?!\d{1,3}\.))+[-._()<>;\/:A-Za-z0-9]+'
#DOI_PATTERN = r'\bhttps://doi.org/10\.\d{4,5}\/[-._\/:A-Za-z0-9]+'
DOI_PATTERN = r'\b10\.\d{4,5}\/[-._\/:A-Za-z0-9]+'
EPI_PATTERN = r'\bEPI[-_A-Z0-9]{2,}'
SAM_PATTERN = r'\bSAMN[0-9]{2,}'          # SAMN07159041
IPR_PATTERN = r'\bIPR[0-9]{2,}'
CHE_PATTERN = r'\bCHEMBL[0-9]{2,}'
PRJ_PATTERN = r'\bPRJ[A-Z0-9]{2,}'
E_G_PATTERN = r'\bE-[A-Z]{4}-[0-9]{2,}'   # E-GEOD-19722 or E-PROT-100
ENS_PATTERN = r'\bENS[A-Z]{4}[0-9]{2,}'
CVC_PATTERN = r'\bCVCL_[A-Z0-9]{2,}'
EMP_PATTERN = r'\bEMPIAR-[0-9]{2,}'
PXD_PATTERN = r'\bPXD[0-9]{2,}'
HPA_PATTERN = r'\bHPA[0-9]{2,}'
SRR_PATTERN = r'\bSRR[0-9]{2,}'
GSE_PATTERN = r'\b(GSE|GSM|GDS|GPL)\d{4,6}\b' # Example for GEO accession numbers (e.g., GSE12345, GSM12345)
GNB_PATTERN = r'\b[A-Z]{1,2}\d{5,6}\b' # GenBank accession numbers (e.g., AB123456, AF000001)
CAB_PATTERN = r'\bCAB[0-9]{2,}'

# Combine all patterns into a list
DATASET_ID_PATTERNS = [
    DOI_PATTERN,
    EPI_PATTERN,
    SAM_PATTERN,
    IPR_PATTERN,
    CHE_PATTERN,
    PRJ_PATTERN,
    E_G_PATTERN,
    ENS_PATTERN,
    CVC_PATTERN,
    EMP_PATTERN,
    PXD_PATTERN,
    HPA_PATTERN,
    SRR_PATTERN,
    GSE_PATTERN,
    GNB_PATTERN,
    CAB_PATTERN,
]

# Compile all patterns for efficiency
COMPILED_DATASET_ID_REGEXES = [re.compile(p) for p in DATASET_ID_PATTERNS]

# Data related keywords to look for in the text
# These keywords help to ensure that the text is relevant to datasets
DATA_RELATED_KEYWORDS = ['data release', 'download', 'program data', 'data availability', 'the data', 'dataset', 'database', 'repository', 'data source', 'data access', 'archive', 'arch.', 'digital']

def is_text_data_related(text: str) -> bool:
    if not text:
        return False
    
    text_lower = text.lower()
    return any(keyword in text_lower for keyword in DATA_RELATED_KEYWORDS)

def text_has_dataset_id(text: str) -> bool:
    """
    Check if the given text contains any dataset identifier.
    
    Args:
        text (str): The text to check for dataset identifiers.
        
    Returns:
        bool: True if any dataset identifier is found, False otherwise.
    """

    occurrences_with_context: list[str] = []
    for regex in COMPILED_DATASET_ID_REGEXES:
        if regex.search(text):
            text_lower = text.lower()
            # Check for specific keywords in the text
            if any(keyword in text_lower for keyword in DATA_RELATED_KEYWORDS):
                return True
    return False

def extract_dataset_ids(text: str, context_chars: int = 250) -> str:
    """
    Extract dataset identifiers with context from the given text.
    
    Args:
        text (str): The text to search for dataset identifiers.
        context_chars (int): Number of characters to include before and after the match for context.
        
    Returns:
        list[str]: A list of extracted dataset identifiers with context.
    """
    text = text.replace('\u200b', '')
    is_small_context = len(text) < context_chars * 2
    dataset_ids: list[str] = []
    occurrences_with_context: list[str] = []
    if is_text_data_related(text):
        for regex in COMPILED_DATASET_ID_REGEXES:
            matches = regex.finditer(text, re.IGNORECASE)
            for match in matches:
                dataset_id = text[match.start() : match.end()]
                if is_small_context:
                    dataset_ids.append(dataset_id)
                else:
                    extracted_snippet = text[max(0, match.start() - context_chars): match.end() + context_chars ]
                    if is_text_data_related(extracted_snippet):
                        occurrences_with_context.append("{" + f'"dataset_ids": {[dataset_id]}, context: "{extracted_snippet}"' + "}")
        if dataset_ids:
            occurrences_with_context.append("{" + f'"dataset_ids": {dataset_ids}, context: "{text}"' + "}")
    
    # If no occurrences found, return an empty string
    # Otherwise, join the occurrences with a specific separator
    if not occurrences_with_context:
        return ""
    return ",".join(occurrences_with_context)

In [131]:
# Use NLP to get sentences from the given text
def get_sentences_from_text(text: str, nlp=NLP_SPACY) -> str:
    if not text:
        return ""
    doc_spacy = nlp(text)
    return "\n".join([sent.text.replace('-\n', '-').replace('_\n', '_').replace('/\n', '/').replace('\n', ' ') for sent in doc_spacy.sents])

In [175]:

def extract_element_text(element: Element | None) -> str:
    if element is not None:
        # Use itertext() to get all text content from the <p> tag and its descendants
        # and join them into a single string.
        all_text = " ".join(element.itertext(tag=None)).replace('\u200b', '').strip()
        return all_text[:2000]
    else:
        return ""
    
def extract_next_sibling_text(elements: list[Element] | None, sibling_xpath: str) -> str:
    """
    Extracts text from the next sibling of the given XML element.
    
    Args:
        element (Element | None): The XML element whose next sibling's text is to be extracted.
        sibling_xpath (str): The XPath expression to find the next sibling element. (eg. "following-sibling::passage[1]")
        
    Returns:
        str: A string containing the text from the next sibling element, or an empty string if no sibling exists.

    """
    # Check if the provided elements list is None or empty
    if not elements:
        return ""
    
    # Assuming there's only one such element, take the first one found
    # and find the element immediately following based on the given sibling_xpath.
    first_element = elements[0]
    sibling_elements = first_element.xpath(sibling_xpath)

    if not sibling_elements:
        # print("DEBUG: No following <passage> element found.") # Uncomment for debugging
        return ""
    
    next_sibling = sibling_elements[0]
    if next_sibling is None:
        return ""
    
    return extract_element_text(next_sibling)

def extract_elements_text(elements: list[Element] | None) -> str:
    elements_text = []
    if elements is None:
        return ""
    
    for element in elements:
        text = extract_element_text(element)
        if text:
            elements_text.append(text)

    return " ".join(elements_text).strip()

def extract_text_from_elements_within_element(element: Element | None, child_xpaths: list[str] = [], ns: dict[str, str] | None = None) -> str:
    """
    Extracts text from elements within a given XML element that match the specified tag names.
    
    Args:
        element (Element | None): The XML element to search within.
        tag_names (list[str]): A list of tag names to search for.
        
    Returns:
        str: A string containing the extracted text from the matching elements.
    """
    if element is None:
        return ""
    
    if not child_xpaths:
        # If no child tag names are provided, return the text of the element itself
        return extract_element_text(element)
    
    extracted_text = []
    for xpath in child_xpaths:
        for child in element.findall(xpath, namespaces=ns):
            text = extract_element_text(child)
            if text:
                extracted_text.append(text)
    
    return "|".join(extracted_text)

def extract_data_related_elements_text(elements: list[Element] | None, child_xpaths: list[str] = [], ns: dict[str, str] | None = None) -> list[str]:
    elements_text = []
    if elements is None:
        return elements_text
    
    for element in elements:
        text = extract_dataset_ids(extract_text_from_elements_within_element(element, child_xpaths, ns))
        if text:
            elements_text.append(text)

    return elements_text

def extract_data_related_elements_text_from_xpath_list(root: Element | None, xpath_list: list[str], ns: dict[str, str] | None = None) -> list[str]:
    """
    Extracts text from elements in the XML tree that match the provided XPath expressions.
    
    Args:
        root (Element | None): The root element of the XML tree.
        xpath_list (list[str]): A list of XPath expressions to search for elements.
        
    Returns:
        list[str]: A list of extracted text from the matching elements.
    """
    elements_text = []
    if root is None or not xpath_list:
        return elements_text
    
    for xpath in xpath_list:
        primary_xpath, *child_xpath_text = xpath.split('||')
        child_xpaths = child_xpath_text[0].split(',') if child_xpath_text else []
        elements = root.findall(primary_xpath, namespaces=ns)
        if elements:
            elements_text.extend(extract_data_related_elements_text(elements, child_xpaths, ns))
    return elements_text

In [133]:

def extract_xml_text_jats(root: Element) -> dict[str, str | list[str]]:
    # Find the title, abstract, and data availablity info for Journal Archiving and Interchange DTD (JATS)
    # The ".//" ensures it searches anywhere in the document, not just direct children of root.
    ns = None  # No namespaces for JATS

    xpath_title = ".//article-title"
    xpath_abstract = ".//abstract"
    xpath_data_avail = ".//notes[@notes-type='data-availability']"
    xpath_citations = [".//element-citation||.article-title,.source,.pub-id", ".//mixed-citation"]  # List of XPath expressions for citations
        
    return {
        'title': extract_element_text(root.find(xpath_title, ns)),
        'abstract': extract_element_text(root.find(xpath_abstract, ns)),
        'data_availability': extract_element_text(root.find(xpath_data_avail, ns)),
        'other_data_citations': extract_data_related_elements_text_from_xpath_list(root, xpath_citations, ns=ns),
    }


In [134]:
def extract_xml_text_tei(root: Element) -> dict[str, str | list[str]]:
    # Find the title, abstract, and data availability info for Text Encoding Initiative (TEI)
    # Set the namespace for TEI
    ns = {'tei': 'http://www.tei-c.org/ns/1.0'}

    xpath_title = ".//tei:title"
    xpath_abstract = ".//tei:abstract"
    xpath_data_avail = "" #".//tei:biblstruct"
    xpath_citations = [".//tei:biblstruct||.//tei:title,.//tei:idno,.//tei:notes"]  # List of XPath expressions for citations
        
    return {
        'title': extract_element_text(root.find(xpath_title, namespaces=ns)),
        'abstract': extract_element_text(root.find(xpath_abstract, namespaces=ns)),
        'data_availability': xpath_data_avail,  # No direct extraction for TEI data_availability
        'other_dataset_citations': extract_data_related_elements_text_from_xpath_list(root, xpath_citations, ns=ns),
    }


In [None]:
def extract_xml_text_wiley(root: Element) -> dict[str, str | list[str]]:
    # Find the title, abstract, and data availability info for Wiley XML format
    # Set the namespace for Wiley
    ns = {'ns': 'http://www.wiley.com/namespaces/wiley'}

    xpath_title = ".//ns:publicationMeta[@level='part']/ns:titleGroup"    #<publicationMeta level="part"><titleGroup><title type="main">
    xpath_abstract = ".//ns:abstract[@type='main']"  #<abstract type="main"
    xpath_data_avail = ".//ns:section[@type='dataAvailability']"  #<section numbered="no" type="dataAvailability"
    xpath_citations = [".//ns:citation||.//ns:articleTitle,.//ns:journalTitle,.//ns:url"]  # List of XPath expressions for citations
        
    return {
        'title': extract_elements_text(root.findall(xpath_title, namespaces=ns)),
        'abstract': extract_element_text(root.find(xpath_abstract, namespaces=ns)),
        'data_availability': extract_element_text(root.find(xpath_data_avail, namespaces=ns)),
        'other_dataset_citations': extract_data_related_elements_text_from_xpath_list(root, xpath_citations, ns=ns),
    }

In [None]:
# passages = root.xpath(
#     ".//passage[infon[@key='section_type' and text()='TABLE']]"
# )
# text_elements = root.xpath(".//passage[infon[@key='section_type' and text()='TITLE']]/text")

def extract_xml_text_bioc(root: Element) -> dict[str, str | list[str]]:
    # Find the title, abstract, and data availability info for BioC-API XML format
    ns = None  # No namespaces for BioC

    

    xpath_title = "string(.//passage[infon[@key='section_type' and text()='TITLE']]/text)"
    xpath_abstract = "string(.//passage[infon[@key='section_type' and text()='ABSTRACT']]/text)"
    #xpath_data_avail = ".//passage[infon[@key='section_type' and text()='SUPPL']]/text[text()='DATA ACCESSIBILITY:']"
    xpath_data_avail = ".//passage[text[text()='DATA ACCESSIBILITY:']]"
    xpath_data_avail_sibling = "following-sibling::passage[1]"
    xpath_citations = []
        
    return {
        'title': root.xpath(xpath_title, namespaces=ns),
        'abstract': root.xpath(xpath_abstract, namespaces=ns)[:2000],  # Limit to 2000 characters
        'data_availability': extract_next_sibling_text(root.xpath(xpath_data_avail, namespaces=ns), xpath_data_avail_sibling),
        'other_dataset_citations': xpath_citations,
    }

In [178]:
# Dictionary mapping XML types to their respective extraction functions
XML_TYPE_EXTRACTORS = {
    'jats': extract_xml_text_jats,
    'tei': extract_xml_text_tei,
    'wiley': extract_xml_text_wiley,
    'bioc': extract_xml_text_bioc,
}

# --- Data Loading ---
def load_file_paths(dataset_type_dir: str) -> pd.DataFrame: 
    pdf_path = os.path.join(dataset_type_dir, 'PDF')
    xml_path = os.path.join(dataset_type_dir, 'XML')
    dataset_type = os.path.basename(dataset_type_dir)
    pdf_files = [f for f in os.listdir(pdf_path) if f.endswith('.pdf')]
    xml_files = [f for f in os.listdir(xml_path) if f.endswith('.xml')]
    df_pdf = pd.DataFrame({
        'article_id': [f.replace('.pdf', '') for f in pdf_files],
        'pdf_file_path': [os.path.join(pdf_path, f) for f in pdf_files]
    })
    df_xml = pd.DataFrame({
        'article_id': [f.replace('.xml', '') for f in xml_files],
        'xml_file_path': [os.path.join(xml_path, f) for f in xml_files]
    })
    merge_df = pd.merge(df_pdf, df_xml, on='article_id', how='outer', suffixes=('_pdf', '_xml'), validate="one_to_many")
    merge_df['dataset_type'] = dataset_type
    return merge_df

def read_pdf_text(pdf_path: str, xml_type: str | None = None) -> str:
    """Extracts all text from a PDF file using PyMuPDF."""
    text = ""
    p1 = None
    abstrct_block_found = False
    num_blocks = 0
    if not fitz:
        return text  # Return empty string if fitz is not available
    try:
        with fitz.open(pdf_path) as doc:
            for page in doc:
                # Extract text from the page
                textpage = page.get_textpage()
                if page.number == 0:
                    p1 = get_sentences_from_text(textpage.extractTEXT())
                    p1 = p1[:int(len(p1)/2)] + ".\npotential_dataset_ids: ["

                # Extract text from all blocks that have dataset id's
                blocks = textpage.extractBLOCKS()
                for block in blocks:
                    if "dryad" in block[4]:
                        print(block[4])
                    block_text = get_sentences_from_text(block[4])
                    if page.number == 0 and len(block_text) > 100 and "abstract" in block_text.lower():
                        abstrct_block_found = True
                        text += block_text + ".\npotential_dataset_ids: ["
                    else:
                        context_chars = min(250, len(block_text))  # Use a minimum
                        dataset_ids_found = extract_dataset_ids(block_text, context_chars)  # Extract dataset IDs from the block text
                        if dataset_ids_found:
                            # Append the dataset IDs found in the block to the text
                            num_blocks += 1
                            text += dataset_ids_found + ","  # Ensure the block text is processed for dataset IDs
    except Exception as e:
        print(f"Error reading PDF {pdf_path}: {e}")

    if p1 and not abstrct_block_found:
        text = p1 + text

    print(f"Total blocks with data related dataset IDs found: {num_blocks}")
    print(f"Extracted text from {pdf_path}. Length: {len(text)} characters")
    return text + "]"

def read_xml_text(xml_file_path: str, xml_type: str | None = None) -> str:
    """Reads and concatenates all text content from an XML file."""
    try:
        tree = etree.parse(xml_file_path, etree.XMLParser())
        root = tree.getroot()
        if xml_type and xml_type in XML_TYPE_EXTRACTORS:
            # Use the specific extractor function for the given XML type
            extract_function = XML_TYPE_EXTRACTORS[xml_type]
        else:
            # Default to JATS if no specific type is provided
            extract_function = XML_TYPE_EXTRACTORS['jats']
        parsed_dict = extract_function(root)
        return json.dumps(parsed_dict)
    except Exception as e:
        print(f"Error reading XML {xml_file_path}: {e}")
        return ""

def process_unsupported_file(file_path: str, xml_type: str | None = None) -> str:
    return f"Unsupported file type for: {file_path}"

# Dictionary mapping file extensions to loading functions
FILE_LOADERS = {
    '.xml': read_xml_text,
    '.pdf': read_pdf_text,
}

def load_article_text(file_path: str, xml_type: str | None = None) -> str:
    """
    Loads text content from a single article file (PDF or XML).
    Returns the text content of the given file.
    """
    text_content = ""

    # Get the file extension (e.g., '.xml', '.pdf')
    _, file_extension = os.path.splitext(file_path)
    file_extension = file_extension.lower() # Ensure lowercase for consistent lookup

    # Get the appropriate function from the dictionary,
    # or fall back to a default 'unsupported' function if not found.
    loader_function = FILE_LOADERS.get(file_extension, process_unsupported_file)

    # Call the selected function
    text_content = loader_function(file_path, xml_type=xml_type)

    return text_content


In [None]:
# Test loading various PDF files
#pdf_file_path = os.path.join(ARTICLE_TRAIN_DIR, 'PDF', '10.1002_2017jc013030.pdf')
#pdf_file_path = os.path.join(ARTICLE_TRAIN_DIR, 'PDF', '10.1017_rdc.2022.19.pdf')
#pdf_file_path = os.path.join(ARTICLE_TRAIN_DIR, 'PDF', '10.1017_s0007123423000601.pdf')
#pdf_file_path = os.path.join(ARTICLE_TRAIN_DIR, 'PDF', '10.3389_fcimb.2024.1292467.pdf')
#pdf_file_path = os.path.join(ARTICLE_TRAIN_DIR, 'PDF', '10.1002_esp.5058.pdf') # This one is big
#pdf_file_path = os.path.join(ARTICLE_TRAIN_DIR, 'PDF', '10.1002_esp.5059.pdf') # This one is big
#pdf_file_path = os.path.join(ARTICLE_TRAIN_DIR, 'PDF', '10.1002_ece3.4466.pdf') # dryad
#pdf_file_path = os.path.join(ARTICLE_TRAIN_DIR, 'XML', '10.1002_ece3.4466.xml') # dryad
pdf_file_path = os.path.join(ARTICLE_TRAIN_DIR, 'XML', '10.1002_mp.14424.xml')
pdf_file_path = os.path.join(ARTICLE_TRAIN_DIR, 'XML', '10.1007_s00259-022-06053-8.xml')    # jats
pdf_file_path = os.path.join(ARTICLE_TRAIN_DIR, 'XML', '10.1007_s00382-022-06361-7.xml')    # tei
pdf_file_path = os.path.join(ARTICLE_TRAIN_DIR, 'XML', '10.1111_1365-2435.13431.xml')       # wiley
pdf_file_path = os.path.join(ARTICLE_TRAIN_DIR, 'XML', '10.1111_mec.16977.xml')             # bioc
file_text = load_article_text(pdf_file_path, 'bioc')
display(len(file_text))
file_text

1773

'{"title": "Genetic basis of ecologically relevant body shape variation among four genera of cichlid fishes", "abstract": "Divergence in body shape is one of the most widespread and repeated patterns of morphological variation in fishes and is associated with habitat specification and swimming mechanics. Such ecological diversification is the first stage of the explosive adaptive radiation of cichlid fishes in the East African Rift Lakes. We use two hybrid crosses of cichlids (Metriaclima sp. \\u00c3\\u0097 Aulonocara sp. and Labidochromis sp. \\u00c3\\u0097 Labeotropheus sp., >975 animals total) to determine the genetic basis of body shape diversification that is similar to benthic-pelagic divergence across fishes. Using a series of both linear and geometric shape measurements, we identify 34 quantitative trait loci (QTL) that underlie various aspects of body shape variation. These QTL are spread throughout the genome, each explain 3.2\\u00e2\\u0080\\u00938.6% of phenotypic variation,

In [91]:
# Load the labeled training data CSV file
print(f"Loading labeled training data from: {LABELED_TRAINING_DATA_CSV_PATH}")
train_labels_df = pd.read_csv(LABELED_TRAINING_DATA_CSV_PATH)

print(f"Training labels shape: {train_labels_df.shape}")
display(train_labels_df.head())

Loading labeled training data from: ./kaggle/input/make-data-count-finding-data-references\train_labels.csv
Training labels shape: (1028, 3)


Unnamed: 0,article_id,dataset_id,type
0,10.1002_2017jc013030,https://doi.org/10.17882/49388,Primary
1,10.1002_anie.201916483,Missing,Missing
2,10.1002_anie.202005531,Missing,Missing
3,10.1002_anie.202007717,Missing,Missing
4,10.1002_chem.201902131,Missing,Missing


In [92]:
# Create a new column 'dataset_id_trim' by extracting the first 3 characters of the 'dataset_id' column
train_labels_df['dataset_id_trim'] = train_labels_df['dataset_id'].str[:3]
# Find the most frequent types of dataset_id's
freq_dataset_id_df = train_labels_df.groupby('dataset_id_trim').count().reset_index()
freq_dataset_id_df = freq_dataset_id_df[['dataset_id_trim', 'article_id']].sort_values(by='article_id', ascending=False)
print(f"Grouped dataset ID counts:\n{freq_dataset_id_df.head(10)}")


Grouped dataset ID counts:
   dataset_id_trim  article_id
52             htt         325
29             Mis         309
20             EPI          64
47             SAM          41
25             IPR          33
11             CHE          29
41             PRJ          26
16             E-G          25
19             ENS          21
26             K02          20


In [93]:
train_labels_df[train_labels_df['dataset_id_trim'] == 'EPI'].sample(3)  # Display the first 10 rows where dataset_id_trim is 'htt'

Unnamed: 0,article_id,dataset_id,type,dataset_id_trim
927,10.3390_v11060565,EPI954557,Secondary,EPI
830,10.3389_fcimb.2024.1292467,EPI_ISL_10271777,Secondary,EPI
377,10.1128_JVI.01717-21,EPI_ISL_293290,Primary,EPI


In [94]:
# Group training data by article_id to get all datasets for each article
# This creates a dictionary where keys are article_ids and values are lists of dataset dicts
grouped_training_data = {}
for article_id, group_df in train_labels_df.groupby('article_id'):
    grouped_training_data[article_id] = group_df[['dataset_id', 'type']].to_dict('records')

# Example usage of grouped_training_data
print(f"Example grouped training data for article_id '10.1002_2017jc013030': {grouped_training_data['10.1002_2017jc013030']}")


Example grouped training data for article_id '10.1002_2017jc013030': [{'dataset_id': 'https://doi.org/10.17882/49388', 'type': 'Primary'}]


In [95]:
def read_first_line_of_xml(file_path: str) -> str | None:
    """
    Reads and returns the first line of an XML file.

    Args:
        file_path (str): The path to the XML file.

    Returns:
        str | None: The first line of the file, stripped of leading/trailing whitespace,
                    or None if the file cannot be read or is empty.
    """
    if not file_path and not os.path.exists(file_path):
        return None
    
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            first_line = f.readline().replace('<?xml version="1.0" encoding="UTF-8"?>', '').replace('<?xml version="1.0" encoding="UTF-8" standalone="yes"?>', '')
            return first_line.strip()[:90] if first_line else None
    except UnicodeDecodeError:
        print(f"Warning: Could not decode '{file_path}' with UTF-8. Trying ISO-8859-1.")
        try:
            with open(file_path, 'r', encoding='iso-8859-1') as f:
                first_line = f.readline()
                return first_line.strip()[:90] if first_line else None
        except Exception as e:
            print(f"Error: Could not read '{file_path}' with ISO-8859-1 either: {e}")
            return None
    except Exception as e:
        print(f"Error reading file '{file_path}': {e}")
        return None

In [96]:
# Load file paths for training and testing datasets
train_file_paths_df = load_file_paths(ARTICLE_TRAIN_DIR)
test_file_paths_df = load_file_paths(ARTICLE_TEST_DIR)

# Remove rows in train_file_paths_df that have a corresponding article_id in test_file_paths_df
train_file_paths_df = train_file_paths_df[~train_file_paths_df['article_id'].isin(test_file_paths_df['article_id'])]

# Merge the file paths with the grouped_training_data
train_file_paths_df['dataset_info'] = train_file_paths_df['article_id'].map(grouped_training_data)
test_file_paths_df['dataset_info'] = test_file_paths_df['article_id'].map(grouped_training_data)

train_file_paths_df['xml_text'] = train_file_paths_df['xml_file_path'].apply(read_first_line_of_xml)
test_file_paths_df['xml_text'] = test_file_paths_df['xml_file_path'].apply(read_first_line_of_xml)

print(f"Train files paths shape: {train_file_paths_df.shape}")
display(train_file_paths_df.sample(3))
print(f"Test files paths shape: {test_file_paths_df.shape}")
display(test_file_paths_df.sample(3))

Error reading file 'nan': expected str, bytes or os.PathLike object, not float
Error reading file 'nan': expected str, bytes or os.PathLike object, not float
Error reading file 'nan': expected str, bytes or os.PathLike object, not float
Error reading file 'nan': expected str, bytes or os.PathLike object, not float
Error reading file 'nan': expected str, bytes or os.PathLike object, not float
Error reading file 'nan': expected str, bytes or os.PathLike object, not float
Error reading file 'nan': expected str, bytes or os.PathLike object, not float
Error reading file 'nan': expected str, bytes or os.PathLike object, not float
Error reading file 'nan': expected str, bytes or os.PathLike object, not float
Error reading file 'nan': expected str, bytes or os.PathLike object, not float
Error reading file 'nan': expected str, bytes or os.PathLike object, not float
Error reading file 'nan': expected str, bytes or os.PathLike object, not float
Error reading file 'nan': expected str, bytes or os.

Unnamed: 0,article_id,pdf_file_path,xml_file_path,dataset_type,dataset_info,xml_text
407,10.1590_0102-77863220007,./kaggle/input/make-data-count-finding-data-re...,,train,"[{'dataset_id': 'Missing', 'type': 'Missing'}]",
320,10.1186_s13059-020-02048-6,./kaggle/input/make-data-count-finding-data-re...,./kaggle/input/make-data-count-finding-data-re...,train,"[{'dataset_id': 'Missing', 'type': 'Missing'}]","<!DOCTYPE article PUBLIC ""-//NLM//DTD JATS (Z3..."
405,10.1590_0047-2085000000239,./kaggle/input/make-data-count-finding-data-re...,,train,"[{'dataset_id': 'Missing', 'type': 'Missing'}]",


Test files paths shape: (30, 6)


Unnamed: 0,article_id,pdf_file_path,xml_file_path,dataset_type,dataset_info,xml_text
25,10.1002_esp.5058,./kaggle/input/make-data-count-finding-data-re...,,test,[{'dataset_id': 'https://doi.org/10.5061/dryad...,
29,10.1007_jhep07(2018)134,./kaggle/input/make-data-count-finding-data-re...,,test,"[{'dataset_id': 'Missing', 'type': 'Missing'}]",
11,10.1002_ece3.3985,./kaggle/input/make-data-count-finding-data-re...,./kaggle/input/make-data-count-finding-data-re...,test,"[{'dataset_id': 'Missing', 'type': 'Missing'}]","<!DOCTYPE article PUBLIC ""-//NLM//DTD JATS (Z3..."


In [97]:
train_file_paths_df['xml_text'].unique()

array([None,
       '<!DOCTYPE article PUBLIC "-//NLM//DTD JATS (Z39.96) Journal Archiving and Interchange DTD ',
       '<html><body><tei xml:space="preserve" xmlns="http://www.tei-c.org/ns/1.0" xmlns:xlink="htt',
       '<component xmlns="http://www.wiley.com/namespaces/wiley" xmlns:wiley="http://www.wiley.com',
       '<!DOCTYPE article PUBLIC "-//NLM//DTD Journal Archiving and Interchange DTD v3.0 20080202/',
       '<!DOCTYPE collection SYSTEM "BioC.dtd"><collection><source>BioC-API</source><date>20250507',
       ''], dtype=object)

In [98]:
train_file_paths_df

Unnamed: 0,article_id,pdf_file_path,xml_file_path,dataset_type,dataset_info,xml_text
30,10.1007_jhep11(2018)113,./kaggle/input/make-data-count-finding-data-re...,,train,"[{'dataset_id': 'Missing', 'type': 'Missing'}]",
31,10.1007_jhep11(2018)115,./kaggle/input/make-data-count-finding-data-re...,,train,"[{'dataset_id': 'Missing', 'type': 'Missing'}]",
32,10.1007_jhep12(2018)117,./kaggle/input/make-data-count-finding-data-re...,,train,"[{'dataset_id': 'Missing', 'type': 'Missing'}]",
33,10.1007_s00259-022-06053-8,./kaggle/input/make-data-count-finding-data-re...,./kaggle/input/make-data-count-finding-data-re...,train,[{'dataset_id': 'https://doi.org/10.7937/k9/tc...,"<!DOCTYPE article PUBLIC ""-//NLM//DTD JATS (Z3..."
34,10.1007_s00382-012-1636-1,./kaggle/input/make-data-count-finding-data-re...,,train,"[{'dataset_id': 'Missing', 'type': 'Missing'}]",
...,...,...,...,...,...,...
519,10.7554_elife.74937,./kaggle/input/make-data-count-finding-data-re...,./kaggle/input/make-data-count-finding-data-re...,train,[{'dataset_id': 'https://doi.org/10.5281/zenod...,"<!DOCTYPE article PUBLIC ""-//NLM//DTD JATS (Z3..."
520,10.7717_peerj.10452,./kaggle/input/make-data-count-finding-data-re...,./kaggle/input/make-data-count-finding-data-re...,train,"[{'dataset_id': 'PRJNA664798', 'type': 'Second...","<!DOCTYPE article PUBLIC ""-//NLM//DTD JATS (Z3..."
521,10.7717_peerj.11352,./kaggle/input/make-data-count-finding-data-re...,./kaggle/input/make-data-count-finding-data-re...,train,[{'dataset_id': 'https://doi.org/10.7291/d11m3...,"<!DOCTYPE article PUBLIC ""-//NLM//DTD JATS (Z3..."
522,10.7717_peerj.12422,./kaggle/input/make-data-count-finding-data-re...,./kaggle/input/make-data-count-finding-data-re...,train,[{'dataset_id': 'https://doi.org/10.15468/dl.c...,"<!DOCTYPE article PUBLIC ""-//NLM//DTD JATS (Z3..."


In [52]:
test_file_paths_df

Unnamed: 0,article_id,pdf_file_path,xml_file_path,dataset_type,dataset_info,xml_text
0,10.1002_2017jc013030,./kaggle/input/make-data-count-finding-data-re...,./kaggle/input/make-data-count-finding-data-re...,test,[{'dataset_id': 'https://doi.org/10.17882/4938...,"<?xml version=""1.0"" encoding=""UTF-8""?><html><b..."
1,10.1002_anie.201916483,./kaggle/input/make-data-count-finding-data-re...,./kaggle/input/make-data-count-finding-data-re...,test,"[{'dataset_id': 'Missing', 'type': 'Missing'}]","<!DOCTYPE article PUBLIC ""-//NLM//DTD JATS (Z3..."
2,10.1002_anie.202005531,./kaggle/input/make-data-count-finding-data-re...,./kaggle/input/make-data-count-finding-data-re...,test,"[{'dataset_id': 'Missing', 'type': 'Missing'}]","<?xml version=""1.0"" encoding=""UTF-8""?><!DOCTYP..."
3,10.1002_anie.202007717,./kaggle/input/make-data-count-finding-data-re...,./kaggle/input/make-data-count-finding-data-re...,test,"[{'dataset_id': 'Missing', 'type': 'Missing'}]","<!DOCTYPE article PUBLIC ""-//NLM//DTD JATS (Z3..."
4,10.1002_chem.201902131,./kaggle/input/make-data-count-finding-data-re...,./kaggle/input/make-data-count-finding-data-re...,test,"[{'dataset_id': 'Missing', 'type': 'Missing'}]","<!DOCTYPE article PUBLIC ""-//NLM//DTD JATS (Z3..."
5,10.1002_chem.201903120,./kaggle/input/make-data-count-finding-data-re...,./kaggle/input/make-data-count-finding-data-re...,test,"[{'dataset_id': 'Missing', 'type': 'Missing'}]","<!DOCTYPE article PUBLIC ""-//NLM//DTD JATS (Z3..."
6,10.1002_chem.202000235,./kaggle/input/make-data-count-finding-data-re...,./kaggle/input/make-data-count-finding-data-re...,test,"[{'dataset_id': 'Missing', 'type': 'Missing'}]","<!DOCTYPE article PUBLIC ""-//NLM//DTD JATS (Z3..."
7,10.1002_chem.202001412,./kaggle/input/make-data-count-finding-data-re...,./kaggle/input/make-data-count-finding-data-re...,test,"[{'dataset_id': 'Missing', 'type': 'Missing'}]","<!DOCTYPE article PUBLIC ""-//NLM//DTD JATS (Z3..."
8,10.1002_chem.202001668,./kaggle/input/make-data-count-finding-data-re...,./kaggle/input/make-data-count-finding-data-re...,test,"[{'dataset_id': 'Missing', 'type': 'Missing'}]","<!DOCTYPE article PUBLIC ""-//NLM//DTD JATS (Z3..."
9,10.1002_chem.202003167,./kaggle/input/make-data-count-finding-data-re...,./kaggle/input/make-data-count-finding-data-re...,test,"[{'dataset_id': 'Missing', 'type': 'Missing'}]","<!DOCTYPE article PUBLIC ""-//NLM//DTD JATS (Z3..."


In [390]:
# --- QwenModelEval Class ---
# kagglehub.model_download("qwen-lm/qwen-3/transformers/0.6b")
#max_new_tokens=32768
class QwenModelEval:
    def __init__(self, model_name, sys_prompt, enable_thinking=True, max_new_tokens=1024):
        print(f"Loading Qwen model and tokenizer from: {model_name}")
        self.model_name = model_name
        self.sys_prompt = sys_prompt
        self.enable_thinking = enable_thinking  # Enable or disable thinking mode
        self.max_new_tokens = max_new_tokens  # Set the maximum number of new tokens to generate
        # Load the tokenizer and model
        # Using trust_remote_code=True to allow custom model code execution
        self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
        self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto", trust_remote_code=True)
        self.model.eval() # Set the model to evaluation mode here.

    def generate_response(self, user_input):  
        inputs = self._get_inputs(user_input)
        # Disable gradient calculation during inference
        # Generate the response using the model
        with torch.no_grad(): 
            generated_ids = self.model.generate(
                **inputs,
                max_new_tokens=self.max_new_tokens,
                pad_token_id=self.tokenizer.eos_token_id,
                eos_token_id=self.tokenizer.convert_tokens_to_ids("<|im_end|>"),
                # do_sample=False, # Use greedy decoding (fastest)
                # num_beams=1,     # Do not use beam search (fastest)
                # temperature=0.0, # Make output deterministic (if do_sample=False, this has no effect)                
                temperature=0.6 if self.enable_thinking else 0.7,
                top_p=0.95 if self.enable_thinking else 0.8,
                top_k=20,
                min_p=0
            )
        # Parse the response and thinking content
        return self._parse_response(inputs, generated_ids)

    def _get_inputs(self, user_input):
        """Prepare the input for the model based on user input."""
        # Trim the user input to a maximum length for better performance
        user_input = user_input[:4096]  # Limit input length to 4096 characters
        print(f"Preparing input with length: {len(user_input)}")
        # Create the messages for the chat template
        messages = [
            {"role": "system", "content": self.sys_prompt},
            {"role": "user", "content": user_input}
        ]
        text = self.tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True,
            enable_thinking=self.enable_thinking
        )
        return self.tokenizer(text, return_tensors="pt").to(self.model.device)
    
    def _parse_response(self, inputs, generated_ids):
        print("Parsing response from generated IDs...")
        # Extract the output IDs from the generated IDs
        output_ids = generated_ids[0][len(inputs.input_ids[0]):].tolist()
        try:
            index = len(output_ids) - output_ids[::-1].index(151668)
        except ValueError:
            index = 0

        thinking_content = self.tokenizer.decode(output_ids[:index], skip_special_tokens=True).strip("\n")
        raw_response = self.tokenizer.decode(output_ids[index:], skip_special_tokens=True).strip("\n")
        response = self._parse_json(raw_response)
        return response, thinking_content
    
    def _parse_json(self, raw_response: str) -> list[dict[str,str]]:
        # Remove code block markers and leading/trailing whitespace
        cleaned = raw_response.strip()
        if cleaned.startswith("```json"):
            cleaned = cleaned[len("```json"):].strip()
        if cleaned.endswith("```"):
            cleaned = cleaned[:-3].strip()

        # Now parse as JSON
        try:
            return json.loads(cleaned)
        except json.JSONDecodeError as e:
            return []        
        


In [396]:
# Define the one-shot reasoning, task, and example prompt
# This prompt is designed to guide the model through a structured reasoning process

# reason_prompt =  '''
# You are an advanced AI reasoning assistant tasked with delivering a comprehensive analysis of a specific problem or question.  Your goal is to outline your reasoning process in a structured and transparent manner, with each step reflecting a thorough examination of the issue at hand, culminating in a well-reasoned conclusion.

# ### Key Instructions:
# 1.  Conduct **at least 5 distinct reasoning steps**, each building on the previous one.
# 2.  **Acknowledge the limitations** inherent to AI, specifically what you can accurately assess and what you may struggle with.
# 3.  **Adopt multiple reasoning frameworks** to resolve the problem or derive conclusions, such as:
# - **Deductive reasoning** (drawing specific conclusions from general principles)
# - **Inductive reasoning** (deriving broader generalizations from specific observations)
# - **Abductive reasoning** (choosing the best possible explanation for the given evidence)
# - **Analogical reasoning** (solving problems through comparisons and analogies)
# 4.  **Critically analyze your reasoning** to identify potential flaws, biases, or gaps in logic.
# 5.  When reviewing, apply a **fundamentally different perspective or approach** to enhance your analysis.
# 6.  **Employ at least 2 distinct reasoning methods** to derive or verify the accuracy of your conclusions.
# 7.  **Incorporate relevant domain knowledge** and **best practices** where applicable, ensuring your reasoning aligns with established standards.
# 8.  **Quantify certainty levels** for each step and your final conclusion, where applicable.
# 9.  Consider potential **edge cases or exceptions** that could impact the outcome of your reasoning.
# 10.  Provide **clear justifications** for dismissing alternative hypotheses or solutions that arise during your analysis.
# '''

reason_prompt =  '''
You are an advanced AI research assistant that is skilled in identifying and classifying datasets used within academic research papers.
Be as accurate as possible but don't over think it.
'''

task_prompt = '''
You are given an article_id and the associated text of an academic research paper.
Within the text of the paper, you are given an abstract and a list of potential_dataset_ids that contains dataset_ids and their associated context within the paper.
You have 3 tasks:

1. Your first task is to identify all citations of datasets used in the research for this article. An article may cite zero or many datasets.
Datasets in an article can be cited within the context using various terms such as "data release", "data availability", "dataset", "database", "repository", "data source", "data access", "data archive".
Each dataset has a unique, persistent identifier to represent it called a dataset_id. If you find more than one citation of the same dataset_id, only process the first one.

There are 2 ways to identify a dataset_id:
The first way to identify a dataset_id is via a Digital Object Identifier (DOI). DOIs are used for all papers and some datasets. We want to identify DOIs that are used as dataset_id's, not the DOI of the paper itself or any other papers.
They take the following form: https://doi.org/[prefix]/[suffix]. The prefix always starts with "10." and is followed by a 4 or 5 digit number. The suffix can contain letters, numbers, and special characters.
Examples of DOI dataset_ids (just for reference):
https://doi.org/10.12345/12345
https://doi.org/10.1234/xxxx.1x1x-xx11

The second way to identify a dataset_id is via an Accession ID. Accession ID's vary in form by individual data repository where the data live. 
Examples of Accession ID's (just for reference):
"EPI_ISL_10271777" (EPI dataset)
"IPR000264" (InterPro dataset)
"SAMN07159041" (NCBI Sequence Read Archive dataset)
"CHEMBL1782574" (ChEMBL dataset)


2. Your second task is to classify the type of each dataset_id that you find as "Primary" or "Secondary" as it is used within the context of the paper.
Primary - raw or processed data generated as part of this paper, specifically for this study
Secondary - raw or processed data derived or reused from existing records or published data

3. Your third task is to return your results in a JSON format.
If an article does not refer to any dataset_id's, return a single JSON object with the following structure:
```json
[
    {
        "dataset_id": "Missing",
        "type": "Missing"
    }
]
If an article refers to one or more dataset_id's, you need to classify the type of each dataset as "Primary" or "Secondary" and
return every dataset found in a JSON array of objects, where each object has the following structure:
```json
[
    {
        "dataset_id": dataset_id here,
        "type": type here
    },
    ...
]
'''

SYS_PROMPT = reason_prompt + task_prompt

In [94]:
# train_file_paths_df['text'] = train_file_paths_df['pdf_file_path'].apply(load_article_text)
# test_file_paths_df['text'] = test_file_paths_df['pdf_file_path'].apply(load_article_text)

#Create a new df from train_file_paths_df where the string representation of dataset_info does not contain 'Missing'.
train_file_paths_df_2 = train_file_paths_df[~train_file_paths_df['dataset_info'].astype(str).str.contains('Missing')]
train_file_paths_df_3 = train_file_paths_df_2[train_file_paths_df_2['article_id'].astype(str).str.contains('10.1017_')]
train_file_paths_df_3


Unnamed: 0,article_id,pdf_file_path,xml_file_path,dataset_type,dataset_info
52,10.1017_rdc.2022.19,./kaggle/input/make-data-count-finding-data-re...,,train,[{'dataset_id': 'https://doi.org/10.11588/data...
53,10.1017_s0007123423000601,./kaggle/input/make-data-count-finding-data-re...,./kaggle/input/make-data-count-finding-data-re...,train,[{'dataset_id': 'https://doi.org/10.7910/dvn/f...


In [397]:
# Instantiate the QwenModelEval class with the model path and system prompt
inference_model = QwenModelEval(QWEN_BASE_MODEL_PATH, sys_prompt=SYS_PROMPT, enable_thinking=True, max_new_tokens=1576)

Loading Qwen model and tokenizer from: C:\Users\jim\.cache\kagglehub\models\qwen-lm\qwen-3\transformers\0.6b\1


In [398]:
def process_articles(file_paths_df: pd.DataFrame, model) -> pd.DataFrame:
    results = []
    for i, row in file_paths_df.iterrows():
        article_id = row['article_id']
        pdf_file_path = row['pdf_file_path']
        xml_file_path = row['xml_file_path']

        print(f"Processing article {i}/{len(file_paths_df)}: {article_id}")

        # Load the text content from the PDF or XML file
        text_content = load_article_text(pdf_file_path) if pdf_file_path else load_article_text(xml_file_path)

        # Prepare the user input for the model
        user_input = f"Article ID: {article_id}\nText Content: {text_content}\n"

        # Generate response from the model
        response, thinking_content = model.generate_response(user_input)

        results.append({
            'article_id': article_id,
            'llm_input': user_input,
            'llm_response': response,
            'llm_thinking_content': thinking_content
        })

    return pd.DataFrame(results).sort_values(by=["article_id"]).reset_index(drop=True)




In [399]:
processed_articles_df = process_articles(train_file_paths_df_3, inference_model)
processed_articles_df

Processing article 52/2: 10.1017_rdc.2022.19
Total blocks with data related dataset IDs found: 1
Extracted text from ./kaggle/input/make-data-count-finding-data-references\train\PDF\10.1017_rdc.2022.19.pdf. Length: 1934 characters
Preparing input with length: 1982
Parsing response from generated IDs...
Processing article 53/2: 10.1017_s0007123423000601
Total blocks with data related dataset IDs found: 1
Extracted text from ./kaggle/input/make-data-count-finding-data-references\train\PDF\10.1017_s0007123423000601.pdf. Length: 1248 characters
Preparing input with length: 1302
Parsing response from generated IDs...


Unnamed: 0,article_id,llm_input,llm_response,llm_thinking_content
0,10.1017_rdc.2022.19,Article ID: 10.1017_rdc.2022.19\nText Content:...,[{'dataset_id': 'https://doi.org/10.11588/data...,"<think>\nOkay, let's tackle this problem step ..."
1,10.1017_s0007123423000601,Article ID: 10.1017_s0007123423000601\nText Co...,[{'dataset_id': 'https://doi.org/10.7910/DVN/F...,"<think>\nOkay, let's tackle this step by step...."


In [400]:
# Save processed_articles_df to CSV
processed_articles_df.to_csv("processed_articles.csv", index=False)


In [401]:
import json
#response_text = processed_articles_df[processed_articles_df['article_id']=='10.1017_rdc.2022.19']['llm_response']
response_text = processed_articles_df.loc[processed_articles_df['article_id'] == '10.1017_rdc.2022.19', 'llm_response'].values
response_text[0]
# parsed_json = json.loads(response_text[0])
# print(parsed_json)

[{'dataset_id': 'https://doi.org/10.11588/data/10100', 'type': 'Primary'}]

In [None]:
test_set_processed_articles_df = process_articles(test_file_paths_df, inference_model)
test_set_processed_articles_df.to_csv("test_set_processed_articles.csv", index=False)
test_set_processed_articles_df


Processing article 0/30: 10.1002_2017jc013030
Total blocks with data related dataset IDs found: 3
Extracted text from ./kaggle/input/make-data-count-finding-data-references\test\PDF\10.1002_2017jc013030.pdf. Length: 4335 characters
Preparing input with length: 4096
Parsing response from generated IDs...
Processing article 1/30: 10.1002_anie.201916483
Total blocks with data related dataset IDs found: 0
Extracted text from ./kaggle/input/make-data-count-finding-data-references\test\PDF\10.1002_anie.201916483.pdf. Length: 789 characters
Preparing input with length: 840
Parsing response from generated IDs...
Processing article 2/30: 10.1002_anie.202005531
Total blocks with data related dataset IDs found: 0
Extracted text from ./kaggle/input/make-data-count-finding-data-references\test\PDF\10.1002_anie.202005531.pdf. Length: 1088 characters
Preparing input with length: 1139
Parsing response from generated IDs...
Processing article 3/30: 10.1002_anie.202007717
Total blocks with data related 

In [25]:
eval_model = QwenModelEval(QWEN_BASE_MODEL_PATH, sys_prompt="You are a chatbot.",)

Loading Qwen model and tokenizer from: C:\Users\jim\.cache\kagglehub\models\qwen-lm\qwen-3\transformers\0.6b\1


In [26]:
# --- Example Usage of eval_model ---

# First input (without /think or /no_think tags, thinking mode is enabled by default)
user_input_1 = "How many r's in strawberries?"
print(f"User: {user_input_1}")
response_1 = eval_model.generate_response(user_input_1)
print(f"Bot: {response_1}")
print("----------------------")

The following generation flags are not valid and may be ignored: ['temperature', 'top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


User: How many r's in strawberries?
Preparing input for user: How many r's in strawberries?
Generating response for user input:
Parsing response from generated IDs...
Bot: ('There are 2 \'r\'s in the word "strawberries".', '<think>\nOkay, the user is asking how many \'r\'s are in "strawberries". Let me start by breaking down the word. "Strawberries" is spelled S-T-R-A-W-B-E-R-R-I-N-G-S. Let me count each \'r\' here. \n\nFirst letter: S, no \'r\'. Second: T, no. Third: R, one \'r\'. Fourth: A, no. Fifth: W, no. Sixth: B, no. Seventh: E, no. Eighth: R, another \'r\'. Ninth: I, no. Tenth: N, no. Eleventh: G, no. So that\'s two \'r\'s in total. \n\nWait, maybe I should double-check. Let me write it out again: S-T-R-A-W-B-E-R-R-I-N-G-S. Yes, the third and eighth letters are both \'r\'s. So two \'r\'s. The user might be testing if I can count them correctly. I should make sure there\'s no other \'r\'s I missed. No, I think that\'s all. So the answer is two.\n</think>')
----------------------

In [None]:

# --- 2. Data Preparation for LLM Training (Revised for Combined Task) ---

def load_base_llm_for_training():
    """Loads the base Qwen model and tokenizer for fine-tuning."""
    global llm_tokenizer, llm_model
    if not AutoModelForCausalLM or not QWEN_BASE_MODEL_PATH:
        print("LLM components not available or base model path not set. Skipping LLM loading.")
        return False
    try:
        print(f"Loading Qwen tokenizer from: {QWEN_BASE_MODEL_PATH}")
        llm_tokenizer = AutoTokenizer.from_pretrained(QWEN_BASE_MODEL_PATH, trust_remote_code=True)
        if llm_tokenizer.pad_token is None:
            llm_tokenizer.pad_token = llm_tokenizer.eos_token
            print("Set tokenizer.pad_token to tokenizer.eos_token")

        print(f"Loading Qwen model from: {QWEN_BASE_MODEL_PATH}")
        llm_model = AutoModelForCausalLM.from_pretrained(
            QWEN_BASE_MODEL_PATH,
            torch_dtype=torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float32,
            device_map="auto", # Automatically uses GPU if available
            trust_remote_code=True,
            # load_in_8bit=True if bnb else False # Uncomment if bitsandbytes is used
        )
        print(f"Base LLM loaded successfully on {llm_model.device}.")
        return True
    except Exception as e:
        print(f"Error loading base LLM for training: {e}")
        llm_tokenizer, llm_model = None, None # Reset to None on failure
        return False

def prepare_training_data_for_llm(
    training_df: pd.DataFrame,
    all_article_texts: dict[str, str],
    tokenizer_max_length: int
) -> Dataset:
    """
    Prepares training data for LLM fine-tuning, aggregating dataset IDs and classifications
    per article and formatting into ChatML JSON output.
    """
    formatted_examples = []

    # Group training data by article_id to get all datasets for each article
    # This creates a dictionary where keys are article_ids and values are lists of dataset dicts
    grouped_training_data = training_df.groupby('article_id').apply(
        lambda x: [{"dataset_id": row['dataset_id'], "classification": row['label']} for _, row in x]
    ).to_dict()

    # Get all article IDs for which we have text content
    all_article_ids_with_text = set(all_article_texts.keys())
    
    # Iterate through all articles for which we have text (these are our potential training examples)
    for article_id in all_article_ids_with_text:
        article_text = all_article_texts.get(article_id, "")
        if not article_text:
            print(f"Warning: Article text for {article_id} not found. Skipping training example.")
            continue

        # Truncate article text to fit within the model's context window
        # Reserve tokens for the prompt and the expected JSON response.
        # A typical Qwen 1.5 model has 32768 max_seq_length.
        # 512 tokens for prompt/response is a safe estimate.
        truncated_article_text = article_text[:tokenizer_max_length - 512] 

        # Determine the ground truth output for this article
        if article_id in grouped_training_data:
            # Article has datasets, format them as JSON
            ground_truth_datasets = grouped_training_data[article_id]
            assistant_response_json = json.dumps(ground_truth_datasets, ensure_ascii=False)
        else:
            # Article has no datasets in training data, so the model should output an empty list.
            # This explicitly trains the model to output '[]' for "Missing" cases.
            assistant_response_json = "[]"
            # print(f"Info: Article {article_id} has no datasets in training data. Training to output '[]'.")

        # Construct the user message for the LLM
        user_message = f"""
Article Text:
{truncated_article_text}

Task: Identify all datasets or databases used in this research article and classify each as "Primary" (if created by the authors for this research) or "Secondary" (if an existing dataset used in this research).

Output Format: Provide a JSON list of objects. Each object should have "dataset_id" and "classification" keys. If no datasets are identified, return an empty JSON list: [].
"""
        # Construct the full ChatML formatted string for SFTTrainer
        # The trainer will use this entire string as the 'text' field.
        chatml_formatted_string = f"<|im_start|>system\nYou are an expert research assistant. Your task is to extract and classify datasets from scientific articles.<|im_end|>\n<|im_start|>user\n{user_message.strip()}<|im_end|>\n<|im_start|>assistant\n{assistant_response_json}<|im_end|>"
        
        formatted_examples.append({"text": chatml_formatted_string})

    if not formatted_examples:
        raise ValueError("No training examples could be prepared. Check your data and article texts.")

    return Dataset.from_list(formatted_examples)

# --- 3. LLM Model Training (Fine-tuning) ---

# Attempt to load tokenizer and model if not already loaded (e.g., if previous training failed or was skipped)
if llm_model is None:
    load_base_llm_for_training()

if llm_model and not training_df.empty and Dataset: # Ensure Dataset is imported
    print("\n--- Preparing data for Fine-tuning (Combined Task) ---")
    # Use the model's max_length for context, or a reasonable default if tokenizer isn't loaded
    max_len = llm_tokenizer.model_max_length if llm_tokenizer else 4096 
    train_dataset = prepare_training_data_for_llm(training_df, all_article_texts, max_len)
    
    print(f"Prepared {len(train_dataset)} examples for fine-tuning.")
    print("Example formatted training instance (first 500 chars):")
    print(train_dataset[0]['text'][:500])

    print("\n--- Starting Fine-tuning (Combined Task) ---")
    try:
        training_args = TrainingArguments(
            output_dir=f"{FINE_TUNED_MODEL_OUTPUT_DIR}/checkpoints",
            num_train_epochs=1,  # Start with 1 epoch, adjust as needed
            per_device_train_batch_size=1, # Adjust based on VRAM
            gradient_accumulation_steps=4, # Effective batch size = 1 * 4 = 4
            learning_rate=2e-5,
            logging_steps=10,
            save_steps=50, # Save checkpoints periodically
            fp16=torch.cuda.is_available() and not torch.cuda.is_bf16_supported(),
            bf16=torch.cuda.is_available() and torch.cuda.is_bf16_supported(),
            optim="paged_adamw_8bit", # Good for memory efficiency if bitsandbytes is installed
            # report_to="none", # Disable logging to external services
            # max_steps=100, # For quick testing
        )

        trainer = SFTTrainer(
            model=llm_model,
            tokenizer=llm_tokenizer,
            train_dataset=train_dataset,
            dataset_text_field="text", # This field contains the full ChatML string
            args=training_args,
            max_seq_length=max_len, # Use the model's full context length
            packing=False, # Set to True if your inputs are much shorter than max_seq_length
        )

        trainer.train()
        print("Fine-tuning completed.")

        print(f"Saving fine-tuned model to: {FINE_TUNED_MODEL_OUTPUT_DIR}")
        trainer.save_model(FINE_TUNED_MODEL_OUTPUT_DIR)
        print("Model and tokenizer saved.")

    except Exception as e:
        print(f"An error occurred during fine-tuning: {e}")
        import traceback
        traceback.print_exc()
        llm_model = None # Mark model as failed to load/train
else:
    print("Skipping LLM fine-tuning due to missing training data or LLM components.")


# --- 4. LLM-based Extraction & Classification (Inference) ---

# Load the fine-tuned model for inference (if training was successful)
# If training was skipped or failed, this will attempt to load from the base path or fail.
if inference_model is None: # Only load if not already loaded
    if AutoModelForCausalLM: # Check if transformers is available
        if os.path.exists(FINE_TUNED_MODEL_OUTPUT_DIR) and os.path.isdir(FINE_TUNED_MODEL_OUTPUT_DIR):
            MODEL_TO_LOAD = FINE_TUNED_MODEL_OUTPUT_DIR
            print(f"Loading fine-tuned model for inference from: {MODEL_TO_LOAD}")
        else:
            MODEL_TO_LOAD = QWEN_BASE_MODEL_PATH
            print(f"Fine-tuned model not found. Loading base model for inference from: {MODEL_TO_LOAD}")

        try:
            inference_tokenizer = AutoTokenizer.from_pretrained(MODEL_TO_LOAD, trust_remote_code=True)
            if inference_tokenizer.pad_token is None:
                inference_tokenizer.pad_token = inference_tokenizer.eos_token
            inference_model = AutoModelForCausalLM.from_pretrained(
                MODEL_TO_LOAD,
                torch_dtype=torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float32,
                device_map="auto",
                trust_remote_code=True
            ).eval() # Set to evaluation mode
            print(f"Inference LLM loaded successfully on {inference_model.device}.")
        except Exception as e:
            print(f"Error loading inference LLM from {MODEL_TO_LOAD}: {e}")
            inference_model, inference_tokenizer = None, None
    else:
        print("Transformers library not available. Cannot load LLM for inference.")


def extract_and_classify_with_llm(article_text: str) -> list[dict]:
    """
    Uses the loaded LLM to extract dataset IDs and classify them.
    Returns a list of dictionaries like [{"dataset_id": "...", "classification": "..."}].
    Returns an empty list if LLM is unavailable or parsing fails.
    """
    if not inference_model or not inference_tokenizer:
        print("  LLM unavailable for extraction/classification.")
        return [] # Return empty list if LLM is not loaded

    # Truncate article text for inference if it exceeds model's context window
    # Use the same max_length as during training for consistency
    max_inference_context_length = inference_tokenizer.model_max_length - 256 # Reserve tokens for prompt and response
    truncated_article_text = article_text[:max_inference_context_length]

    user_message = f"""
Article Text:
{truncated_article_text}

Task: Identify all datasets or databases used in this research article and classify each as "Primary" (if created by the authors for this research) or "Secondary" (if an existing dataset used in this research).

Output Format: Provide a JSON list of objects. Each object should have "dataset_id" and "classification" keys. If no datasets are identified, return an empty JSON list: [].
"""
    messages = [
        {"role": "system", "content": "You are an expert research assistant. Your task is to extract and classify datasets from scientific articles."},
        {"role": "user", "content": user_message.strip()}
    ]
    
    input_ids = inference_tokenizer.apply_chat_template(
        messages,
        tokenize=True,
        add_generation_prompt=True,
        return_tensors="pt"
    ).to(inference_model.device)

    try:
        with torch.no_grad():
            outputs = inference_model.generate(
                input_ids,
                max_new_tokens=512, # Allow more tokens for multiple dataset outputs
                pad_token_id=inference_tokenizer.eos_token_id,
                eos_token_id=inference_tokenizer.convert_tokens_to_ids("<|im_end|>")
            )
        
        response_text = inference_tokenizer.decode(
            outputs[0][input_ids.shape[1]:],
            skip_special_tokens=False # Keep special tokens to remove <|im_end|> explicitly
        ).strip()
        response_text = response_text.replace("<|im_end|>", "").strip()
        
        print(f"  LLM raw response: '{response_text}'")

        # Attempt to parse the JSON output
        try:
            parsed_data = json.loads(response_text)
            if isinstance(parsed_data, list):
                # Validate structure: each item should be a dict with 'dataset_id' and 'classification'
                valid_datasets = []
                for item in parsed_data:
                    if isinstance(item, dict) and 'dataset_id' in item and 'classification' in item:
                        # Basic validation for classification label
                        if item['classification'] in ["Primary", "Secondary"]:
                            valid_datasets.append(item)
                        else:
                            print(f"  Warning: Invalid classification '{item['classification']}' for dataset '{item.get('dataset_id', 'N/A')}'. Skipping.")
                    else:
                        print(f"  Warning: Malformed JSON object: {item}. Skipping.")
                return valid_datasets
            else:
                print(f"  Warning: LLM did not return a JSON list: {response_text}")
                return []
        except json.JSONDecodeError as jde:
            print(f"  Error decoding JSON from LLM response: {jde}. Raw response: '{response_text}'")
            return []

    except Exception as e:
        print(f"  Error during LLM generation: {e}")
        return []

# --- Main Processing Loop for all articles (Revised) ---
print("\n--- Starting Article Processing and Classification (LLM-driven) ---")
final_results = []

for article_id, article_text in all_article_texts.items():
    print(f"\nProcessing article: {article_id}")
    
    # LLM directly extracts and classifies
    identified_datasets = extract_and_classify_with_llm(article_text)
    
    if not identified_datasets:
        # If LLM returns an empty list, classify the article as "Missing"
        print(f"  LLM identified no datasets for {article_id}. Classifying as 'Missing'.")
        final_results.append({
            "article_id": article_id,
            "dataset_id": "N/A", # Indicate no specific dataset ID
            "classification_label": "Missing"
        })
    else:
        print(f"  LLM identified {len(identified_datasets)} dataset(s) for {article_id}.")
        for item in identified_datasets:
            final_results.append({
                "article_id": article_id,
                "dataset_id": item.get("dataset_id", "Unknown"), # Use .get() for safety
                "classification_label": item.get("classification", "Uncertain_LLM")
            })


# --- 5. Results & Output ---

print("\n--- Final Results ---")
if final_results:
    results_df = pd.DataFrame(final_results)
    print(results_df.head(10)) # Print first 10 rows
    
    # Save to CSV
    results_df.to_csv(FINAL_RESULTS_CSV_PATH, index=False)
    print(f"\nResults saved to: {FINAL_RESULTS_CSV_PATH}")
else:
    print("No results generated.")

print("\nProcessing complete, Jim!")