## 0. Kaggle Environment and Imports

In [63]:
# Imports
import os
import glob
import re
import pandas as pd
import collections
import xml.etree.ElementTree as ET
import PyPDF2

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer


## 1. Configuration

In [5]:

# Constants
MAX_TOKENS = 4096  # Adjust based on your model's capabilities
QWEN_MODEL_NAME_OR_PATH = "/kaggle/input/qwen-model-files/qwen-7b-chat" # Example
OUTPUT_CSV_PATH = "/kaggle/working/article_dataset_classification.csv"
ARTICLES_BASE_DIR = './kaggle/input/make-data-count-finding-data-references/'
ARTICLES_TRAIN_DIR = ARTICLES_BASE_DIR + 'train/'
ARTICLES_TEST_DIR = ARTICLES_BASE_DIR + 'test/'
ARTICLE_FORMATS = [{'format':'PDF', 'ext': '.pdf'}, {'format': 'XML', 'ext': '.xml'}]
train_labels_file_path = ARTICLES_BASE_DIR+'train_labels.csv'
sample_submission_file_path = ARTICLES_BASE_DIR+'sample_submission.csv'


## 2. Data Loading

In [20]:
# Read text from PDF files using PyPDF2
def read_pdf_text(pdf_file_path) -> str:
    """Extracts all text from a PDF file."""
    text = ""
    # Ensure the file path is a string and not NaN or empty
    if pd.isna(pdf_file_path) or not pdf_file_path:
        return text
    
    # Convert to string and strip whitespace
    pdf_file_path = str(pdf_file_path).strip()

    try:
        with open(pdf_file_path, 'rb') as pdf_file_obj:
            pdf_reader = PyPDF2.PdfReader(pdf_file_obj)
            num_pages = len(pdf_reader.pages)
            for page_num in range(num_pages):
                page_obj = pdf_reader.pages[page_num]
                text += page_obj.extract_text()
    except Exception as e:
        print(f"Error reading {pdf_file_path} with PyPDF2: {e}")
        
    return text

def read_xml_text(xml_file_path) -> str:
    """Reads and concatenates all text content from an XML file."""
    # Using your previously developed function
    all_text_parts = []
    try:
        tree = ET.parse(xml_file_path)
        root = tree.getroot()
        for element in root.iter():
            if element.text:
                cleaned_text = element.text.strip()
                if cleaned_text:
                    all_text_parts.append(cleaned_text)
            if element.tail:
                cleaned_tail = element.tail.strip()
                if cleaned_tail:
                    all_text_parts.append(cleaned_tail)
        return " ".join(all_text_parts) if all_text_parts else ""
    except Exception as e:
        print(f"Error reading XML {xml_file_path}: {e}")
        return ""



## 3. Data Extraction

In [28]:
# Function to normalize dataset IDs
# This function takes a dataset ID as input and normalizes it by removing the "doi.org/" prefix if it exists.
def get_dataset_id_regex(id: str) -> str:
    # Regex to capture the DOI part after "doi.org/"
    # It handles optional "https://" and "www."
    regex_id = id
    dryad_marker = "/dryad."
    regex = r"(?:https://)?(?:www\.)?doi\.org/(.+)"
    match = re.search(regex, str(id).lower())
    if match:
        # The DOI is in the first capturing group
        full_doi_candidate = match.group(1)
        dryad_index = full_doi_candidate.find(dryad_marker)
        if dryad_index != -1:
            # Calculate the starting point of the suffix (right after "/dryad." + 5 characters)
            start_of_suffix = dryad_index + len(dryad_marker) + 5
            # "/dryad." is found in the DOI candidate
            prefix = full_doi_candidate[:start_of_suffix]
            
            # Get the remaining characters for the suffix
            suffix = full_doi_candidate[start_of_suffix : ]
            
            # Construct the regex ID
            regex_id = prefix + '\\s?' + suffix
        else:
            # Otherwise, return the full DOI
            regex_id = full_doi_candidate

    return regex_id.replace('.', '\\.\\s?').replace('/', '\\/')

In [36]:
def find_regex_with_context(main_string: str, search_regex: str, context_chars: int = 200) -> list[str]:
    """
    Finds all occurrences of search_regex within main_string and returns
    a context window for each. The context window includes the matching search_string
    itself, surrounded by up to 'context_chars' characters from before and
    after its occurrence in the main_string.

    Args:
        main_string (str): The string to search within.
        search_regex (str): The regular expression to search for.
        context_chars (int): The number of characters to include before and after
                             the search_string in the context window. Defaults to 100.

    Returns:
        List[str]: A list of strings, where each string is an occurrence of
                   search_string surrounded by its context. Returns an empty
                   list if search_string is not found, or if either
                   main_string or search_string is empty.
    """
    # Ensure the main_string and search_regex are valid
    if not main_string or not search_regex:
        return []

    re_doi = re.compile(search_regex, re.IGNORECASE)
    occurrences_with_context: list[str] = []
    len_search: int = len(search_regex)

    doi_matches = re_doi.finditer(main_string, re.IGNORECASE)
    for match in doi_matches:
        extracted_snippet = main_string[max(0, match.start() - context_chars): match.start() + len_search ]
        occurrences_with_context.append(extracted_snippet)
            
    return occurrences_with_context

In [60]:
def remove_unmatched_parentheses(s: str) -> str:
    """
    Removes non-matching '(' and ')' characters from a string.
    A parenthesis is considered matching if it forms a valid pair.

    Args:
        s (str): The input string.

    Returns:
        str: The string with all non-matching parentheses removed.
    """
    if not s:
        return ""

    # Use a deque as a stack to store indices of opening parentheses.
    # When we find a '(', we push its index. When we find a ')', we pop an index.
    open_paren_indices_stack = collections.deque()
    
    # A boolean list to mark characters that should be kept in the final string.
    # Initially, assume all characters are kept. We'll mark unmatched parentheses as False.
    keep_char = [True] * len(s)

    for i, char in enumerate(s):
        if char == '(':
            # This is a potential opening parenthesis. Store its index.
            open_paren_indices_stack.append(i)
        elif char == ')':
            if open_paren_indices_stack:
                # Found a matching opening parenthesis for this closing one.
                # Pop the index of the matched opening parenthesis from the stack.
                open_paren_indices_stack.pop()
            else:
                # This closing parenthesis has no matching opening parenthesis.
                # It is unmatched and should be removed.
                keep_char[i] = False
        # For non-parenthesis characters, keep_char[i] remains True (its default value).
    
    # After iterating through the entire string, any opening parentheses
    # remaining in the stack are unmatched because they never found a closing pair.
    # Mark these for removal.
    while open_paren_indices_stack:
        unmatched_open_idx = open_paren_indices_stack.pop()
        keep_char[unmatched_open_idx] = False
            
    # Construct the final string by iterating through the original string
    # and appending only the characters marked to be kept.
    final_chars = [s[i] for i, should_keep in enumerate(keep_char) if should_keep]
            
    return "".join(final_chars)


In [61]:
def scrub_doi(doi: str) -> str:
    doi = doi.strip()
    # Remove non-matching "(" and ")" characters
    doi = remove_unmatched_parentheses(doi)
    # Remove any leading or trailing whitespace
    doi = doi.strip()
    return doi

In [56]:
def extract_dois_from_text(text) -> list[str]:
    """
    Extracts DOIs from text using a flexible regex.
    (This would use your modified extract_doi_flexible logic,
     but here we'll just use a general DOI regex for simplicity in this plan)
    """
    if not text:
        return []
    # A common DOI regex pattern (simplified for this example)
    # For example, to find any string starting with 10. followed by numbers/dots/slashes
    #doi_pattern = r'\b10\.\d{4,9}/[-._;()/:A-Za-z0-9]+\b'
    #doi_pattern = r'\b10\.\s?\d{4,9}\/[-._()<>;\/:A-Za-z0-9\s]+[-._()<>;\/:0-9]+'
    doi_pattern = r'\b10\.\s?\d{4,9}\/[-._()<>;\/:A-Za-z0-9]+\s?(?![A-Z]+)+[-._()<>;\/:A-Za-z0-9]+'
    found_dois = set(re.findall(doi_pattern, text))
    found_dois = ["".join(scrub_doi(doi).split()) for doi in found_dois]  # Clean up whitespace
    return list(set(found_dois))  # Return unique DOIs)


def extract_dataset_ids(text, extracted_dois: list[str]) -> list[str]:
    """
    Identifies dataset IDs. For now, let's assume dataset_ids are primarily the DOIs found.
    This function can be expanded with more regexes for other ID types.
    """
    # For this version, we'll consider all found DOIs as potential dataset_ids
    # You might want to add more sophisticated logic here to find non-DOI dataset IDs
    data_related_dois = []
    for doi in extracted_dois:
        print(f"Processing DOI: {doi}")
        # Skip if DOI is empty or NaN
        if pd.isna(doi) or not doi:
            continue
        # Normalize the DOI to a regex format
        regex_id = get_dataset_id_regex(doi)
        # Find occurrences of this DOI in the text with context
        occurrences = find_regex_with_context(text, regex_id)
        # Check if any string in occurrences contains 'data' (case-insensitive)
        if occurrences and any(('dataset' in s.lower() or 'database' in s.lower()) for s in occurrences):
            data_related_dois.append(doi)
    return data_related_dois 


## 4. LLM Classification

In [None]:
# Global LLM model and tokenizer (load once)
llm_model = None
llm_tokenizer = None
device = "cuda" if torch and torch.cuda.is_available() else "cpu"

def load_llm():
    global llm_model, llm_tokenizer
    if not AutoModelForCausalLM or not QWEN_MODEL_NAME_OR_PATH:
        print("LLM components not available or path not set. Skipping LLM loading.")
        return False
    try:
        print(f"Loading Qwen tokenizer from: {QWEN_MODEL_NAME_OR_PATH}")
        llm_tokenizer = AutoTokenizer.from_pretrained(QWEN_MODEL_NAME_OR_PATH, trust_remote_code=True)
        print(f"Loading Qwen model from: {QWEN_MODEL_NAME_OR_PATH}")
        llm_model = AutoModelForCausalLM.from_pretrained(
            QWEN_MODEL_NAME_OR_PATH,
            device_map="auto", # Automatically uses GPU if available
            trust_remote_code=True
        ).eval() # Set to evaluation mode
        print(f"LLM loaded successfully on {llm_model.device}.")
        return True
    except Exception as e:
        print(f"Error loading LLM: {e}")
        return False

def generate_llm_classification(article_text_snippet, dataset_id):
    """
    Uses the LLM to classify dataset usage.
    article_text_snippet: A relevant portion of article text, or full text if manageable.
    """
    if not llm_model or not llm_tokenizer:
        print("LLM not loaded. Cannot classify.")
        return "Error: LLM not loaded"

    prompt = f"""
    You are an expert research assistant. Your task is to determine how a dataset was used in a research article.
    Read the following article context and the dataset identifier carefully.

    Article Context (excerpt):
    "{article_text_snippet[:4000]}"  # Truncate for context window, adjust as needed

    Dataset Identifier: "{dataset_id}"

    Question: Based on the provided article context, was the dataset (identified as "{dataset_id}"):
    1. Created by the authors primarily for the research described in THIS article? (If so, it's "Primary")
    2. An existing dataset that the authors obtained and used for their research in THIS article? (If so, it's "Secondary")

    Please respond with only one word: "Primary" or "Secondary".
    """
    
    # This is a simplified generation example for Qwen.
    # You'll need to adapt this to the specific chat/completion format Qwen expects.
    # For many chat models, it's a list of messages:
    # messages = [{"role": "user", "content": prompt}]
    # inputs = llm_tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt").to(device)
    
    # Simpler non-chat completion style (check Qwen docs for best practice):
    inputs = llm_tokenizer(prompt, return_tensors="pt", truncation=True, max_length=llm_tokenizer.model_max_length - 50).to(llm_model.device) # Reserve some tokens for generation

    try:
        with torch.no_grad(): # Important for inference
            outputs = llm_model.generate(
                **inputs,
                max_new_tokens=10, # We only expect "Primary" or "Secondary"
                pad_token_id=llm_tokenizer.eos_token_id # Important for some models
            )
        
        # Decode the generated tokens, skipping special tokens and the prompt
        response_text = llm_tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True).strip()
        
        print(f"LLM raw response for {dataset_id}: '{response_text}'")

        if "Primary" in response_text:
            return "Primary"
        elif "Secondary" in response_text:
            return "Secondary"
        else:
            print(f"Warning: LLM response for {dataset_id} not clearly Primary/Secondary: '{response_text}'")
            return "Uncertain" # Or handle as per your logic

    except Exception as e:
        print(f"Error during LLM generation for {dataset_id}: {e}")
        return "Error: LLM generation failed"



## 5. Main Processing Logic

In [13]:
def get_all_article_files(base_directory):
    """
    Returns a list of all article files (PDF and XML) in the specified base_directory.
    """
    # Collect files from all formats and flatten the list
    all_article_files = [glob.glob(os.path.join(base_directory, fmt['format'], f"*{fmt['ext']}")) for fmt in ARTICLE_FORMATS]  
    return [item for sublist in all_article_files for item in sublist]


In [26]:

# --- 5. Main Processing Logic ---
def process_articles(articles_directory):
    results = []
    
    # if not load_llm(): # Attempt to load LLM once
    #     print("Proceeding without LLM classification.")

    article_files = get_all_article_files(articles_directory)

    for i, filepath in enumerate(article_files):
        print(f"\nProcessing article {i+1}/{len(article_files)}: {os.path.basename(filepath)}")
        article_id = os.path.splitext(os.path.basename(filepath))[0]
        text_content = None

        if filepath.endswith(".pdf"):
            text_content = read_pdf_text(filepath)
        elif filepath.endswith(".xml"):
            text_content = read_xml_text(filepath)

        if not text_content:
            print(f"Could not extract text from {filepath}. Skipping.")
            results.append({
                "article_id": article_id,
                "dataset_id_found": "N/A",
                "doi_found": "N/A",
                "classification_label": "Error: No text"
            })
            continue

        # Extract DOIs and other potential dataset IDs
        extracted_dois = extract_dois_from_text(text_content)
        dataset_ids_found = extract_dataset_ids(text_content, extracted_dois) # Currently just uses DOIs

        if not dataset_ids_found:
            print(f"No dataset IDs/DOIs found in {article_id}.")
            results.append({
                "article_id": article_id,
                "dataset_id_found": "None",
                "doi_found": "None",
                "classification_label": "Missing"
            })
        else:
            print(f"Found {len(dataset_ids_found)} potential dataset(s)/DOI(s) in {article_id}: {dataset_ids_found}")
            for ds_id in dataset_ids_found:
                # For LLM classification, provide context.
                # You might want to be more selective about the text snippet.
                # For now, using the beginning of the text.
                classification = "LLM_Disabled"
                # if llm_model and llm_tokenizer: # Check if LLM is loaded
                #     classification = generate_llm_classification(text_content, ds_id)
                
                # Determine associated DOI (might be the same as ds_id if it's a DOI)
                # This logic might need refinement based on how ds_id is defined
                doi_for_dataset = ds_id if re.match(r"^10\.", ds_id) else "N/A (Non-DOI ID)"

                results.append({
                    "article_id": article_id,
                    "dataset_id_found": ds_id,
                    "doi_found": doi_for_dataset, 
                    "classification_label": classification
                })
    return pd.DataFrame(results)


In [64]:
result_df = process_articles(ARTICLES_TEST_DIR)
result_df



Processing article 1/55: 10.1002_2017jc013030.pdf
Processing DOI: 10.1029/2000JC000319
Processing DOI: 10.1002/lom3.10144
Processing DOI: 10.1002/2015GL064540
Processing DOI: 10.4319/lo.2007.52.2.0739
Processing DOI: 10.4319/lom.2012.10.910
Processing DOI: 10.1016/j.dsr.2006.07.009
Processing DOI: 10.1016/S0967-0637(01)00025-5
Processing DOI: 10.3354/meps07998
Processing DOI: 10.1016/j.dsr.2008.11.006
Processing DOI: 10.1364/AO.41.005755
Processing DOI: 10.1111/j.0022-3646.1991.00008.x
Processing DOI: 10.1002/2013GB004781
Processing DOI: 10.5194/bg-5-171-2008
Processing DOI: 10.1016/j.dsr.2003.09.002Acknowledgments
Processing DOI: 10.1175/JTECH-D-15-0193.1
Processing DOI: 10.5670/oceanog.2014.78
Processing DOI: 10.1016/j.rse.2013.09.009
Processing DOI: 10.5194/bg-6-947-2009
Processing DOI: 10.1029/2010JC006899
Processing DOI: 10.5194/os-11-759-2015
Processing DOI: 10.1364/OE.18.015073
Processing DOI: 10.1016/j.rse.2013.03.025
Processing DOI: 10.13155/39459
Processing DOI: 10.1016/j.ds

Unnamed: 0,article_id,dataset_id_found,doi_found,classification_label
0,10.1002_2017jc013030,10.17882/47142,10.17882/47142,LLM_Disabled
1,10.1002_2017jc013030,10.17882/49388,10.17882/49388,LLM_Disabled
2,10.1002_2017jc013030,10.1002/2017JC013030,10.1002/2017JC013030,LLM_Disabled
3,10.1002_2017jc013030,10.5194/essd-9-861-2017,10.5194/essd-9-861-2017,LLM_Disabled
4,10.1002_anie.201916483,,,Missing
...,...,...,...,...
60,10.1002_ecs2.4619,10.25349/D9QW5X,10.25349/D9QW5X,LLM_Disabled
61,10.1002_ejic.201900904,,,Missing
62,10.1002_ejoc.202000916,,,Missing
63,10.1002_esp.5090,10.1080/08120090802546977,10.1080/08120090802546977,LLM_Disabled


In [None]:

# --- 6. Execution ---
if __name__ == "__main__":
    # Ensure the ARTICLES_DIR exists or adjust path
    if not os.path.isdir(ARTICLES_DIR):
        print(f"Articles directory not found: {ARTICLES_DIR}")
        print("Please create dummy files or point to a valid directory for testing.")
        # Create dummy files for a quick test if ARTICLES_DIR is missing
        # This part is for local testing, remove or adapt for Kaggle
        if ARTICLES_DIR == "/kaggle/input/research-articles-dataset/articles/": # Be careful with this
             print("Cannot create dummy files in /kaggle/input. Please provide data via Kaggle Datasets.")
        else: # Local testing
            os.makedirs(ARTICLES_DIR, exist_ok=True)
            with open(os.path.join(ARTICLES_DIR, "article1.pdf"), "w") as f: f.write("Dummy PDF with DOI 10.1234/foo.bar and dataset created by us.") # Needs actual PDF content
            with open(os.path.join(ARTICLES_DIR, "article2.xml"), "w") as f: f.write("<root><text>Used dataset 10.5678/baz.qux from another study.</text></root>")
    
    print("Starting article processing...")
    df_results = process_articles(ARTICLES_DIR)
    
    print("\n--- Results ---")
    print(df_results.head())
    
    df_results.to_csv(OUTPUT_CSV_PATH, index=False)
    print(f"\nResults saved to {OUTPUT_CSV_PATH}")

    # If you have training data, you can load it here and compare/evaluate
    # Example:
    # if os.path.exists(TRAINING_DATA_PATH):
    #     df_train = pd.read_csv(TRAINING_DATA_PATH)
    #     print("\nTraining Data Head:")
    #     print(df_train.head())
    #     # ... further evaluation logic ...