# IMPORT PACKAGES

In [None]:
import json
import os
import shutil
import csv
from functools import partial
import collections
from collections import Counter , defaultdict
import random
from typing import List, Optional, Dict, Tuple, Union 
from dataclasses import dataclass, field
from pathlib import Path
import unicodedata
import time
import logging

import re 
import difflib

import pandas as pd
import numpy as np

import matplotlib.pyplot as plt
import seaborn as sns


from lxml import etree
import xml.etree.ElementTree as ET    


from nltk.tokenize import sent_tokenize

from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_recall_fscore_support, confusion_matrix, ConfusionMatrixDisplay
from sklearn.metrics.pairwise import cosine_distances
from sklearn.model_selection import GroupShuffleSplit, StratifiedGroupKFold


from transformers import (
                            AutoTokenizer, AutoModelForTokenClassification,                
                            AutoModelForSequenceClassification, Trainer,                    
                            TrainingArguments, PreTrainedTokenizer, 
                            DataCollatorForTokenClassification, 
                            DataCollatorWithPadding,PreTrainedTokenizer, 
                            PreTrainedModel,EarlyStoppingCallback
                        )
from datasets import Dataset, DatasetDict, ClassLabel
import torch
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm

import pymupdf
import seqeval
from seqeval.metrics import classification_report, f1_score, precision_score, recall_score,accuracy_score
import fitz  # PyMuPDF

# DOWNLOAD THE DATA FROM KAGGLE

**Download the data from [make-data-count-finding-data-references](https://www.kaggle.com/competitions/make-data-count-finding-data-references/data). Extract the data into a folder called ```data```.**

# CREATE A WORKING DIRECTORY AND ALL NECESSARY FOLDERS

In [None]:
ROOT_PATH = "./" # You can customize
WORKING_DIR = Path(f"{ROOT_PATH}")
WORKING_DIR.mkdir(parents=True, exist_ok=True)
SUBMISSION_FILE_PATH = WORKING_DIR/"submission.csv"
DATA_ROOT = WORKING_DIR/"data"
TRAIN_PDF_FILES = DATA_ROOT/ "train/PDF/"
TRAIN_XML_FILES = DATA_ROOT/"train/XML/"
TRAIN_LABELS_PATH = DATA_ROOT/"train_labels.csv"
TEST_PDF_FILES = DATA_ROOT/ "test/PDF/"
TEST_XML_FILES = DATA_ROOT/"test/XML/"
MODELS_DIR = WORKING_DIR/"Models"
MODELS_DIR.mkdir(parents=True, exist_ok=True)
MODEL_NAME = "allenai/scibert_scivocab_uncased"

# EXPLORATORY DATA ANALYSIS

In [None]:
# Afficher toutes les lignes
pd.set_option('display.max_rows', None)

# Afficher toutes les colonnes
pd.set_option('display.max_columns', None)

pd.options.mode.copy_on_write = True

In [None]:
train_labels_df = pd.read_csv(TRAIN_LABELS_PATH)
train_labels_df.head()

In [None]:
train_labels_df.article_id.nunique()

In [None]:
train_labels_df.shape[0]

In [None]:
train_labels_df.groupby("article_id").agg('count').sort_values(by="dataset_id",ascending=False)["dataset_id"][:25]

In [None]:
train_pdf_articles = [art.stem  for art in TRAIN_PDF_FILES.glob("*.pdf")]
train_xml_articles = [art.stem  for art in TRAIN_XML_FILES.glob("*.xml")]

In [None]:
len(train_pdf_articles)

In [None]:
len(train_xml_articles)

In [None]:
train_pdf_articles_df = pd.DataFrame(train_pdf_articles, columns=["article_id"])
train_xml_articles_df = pd.DataFrame(train_xml_articles, columns=["article_id"])

In [None]:
train_pdf_articles_df.head()

In [None]:
train_xml_articles_df.head()

In [None]:
# Missing values for the dataset_id variable
len(train_labels_df[train_labels_df["dataset_id"]=="Missing"])

In [None]:
# Missing values for the type variable
len(train_labels_df[train_labels_df["type"]=="Missing"])

# DATA CLEANNING

1. Handle missing values
2. Create an identification format(DOI or ACCESSION ID) for any mention
3. Extract the raw text from the articles
4. Extract all the training mentions from the article in their raw format:
   
   4.1. Regex based extractions
   
   4.2. Manuel corrections

In [None]:
train_labels_df.shape[0]

## 1. Handle missing values

In [None]:
cleaned_train_labels_df = train_labels_df[train_labels_df["type"] != "Missing"]
cleaned_train_labels_df.shape[0]

In [None]:
cleaned_train_labels_df.head()

## 2. Create an identification format(DOI or ACCESSION ID) for any mention

In [None]:
def add_doi_type(mention):
    
    doi_pattern = r'(https?://(?:dx\.)?doi\.org/10\.\d{4,9}/[-._;()/:a-zA-Z0-9]+|doi:10\.\d{4,9}/[-._;()/:a-zA-Z0-9]+)'
    
    return "DOI" if re.fullmatch(doi_pattern,mention,flags=re.IGNORECASE) else "ACC"

cleaned_train_labels_df["dataset_id_type"] = cleaned_train_labels_df.dataset_id.apply(add_doi_type)

cleaned_train_labels_df.head()

In [None]:
dois_labels_df = cleaned_train_labels_df.query("dataset_id_type == 'DOI'")
dois_labels_df.head(10)

In [None]:
acc_labels_df = cleaned_train_labels_df.query("dataset_id_type == 'ACC'")
acc_labels_df.head(10)

In [None]:
acc_labels_df.shape[0]

## 3. Extract the raw text from the articles

In [None]:
class ExtractTextFromArticles:
    """
    A utility class to extract raw text from a collection of scientific articles 
    stored in either XML or PDF formats. Supports batch extraction and export to JSON.

    Attributes:
        file_format (str): The format of the articles to process ("xml" or "pdf").
    """

    def __init__(self, file_format: str):
        """
        Initialize the extractor with a specified file format.
        
        Args:
            file_format (str): The format of the articles ("xml" or "pdf").
        """
        self.file_format = file_format.lower()

        self.invisible_chars = re.compile(
            r'[\u00A0\u200B\u200C\u200D\u2060\uFEFF\u202F\u00AD\u2003\x00-\x09\x0B-\x0C\x0E-\x1F\x7F\xa0]')
        self.invisible_spaces =  re.compile(r'[\u00A0\u2000-\u200A\u202F\u205F]')     

        # Configure logging
        self.logger = logging.getLogger(__name__)
        self.logger.setLevel(logging.INFO)
        
        if not self.logger.handlers:
            handler = logging.StreamHandler()
            handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
            self.logger.addHandler(handler)

            
    def remove_inv_chars(self, text: str) -> str:
         
        """Remove invisible characters and normalize spaces."""
        if self.file_format =="pdf":
            text = self.invisible_spaces.sub(' ', text)
            text = self.invisible_chars.sub('', text)
        if self.file_format =="xml":
            text = re.sub(r'[^\x20-\x7E\u000A\u2013]', '', text)
        return text   


    def clean_articles(self, articles: List[Dict[str,str]]) -> Dict[str,str]:
         
        """Remove invisible characters and normalize spaces in the text"""

        cleaned = {d['file_name']:self.remove_inv_chars(d['raw_full_text']) for d in articles}

        return cleaned  
        
   
    def extract_all_texts(self, file_paths):
        """
        Extract raw text from all files in the given list of file paths.

        Args:
            file_paths (List[Path]): List of Path objects pointing to XML or PDF files.

        Returns:
            map: A map object containing dictionaries with 'article_id' and 'raw_full_text'.
        """
        self.logger.info(f"Starting extraction from {len(file_paths)} file(s)...")
        if self.file_format == "xml":
            return map(self.from_xml_file, file_paths)
        else:
            return map(self.from_pdf_file, file_paths)
        

    def from_xml_file(self, file_path):
        """
        Extract all text content from a single XML file.

        Args:
            file_path (Path): Path to an XML file.

        Returns:
            dict: A dictionary with:
                - 'article_id': file name without extension
                - 'raw_full_text': all text content concatenated
        """
        try:
            tree = ET.parse(file_path)
            root = tree.getroot()
            text_content = ' '.join(root.itertext())
            self.logger.debug(f"Extracted XML: {file_path.name}")
            return {
                "file_name": file_path.stem,
                "raw_full_text": ' '.join(text_content.split())
            }
        except Exception as e:
            self.logger.error(f"Failed to parse XML: {file_path.name} — {e}")
            return {"file_name": file_path.stem, "raw_full_text": ""}
 

    def from_pdf_file(self,file_path):
        """
        Extract all text content from a single PDF file, removing repetitive headers and footers.
    
        Args:
            file_path (Path): Path to a PDF file.
    
        Returns:
            dict: {
                - 'article_id': file name without extension
                - 'raw_full_text': text without repeated headers/footers
            }
        """
        def is_informative(line: str) -> bool:
            
            line = line.strip()
        
            # Supprimer les lignes du type : "5622  |", "14 |", etc.
            if re.fullmatch(r"\d+\s*\|", line):
                return False
        
            return True
            
        try:    
            with fitz.open(file_path) as doc:
                line_occurrences = collections.Counter()
                page_lines_list = []
    
                # Step 1: collect all lines per page and count occurrences
                for page in doc:
                    lines = page.get_text().splitlines()
                    page_lines_list.append(lines)
                    line_occurrences.update(set(lines))  # use set to count once per page
    
                # Step 2: identify potential headers/footers (appear on many pages)
                min_repeats = max(2, int(len(doc) * 0.7))  # heuristic: appears in >=70% of pages
                common_lines = {line for line, count in line_occurrences.items() if count >= min_repeats}
    
                # Step 3: clean pages
                cleaned_pages = []
                for lines in page_lines_list:
                    cleaned = [line for line in lines if line.strip() and line not in common_lines
                              and is_informative(line)]
                    cleaned_pages.append("\n".join(cleaned))
    
                full_text = "\f".join(cleaned_pages)  # \f = page separator (chr(12))
    
            self.logger.debug(f"Extracted and cleaned PDF: {file_path.name}")
            return {
                "file_name": file_path.stem,
                "raw_full_text": full_text.strip()
            }
    
        except Exception as e:
            self.logger.error(f"Failed to extract PDF: {file_path.name} — {e}")
            return {"file_name": file_path.stem, "raw_full_text": ""}

    def get_one_path(self, file_name: str, root_path: Path):
        """
        Construct a full file path from a root path and a file name.

        Args:
            file_name (str): The name of the file (e.g., "article1.xml").
            root_path (Path): The directory containing the file.

        Returns:
            Path: Full path to the file.
        """
        return root_path / file_name

    def get_all_paths(self, root_path: Path, cleaned_train_article_ids: list):
        """
        Retrieve all relevant file paths from a directory, optionally filtering by article ID.

        Args:
            root_path (Path): Directory containing the files.
            cleaned_train_article_ids (list): List of article IDs (without extensions) to include.

        Returns:
            List[Path]: List of matching file paths.
        """
        ext = "pdf" if self.file_format == "pdf" else "xml"
        all_files = list(root_path.glob(f"*.{ext}"))

        if cleaned_train_article_ids:
            filtered = [self.get_one_path(art.name, root_path)
                        for art in all_files
                        if art.stem in cleaned_train_article_ids]
            self.logger.info(f"Filtered {len(filtered)} {ext.upper()} files out of {len(all_files)} based on article IDs.")
            return filtered
        else:
            self.logger.info(f"Retrieved {len(all_files)} {ext.upper()} file(s).")
            return [self.get_one_path(art.name, root_path) for art in all_files]

    def extract_text_data(self, file_paths):
        """
        Extract all text content from the given files and return as a list of dicts,
        without exporting to any file.
    
        Args:
            file_paths (List[Path]): List of Path objects to XML or PDF files.
    
        Returns:
            List[Dict]: List of dicts with 'article_id' and 'raw_full_text' keys.
        """
        if not file_paths:
            self.logger.warning("No files provided for extraction. Returning empty list.")
            return []
    
        self.logger.info(f"Extracting texts from {len(file_paths)} file(s)...")
        all_texts = list(self.extract_all_texts(file_paths))
    
        self.logger.info(f"Extraction complete. {len(all_texts)} article(s) extracted.")
        return all_texts


In [None]:
cleaned_train_article_ids = cleaned_train_labels_df.article_id.drop_duplicates().tolist()

cleaned_train_article_ids[:5]

### XML FILES

In [None]:
xml_extractor = ExtractTextFromArticles(file_format="XML")

In [None]:
# TRAIN 
cleaned_train_xml_paths = [xml_extractor.get_one_path(art.name, TRAIN_XML_FILES)  for art in TRAIN_XML_FILES.glob("*.xml") if art.stem in cleaned_train_article_ids]


In [None]:
# TEST

test_xml_paths = [xml_extractor.get_one_path(art.name, TEST_XML_FILES)  for art in TEST_XML_FILES.glob("*.xml")]

### PDF FILES

In [None]:
pdf_extractor = ExtractTextFromArticles(file_format="PDF")

In [None]:
# TRAIN 
cleaned_train_pdf_paths = [pdf_extractor.get_one_path(art.name, TRAIN_PDF_FILES)  for art in TRAIN_PDF_FILES.glob("*.pdf") if art.stem in cleaned_train_article_ids]



In [None]:
# TEST
test_pdf_paths = [pdf_extractor.get_one_path(art.name, TEST_PDF_FILES)  for art in TEST_PDF_FILES.glob("*.pdf")]



### PDF ONLY ARTICLES

In [None]:
def get_pdf_only(pdf_paths, xml_paths):

    pdf_names = [fname.stem for fname in pdf_paths]

    xml_names = [fname.stem for fname in xml_paths]

    pdf_only_names = list(set(pdf_names).difference(set(xml_names)))

    pdf_only_paths = [fpath for fpath in pdf_paths if fpath.stem in pdf_only_names]

    return pdf_only_paths

In [None]:
cleaned_train_pdf_only = get_pdf_only(cleaned_train_pdf_paths, cleaned_train_xml_paths)

len(cleaned_train_pdf_only)

In [None]:
len(cleaned_train_pdf_paths)

In [None]:
len(cleaned_train_xml_paths)

In [None]:
test_pdf_only = get_pdf_only(test_pdf_paths, test_xml_paths)

len(test_pdf_only)

## 4. Retrive all the training mentions for each article

### 4.1. Automated extraction : regex-based extraction

In [None]:
class DOIsFormatHandler:
    """
    A handler class to extract, clean, normalize, and deduplicate DOIs (Digital Object Identifiers)
    from raw text (e.g., full-text scientific articles).

    This class is designed to process DOIs from text files with various formatting issues, 
    including invisible characters, inconsistent DOI formats, line breaks, and punctuation artifacts.

    Attributes:
        doi_pattern (str): Regex pattern used to match DOIs in text.
        source_file_format (str): Format of the input source file, e.g., "xml".
        invisible_chars (Pattern): Regex to detect invisible control characters.
        invisible_spaces (Pattern): Regex to detect various non-breaking/invisible spaces.
        modern_doi_format_https (str): Regex for HTTPS modern DOI format.
        modern_doi_format_http (str): Regex for HTTP modern DOI format.
        old_doi_format_https (str): Regex for old-style HTTPS DOI format.
        old_doi_format_http (str): Regex for old-style HTTP DOI format.
        short_doi_format (str): Regex for abbreviated DOI formats.
        very_short_doi_format (str): Regex for minimal DOIs with just prefix/suffix.
    """

    def __init__(self, doi_pattern=None, source_file_format="xml", invisible_chars=None, invisible_spaces=None):
        # Use default regex pattern for DOIs if none is provided
        self.doi_pattern = doi_pattern or r'''(?ix)
            (?:https?\s*:\s*/\s*/\s*(?:dx\s*\.\s*)?doi\s*\.\s*org\s*/\s*|
               (?:dx\s*\.\s*)?doi\s*\.\s*org\s*/\s*|
               doi\s*:\s*)?
            10\s*\.\s*\d{4,9}
            (?:\s*/\s*[\w\-]+(?:\s*[\./\-]\s*[\w\-]+)*)'''

        # Patterns to match and remove invisible characters and spaces
        self.invisible_chars = invisible_chars or re.compile(
            r'[\u00A0\u200B\u200C\u200D\u2060\uFEFF\u202F\u00AD\u2003\x00-\x09\x0B-\x0C\x0E-\x1F\x7F\xa0]')
        self.invisible_spaces = invisible_spaces or re.compile(r'[\u00A0\u2000-\u200A\u202F\u205F]')

        # DOI format regexes for validation and normalization
        self.modern_doi_format_https = r'https://doi\.org/10\.\d{4,9}/[-._;()/:a-zA-Z0-9]+'
        self.modern_doi_format_http = r'http://doi\.org/10\.\d{4,9}/[-._;()/:a-zA-Z0-9]+'
        self.old_doi_format_https = r'https://dx\.doi\.org/10\.\d{4,9}/[-._;()/:a-zA-Z0-9]+'
        self.old_doi_format_http = r'http://dx\.doi\.org/10\.\d{4,9}/[-._;()/:a-zA-Z0-9]+'

        # Short and very short DOI formats
        self.short_doi_format = r'''(?ix)
            (
                (?:doi:|dx\.doi\.org/|doi\.org/)
                10\.\d{4,9}/[-._;()/:a-zA-Z0-9]+
            )
            \.?'''
        self.very_short_doi_format = r'10\.\d{4,9}/[^\s"<>]+'

        self.source_file_format = source_file_format

    def extract_dois(self, text: str):
        """
        Extracts and cleans all DOIs from the given text.
        """
        extracted_dois = []

        # Clean the text of invisible or control characters
        text = self.remove_inv_chars(text)

        # Find all potential DOIs using regex
        matches = list(re.finditer(self.doi_pattern, text, re.IGNORECASE | re.VERBOSE))

        extracted = []
        
        for m in matches:
            raw_ext = m.group()
            start = m.start()
            end = m.end()
            
            raw_doi = self.clean_extraction(raw_ext)
 
            # raw_doi = raw_ext
            
            cleaned_doi = self.remove_trailing_spaces(
                self.remove_ligatures(
                    self.remove_line_break(raw_doi)
                )
            )
            extracted_dois.append({
                "raw_extraction": raw_ext,
                "start": start,
                "end": end,
                "raw_doi": raw_doi,
                "cleaned_doi": cleaned_doi
            })

        return extracted_dois

    def contains_no_digits(self, sequence):
        """Check if a given sequence contains no digits."""
        return not re.search(r'\d', sequence)

    def remove_chars_after_period(self, raw_extraction):
        """
        Recursively removes the sentence-like fragment after a period if it contains no digits.
        Helps isolate the DOI from trailing sentence fragments.
        """
        segments = raw_extraction.split(". ")
        if self.contains_no_digits(segments[-1]):
            if len(segments) >= 2:
                to_remove = ". " + segments[-1]
                if raw_extraction.endswith(to_remove):
                    raw_extraction = raw_extraction[: -len(to_remove)].rstrip()
                    if len(segments) > 2:
                        return self.remove_chars_after_period(raw_extraction)
        return raw_extraction

    def _contains_reference_num(self, sequence):
        """
        Detects reference-like numbering (e.g., '1. Smith') at line endings.
        """
        segments = sequence.replace(" ", "").split('.')
        return len(segments) == 2 and segments[0].isdigit() and self.contains_no_digits(segments[1])

    def remove_reference_num(self, raw_extraction):
        """Removes reference number line endings if detected."""
        segments = raw_extraction.split("\n")
        last_element = segments[-1]
        if self._contains_reference_num(last_element):
            return raw_extraction[: -len("\n" + last_element)].rstrip()
        return raw_extraction

    def remove_chars_after_line_break(self, raw_extraction):
        """Removes line-break suffixes that are not DOI content."""
        while True:
            segments = raw_extraction.split("\n")
            last_element = segments[-1]
            if self.contains_no_digits(last_element) or last_element.upper().startswith("GBIF"):
                if raw_extraction.endswith("\n" + last_element):
                    raw_extraction = raw_extraction[: -len("\n" + last_element)].rstrip()
                else:
                    break
            else:
                break
        return raw_extraction

    def clean_very_short_format(self, raw_extraction):
        """
        Handles cleaning of very short DOI patterns that might have extra trailing text.
        """
        raw_doi = self.remove_line_break(raw_extraction).replace(" ", "")
        if  re.fullmatch(self.very_short_doi_format, raw_doi):
            segments = re.split(self.very_short_doi_format, raw_extraction)
            if len(segments) == 2 and segments[0] == "":
                if segments[-1]:
                    raw_extraction = raw_extraction.rstrip(segments[-1])
        return raw_extraction

    def clean_extraction(self, raw_extraction):
        """
        Performs a full cleaning pipeline on a raw DOI extraction.
        """
        raw_extraction = self.remove_reference_num(raw_extraction)
        raw_extraction = self.remove_chars_after_line_break(raw_extraction)
        raw_extraction = self.remove_chars_after_period(raw_extraction)
        raw_extraction = self.clean_very_short_format(raw_extraction)
        return self._remove_last_period(raw_extraction)

    def _remove_last_period(self, raw_doi):
        """Remove trailing period (".") if it exists."""
        return raw_doi[:-1].rstrip() if raw_doi and raw_doi.endswith(".") else raw_doi

    def normalize_doi(self, dois_dict):
        """
        Normalizes DOI to a standard HTTPS format (e.g., https://doi.org/...)
        """
        cleaned_doi = dois_dict["cleaned_doi"]

        if not cleaned_doi:
            return ""
        doi = cleaned_doi.strip().replace('\n', '').replace(' ', '')
        doi = re.sub(r'^(doi:|DOI:)', '', doi, flags=re.IGNORECASE)
        doi = re.sub(r'^(https?://)?(dx\./?)?doi\.org/', '', doi, flags=re.IGNORECASE)
        
        normalized_doi= f"https://doi.org/{doi}"

        dois_dict.update({"normalized_doi": normalized_doi})
        return dois_dict


    def remove_trailing_spaces(self, raw_doi):
        """Remove all spaces from the DOI string."""
        return re.sub(" ", "", raw_doi)

    def remove_inv_chars(self, text):
        """Remove invisible characters and normalize spaces."""
        text = self.invisible_spaces.sub(' ', text)
        text = self.invisible_chars.sub('', text)
        return text

    def remove_line_break(self, raw_doi):
        """Remove all line breaks from a DOI string."""
        return re.sub(r"\n", " ", raw_doi)

    def remove_ligatures(self, raw_doi):
        """Normalize ligatures and combined characters."""
        return unicodedata.normalize('NFKD', raw_doi)

    def add_doi_type(self, mention):
        """
        Classifies string as DOI or ACCESSION based on its format.
        """
        return "DOI" if re.fullmatch(self.modern_doi_format_https, mention, flags=re.IGNORECASE) else "ACCESSION"

    def _replace_invisible_spaces(self, text):
        """
        Replace various invisible space-like characters with a normal space.
        """
        invisible_spaces = ['\u00A0', '\u202F', '\u2007', '\u2060', '\u200B']
        for char in invisible_spaces:
            text = text.replace(char, ' ')
        return text

    def remove_duplicates(self, dois_dicts):
        """
        Deduplicates DOI entries based on their normalized DOI string.
        """
        seen = set()
        result = []
        for d in dois_dicts:
            if d["normalized_doi"] not in seen:
                seen.add(d["normalized_doi"])
                result.append(d)
        return result

    def process_article(self, art_dict):
        """
        Main method to process a single article's text and extract all cleaned and normalized DOIs.

        Args:
            art_dict (dict): A dictionary with at least 'article_id' and 'raw_full_text'.

        Returns:
            dict: Contains the article ID and list of extracted DOI dictionaries.
        """
        file_name = art_dict.get("file_name")
        text = art_dict.get("raw_full_text")
        raw_dois = self.extract_dois(text)
        normalized_dois = list(map(self.normalize_doi, raw_dois))
        cleaned_dois = self.remove_duplicates(normalized_dois)
        return {"file_name": file_name, "extracted_dois": cleaned_dois}


#### a. PDF FILES

In [None]:
pdf_dois_handler = DOIsFormatHandler(source_file_format="pdf")

In [None]:
pdf_train_texts = pdf_extractor.extract_text_data(cleaned_train_pdf_paths)

In [None]:
all_dois_from_pdfs=[pdf_dois_handler.process_article(art) for art in pdf_train_texts]

In [None]:
pdf_clean_train_texts = pdf_extractor.clean_articles(pdf_train_texts)

In [None]:
def get_normalized_dois(all_dois, cleaned_train_article_ids):
    """
    Extracts and organizes normalized DOIs from a list of article DOI data.

    For each article in the input list, this function:
    - Collects the normalized DOIs (in lowercase) if the article ID is in `cleaned_train_article_ids`.
    - Creates a mapping from each normalized DOI to its corresponding raw DOI string.

    Args:
        all_dois (list): A list of dictionaries, each containing:
            - "article_id" (str): The unique identifier for the article.
            - "extracted_dois" (list): A list of DOI dictionaries with at least:
                - "normalized_doi" (str): The normalized DOI.
                - "raw_doi" (str): The original, unprocessed DOI string.
        
        cleaned_train_article_ids (set or list): A collection of article IDs to include in the output.

    Returns:
        tuple:
            - dict: Keys are article IDs, values are lists of lowercase normalized DOIs
                    (only for articles in `cleaned_train_article_ids`).
            - dict: Keys are article IDs, values are dictionaries mapping normalized DOIs (lowercase)
                    to their corresponding raw DOIs.
    """
    normalized = {
        article["file_name"]: [
            doi["normalized_doi"].lower()
            for doi in article["extracted_dois"]
        ]
        for article in all_dois
        if article["file_name"] in cleaned_train_article_ids
    }

    norm_raw_mapping = {
        article["file_name"]: {
            doi["normalized_doi"].lower(): doi["raw_doi"]
            for doi in article["extracted_dois"]
        }
        for article in all_dois
    }

    return normalized, norm_raw_mapping


In [None]:
clean_normalized_dois, norm_raw_mapping  = get_normalized_dois(all_dois_from_pdfs,cleaned_train_article_ids)

In [None]:
def check_identification(article_id, dataset_id, normalized_dois, norm_raw_map):
    """
    Checks whether a given dataset ID matches a normalized DOI from an article.

    The function verifies if the `dataset_id`:
    1. Exactly matches one of the normalized DOIs for the article.
    2. Closely matches (using fuzzy string matching) a DOI if there's no exact match.

    Args:
        article_id (str): The unique identifier of the article to check.
        dataset_id (str): The DOI or identifier to validate against the article's DOIs.
        normalized_dois (dict): Mapping of article_id → list of normalized DOIs (lowercase).
        norm_raw_map (dict): Mapping of article_id → {normalized DOI → raw DOI}.

    Returns:
        tuple:
            - bool: True if an exact match is found, False otherwise.
            - str or None: The matched DOI (exact or closest match), or None if no match.
            - str or None: The raw DOI corresponding to the matched normalized DOI, or None.
    """
    art_norm_dois = normalized_dois.get(article_id, [])

    if not art_norm_dois:
        return False, None, None

    if dataset_id in art_norm_dois:
        # Exact match found
        raw_doi = norm_raw_map.get(article_id, {}).get(dataset_id)
        return True, dataset_id, raw_doi

    # Attempt fuzzy match if no exact match
    close_matches = difflib.get_close_matches(dataset_id, art_norm_dois, n=1, cutoff=0.8)
    match = close_matches[0] if close_matches else None

    raw_match = norm_raw_map.get(article_id, {}).get(match) if match else None

    return False, match, raw_match


In [None]:
dois_labels_df2 = dois_labels_df.copy()

dois_labels_df2[["is_well_identified", "normalized_match", "raw_match"]] = dois_labels_df2.apply(
    
    lambda row: check_identification(row["article_id"], row["dataset_id"],clean_normalized_dois,norm_raw_mapping),
    
    axis=1,
    
    result_type='expand'  
)


In [None]:
dois_labels_df2.head()

#### b. XML FILES

In [None]:
xml_train_texts = xml_extractor.extract_text_data(cleaned_train_xml_paths)

xml_clean_train_texts = xml_extractor.clean_articles(xml_train_texts)

In [None]:
not_identified_dois = dois_labels_df2[dois_labels_df2.is_well_identified == False]

not_identified_dois.shape[0]

In [None]:
not_identified_dois.head()

**The following mentions are present in the corresponding XML files**:
'10.7554_elife.63455'	: 'https://doi.org/10.5061/dryad.37pvmcvj9',
'10.7554_elife.63455' :	'https://doi.org/10.5061/dryad.qnk98sffp',
'10.7554_elife.73695':	'https://doi.org/10.5441/001/1.3q2131q5',
'10.7554_elife.74937'	:'https://doi.org/10.5281/zenodo.6335347'

In [None]:
dois_found_in_xml = {
'10.7554_elife.63455&https://doi.org/10.5061/dryad.37pvmcvj9'	:['https://doi.org/10.5061/dryad.37pvmcvj9',
                                                                 '10.5061/dryad.37pvmcvj9',True],
'10.7554_elife.63455&https://doi.org/10.5061/dryad.qnk98sffp' :['https://doi.org/10.5061/dryad.qnk98sffp',
                                                                '10.5061/dryad.qnk98sffp',True],
'10.7554_elife.73695&https://doi.org/10.5441/001/1.3q2131q5':['https://doi.org/10.5441/001/1.3q2131q5',
                                                             'https://doi.org/10.5441/001/1.3q2131q5',True],
'10.7554_elife.74937&https://doi.org/10.5281/zenodo.6335347':['https://doi.org/10.5281/zenodo.6335347',
                                                                  'https://doi.org/10.5281/zenodo.6335347',True]
}

dois_found_in_xml_tuples =     [
('10.7554_elife.63455','10.5061/dryad.37pvmcvj9'),
('10.7554_elife.63455','10.5061/dryad.qnk98sffp'),
('10.7554_elife.73695','https://doi.org/10.5441/001/1.3q2131q5'),
('10.7554_elife.74937','https://doi.org/10.5281/zenodo.6335347')
    ]

for k, v in dois_found_in_xml.items():

    article_id, dataset_id = k.split("&")
    
    index = dois_labels_df2[(dois_labels_df2.article_id ==article_id) & (dois_labels_df2.dataset_id ==dataset_id)].index

    dois_labels_df2.loc[index,['normalized_match',"raw_match", 'is_well_identified']]= v

### 4.2. Manual corrections

In [None]:
not_identified_dois = dois_labels_df2[dois_labels_df2.is_well_identified == False]

not_identified_dois.shape[0]

**There are 42 non catched DOI mentions**. There are differents situations :  
1. some mentions don't exist in the articles (pdf format) even though they are given in the training data
2. some mentions are incorrect in the training data : v2 instead of v1 for example
3. some correct mentions are not catched by the regex based fonction : because of line breaks for example

In [None]:
not_identified_art = not_identified_dois[not_identified_dois.normalized_match.isnull()].article_id.values.tolist()
not_identified_art

#### 1 => We remove these articles from the training data because they are not found in the corresponding pdf articles. One of them (10.1590_1809-9823.2007.10034) is not in English

Articles to remove :  
['10.1016_j.cpc.2024.109087',
 '10.1021_acs.jcim.9b01185',
 '10.1038_s41558-022-01301-z',
 '10.1111_1365-2656.12594',
 '10.3390_s19030479']

In [None]:
art_to_remove = ['10.1016_j.cpc.2024.109087',
 '10.1021_acs.jcim.9b01185',
 '10.1038_s41558-022-01301-z',
 '10.3390_s19030479']

In [None]:
# 1.
dois_labels_df2 = dois_labels_df2.loc[~dois_labels_df2.article_id.isin(art_to_remove),:]

dois_labels_df2.reset_index(drop=True, inplace=True)

dois_labels_df2.shape[0]

#### 2 => We'll keep the DOIs in the same format that they are mentionned in the article

'https://doi.org/10.17862/cranfield.rd.19146182' => 'https://doi.org/10.17862/cranfield.rd.19146182.v1'  

'https://doi.org/10.11583/dtu.20555586' => 'https://doi.org/10.11583/dtu.20555586.v2'  

'https://doi.org/10.11583/dtu.20555586.v3' => 'https://doi.org/10.11583/dtu.20555586.v2'  

'https://doi.org/10.23642/usn.15134442'  => 'https://doi.org/10.23642/usn.15134442.v1'  

'https://doi.org/10.5281/zenodo.8014149' => 'https://doi.org/10.5281/zenodo.8014150'  

'https://doi.org/10.25377/sussex.21184705.v1' => 'https://doi.org/10.25377/sussex.21184705'  

'https://doi.org/10.5281/zenodo.1472499' =>'https://doi.org/10.5281/zenodo.1472500'	  

'https://doi.org/10.5281/zenodo.1135371' => 'https://doi.org/10.5281/zenodo.1135372'

'https://doi.org/10.5285/378f0f77-1842-4789-ba15-6fbdf7d02299' => 'http://\ndx.doi.org/10.5061/dryad.gb5mkkwk8'

'https://doi.org/10.5061/dryad.ns1rn8pnt' => 'http://doi.pangaea.de/10.1594/PANGAEA.806957' 


#### 3 => We fix the extractions

In [None]:
to_change = {
    '10.5256/f1000research. \n13622.d19423315':'10.5256/f1000research. \n13622.d194233',
    'https://doi.org/10.5066/P9VYSWEH.\n64': 'https://doi.org/10.5066/P9VYSWEH',
    'https://doi.org/10.5066/P9H1S1PV.\n34':'https://doi.org/10.5066/P9H1S1PV',
    'https://doi.org/10.5066/P9FZ4OJW.\n54':'https://doi.org/10.5066/P9FZ4OJW',
    'https://doi.org/10.5066/P92BGHW1.\n62':'https://doi.org/10.5066/P92BGHW1',
    'https://doi.org/10.5061/dryad.g1jws': 'https://doi.org/10.5061/dryad.g1jws\ntqsg',
    'https://doi.org/10.17638/datacat.\nliverpool.ac.uk/417.\n37': 'https://doi.org/10.17638/datacat.\nliverpool.ac.uk/417',
    'https://doi.org/10.5281/zenodo.10908593/': 'https://doi.org/10.5281/zenodo.10908593',
    'https://doi.org/10.5281/zenodo.817658.\n856': 'https://doi.org/10.5281/zenodo.817658',
    'DOI: 10.5281/zenodo.6010342.\n29':'DOI: 10.5281/zenodo.6010342',
    'DOI: 10.5281/zenodo.6010342.\n29':'DOI: 10.5281/zenodo.6010342',
    'https://doi.org/10.6073/pasta/be42bb841e696b7bca':'https://doi.org/10.6073/pasta/be42bb841e696b7bca\nd9957aed33db5e',
    'https://doi.org/10.5061/dryad':'https://doi.org/10.5061/dryad.\nwpzgmsbps',
    'https://doi.org/10.6073/pasta/\nbb935444378d112d9189556fd22a441d17':'https://doi.org/10.6073/pasta/\nbb935444378d112d9189556fd22a441d',
    'https://\ndoi.org/10.15468/dl.yu2xpvSara':'https://\ndoi.org/10.15468/dl.yu2xpv',
    'https://doi.\norg/10.13012/B2IDB-5784165_V1.reports': 'https://doi.\norg/10.13012/B2IDB-5784165_V1',
    'https://\ndoi .org /10 .17862 /cranﬁeld .rd .19146182 .v1':'https://doi .org /10 .17862 /cranﬁeld .rd .\n19146182 .v1'
    
}

In the '10.3389_fevo.2023.1112519' article, I found the following mentions :  
- 'doi.org/10.1594/PANGAEA.921544' => DOI
- 10.1594/PANGAEA.941237' => DOI
- 'PRJNA780103'=> ACCESSION

In [None]:
dois_labels_df2['raw_dataset_id'] = dois_labels_df2["raw_match"]

dois_labels_df2.replace({'raw_dataset_id' : to_change},inplace=True )

In [None]:
dois_labels_df2.at[67, 'raw_dataset_id'] = 'http://doi.pangaea.de/10.1594/PANGAEA.806957'

dois_labels_df2.at[136, 'raw_dataset_id'] = 'http://\ndx.doi.org/10.5061/dryad.gb5mkkwk8'

In [None]:
def get_normalized_doi(dataset_id_type, dataset_id):

    if str(dataset_id_type) == 'DOI':

        normalized = pdf_dois_handler.normalize_doi({'cleaned_doi':str(dataset_id).replace('\n','').lower()})

        return normalized['normalized_doi']

    else:

        return dataset_id

    

In [None]:
indices_to_change = dois_labels_df2[dois_labels_df2.article_id == '10.3389_fevo.2023.1112519'].index.tolist()

indices_to_change

In [None]:
dois_labels_df2.loc[indices_to_change, 'raw_dataset_id'] = ['https://doi.\npangaea.de/10.1594/PANGAEA.941237',
                                                            'doi.org/10.1594/PANGAEA.921544',
                                                            'PRJNA780103'
                                                           ]

In [None]:
dois_labels_df2.loc[indices_to_change[-1], 'dataset_id_type'] = 'ACC'

In [None]:
dois_labels_df2['normalized_dataset_id'] = dois_labels_df2.apply(
    
    lambda row: get_normalized_doi(row["dataset_id_type"], row["raw_dataset_id"]),
    
    axis=1)


In [None]:
dois_labels_df3 = dois_labels_df2.loc[:, ['article_id','raw_dataset_id','type','dataset_id_type']]

In [None]:
dois_labels_df3.drop_duplicates(ignore_index=True,inplace=True)

In [None]:
dois_labels_df3.dropna(ignore_index=True,inplace=True)

In [None]:
dois_labels_df3.shape[0]

In [None]:
dois_labels_df3[dois_labels_df3.article_id == '10.3389_fevo.2023.1112519']

In [None]:
change_locations = dois_labels_df3[dois_labels_df3.article_id =='10.1107_s2059798322005691'].index.tolist()
new_values_list = [
 'https://doi.org/10.18150/VAOJLJ',
 'https://\ndoi.org/10.18150/R8VJ7V', 
 'https://doi.org/10.18150/\nVAZZ2F',
 # 'https://doi.org/10.18150/T0WC49', 
 'https://doi.org/10.18150/LIENZ5',
 'https://doi.org/10.18150/6OXPLO',
 'https://doi.org/\n10.18150/WEFSC9',
 # 'https://doi.org/10.18150/\nR3BTBM',
 'https://doi.org/10.18150/XSEXUF']

In [None]:
dois_labels_df3.loc[change_locations,'raw_dataset_id']= new_values_list

In [None]:
clean_dois_labels = dois_labels_df3.copy()

clean_dois_labels.rename(columns={'raw_dataset_id':'dataset_id'}, inplace=True)

clean_dois_labels.head()

##### ACCESSIONS

In [None]:
acc_labels_df.head()

In [None]:
def acc_is_well_identified(dataset_id, text):

    match = re.search(dataset_id, text, flags=re.IGNORECASE)

    return True if match else False


def add_acc_well_identified(article_id, dataset_id,texts=pdf_clean_train_texts):

    text = texts.get(article_id)

    return acc_is_well_identified(dataset_id, text)


acc_labels_df['is_well_identified'] = acc_labels_df.apply(lambda row : add_acc_well_identified(row['article_id'], row['dataset_id']), axis=1)

In [None]:
acc_labels_df[acc_labels_df.is_well_identified ==False].shape[0], acc_labels_df[acc_labels_df.is_well_identified ==False].article_id.drop_duplicates().tolist()

In [None]:
not_found_accs = acc_labels_df.loc[acc_labels_df.is_well_identified ==False,:]

not_found_accs

**The article '10.1080_21645515.2023.2189598' does not contain any of the extracted mentions, so we are removing it from the training set. We are also removing the unmatched mentions from the two other articles ('10.7554_eLife.63194', '10.7554_eLife.72626').**

**We search for the not found mentions in the xml files.**

In [None]:
accs_found_in_xml_tuples =  list(zip(not_found_accs[['article_id']].values.flatten().tolist(),not_found_accs[['dataset_id']].values.flatten().tolist()))


check_list =  [re.search(t[1],xml_clean_train_texts[t[0]], flags=re.I) is not None for t in accs_found_in_xml_tuples]

all(check_list)

**All the 16 mentions are present in the xml corresponding files. So, we xill keep them in the training data**

In [None]:
acc_labels_df2 = acc_labels_df.copy()

In [None]:
acc_labels_df2.loc[acc_labels_df2.dataset_id.isin(not_found_accs[['dataset_id']].values.flatten().tolist()),'is_well_identified'] =True

In [None]:
clean_acc_labels = acc_labels_df2.loc[:, clean_dois_labels.columns].copy()


clean_acc_labels.head()

### Final labels DataFrame

In [None]:
dataset_id_labels_df = pd.concat([clean_dois_labels, clean_acc_labels],ignore_index=True)

In [None]:
final_labels_df = dataset_id_labels_df.copy()

In [None]:
#final_labels_df = final_labels_df[final_labels_df.article_id!='10.1029_2022gl100473']

In [None]:
final_labels_df.shape[0], dataset_id_labels_df.shape[0]

In [None]:
train_article_ids = dataset_id_labels_df.article_id.drop_duplicates().values.tolist()

len(train_article_ids)

In [None]:
final_labels_df.columns.tolist()

### Add source file

In [None]:
def add_source_file(row):
    
    mentions_found_in_xml = dois_found_in_xml_tuples + accs_found_in_xml_tuples

    df_tuple = row['article_id'],row['dataset_id']
    xml_list = ['10.1093_nar_gkp1049',
               '10.1021_acsomega.3c06074',
               '10.1038_s41598-020-59839-x',
               '10.3390_microorganisms8121872',
              ]
    article_id = row['article_id']

    if df_tuple in mentions_found_in_xml or article_id in xml_list :

        return "xml"
    else:
        return "pdf"


In [None]:
dois_found_in_xml_tuples

In [None]:
final_labels_df["source_file"] = final_labels_df.apply(lambda row : add_source_file(row), axis=1)

In [None]:
final_labels_df.head()

### Add mention spans

In [None]:
len(pdf_clean_train_texts)

In [None]:
pdf_clean_texts = pdf_clean_train_texts 

In [None]:
len(pdf_clean_texts)

In [None]:
def add_mention_spans(row):

    article_id, dataset_id = row['article_id'], row['dataset_id']

    if row['source_file'] =="pdf":

        match = re.search(re.escape(dataset_id), pdf_clean_texts.get(article_id,""))

        if match:

            return match.span()
    else:
        match = re.search(re.escape(dataset_id), xml_clean_train_texts.get(article_id,""))

        if match:

            return match.span()
    return None

In [None]:
final_labels_df[["start","end"]] = final_labels_df.apply(lambda row: add_mention_spans(row), axis=1, result_type="expand")

In [None]:
final_labels_df.head()

# TRAINING DATA PREPARATION

1. Prepare training for the NER task (extract all dois and accessions from article text)
2. Prepare the training data for the classification the task (get the type of usage of the data)
3. Split the training data into training - validation datasets

In [None]:
# Utility class for training data preparation
@dataclass
class TrainingDataPreparator:
    tokenizer: PreTrainedTokenizer
    class_labels: List[str] = None
    label2id_ner: Dict[str, int] = None
    id2label_ner: Dict[int, str] = None
    label2id_cls: Dict[str, int] = None
    id2label_cls: Dict[int, str] = None
    max_seq_len: int = 512
    ner_labels: List[str] = field(default_factory=lambda: ['O', 'B-DOI', 'I-DOI', 'B-ACC', 'I-ACC'])

    def __post_init__(self):
        self.label2id_ner = {label: idx for idx, label in enumerate(self.ner_labels)}
        self.id2label_ner = {idx: label for label, idx in self.label2id_ner.items()}


    def tokenize_full_text(self, text: str):
        encoding = self.tokenizer(
            text,
            return_offsets_mapping=True,
            return_attention_mask=False,
            return_token_type_ids=False,
            truncation=False
        )
        return encoding["input_ids"], encoding["offset_mapping"], self.tokenizer.convert_ids_to_tokens(encoding["input_ids"])

    def get_tokens_for_mention(
        self,
        offsets: List[Tuple[int, int]],
        mention_start: int,
        mention_end: int,
        allow_partial_overlap: bool = True
    ) -> List[int]:
        token_indices = []
        for idx, (tok_start, tok_end) in enumerate(offsets):
            if allow_partial_overlap:
                if tok_end > mention_start and tok_start < mention_end:
                    token_indices.append(idx)
            else:
                if tok_start >= mention_start and tok_end <= mention_end:
                    token_indices.append(idx)
        return token_indices

    def extract_token_windows_around_mentions(
        self,
        input_ids: List[int],
        offsets: List[Tuple[int, int]],
        mentions: List[Dict[str, Union[str, int]]],
        window_context_size: int = 100
    ) -> List[Dict]:
        windows = []
        max_len = self.max_seq_len - 2  

        for mention in mentions:
            token_indices = self.get_tokens_for_mention(
                offsets, mention['start'], mention['end'], allow_partial_overlap=True
            )
            if not token_indices:
                continue

            mention_start_idx = token_indices[0]
            mention_end_idx = token_indices[-1]

            window_start = max(0, mention_start_idx - window_context_size)
            window_end = min(len(input_ids), mention_end_idx + 1 + window_context_size)

            # Adjust window to fit max_len
            if (window_end - window_start) > max_len:
                excess = (window_end - window_start) - max_len
                remove_before = min(excess // 2, window_start)
                remove_after = excess - remove_before
                window_start += remove_before
                window_end -= remove_after

            windows.append({
                'input_ids': input_ids[window_start:window_end],
                'offsets': offsets[window_start:window_end],
                'mention': mention,
                'window_start_idx': window_start,
                'window_end_idx': window_end
            })

        return windows

    def biofy_labels(self, windows: List[Dict], mentions: List[Dict]) -> List[List[int]]:
        all_labels = []
        for w in windows:
            labels = ['O'] * len(w['input_ids'])
            token_spans = w['offsets']
    
            # Mentions available in the window
            local_mentions = [
                m for m in mentions
                if m['start'] < token_spans[-1][1] and m['end'] > token_spans[0][0]
            ]
    
            for mention in local_mentions:
                token_indices = self.get_tokens_for_mention(token_spans, mention['start'], mention['end'])
                if not token_indices:
                    continue
                labels[token_indices[0]] = f"B-{mention['dataset_id_type']}"
                for idx in token_indices[1:]:
                    labels[idx] = f"I-{mention['dataset_id_type']}"
    
            label_ids = [self.label2id_ner.get(label, self.label2id_ner['O']) for label in labels]
            all_labels.append(label_ids)
        return all_labels
   
    def prepare_dataset(
        self,
        df: pd.DataFrame,
        text_dicts: Dict[str, str],
        entity_col: str = "dataset_id",
        type_col: str = "dataset_id_type",
        article_id_col: str = "article_id",
        start_col: str = "start",
        end_col: str = "end",
        window_context_size: int = 50
    ) -> Dataset:
        input_ids_all = []
        attention_mask_all = []
        labels_all = []
        article_ids_all = []
    
        global_token_indices = set()
        global_label_counter = Counter()
    
        for (article_id, source_file), group in df.groupby([article_id_col, "source_file"]):
            text_dict = text_dicts.get(source_file)
    
            if not text_dict:
                continue
    
            text = text_dict.get(article_id)
            if not text:
                continue
    
            input_ids, offsets, _ = self.tokenize_full_text(text)
    
            mentions = [
                {
                    "start": int(row[start_col]),
                    "end": int(row[end_col]),
                    "dataset_id_type": str(row[type_col])
                }
                for _, row in group.iterrows()
                if row[start_col] >= 0 and row[end_col] >= 0
            ]
    
            # Retrive indices of tokens present in the mentions of this article
            all_mention_token_indices = set()
            for mention in mentions:
                token_indices = self.get_tokens_for_mention(offsets, mention['start'], mention['end'])
                all_mention_token_indices.update(token_indices)
    
            # Add to global indices
            global_token_indices.update(all_mention_token_indices)
    
            # Creating the global label for this article
            full_labels = ['O'] * len(input_ids)
            for mention in mentions:
                token_indices = self.get_tokens_for_mention(offsets, mention['start'], mention['end'])
                if not token_indices:
                    continue
                full_labels[token_indices[0]] = f"B-{mention['dataset_id_type']}"
                for idx in token_indices[1:]:
                    full_labels[idx] = f"I-{mention['dataset_id_type']}"
    
            # Global count of B/I labels
            for label in full_labels:
                if label != 'O':
                    global_label_counter[label] += 1
    
            #  Splitting into windows
            windows = self.extract_token_windows_around_mentions(
                input_ids, offsets, mentions, window_context_size=window_context_size
            )
    
            for w in windows:
                label_window_raw = full_labels[w['window_start_idx']:w['window_end_idx']]
                label_ids = [self.label2id_ner.get(lab, self.label2id_ner['O']) for lab in label_window_raw]
    
                input_ids_window = [self.tokenizer.cls_token_id] + w['input_ids'] + [self.tokenizer.sep_token_id]
                label_window = [-100] + label_ids + [-100]
                attention_mask = [1] * len(input_ids_window)
    
                # Padding
                pad_len = self.max_seq_len - len(input_ids_window)
                if pad_len > 0:
                    input_ids_window += [self.tokenizer.pad_token_id] * pad_len
                    label_window += [-100] * pad_len
                    attention_mask += [0] * pad_len
                else:
                    input_ids_window = input_ids_window[:self.max_seq_len]
                    label_window = label_window[:self.max_seq_len]
                    attention_mask = attention_mask[:self.max_seq_len]
    
                input_ids_all.append(input_ids_window)
                labels_all.append(label_window)
                attention_mask_all.append(attention_mask)
    
                # Add article_id for this window
                article_ids_all.append(article_id)
    
        return Dataset.from_dict({
            "input_ids": input_ids_all,
            "attention_mask": attention_mask_all,
            "labels": labels_all,
            "article_id": article_ids_all
        })

        
    
    def add_start_end_positions(
        self,
        df: pd.DataFrame,
        text_dict: Dict[str, str],
        entity_col: str = 'dataset_id',
        article_id_col: str = 'article_id'
    ) -> pd.DataFrame:
        df = df.copy()
        df['start'] = -1
        df['end'] = -1

        for article_id, group in df.groupby(article_id_col):
            text = text_dict.get(article_id, "")
            if not text:
                continue

            used_spans = set()
            for idx, row in group.iterrows():
                mention = str(row[entity_col]).strip()
                mention_escaped = re.escape(mention)

                found = False
                for match in re.finditer(mention_escaped, text):
                    start, end = match.start(), match.end()
                    if (start, end) not in used_spans:
                        df.at[idx, 'start'] = start
                        df.at[idx, 'end'] = end
                        used_spans.add((start, end))
                        found = True
                        break

                if not found:
                    print(f"Mention '{mention}' not found or exhausted in article {article_id}")

        return df

# NER MODEL TRAINING AND EVALUATION

In [None]:
df_with_spans = final_labels_df.copy()

df_with_spans = df_with_spans[df_with_spans.article_id != '10.1029_2022gl100473']

text_dicts = {"pdf":pdf_clean_texts,"xml":xml_clean_train_texts}

In [None]:
df_with_spans.shape[0]

In [None]:
# 1. Load  SciBERT Tokenizer and Model
ner_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

preparator = TrainingDataPreparator(tokenizer=ner_tokenizer)

ner_model = AutoModelForTokenClassification.from_pretrained(
    MODEL_NAME,
    num_labels=len(preparator.ner_labels),
    id2label=preparator.id2label_ner,
    label2id=preparator.label2id_ner
)

# 2. Add start and end of mentions
# df_with_spans = preparator.add_start_end_positions(df_mentions, text_dict)

# 3. Prépare the dataset
full_dataset = preparator.prepare_dataset(df_with_spans, text_dicts)

# 4. Split train/validation
# Ensure the 'article_id' column exists in the dataset
assert "article_id" in full_dataset.column_names, "'article_id' column must be present in the dataset"

# Extract article IDs as a NumPy array
article_ids = np.array(full_dataset["article_id"])

# Instantiate the GroupShuffleSplit with fixed random seed
gss = GroupShuffleSplit(test_size=0.25, random_state=42)

# Perform the split using article_id as grouping variable
train_indices, val_indices = next(gss.split(np.arange(len(article_ids)), groups=article_ids))

# Select subsets from the full dataset using the computed indices
ner_train_dataset = full_dataset.select(train_indices)
ner_val_dataset = full_dataset.select(val_indices)

# # 4. Split train/validation
# dataset_split = full_dataset.train_test_split(test_size=0.25, seed=42)
# ner_train_dataset = dataset_split["train"]
# ner_val_dataset = dataset_split["test"]

# 5. Define metrics

def ner_compute_metrics(p):
    predictions, labels = p
    preds = np.argmax(predictions, axis=2)

    true_labels = [[preparator.id2label_ner[l] for l in label if l != -100] for label in labels]
    true_preds = [[preparator.id2label_ner[p] for (p, l) in zip(pred, label) if l != -100]
                  for pred, label in zip(preds, labels)]

    return {
        "precision": precision_score(true_labels, true_preds),
        "recall": recall_score(true_labels, true_preds),
        "f1": f1_score(true_labels, true_preds),
        "accuracy": accuracy_score(true_labels, true_preds),
    }
# 6. Entraînement
ner_output_dir = MODELS_DIR/'ner_model'
ner_training_args = TrainingArguments(
    output_dir = ner_output_dir,
    eval_strategy="epoch",
    save_strategy="epoch",
    learning_rate=3e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=20,
    weight_decay=0.01,
    logging_dir="./logs",
    load_best_model_at_end=True,
    metric_for_best_model="f1",
    save_total_limit=3,
    report_to="none", 
)

ner_trainer = Trainer(
    model=ner_model,
    args=ner_training_args,
    train_dataset=ner_train_dataset,
    eval_dataset=ner_val_dataset,
    processing_class=ner_tokenizer,
    compute_metrics=ner_compute_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]
)

In [None]:
len(ner_train_dataset)

In [None]:
len(ner_val_dataset)

In [None]:
from collections import Counter
all_labels = [label for sequence in full_dataset['labels'] for label in sequence if label != -100]
Counter(all_labels)

In [None]:
ner_trainer.train()

In [None]:
# Evaluate the model on validation data
eval_results = ner_trainer.predict(ner_trainer.eval_dataset)

predictions = eval_results.predictions  
labels = eval_results.label_ids         


In [None]:
def compute_metrics_no_O(p, id2label):
    predictions, labels = p
    preds = np.argmax(predictions, axis=2)

    # Convert ids to labels, ignoring -100 padding tokens
    true_labels = [[id2label[l] for l in label if l != -100] for label in labels]
    true_preds = [[id2label[pred] for pred, l in zip(pred_seq, label) if l != -100]
                  for pred_seq, label in zip(preds, labels)]

    filtered_labels = []
    filtered_preds = []
    for label_seq, pred_seq in zip(true_labels, true_preds):
        new_label_seq = []
        new_pred_seq = []
        for l, p in zip(label_seq, pred_seq):
            if l != "O":
                new_label_seq.append(l)
                new_pred_seq.append(p)
        if new_label_seq:
            filtered_labels.append(new_label_seq)
            filtered_preds.append(new_pred_seq)

    # Get string report 
    report_str = classification_report(filtered_labels, filtered_preds)
    print(report_str)

    # Compute numeric metrics
    prec = precision_score(filtered_labels, filtered_preds)
    rec = recall_score(filtered_labels, filtered_preds)
    f1 = f1_score(filtered_labels, filtered_preds)

    return {
        "precision_no_O": prec,
        "recall_no_O": rec,
        "f1_no_O": f1,
    }


results = compute_metrics_no_O((predictions, labels), preparator.id2label_ner)
print(results)


In [None]:
# Plotting the confusion matrix on validation data
# metric = evaluate.load("seqeval")
predictions_output = ner_trainer.predict(ner_val_dataset)
logits = predictions_output.predictions
labels = predictions_output.label_ids


preds = np.argmax(logits, axis=2)


true_labels = [
    [preparator.id2label_ner[label] for label in label_seq if label != -100]
    for label_seq in labels
]
true_preds = [
    [preparator.id2label_ner[pred] for pred, label in zip(pred_seq, label_seq) if label != -100]
    for pred_seq, label_seq in zip(preds, labels)
]


flat_preds = [p for seq in true_preds for p in seq]
flat_labels = [l for seq in true_labels for l in seq]

filtered_preds = [p for p, l in zip(flat_preds, flat_labels) if l != 'O']
filtered_labels = [l for l in flat_labels if l != 'O']

labels_unique = sorted(list(set(filtered_labels + filtered_preds)))

cm = confusion_matrix(filtered_labels, filtered_preds, labels=labels_unique)

plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt="d", xticklabels=labels_unique, yticklabels=labels_unique, cmap="Blues")
plt.xlabel("Predictions")
plt.ylabel("Actual values")
plt.title("Confusion Matrix (without 'O')")
plt.tight_layout()
plt.show()


# POST-PROCESSING AND NER INFERENCE  

### Building heuristic rules to distinguish dataset mentions from articles references

In [None]:
# Extraction of prefixes and suffixes representing data repositories from the training set
def get_dois_prefixes_suffixes(raw_doi: str) -> Tuple[str,str]:
    if not raw_doi:
        return ""
    doi = raw_doi.strip().replace('\n', '').replace(' ', '')
    doi = doi.replace('https://dx/','')
    doi = re.sub(r'^(doi:|DOI:)', '', doi, flags=re.IGNORECASE)
    doi = re.sub(r'^(https?://)?(dx\.)?doi\.org/', '', doi, flags=re.IGNORECASE)

    if len(doi.split('/')) == 2:
        
        prefix = doi.split('/')[0]
        
        if len(doi.split('/')[1].split('.'))>=2:
        
            suffix = doi.split('/')[1].split('.')[0]
    
            if prefix == '10.17632':
                
                suffix =''
            return  prefix, suffix  

        else:
            
            return prefix, ''


            return prefix, doi.split('/')[1].replace('.'+last,'')
    else:
        prefix, suffix = doi.split('/')[0],doi.split('/')[1]

    return  prefix, suffix

In [None]:
dois = final_labels_df[final_labels_df.dataset_id_type=='DOI'].dataset_id.values.tolist()

pre_suffixes = [get_dois_prefixes_suffixes(doi) for doi in dois]

prefixes, suffixes =[],[]

for pre_suffix in pre_suffixes:

    pre, suf = pre_suffix

    if pre not in prefixes:
        
        if  pre!='' and pre not in ['http:','https:']:
            
            prefixes.append(pre)
        
    if suf not in suffixes:
        if suf !='':
            suffixes.append(suf)
            
len(prefixes), len(suffixes)

**By cleaning the prefixes and suffixes, and grouping them, we obtain:**
['10.5281','10.5061','10.7937','10.23728', '10.6084','10.34740', '10.17632', '10.7910',
 '10.5072','10.15125','10.5255', '10.5257', '10.1594', '10.17605', '10.3886', '10.25337',
 '10.34739', '10.5878', '10.15468', '10.5067', '10.5066', '10.5291', '10.17638', '10.2305',
 '10.7289','10.17862', '10.11583', '10.23642', '10.5285', '10.5441',  '10.17882', '10.25349', 
 '10.6075', '10.6073', '10.17863', '10.11588', '10.21942', '10.25377', '10.3334', '10.13020', 
 '10.25326', '10.25921', '10.34973', '10.15482', '10.18150', '10.15131', '10.4121', '10.5256', 
 '10.13012', '10.25387', '10.25386', '10.5518', '10.6078', '10.15485', '10.25422', '10.22033',
 '10.6096', '10.18434', '10.24381', '10.7291',
 'zenodo', 'dryad', 'figshare', 'mendeley', 'dataverse', 'ICPSR', 'ukda-sn', 'pangaea', 'snd', 
 'ILL-DATA', 'datacat', 'CDIAC', 'IUCN.UK', 'GPM', 'IMMERGDF', 'OTG', 'sussex', 'CHEMBL', 'TCIA',
 'cranfield', 'dtu', 'usn', 'K9', 'cranﬁeld', 'cranfield','pasta', 'CAM', 'data', 'DVN', 'uva', 'dl',
 'USDA.ADC', 'shef.data', 'ICPSR', 'datacat.liverpool.ac.uk', 'f1000research', 'g3',
 'genetics', 'MODIS', 'azu', 'ESGF', 'Suborbital', 'AEROCLO', 'cds',
]
**I also added a few suffixes not found in the training data.**

In [None]:
def extract_alpha_prefix(acc: str) -> str:
    """
    Extracts the alphabetic prefix from an accession string.
    Keeps underscores and hyphens that appear before the first digit.
    Discards prefixes shorter than 2 characters.

    Examples:
        - 'GSE12345'       -> 'GSE'
        - 'EPI_ISL_291131' -> 'EPI_ISL_'
        - 'E-PROT-100'     -> 'E-PROT-'
        - 'A1'             -> None (too short)

    Args:
        acc (str): Accession string

    Returns:
        str or None: Extracted prefix, or None if not valid
    """
    match = re.match(r'^([A-Za-z_-]+(?:[_-][A-Za-z]+)*[_-]*)\d+', acc)
    if match:
        prefix = match.group(1)
        if len(prefix.strip("_-")) >= 2:
            return prefix
    return None


def collect_acc_prefixes(acc_list: List[str]) -> List[str]:
    """
    Extracts a list of unique alphabetic prefixes from a list of accession strings.

    Args:
        acc_list (List[str]): List of accession strings.

    Returns:
        List[str]: List of unique alphabetic prefixes (no duplicates).
    """
    prefixes = {
        extract_alpha_prefix(acc) for acc in acc_list
        if extract_alpha_prefix(acc) is not None
    }
    return prefixes

accs = acc_labels_df.dataset_id.tolist()
acc_prefixes = collect_acc_prefixes(accs)

**List of ACC prefixes** :['BX', 'CAB', 'CHEMBL', 'CP', 'CVCL_','CVCL_D',
 'CVCL_KS', 'E-GEOD-', 'E-PROT-', 'EMPIAR-', 'ENSBTAG', 'ENSMMUT', 'ENSOARG',
 'EPI', 'EPI_ISL_', 'ERR', 'GSE', 'HPA', 'IPR', 'KX', 'MODEL', 'NC_',
 'NM_', 'PF', 'PRJNA', 'PXD', 'SAMN', 'SRP', 'SRR', 'SRX', 'STH', 'rs]

In [None]:
#A class that filters the most plausible data mentions.
class DatasetMentionFilter:
    """
    Filters extracted mentions to retain only those likely referring to datasets.
    Applies confidence threshold, deduplication, context extraction, and format-specific heuristics
    (DOI or ACC), including support for detection of mentions embedded in tables.
    """

    def __init__(self, confidence_threshold: float = 0.90, 
                file_format="pdf",
                dataset_indicators: List[str] = None,
                dataset_keywords: List[str] = None,
                context_window_chars: int = 300,
                ):
        """
        Initializes the filter.

        :param confidence_threshold: Minimum confidence required to keep a mention.
        
        """
        self.confidence_threshold = confidence_threshold
        self.file_format = file_format.lower()
        self.context_window_chars = context_window_chars

        self.acc_valid_prefixes = [
        'BX', 'CAB', 'CHEMBL', 'CP', 'CVCL_', 'CVCL_D', 'CVCL_KS', 'E-GEOD-', 'E-PROT-', 'EMPIAR-',
        'ENSBTAG', 'ENSMMUT', 'ENSOARG', 'EPI', 'EPI_ISL_', 'ERR', 'GSE', 'HPA', 'IPR', 'KX',"GCA",
        'MODEL', 'NC_', 'NM_', 'PF','PRJNA','PXD', 'SAMN', 'SRP', 'SRR', 'SRX', 'STH', 'rs',
        'KT','MN','JN', 'JX', 'FM', 'KC','KP','AY', 'AF','AM','PNQY','FM','EU','KM','GU','GQ','CP',
        'FJ','HQ','DQ','JQ','MG','MSV','PRJEB','KT','MF','MH','K0','HGNC','LT','LC','Be',
    ]

        # Known DOI prefixes and keywords that indicate datasets
        self.dataset_indicators = dataset_indicators or ['10.5281','10.5061','10.7937','10.23728', '10.6084','10.34740', '10.17632', '10.7910',
             '10.5072','10.15125','10.5255', '10.5257', '10.1594', '10.17605', '10.3886', '10.25337',
             '10.34739', '10.5878', '10.15468', '10.5067', '10.5066', '10.5291', '10.17638', '10.2305',
             '10.7289','10.17862', '10.11583', '10.23642', '10.5285', '10.5441',  '10.17882', '10.25349', 
             '10.6075', '10.6073', '10.17863', '10.11588', '10.21942', '10.25377', '10.3334', '10.13020', 
             '10.25326', '10.25921', '10.34973', '10.15482', '10.18150', '10.15131', '10.4121', '10.5256', 
             '10.13012', '10.25387', '10.25386', '10.5518', '10.6078', '10.15485', '10.25422', '10.22033',
             '10.6096', '10.18434', '10.24381', '10.7291', '10.5065','zenodo', 'dryad', 'figshare', 'mendeley', 
             'dataverse', 'ICPSR', 'ukda-sn', 'pangaea', 'snd', 'ILL-DATA', 'datacat', 'CDIAC', 'IUCN.UK', 
             'GPM', 'IMMERGDF', 'OTG', 'sussex', 'CHEMBL', 'TCIA', 'cranfield', 'dtu', 'usn', 'K9', 
             'cranﬁeld', 'pasta', 'CAM', 'data', 'DVN', 'uva', 'dl','USDA.ADC', 'shef.data', 'ICPSR', 
             'datacat.liverpool.ac.uk',  'MODIS', 'azu', 'ESGF',  'Suborbital', 'AEROCLO', 'cds',
             'ccdc','uhhfdm', 'hepdata', 'wdcc', 'm9', 'rodare', '10.5517', '10.5061', '10.5441',
             '10.25592','10.17182', '10.1594', '10.6084', '10.11583',
             '10.15468', '10.5878','10.14278', '10.26093','10.13155', #'f1000research', 'g3','genetics',
        ]

        self.dataset_keywords = dataset_keywords or [
             "data availability statement","These data are provided", 
            "deposition number",'data accessibility',"Associated data for this publication",
            "data available for download from", "deposition numbers",
            "data downloaded from","Data Repository", "Data used in this",  "ccdc",
            "Data referring to","Data deposition", "Acession numbers","Accession no",
            "PDB references","The following data sets","The following data set",
            "The following datasets","The following dataset",'accession number',
            "contains the supplementary crystallographic data for this paper",
            'Protein Data Bank','GenBank','NCBI','Accession no.',
            'International Nucleotide Sequence Database Collaboration','INSDC','DDBJ','ENA',
            'Global Initiative on Sharing All Influenza Data (GISAID)','GISAID','(NCBI) database',
            'submitted to the National Center for Biotechnology Information',
            'Data associated with this paper have been deposited','Imagery considered in the analysis',
            'sample data','Dryad. Dataset',
          
        ]

        self.funding_keywords = [
            "funding", 
            "grant", 
            "grant number", 
            "grant/award", 
            "award number",
            "funding information",
            'This research was funded by',
            'This work was supported financially by',
            'is supported by',
                ]

    def is_likely_funding_mention(self, context: str, mention_text: str) -> bool:
        """
        Checks if the sentence containing the mention_text suggests it's part of funding or grant information.
    
        :param context: Full context text
        :param mention_text: The exact mention (used to find the sentence containing it)
        :return: True if likely funding-related
        """
    
        # Normalize
        context = context.strip()
        mention_text = mention_text.strip()
    
        if not mention_text:
            return False
    
        # Tokenize context into sentences
        sentences = sent_tokenize(context)
    
        # Find the sentence that contains the mention
        target_sentence = ""
        for sentence in sentences:
            if mention_text in sentence:
                target_sentence = sentence.lower()
                break
    
        if not target_sentence:
            return False  # mention not found in any sentence
    
        # Search for keywords in the sentence only
        for keyword in self.funding_keywords:
            keyword = keyword.lower()
            if " " in keyword:
                if keyword in target_sentence:
                    return True
            else:
                pattern = r'\b' + re.escape(keyword) + r'\b'
                if re.search(pattern, target_sentence):
                    return True
    
        return False

    def has_valid_acc_prefix(self,acc: str) -> bool:
        """
        Checks if the accession string starts with one of the known valid prefixes.
    
        Args:
            acc (str): Accession string to check.
    
        Returns:
            bool: True if the accession starts with a valid prefix, False otherwise.
        """
        acc = acc.strip()
        return any(acc.startswith(prefix) for prefix in self.acc_valid_prefixes)

        
    def get_canonical_doi(self,mention: str) ->str:
        
        if not mention:
            return ""
        doi = mention.strip().replace('\n', '').replace(' ', '')
        doi = doi.replace('https://dx/','')
        doi = re.sub(r'^(doi:|DOI:)', '', doi, flags=re.IGNORECASE)
        doi = re.sub(r'^(https?://)?(dx\.)?doi\.org/', '', doi, flags=re.IGNORECASE)
        
        return doi

    def is_valid_doi_format(self,doi: str) -> bool:
        
        doi_pattern = r'(10\.\d{4,9}/[-._;()/:a-zA-Z0-9]+)'
        
        doi_clean = self.get_canonical_doi(doi)
        
        return re.fullmatch(doi_pattern, doi_clean, flags=re.IGNORECASE) is not None
    
    def is_valid_acc_format(self,acc: str) -> bool:
    
        acc_clean = acc.strip()
        pattern = re.compile(
            r'^[A-Z0-9_.\-]{4,}$'                      
            r'|'
            r'^[A-Z0-9_.\-]{4,}\s*-\s*[A-Z0-9_.\-]{4,}$' 
            ,
            flags=re.IGNORECASE 
        )
        if acc_clean.isalpha():
            
            return False

        if len(acc_clean)<4:
            
            return False
            
        if acc_clean[0].isdigit():
            
            return False
            
        return re.fullmatch(pattern, acc_clean) is not None


    def is_article_doi(self, mention: str, article_id: str) -> bool:
        """
        Checks if the mention corresponds to the article's own DOI.

        :param mention: Mention string
        :param article_id: ID of the article (underscores instead of slashes)
        :return: True if mention refers to article DOI
        """
        # pattern = article_id.replace("_", "/")
        pattern = article_id.split('_')[0]
        return pattern in mention 

    def has_dataset_indicator(self, mention: str) -> bool:
        """
        Checks whether the mention contains a known dataset-related DOI prefix or name.

        :param mention: Mention string
        :return: True if known dataset indicator is found
        """
        mention = mention.lower()
        return any(indicator.lower() in mention for indicator in self.dataset_indicators)


    def contains_keywords(self, context: str,mention_start:int) -> bool:
        context = context.lower()
    
        for keyword in self.dataset_keywords:
            keyword = keyword.lower()
            if " " in keyword:
                kw_index = context.find(keyword)
            else:
                pattern = r'\b' + re.escape(keyword) + r'\b'
                match = re.search(pattern, context)
                kw_index = match.start() if match else -1
    
            if kw_index != -1 and mention_start > kw_index:
                return True
    
        return False
    def deduplicate_extractions(self, extractions: List[Dict]) -> List[Dict]:
        """
        Removes duplicate mentions based on normalized mention text.

        :param extractions: List of mention dictionaries
        :return: Deduplicated list
        """
        seen = set()
        deduped = []
        for m in extractions:
            key = m["mention"].strip().lower()
            if key not in seen:
                seen.add(key)
                deduped.append(m)
        return deduped

    def is_truncated_mention(self, mention: str) -> bool:
        """
        Detects obviously incomplete or broken mentions, such as partial DOI strings.
        Returns True only if the mention is clearly invalid or unusable.
        """
        mention = mention.strip().lower()
    
        # Known invalid or partial forms
        known_invalid = [
            "https", "http", "doi", "doi.org", "dx.doi.org",
            "https://", "http://", "https://doi.org", "http://doi.org",
            "https://doi.org/", "http://doi.org/",
            "https://dx.doi.org", "http://dx.doi.org",
            "https://dx.doi.org/", "http://dx.doi.org/",
            "https://doi.org/10.", "http://doi.org/10.",
            "https://dx.doi.org/10.", "http://dx.doi.org/10."
        ]
        if mention in known_invalid:
            
            return True
            
        if re.fullmatch(r"10\.\d{4,9}/", mention):
            
            return True

        if len(self.get_canonical_doi(mention)) <12:

            return True
    
        # Check for DOI-like prefixes but too short to be valid (≤25 chars)
        doi_prefixes = [
            "https://doi.org/10.", "http://doi.org/10.",
            "https://dx.doi.org/10.", "http://dx.doi.org/10."
        ]
        if any(mention.startswith(prefix) and len(mention) <= 25 for prefix in doi_prefixes):
            return True
    
        return False

    def remove_truncated_mentions(self, mentions: List[Dict]) -> List[Dict]:
        """
        Removes truncated or partial dataset mentions based solely on textual comparison.
    
        This function compares all mentions and removes any mention that is a strict substring
        of a longer one (e.g., "zenodo.12345" will be removed if "10.5281/zenodo.12345" is present).
        It does not rely on character offsets, making it robust to duplicate mentions found in
        different parts of the text.
    
        Args:
            mentions (List[Dict]): A list of mention dictionaries, each containing at least a
                                   "mention" key with the extracted text.
    
        Returns:
            List[Dict]: A filtered list of mentions with truncated duplicates removed.
        """
        mention_texts = [m["mention"].strip().lower() for m in mentions]
        to_remove = set()
    
        for i, m1 in enumerate(mention_texts):
            for j, m2 in enumerate(mention_texts):
                if i == j:
                    continue
                if m1 != m2 and m1 in m2 and len(m1) < len(m2):
                    to_remove.add(i)
                    break
    
        return [m for i, m in enumerate(mentions) if i not in to_remove]


    def get_context_window_around_chars(self, text: str, start: int, end: int) -> str:
        """
        Extracts ±N characters around a mention.

        :param text: Full article text
        :param start: Start position of the mention
        :param end: End position of the mention
        :param window: Number of characters before and after to include
        :return: Context string
        """
        left = max(0, start - self.context_window_chars)
        right = min(len(text), end + self.context_window_chars)
        return text[left:right]      

    def extract_context(self,text: str, start: int, end: int) -> str:
        """
        Extracts the sentence(s) containing the mention, and optionally includes one or two sentences
        before and one after, based on rules:
        
        - If the mention spans multiple sentences, include all of them.
        - Always include one sentence before and one after (unless close to the edge).
        - If the mention starts close to the beginning of a sentence, include two before.
        - If the mention ends near the end of a sentence, don't add one after.
        """
    
        # Define a local window around the mention to improve performance
        window_size = 1500
        window_start = max(0, start - window_size)
        window_end = min(len(text), end + window_size)
        local_text = text[window_start:window_end]
    
        # Tokenize the local window into sentences
        sentences = sent_tokenize(local_text)
    
        # Compute mention's position relative to the local window
        local_start = start - window_start
        local_end = end - window_start
        mention_span = set(range(local_start, local_end))
    
        # Track character spans of each sentence within the local window
        spans = []
        current_pos = 0
        for sentence in sentences:
            s_start = local_text.find(sentence, current_pos)
            s_end = s_start + len(sentence)
            spans.append((s_start, s_end))
            current_pos = s_end
    
        # Identify sentences that overlap with the mention
        overlapping_indices = []
        for i, (s_start, s_end) in enumerate(spans):
            if mention_span & set(range(s_start, s_end)):
                overlapping_indices.append(i)
    
        # Fallback: if no overlap is found, return a basic character window
        if not overlapping_indices:
            return local_text[local_start - 100: local_end + 100]
    
        first_idx = overlapping_indices[0]
        last_idx = overlapping_indices[-1]
    
        # Default: include one sentence before and one after
        start_idx = max(0, first_idx - 1)
        end_idx = min(len(sentences), last_idx + 2)
    
        # Rule: if the mention starts very close to the beginning of a sentence, include two before
        if abs(local_start - spans[first_idx][0]) < 40:
            start_idx = max(0, first_idx - 2)
    
        # Rule: if the mention ends very close to the end of a sentence, skip adding one after
        if abs(local_end - spans[last_idx][1]) < 5:
            end_idx = last_idx + 1
    
        # Join the selected sentences as the final context
        context = " ".join(sentences[start_idx:end_idx])
        return context.strip()
    
    
    def filter_mentions(
        self,
        mentions_dict: Dict[str, List[Dict]],
        #texts: Dict[str, str]
        texts_dict
    ) -> Dict[str, List[Dict]]:
        """
        Filters mentions for all articles.

        :param mentions_dict: Dictionary with article_id as key and list of mention dicts as value
        :param texts: Dictionary with article_id as key and article text as value
        :return: Filtered dictionary with same structure
        """
        result = defaultdict(dict)

        for article_id, mentions in mentions_dict.items():

            mentions = self.remove_truncated_mentions(mentions)

            #  Extract dois
            doi_mentions = [m for m in mentions if m["mention_format"] == "DOI"]

            # Remove truncated dois
            doi_mentions = [m for m in doi_mentions if not self.is_truncated_mention(m["mention"])]

            # Validate doi format 
            validated_dois = [m for m in doi_mentions if self.is_valid_doi_format(m["mention"])]

            cleaned_dois = []
            for m in validated_dois:
                source_file = m.get('source_file','')
                if not source_file:
                    continue
                texts = texts_dict.get(source_file,{})
                if not texts:
                    continue
                text = texts.get(article_id, "")
                if not text:
                    continue
                    
                if self.is_article_doi(m["mention"], article_id):
                    continue
                m["context"] = self.extract_context(text, m["start"], m["end"])
                
                cleaned_dois.append(m)

            cleaned_dois = [ m for m in cleaned_dois if m['context']]
            
            # Filter dois by confidence threshold
            cleaned_dois = [m for m in cleaned_dois if m["confidence"] >= self.confidence_threshold]
        

            # Filter relevant DOI mentions
            selected_dois = [m for m in cleaned_dois if self.contains_keywords(m["context"],m['start'])]
            remaining_dois = [m for m in cleaned_dois if m not in selected_dois]
            selected_dois += [m for m in remaining_dois if self.has_dataset_indicator(m["mention"])]
 

            # Deduplicate dois 
            # selected_dois = self.deduplicate_extractions(selected_dois)

            # Extract accessions
            acc_mentions = [m for m in mentions if m["mention_format"] == "ACC"]
            # Validate acc format
            validated_accs = [m for m in acc_mentions if self.is_valid_acc_format(m["mention"])]

            # Filter accessions by confidence threshold
            validated_accs = [m for m in validated_accs if m["confidence"] >= self.confidence_threshold]

            # Get context for accessions
            cleaned_accs = []
            for m in validated_accs:
                source_file = m.get('source_file','')
                if not source_file:
                    continue
                texts = texts_dict.get(source_file,{})
                if not texts:
                    continue
                text = texts.get(article_id, "")
                if not text:
                    continue
                m["context"] = self.extract_context(text, m["start"], m["end"])
                # if not self.is_likely_funding_mention(m["context"]):
                #     cleaned_accs.append(m)
            
            cleaned_accs = [ m for m in validated_accs if m['context']]
            selected_accs = [m for m in cleaned_accs if self.contains_keywords(m["context"],m['start'])]
            remaining_accs = [m for m in cleaned_accs if m not in selected_accs]
            selected_accs += [m for m in remaining_accs if self.has_valid_acc_prefix(m["mention"])]
            selected_accs = [m for m in selected_accs if not self.is_likely_funding_mention(m["context"],m['mention'])]
            # Filter relevant ACC mentions 
            # selected_accs = [
            #                 {**m, "mention": self.remove_ccdc_from_acc(m["mention"])}
            #                 for m in selected_accs
            #                 ]
            # Remove duplicated accessions                
            # selected_accs = self.deduplicate_extractions(selected_accs)
            
            # Remove temporary 'context' fields
            for m in selected_dois:
                m.pop("context", None)
            for m in selected_accs:
                m.pop("context", None)

            selected_mentions = selected_dois + selected_accs


            result[article_id] = selected_mentions


        return result


# CLASSIFICATION OF DATA USAGE TYPE

In [None]:
# Utility class to prepare classification dataset
class ClassificationDataPreparator:
    def __init__(self, tokenizer, max_seq_len: int = 512,max_ctx_length =1600):
        self.tokenizer = tokenizer
        self.max_seq_len = max_seq_len
        self.max_ctx_length = max_ctx_length
        self.label2id = {
            "primary": 0,
            "secondary": 1,
            # "missing": 2,
        }
    
    
    def tokenize_classification_dataset(self, dataset: Dataset) -> Dataset:
        return dataset.map(
            lambda x: self.tokenizer(
                x["text"],
                truncation=True,
                padding="max_length",
                max_length=self.max_seq_len
            ),
            batched=True
        )


    def extract_mention_context(self, text: str, mention_start: int, mention_end: int) -> str:
        
        """
        Extracts the most relevant context around a mention for classification by:
        - Extracting a standard context window around the mention.
        - If the standard context is too long, switching to a table-aware context extraction.
        - Otherwise, enriching the context with relevant data sections.
        - Removing trailing sections starting from 'References' or 'Acknowledgement' that appear
          after the mention text in the context.
    
        Parameters:
        - text (str): The full document text.
        - mention_start (int): The character start index of the mention in the text.
        - mention_end (int): The character end index of the mention in the text.
    
        Returns:
        - str: A cleaned and focused context string surrounding the mention.
        """
    
        standard_context = self.extract_standard_context(text, mention_start, mention_end)
        
        if len(standard_context) >= self.max_ctx_length:
            context = self.extract_table_context(text, mention_start, mention_end,standard_context)
        else:
            context = self.extract_data_section_context(text, standard_context)
            context = self.remove_number_only_lines(context)
        
        # Determine mention_text from the original text and mention positions
        mention_text = text[mention_start:mention_end]
        
        # Clean unwanted trailing sections after the mention
        #context = self.remove_references_acknowledgements_section(context, mention_text)
        
        return context

    def extract_data_section_context(self,text: str, standard_context: str) -> str:
        """
        Extracts additional context from the data-related section using the last match of known patterns.
        Only includes sentences not already present in the standard_context. Uses a local window for efficiency.
    
        :param text: Full document text
        :param standard_context: Previously extracted context (e.g. near a mention)
        :return: Combined enriched context string
        """
    
        # Compile common patterns used in data availability sections
        pattern = re.compile(
            r'Data accessibility|Data availability|DNA Deposition|Data Deposition|'
            r'Data acquisition|Data preparation|Data and software availability|'
            r'DATA ARCHIVING STATEMENT|DATA AVAILABILITY STATEMENT|Data and metadata repository|DATA AND RESOURCES',
            flags=re.IGNORECASE
        )
    
        # First, check if pattern is already in the standard context
        match = pattern.search(standard_context)
        if match:
            start, _ = match.span()
            return standard_context[start:].strip()
    
        # Find all matches in the full text
        matches = list(re.finditer(pattern, text))
        if not matches:
            return standard_context.strip()
    
        # Get the last match position
        last_match = matches[-1]
        match_start = last_match.start()
    
        # === Local window optimization ===
        window_size = 1500
        window_start = max(0, match_start - window_size)
        window_end = min(len(text), match_start + window_size)
        local_text = text[window_start:window_end]
    
        # Tokenize sentences only in the local window
        sentences = sent_tokenize(local_text)
        standard_sentences = set(sent_tokenize(standard_context))
    
        # Adjust match index relative to local_text
        relative_match_start = match_start - window_start
    
        # Find sentence in local_text that contains the pattern
        current_pos = 0
        for i, sentence in enumerate(sentences):
            sentence_start = local_text.find(sentence, current_pos)
            sentence_end = sentence_start + len(sentence)
            current_pos = sentence_end
    
            if sentence_start <= relative_match_start < sentence_end:
                selected_sentences = sentences[i : i + 3]  # sentence with match + next 2
                unique_sentences = [s for s in selected_sentences if s not in standard_sentences]
    
                if not unique_sentences:
                    return standard_context.strip()
    
                return (" ".join(unique_sentences) + " " + standard_context).strip()
    
        # Fallback: nothing matched cleanly
        return standard_context.strip()

    def extract_standard_context(self,text: str, start: int, end: int) -> str:
        """
        Extracts the sentence(s) containing the mention, and optionally includes one or two sentences
        before and one after, based on rules:
        
        - If the mention spans multiple sentences, include all of them.
        - Always include one sentence before and one after (unless close to the edge).
        - If the mention starts close to the beginning of a sentence, include two before.
        - If the mention ends near the end of a sentence, don't add one after.
        """
    
        # Define a local window around the mention to improve performance
        window_size = 1500
        window_start = max(0, start - window_size)
        window_end = min(len(text), end + window_size)
        local_text = text[window_start:window_end]
    
        # Tokenize the local window into sentences
        sentences = sent_tokenize(local_text)
    
        # Compute mention's position relative to the local window
        local_start = start - window_start
        local_end = end - window_start
        mention_span = set(range(local_start, local_end))
    
        # Track character spans of each sentence within the local window
        spans = []
        current_pos = 0
        for sentence in sentences:
            s_start = local_text.find(sentence, current_pos)
            s_end = s_start + len(sentence)
            spans.append((s_start, s_end))
            current_pos = s_end
    
        # Identify sentences that overlap with the mention
        overlapping_indices = []
        for i, (s_start, s_end) in enumerate(spans):
            if mention_span & set(range(s_start, s_end)):
                overlapping_indices.append(i)
    
        # Fallback: if no overlap is found, return a basic character window
        if not overlapping_indices:
            return local_text[local_start - 100: local_end + 100]
    
        first_idx = overlapping_indices[0]
        last_idx = overlapping_indices[-1]
    
        # Default: include one sentence before and one after
        start_idx = max(0, first_idx - 1)
        end_idx = min(len(sentences), last_idx + 2)
    
        # Rule: if the mention starts very close to the beginning of a sentence, include two before
        if abs(local_start - spans[first_idx][0]) < 40:
            start_idx = max(0, first_idx - 2)
    
        # Rule: if the mention ends very close to the end of a sentence, skip adding one after
        if abs(local_end - spans[last_idx][1]) < 5:
            end_idx = last_idx + 1
    
        # Join the selected sentences as the final context
        context = " ".join(sentences[start_idx:end_idx])
        return context.strip()

    def extract_table_context(self, text: str, mention_start: int, mention_end: int, standard_context) -> str:
        """
        Extracts context for mentions in tables by:
        - Finding the last match of data-related section patterns
        - Taking the sentence that contains the match and the next two
        - Adjusting the local window around the mention:
            * If a section is found: 300 chars before the end of the mention
            * Otherwise: 1600 chars before the mention start
    
        :param text: Full article text
        :param mention_start: Character index where the mention starts
        :param mention_end: Character index where the mention ends
        :return: Combined context string
        """
    
        pattern = re.compile(
            r'Data accessibility|Data availability|DNA Deposition|Data Deposition|'
            r'Data acquisition|Data preparation|Data and software availability|'
            r'DATA ARCHIVING STATEMENT|DATA AVAILABILITY STATEMENT|Data and metadata repository|DATA AND RESOURCES',
            flags=re.IGNORECASE
        )
    
    
        # Find all pattern matches and take the last one
        matches = list(re.finditer(pattern, text))
        global_context = ""
    
        if matches:
            last_match = matches[-1]
            match_start = last_match.start()
    
            # === Local window optimization ===
            window_size = 1500
            window_start = max(0, match_start - window_size)
            window_end = min(len(text), match_start + window_size)
            local_text = text[window_start:window_end]
    
            # Tokenize only the local window
            sentences = sent_tokenize(local_text)
            relative_match_start = match_start - window_start
    
            current_pos = 0
            for i, sentence in enumerate(sentences):
                sentence_start = local_text.find(sentence, current_pos)
                sentence_end = sentence_start + len(sentence)
                current_pos = sentence_end
    
                if sentence_start <= relative_match_start < sentence_end:
                    # Get sentence with match + next two
                    selected = sentences[i: i + 4]
                    global_context = " ".join(selected).strip()
                    break
    
        # === Define local context around the mention ===
        if global_context:
            local_context = text[max(0, mention_end - 300): mention_end].strip()
        else:
            # Get mention_text from global text
            mention_text = text[mention_start:mention_end].strip()
        
            # Find position of mention in standard_context
            local_pos = standard_context.find(mention_text)
            if local_pos == -1:
                # Fallback if mention_text not found: return truncated standard_context
                local_context = standard_context[:1600].strip()
            else:
                # Slice standard_context with local window
                window_start = max(0, local_pos - 2000)
                window_end = min(len(standard_context), local_pos + len(mention_text))
                local_context = standard_context[window_start:window_end].strip()
    
        # === Combine contexts ===
        if global_context:
            match = pattern.search(global_context)
            if match:
                start, _ = match.span()
                global_context = global_context[start:]
            return f"{global_context} {local_context}".strip()
        else:
            return local_context


    def remove_number_only_lines(self, text: str) -> str:
        """
        Removes lines that contain only numbers or numbers with spaces (e.g., line/page numbers).
        
        :param text: Raw input text from article
        :return: Cleaned text without number-only lines
        """
        cleaned_lines = []
        for line in text.splitlines():
            # Strip line to check if it only contains digits and optional spaces
            if not re.fullmatch(r'\s*\d[\d\s]*\s*', line):
                cleaned_lines.append(line)
        return '\n'.join(cleaned_lines)

    def prepare_classification_dataset(
        self,
        df_mentions,
        text_dicts: Dict[str, Dict[str, str]],
    ) -> Dataset:
        """Prepares a Hugging Face dataset for classification by extracting
        cleaned sentence contexts for each mention, while removing duplicates (article_id, text)."""
    
        texts = []
        labels = []
        article_ids = []
    
        seen_pairs = set()  # To track (article_id, context) combinations already added
    
        # Group mentions by article and source file for performance
        for (article_id, source_file), group in df_mentions.groupby(["article_id", "source_file"]):
            text = text_dicts.get(source_file, {}).get(article_id, None)
            if not text:
                continue
    
            for _, row in group.iterrows():
                start = row['start']
                end = row['end']
                label_str = row['type'].lower()
    
                if label_str not in self.label2id:
                    continue
    
                context = self.extract_mention_context(text, start, end)
                if not context:
                    continue
                    
                context = context.strip()
    
                pair = (article_id, context)
                if pair in seen_pairs:
                    continue
    
                seen_pairs.add(pair)
                texts.append(context.strip())
                labels.append(self.label2id[label_str])
                article_ids.append(article_id)
    
        return Dataset.from_dict({
            "text": texts,
            "label": labels,
            "article_id": article_ids
        })

    def stratified_group_split(self, dataset, label_col="label", group_col="article_id", test_size=0.25, seed=44):
        df = dataset.to_pandas()
    
        n_splits = round(1 / test_size)
        sgkf = StratifiedGroupKFold(n_splits=n_splits, shuffle=True, random_state=seed)
        
        y = df[label_col]
        groups = df[group_col]
        
        for train_idx, val_idx in sgkf.split(df, y, groups):
            train_dataset = dataset.select(train_idx)
            val_dataset = dataset.select(val_idx)
            
            # Optional: show class distribution
            from collections import Counter
            train_labels = train_dataset[label_col]
            val_labels = val_dataset[label_col]
            
            print(f"Train class distribution: {Counter(train_labels)}")
            print(f"Val class distribution: {Counter(val_labels)}")
            
            return train_dataset, val_dataset
            

In [None]:
#  Tokenizer and data preparation class definition

cls_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

cls_preparator = ClassificationDataPreparator(tokenizer=cls_tokenizer)


In [None]:
# 1. Prepare raw data
df_mentions = final_labels_df.copy()

#df_mentions = df_mentions[df_mentions.article_id!='10.1080_21645515.2023.2189598']

df_mentions['type'] = df_mentions['type'].str.lower()


raw_dataset = cls_preparator.prepare_classification_dataset(df_mentions, text_dicts)
# Définir les classes
class_label = ClassLabel(names=["primary", "secondary"])

# Convertir le dataset
raw_dataset = raw_dataset.cast_column("label", class_label)

In [None]:
# 2. Tokenization
tokenized_dataset = cls_preparator.tokenize_classification_dataset(raw_dataset)

In [None]:
# 3. Split par article
cls_train_dataset, cls_val_dataset = cls_preparator.stratified_group_split(tokenized_dataset)

In [None]:
# Load Metric for Evaluation
def cls_compute_metrics(pred):

    labels = pred.label_ids

    preds = pred.predictions.argmax(-1)

    acc = accuracy_score(y_true=labels, y_pred=preds)

    precision, recall, f1, _ = precision_recall_fscore_support(
        y_pred=labels, y_true=preds, average="weighted"
    )

    return {"accuracy": acc, "precision": precision, "recall": recall, "f1": f1}

In [None]:
# Model Initialization
cls_model = AutoModelForSequenceClassification.from_pretrained(
    MODEL_NAME,
    num_labels=2,
    id2label={0: "primary", 1: "secondary"},
    label2id={"primary": 0, "secondary": 1},
)


In [None]:
# TrainingArguments and Trainer Setup
data_collator = DataCollatorWithPadding(tokenizer=cls_tokenizer)

cls_output_dir = MODELS_DIR/"type_classifier"
cls_training_args = TrainingArguments(
    output_dir=cls_output_dir,
    eval_strategy="epoch",
    save_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=20,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model="f1",
    greater_is_better=True,
    logging_dir="./logs",
    logging_steps=50,
    report_to="none", 
    save_total_limit=2,
)

cls_trainer = Trainer(
    model=cls_model,
    args=cls_training_args,
    train_dataset=cls_train_dataset,
    eval_dataset=cls_val_dataset,
    processing_class=cls_tokenizer,
    data_collator=data_collator,
    compute_metrics=cls_compute_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]
)

In [None]:
# Train and Evaluate
cls_trainer.train()
eval_results = cls_trainer.evaluate()
print(f"Validation results: {eval_results}")

In [None]:
def plot_confusion_matrix(y_true, y_pred, class_names=None, normalize=False, title='Confusion Matrix'):
    cm = confusion_matrix(y_true, y_pred, normalize='true' if normalize else None)

    plt.figure(figsize=(6, 5))
    sns.heatmap(cm, annot=True, fmt='.2f' if normalize else 'd', cmap='Blues',
                xticklabels=class_names, yticklabels=class_names)

    plt.xlabel('Predicted Label')
    plt.ylabel('True Label')
    plt.title(title)
    plt.tight_layout()
    plt.show()


In [None]:
predictions = cls_trainer.predict(cls_val_dataset)
y_pred = predictions.predictions.argmax(axis=1)
y_true = predictions.label_ids

plot_confusion_matrix(y_true, y_pred, class_names=["Primary", "Secondary"])

# INFERENCE ON TEST DATA

In [None]:
# Utility class to extract, filter and classify data reference mentions in articles
class MentionExtractorAndClassifier:
    def __init__(
        self,
        text_dicts,
        ner_model: PreTrainedModel,
        ner_tokenizer: PreTrainedTokenizer,
        classifier_model: PreTrainedModel,
        classifier_tokenizer: PreTrainedTokenizer,
        mention_filter: DatasetMentionFilter,
        cls_preparator: ClassificationDataPreparator,
        classification_window: int = 300,
        strides: List[int] = [312],#[256, 312,412,484],
        device: Optional[str] = None,
        batch_size: int = 8,
    ):
        self.ner_model = ner_model
        self.ner_tokenizer = ner_tokenizer
        self.classifier_model = classifier_model
        self.classifier_tokenizer = classifier_tokenizer
        self.mention_filter = mention_filter
        self.classification_window = classification_window
        self.strides = strides
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        self.batch_size = batch_size
        self.text_dicts = text_dicts
        self.cls_preparator = cls_preparator

   
    def ner_inference_with_sliding(self, text: str,source_file:str) -> List[Dict]:
        all_mentions = []
        for stride in self.strides:
            mentions = self._ner_inference(text, stride=stride,source_file = source_file)
            all_mentions.extend(mentions)

        # Supprimer les doublons exacts
        seen = set()
        unique_mentions = []
        for m in all_mentions:
            key = (m['start'], m['end'], m['mention_format'])
            if key not in seen:
                seen.add(key)
                unique_mentions.append(m)
        return unique_mentions

    def _ner_inference(
        self,
        text: str,
        stride: int,
        max_length: int = 512,
        source_file: str = "pdf",
    ) -> List[Dict]:
        model = self.ner_model
        tokenizer = self.ner_tokenizer
        device = self.device
        batch_size = self.batch_size
    
        model.eval()
        model.to(device)
    
        encoding = tokenizer(
            text,
            return_offsets_mapping=True,
            return_attention_mask=False,
            truncation=False,
            padding=False
        )
        input_ids = encoding["input_ids"]
        offsets = encoding["offset_mapping"]
        id2label = model.config.id2label
    
        all_preds = [None] * len(input_ids)
        all_scores = [0.0] * len(input_ids)
        all_offsets = [None] * len(input_ids)
    
        def get_label_at(idx):
            if 0 <= idx < len(all_preds):
                pred_id = all_preds[idx]
                return id2label[pred_id] if pred_id is not None else "O"
            return "O"
    
        if len(input_ids) <= max_length - 2:
            inputs = tokenizer(
                text,
                return_tensors="pt",
                return_offsets_mapping=True,
                padding="max_length",
                truncation=True,
                max_length=max_length
            )
            offset_mapping = inputs.pop("offset_mapping")[0]
            inputs = {k: v.to(device) for k, v in inputs.items()}
    
            with torch.no_grad():
                outputs = model(**inputs)
                logits = outputs.logits[0][1:-1]
                probs = torch.softmax(logits, dim=-1)
                pred_ids = torch.argmax(probs, dim=-1).cpu().tolist()
                confidences = torch.max(probs, dim=-1).values.cpu().tolist()
    
            all_preds = pred_ids
            all_scores = confidences
            all_offsets = offset_mapping[1:-1]
    
        else:
            window_size = max_length - 2
            start = 0
            while start < len(input_ids):
                end = min(start + window_size, len(input_ids))
                chunk_ids = input_ids[start:end]
                chunk_offsets = offsets[start:end]
    
                input_chunk = [tokenizer.cls_token_id] + chunk_ids + [tokenizer.sep_token_id]
                attention_mask = [1] * len(input_chunk)
    
                pad_len = max_length - len(input_chunk)
                input_chunk += [tokenizer.pad_token_id] * pad_len
                attention_mask += [0] * pad_len
    
                input_tensor = {
                    "input_ids": torch.tensor([input_chunk], device=device),
                    "attention_mask": torch.tensor([attention_mask], device=device)
                }
    
                with torch.no_grad():
                    outputs = model(**input_tensor)
                    logits = outputs.logits[0][1:1 + len(chunk_ids)]
                    probs = torch.softmax(logits, dim=-1)
                    pred_ids = torch.argmax(probs, dim=-1).cpu().tolist()
                    confidences = torch.max(probs, dim=-1).values.cpu().tolist()
    
                for i, idx in enumerate(range(start, end)):
                    if confidences[i] > all_scores[idx]:
                        all_preds[idx] = pred_ids[i]
                        all_scores[idx] = confidences[i]
                        all_offsets[idx] = chunk_offsets[i]
    
                new_start = start + stride
                while (
                    new_start < len(input_ids) and
                    new_start < start + window_size and
                    get_label_at(new_start).startswith("I-")
                ):
                    new_start += 1
                if new_start == start:
                    break
                start = new_start
    
        # Mention extraction
        mentions = []
        current_label = None
        current_score_sum = 0.0
        token_count = 0
        start_char = None
        end_char = None
    
        for pred_id, offset, score in zip(all_preds, all_offsets, all_scores):
            if offset is None:
                continue
            start, end = map(int, offset)
            if start == end:
                continue
    
            label = id2label[pred_id]
    
            if label.startswith("B-"):
                # Save previous mention
                if current_label:
                    mentions.append({
                        "mention": text[start_char:end_char],
                        "mention_format": current_label,
                        "start": start_char,
                        "end": end_char,
                        "confidence": current_score_sum / token_count,
                        "source_file": source_file,
                    })
                # Start new mention
                current_label = label[2:]
                start_char = start
                end_char = end
                current_score_sum = score
                token_count = 1
    
            elif label.startswith("I-") and current_label == label[2:]:
                end_char = end
                current_score_sum += score
                token_count += 1
    
            else:
                # End current mention if exists
                if current_label:
                    mentions.append({
                        "mention": text[start_char:end_char],
                        "mention_format": current_label,
                        "start": start_char,
                        "end": end_char,
                        "confidence": current_score_sum / token_count,
                        "source_file": source_file,
                    })
                current_label = None
                start_char = None
                end_char = None
                current_score_sum = 0.0
                token_count = 0
    
        if current_label:
            mentions.append({
                "mention": text[start_char:end_char],
                "mention_format": current_label,
                "start": start_char,
                "end": end_char,
                "confidence": current_score_sum / token_count,
                "source_file": source_file,
            })
    
        # Optional: merge overlapping mentions with same type
        merged = []
        for m in mentions:
            if merged and merged[-1]['mention_format'] == m['mention_format'] and m['start'] <= merged[-1]['end']:
                overlap = merged[-1]['end'] - m['start']
                suffix = m['mention'][overlap:] if overlap < len(m['mention']) else ""
                merged[-1]['mention'] += suffix
                merged[-1]['end'] = m['end']
                merged[-1]['confidence'] = (merged[-1]['confidence'] + m['confidence']) / 2
            else:
                merged.append(m)
    
        return merged



    def classify_mentions(self, article_id: str, mentions: List[Dict]) -> Dict:
        self.classifier_model.eval()
        self.classifier_model.to(self.device)
    
        for m in mentions:
            start = m["start"]
            end = m['end']
            source_file = m['source_file']
    
            text_dict = self.text_dicts.get(source_file, {})
            text = text_dict.get(article_id, "")
    
            if not text:
                m["type"] = 2
                continue
    
            context = self.cls_preparator.extract_mention_context(text, start, end)
            if not context:
                m["type"] = 2
                m['cls_confidence'] = 0.0
                m['context'] = None
                continue
            
            
            inputs = self.classifier_tokenizer(
                            context,
                            return_tensors="pt",
                            truncation=True,
                            max_length=512,
                            padding="max_length"
                        ).to(self.device)
            with torch.no_grad():
                output = self.classifier_model(**inputs)
                logits = getattr(output, "logits", output[0])
                prob = torch.softmax(logits, dim=-1)
                pred = torch.argmax(prob, dim=-1).item()
                confidence = prob[0, pred].item()
                
            m['context'] = context
            m["type"] = pred  # 0 = Primary, 1 = Secondary
            m['cls_confidence'] = confidence
            m['standard_context'] = self.cls_preparator.extract_standard_context(text, start, end)
        return mentions

    def process_articles(self, articles_dicts) -> Dict[str, List[Dict]]:
        
        all_mentions = defaultdict(list)
        pdf_articles = articles_dicts.get("pdf", {})
        xml_articles = articles_dicts.get("xml", {})
    
        for article_id, text in pdf_articles.items():
            try:
                mentions = self.ner_inference_with_sliding(text=text, source_file="pdf")
                if mentions:
                    all_mentions[article_id].extend(mentions)
            except Exception as e:
                self.logger.warning(f"NER PDF failed for {article_id}: {e}")
    
        for article_id, text in xml_articles.items():
            try:
                mentions = self.ner_inference_with_sliding(text=text, source_file="xml")
                if mentions:
                    all_mentions[article_id].extend(mentions)
            except Exception as e:
                self.logger.warning(f"NER XML failed for {article_id}: {e}")
    
        filtered_mentions = self.mention_filter.filter_mentions(all_mentions, articles_dicts)
    
        classified_mentions = {}
        for article_id, mentions in filtered_mentions.items():
            if mentions:
                classified = self.classify_mentions(article_id, mentions)
                classified_mentions[article_id] = classified
    
        return classified_mentions


In [None]:
# A complete pipeline to extract, filter, classify, normalize and export data reference mentions from articles
class FindDataReferences:
    """
    Pipeline class to extract and classify dataset references from a set of scientific articles.
    """
    
    def __init__(
        self,
        mention_extractor_classifier: MentionExtractorAndClassifier,
        #extractor,   #Instance of ExtractTextFromArticles
        mention_filter,  # Instance of DatasetMentionFilter
        text_dicts,
        missing_ids,
    ):
        self.mention_extractor_classifier = mention_extractor_classifier
        #self.extractor = extractor
        self.mention_filter = mention_filter
        self.text_dicts  = text_dicts
        self.missing_ids = missing_ids

        # Logger
        self.logger = logging.getLogger(__name__)
        self.logger.setLevel(logging.INFO)

        if not self.logger.hasHandlers():
            handler = logging.StreamHandler()
            handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
            self.logger.addHandler(handler)

    def normalize_article_doi(self,doi):
    
        doi = doi.replace('/','_').replace(' ','').lower()
    
        return doi
        
    def normalize_data_doi(self,mention: str) ->str:
        
        if not mention:
            return ""
        doi = mention.strip().replace('\n', '').replace(' ', '')
        doi = doi.replace('https://dx/','')
        doi = re.sub(r'^(doi:|DOI:)', '', doi, flags=re.IGNORECASE)
        doi = re.sub(r'^(https?://)?(dx\.)?doi\.org/', '', doi, flags=re.IGNORECASE)
        doi = f"https://doi.org/{doi.lower()}" 
        doi = unicodedata.normalize('NFKD', doi)
        
        return doi
        
    def normalize_all_dois(self, dois_mentions):
        return [
            {**item, 'mention': self.normalize_data_doi(item['mention'])}
            for item in dois_mentions
        ]


    def normalize_data_acc(self, mentions: List[dict]) -> List[dict]:
        """
        Normalize a list of extracted accession mentions.
        Rules:
            1. Normalize diacritics, casing, and whitespace
            2. Support valid accession formats (GEO, EGA, GISAID, etc.)
            3. If range detected, split with smart logic
            4. For each accession, expand to all with same prefix in context
            5. Remove duplicates

        Args:
            mentions (List[dict]): list of mention dicts with 'mention' key
            context (str): full text context where mentions were extracted

        Returns:
            List[dict]: normalized, expanded and deduplicated mentions
        """
        def expand_mention_range(mention: str, context: str) -> list[str]:
            mention = mention.strip()
        
            # Heuristic: determine the prefix by cutting at the last underscore (e.g., LT05_232072_)
            sep_pos = mention.rfind('_')
        
            if sep_pos != -1:
                prefix = mention[:sep_pos + 1]
            else:
                # Fallback: use leading alphabetical characters as prefix (e.g., "GSE" in GSE12345)
                match = re.match(r'^([A-Z]+)', mention, flags=re.I)
                prefix = match.group(1) if match else ''
        
            if not prefix:
                return [mention]
        
            # Escape the prefix and build a regex to find all matching accessions with same prefix
            prefix_escaped = re.escape(prefix)
            # pattern = re.compile(rf'\b{prefix_escaped}[0-9_.\-]+\b', re.IGNORECASE)
            pattern = re.compile(rf'\b{prefix_escaped}\d{{2,}}[0-9_.\-]*\b', re.IGNORECASE)
        
            # Extract and deduplicate matching accessions from the context
            found = set(m for m in pattern.findall(context))
        
            return sorted(found)

        def split_range(norm: str) -> List[str]:
            match = re.search(r'(.+?)\s*-\s*(.+)', norm)
            if not match:
                return [norm.strip()]
            
            left, right = match.group(1).strip(), match.group(2).strip()

            left_clean = left.replace('.', '').replace(',', '').strip()
            right_clean = right.replace('.', '').replace(',', '').strip()
        
            if left_clean.isdigit() and right_clean.isdigit():
                return [left, right]
        
            left_alpha = re.match(r'^[A-Za-z]+', left)
            right_alpha = re.match(r'^[A-Za-z]+', right)
        
            if left_alpha and right_alpha and left_alpha.group() == right_alpha.group():
                return [left, right]
            else:
                return []

        pattern = re.compile(
            r'^[A-Z0-9_.\-]{4,}$'                      
            r'|'
            r'^[A-Z0-9_.\-]{4,}\s*-\s*[A-Z0-9_.\-]{4,}$' 
            ,
            flags=re.IGNORECASE 
        )

        cleaned_mentions = []
        seen = set()

        for mention_dict in mentions:
            raw = mention_dict.get("mention", "")
            if not raw:
                continue
                
            context = mention_dict.get("standard_context", "") 
            if not context:
                continue

            # Normalize: remove diacritics, trim, uppercase
            norm = unicodedata.normalize('NFKD', raw).strip().upper()

            if not pattern.match(norm):
                continue

            # Handle ranges
            if '-' in norm:
                split_parts = split_range(norm)
            else:
                split_parts = [norm]

            for part in split_parts:
                part = part.strip()
                # Expand mention range in context
                expanded = expand_mention_range(part, context)
                for acc in expanded:
                    if acc not in seen:
                        new_dict = mention_dict.copy()
                        new_dict["mention"] = acc
                        cleaned_mentions.append(new_dict)
                        seen.add(acc)

        return cleaned_mentions

    def deduplicate_extractions(self,extractions: List[Dict]) -> List[Dict]:
        """
        Deduplicates mention dictionaries based on normalized mention string,
        keeping the one with the highest confidence, regardless of type.
        If multiple mentions have the same value and same type, the one with higher confidence is kept.
        If types differ, the one with highest confidence is still preferred.
    
        :param extractions: List of mention dictionaries, each with keys:
                            - "mention"
                            - "type"
                            - "confidence"
        :return: Deduplicated list
        """
        best_by_mention = {}
    
        for m in extractions:
            key = m["mention"]
            new_conf = m.get("confidence", 0)
    
            if key not in best_by_mention:
                best_by_mention[key] = m
            else:
                existing = best_by_mention[key]
                existing_conf = existing.get("confidence", 0)
    
                # Always keep the one with the highest confidence
                if new_conf > existing_conf:
                    best_by_mention[key] = m
    
        return list(best_by_mention.values())

        
    def process_articles(
        self,
        output_csv: Path = SUBMISSION_FILE_PATH
    ):
        """
        Main method to run the pipeline.

        Args:
            articles (Dict[str, str] or List[Path]): Raw texts or file paths.
            file_format (str): Format of the articles ("pdf" or "xml").
            output_csv (Path): Path to output CSV.
        """

        results = []
        row_id = 0

        raw_texts = self.text_dicts
        pdf_texts = raw_texts.get("pdf", {})
        xml_texts = raw_texts.get("xml", {})
        
        self.logger.info(f"Extraction pipeline starting: {len(pdf_texts)} PDF articles to process ...")
        self.logger.info(f"Extraction pipeline starting: {len(xml_texts)} XML articles to process ...")
        
        # Extract, filter and classify data mentions
        classified_mentions = self.mention_extractor_classifier.process_articles(raw_texts)
        #print(classified_mentions)
        # Add usage type labels, normalize DOIs if nedded and prepare data for CSV export
        for article_id,  mentions in classified_mentions.items():
            # Normalize the id of the article itself
            article_id = self.normalize_article_doi(article_id)
            # Remove articles with missing ids
            if article_id in self.missing_ids:
                continue 
            # Split mentions by format: DOI vs ACC
            doi_mentions = [m for m in mentions if m["mention_format"].upper() == "DOI"]
            acc_mentions = [m for m in mentions if m["mention_format"].upper() == "ACC"]
            # Normalize DOI mentions
            normalized_dois = self.normalize_all_dois(doi_mentions)
            # Remove duplicated dois
            normalized_dois = self.deduplicate_extractions(normalized_dois)
            # Normalize ACC mentions
            acc_mentions = self.normalize_data_acc(acc_mentions)
            # Combine all the mentions back
            mentions = normalized_dois + acc_mentions
            
            for mention in mentions:
                
                mention_text = mention["mention"]
               
                type_value = mention.get("type", 2)
                
                if type_value not in [0, 1]:
                    
                    continue  # Skip mention with unknown classification
                
                usage_type = "Primary" if type_value == 0 else "Secondary"

                results.append({
                    "row_id": row_id,
                    "article_id": article_id,
                    "dataset_id": mention_text,
                    "type": usage_type
                })
                row_id += 1

        # Convert results to DataFrame
        df = pd.DataFrame(results)
        
        # Drop duplicates based on article_id + dataset_id
        df["article_id"] = df["article_id"].astype(str)
        df["dataset_id"] = df["dataset_id"].astype(str)
        df["type"] = df["type"].astype(str)

        df.drop_duplicates(subset=["article_id", "dataset_id"], inplace=True)
        
        # Drop rows with missing values in essential columns
        df.dropna(subset=["article_id", "dataset_id", "type"], inplace=True)
        
        #  Drop the old provisional row_id (if it exists)
        if "row_id" in df.columns:
            df.drop(columns=["row_id"], inplace=True)
        
        # Reindex and create a clean row_id
        df.reset_index(drop=True, inplace=True)
        df.insert(0, "row_id", df.index)

        try:
            df.to_csv(output_csv, index=False)
            self.logger.info(f"Extraction pipeline complete. Results saved to: {output_csv}")
            print(f"Print fallback: Results saved to: {output_csv}")
        except Exception as e:
            self.logger.error(f"Failed to save CSV: {e}")
            print(f"Failed to write output to {output_csv}: {e}")


        self.logger.info(f"{len(df)} rows written to submission file.")



In [None]:
missing_labels_df = train_labels_df[train_labels_df.type=='Missing']
missing_ids = [str(aid) for aid in missing_labels_df.article_id.unique()]

In [None]:
def find_data_directory(data_root: Path, subdir: str, ext: str) -> Optional[Path]:
    """
    Attempt to find a directory containing files with given extension.
    
    Checks in this order:
    1. Preferred directory: `data_root / subdir`
    2. 'test' directory itself: `data_root / 'test'`
    3. Any subdirectory inside 'test' that contains matching files
    
    Returns:
        Path to directory containing the files, or None if no files found.
    """
    preferred_dir = data_root / subdir
    test_dir = data_root / "test"

    # Check preferred directory first
    if preferred_dir.exists() and preferred_dir.is_dir():
        if any(preferred_dir.glob(f"*.{ext}")):
            return preferred_dir

    # Check 'test' directory itself
    if test_dir.exists() and any(test_dir.glob(f"*.{ext}")):
        return test_dir

    # Check subdirectories under 'test'
    if test_dir.exists():
        for sub in test_dir.iterdir():
            if sub.is_dir() and any(sub.glob(f"*.{ext}")):
                return sub

    # No directory with matching files found
    return None

def get_test_files(data_root: Path, subdir: str, ext: str) -> Tuple[List[Path], Optional[Path]]:
    """
    Locate the directory containing files with given extension and return the file paths.

    Returns:
        - List of Path objects for files with given extension
        - The Path to the directory containing these files, or None if none found
    """
    dir_path = find_data_directory(data_root, subdir, ext)
    if dir_path is None:
        return [], None

    files = list(dir_path.glob(f"*.{ext}"))
    if not files:
        return [], None

    print(f"{len(files)} {ext.upper()} files found in {dir_path} for testing.")
    return files, dir_path



# Attempt to find PDF and XML test files
test_pdf_paths, TEST_PDF_DIR = get_test_files(DATA_ROOT, "test/PDF", "pdf")
test_xml_paths, TEST_XML_DIR = get_test_files(DATA_ROOT, "test/XML", "xml")


# Filter out missing IDs immediately
test_pdf_paths = [p for p in test_pdf_paths if p.stem not in missing_ids]
test_xml_paths = [p for p in test_xml_paths if p.stem not in missing_ids]


test_text_dicts = {"pdf": {}, "xml": {}}

pdf_text_extractor = ExtractTextFromArticles(file_format="PDF")
xml_text_extractor = ExtractTextFromArticles(file_format="XML")

# Process whichever test data is available
if test_pdf_paths and test_xml_paths:
    test_pdf_only_paths = get_pdf_only(test_pdf_paths,test_xml_paths)
    
    pdf_test_texts = pdf_text_extractor.extract_text_data(test_pdf_only_paths)
    pdf_clean_test_texts = pdf_text_extractor.clean_articles(pdf_test_texts)

    xml_test_texts = xml_text_extractor.extract_text_data(test_xml_paths)
    xml_clean_test_texts = xml_text_extractor.clean_articles(xml_test_texts)

    test_text_dicts['pdf'] = pdf_clean_test_texts
    test_text_dicts['xml'] = xml_clean_test_texts
    
elif test_pdf_paths:
    pdf_test_texts = pdf_text_extractor.extract_text_data(test_pdf_paths)
    pdf_clean_test_texts = pdf_text_extractor.clean_articles(pdf_test_texts)

    test_text_dicts['pdf'] = pdf_clean_test_texts

elif test_xml_paths:
    xml_test_texts = xml_text_extractor.extract_text_data(test_xml_paths)
    xml_clean_test_texts = xml_text_extractor.clean_articles(xml_test_texts)

    test_text_dicts['xml'] = xml_clean_test_texts

else:
    # Neither PDF nor XML test files found - raise a clear error
    raise RuntimeError(
        "No test data found: no PDF files under 'test/PDF' or 'test' directories, "
        "and no XML files under 'test/XML' or 'test' directories."
    )


In [None]:
mention_filter = DatasetMentionFilter(confidence_threshold=0.70)
cls_preparator = ClassificationDataPreparator(tokenizer=cls_tokenizer)

mention_extractor = MentionExtractorAndClassifier(
    ner_model=ner_model,
    ner_tokenizer=ner_tokenizer,
    classifier_model=cls_model,
    classifier_tokenizer=cls_tokenizer,
    mention_filter=mention_filter,
    cls_preparator=cls_preparator,
    classification_window=300,
    strides=[256, 312,412, 484],
    text_dicts = test_text_dicts
)

data_references_finder = FindDataReferences(mention_extractor,mention_filter,
                                            test_text_dicts,missing_ids)

data_references_finder.process_articles()

In [None]:
submission = pd.read_csv(SUBMISSION_FILE_PATH)

In [None]:
submission