## add
- post_validate.py : remove fp
- predict.py : use llb to predict primary or secondary

In [1]:
! uv pip uninstall --system 'tensorflow'
! uv pip install --system --no-index --find-links='/kaggle/input/latest-mdc-whls/whls' 'pymupdf' 'vllm' 'triton' 'logits-processor-zoo' 'numpy<2'
! mkdir -p /tmp/src

[2mUsing Python 3.11.13 environment at: /usr[0m
[2mUsing Python 3.11.13 environment at: /usr[0m
[2mAudited [1m5 packages[0m [2min 193ms[0m[0m


In [2]:
%%writefile /tmp/src/helpers.py
import logging, os, kagglehub, inspect
from pathlib import Path
import polars as pl

IS_KAGGLE_ENV = sum(['KAGGLE' in k for k in os.environ]) > 0
IS_KAGGLE_SUBMISSION = bool(os.getenv("KAGGLE_IS_COMPETITION_RERUN"))
COMP_DIR = Path(('/kaggle/input/make-data-count-finding-data-references' if IS_KAGGLE_SUBMISSION else kagglehub.competition_download('make-data-count-finding-data-references')))
PDF_DIR = COMP_DIR / ('test' if IS_KAGGLE_SUBMISSION else 'train') / 'PDF'
XML_DIR = COMP_DIR / ('test' if IS_KAGGLE_SUBMISSION else 'train') / 'XML'
WORKING_DIR = Path(('/kaggle/working/' if IS_KAGGLE_ENV else '.working/'))



DOI_LINK = 'https://doi.org/'

DEFAULT_LOG_LEVEL = os.getenv("LOG_LEVEL", "DEBUG").upper() if not IS_KAGGLE_SUBMISSION else "WARNING"
LOG_FILE_PATH = os.getenv("LOG_FILE", "logs/project.log")
LOG_DIR = Path(LOG_FILE_PATH).parent

LOG_DIR.mkdir(parents=True, exist_ok=True)

LOG_FORMAT = "%(levelname)s %(asctime)s  [%(filename)s:%(lineno)d - %(funcName)s()] %(message)s"
LOG_DATEFMT = "%Y-%m-%d %H:%M:%S"

def get_logger(name=None):
    if name is None:
        frame = inspect.currentframe()
        if frame is None or frame.f_back is None:
            name = "__main__"
        else:
            name = frame.f_back.f_globals.get("__name__", "__main__")

    logger = logging.getLogger(name)

    if not logger.handlers:
        logger.setLevel(DEFAULT_LOG_LEVEL)
        formatter = logging.Formatter(fmt=LOG_FORMAT, datefmt=LOG_DATEFMT)
        ch = logging.StreamHandler()
        ch.setLevel(DEFAULT_LOG_LEVEL)
        ch.setFormatter(formatter)
        fh = logging.FileHandler(LOG_FILE_PATH)
        fh.setLevel(DEFAULT_LOG_LEVEL)
        fh.setFormatter(formatter)
        logger.addHandler(ch)
        logger.addHandler(fh)
        logger.propagate = False
    return logger

def is_doi_link(name: str) -> pl.Expr:
    return pl.col(name).str.starts_with(DOI_LINK).and_(
        ~pl.col(name).str.contains(r"/dl\.")
    )

def string_normalization(name: str) -> pl.Expr:
    return pl.col(name).str.normalize("NFKC").str.replace_all(r"[^\p{Ascii}]", '').str.replace_all(r"https?://zenodo\.org/record/(\d+)", r" 10.5281/zenodo.$1 ")

def get_df(parse_dir: str):
    records = []
    txt_files = list(Path(parse_dir).glob('*.txt'))
    for txt_file in txt_files:
        id_ = txt_file.stem
        with open(txt_file, 'r') as f:
            text = f.read()
        records.append({'article_id': id_, 'text': text})
    return pl.DataFrame(records).with_columns(string_normalization('text').alias('text'))

def assume_type(df: pl.DataFrame) -> pl.DataFrame:
    return (
        df.with_columns(pl.when(is_doi_link('dataset_id').or_(pl.col('dataset_id').str.starts_with('SAMN'))).then(pl.lit('Primary')).otherwise(pl.lit('Secondary')).alias('type'))
    )

def score(df, gt, on, tag='all'):
    hits = gt.join(df, on=on)
    tp = hits.height
    fp = df.height - tp
    fn = gt.height - tp
    f1 = 2 * tp / (2 * tp + fp + fn) if (2 * tp + fp + fn) != 0 else 0.0
    return f"{tag} - f1: {f1:.4f} [{tp}/{fp}/{fn}]"

def evaluate(df, on=['article_id', 'dataset_id']):
    gt = pl.read_csv(COMP_DIR/'train_labels.csv').filter(pl.col('type')!='Missing')
    return (
        score(df, gt, on),
        score(df.filter(is_doi_link('dataset_id')), gt.filter(is_doi_link('dataset_id')), on, 'doi'),
        score(df.filter(~is_doi_link('dataset_id')), gt.filter(~is_doi_link('dataset_id')), on, 'acc'),
    )

Overwriting /tmp/src/helpers.py


In [3]:
%%writefile /tmp/src/parse.py
import argparse
from pathlib import Path
import pymupdf
import xml.etree.ElementTree as ET
import os
import glob
import re
from helpers import get_logger, PDF_DIR, XML_DIR

l = get_logger()

def pdf_to_txt(output_dir: Path):
    output_dir.mkdir(parents=True, exist_ok=True)
    pdf_files = list(PDF_DIR.glob("*.pdf")) + list(PDF_DIR.glob("*.PDF"))
    existing_txt_files = {f.stem for f in output_dir.glob("*.txt")}
    pdf_count = len(pdf_files)
    for pdf_file in pdf_files:
        txt_file = output_dir / f"{pdf_file.stem}.txt"
        if pdf_file.stem in existing_txt_files:
            continue
        try:
            text = ""
            with pymupdf.open(pdf_file) as doc:
                for page in doc:
                    text += page.get_text()
            txt_file.write_text(text, encoding='utf-8')
        except Exception:
            pass
    return pdf_count

def detect_xml_style(root):
    if '}' in root.tag and 'http://www.tei-c.org/ns/1.0' in root.tag:
        return 'tei'
    html_tags = {'html', 'body', 'div', 'p', 'span', 'h1', 'h2', 'h3', 'h4', 'h5', 'h6'}
    for elem in root.iter():
        if elem.tag in html_tags:
            return 'html'
    return 'generic'

def get_block_elements(style):
    if style == 'tei':
        return {
            'p', 'head', 'title', 'abstract', 'div', 'item', 'list', 'table',
            'row', 'cell', 'note', 'quote', 'lg', 'l', 'sp', 'speaker'
        }
    elif style == 'html':
        return {
            'p', 'div', 'h1', 'h2', 'h3', 'h4', 'h5', 'h6', 'ul', 'ol', 'li',
            'table', 'tr', 'td', 'th', 'blockquote', 'pre', 'section', 'article',
            'header', 'footer', 'nav', 'aside'
        }
    else:
        return {
            'paragraph', 'section', 'title', 'abstract', 'item', 'list', 'entry',
            'cell', 'note', 'block'
        }

def extract_text_with_structure(element, style, block_elements, short_content_threshold=50):
    text_parts = []
    if element.text and element.text.strip():
        cleaned_text = re.sub(r'\s+', ' ', element.text.strip())
        text_parts.append(cleaned_text)
    for child in element:
        child_text = extract_text_with_structure(child, style, block_elements, short_content_threshold)
        if child_text:
            text_parts.append(child_text)
    if element.tail and element.tail.strip():
        cleaned_tail = re.sub(r'\s+', ' ', element.tail.strip())
        text_parts.append(cleaned_tail)
    if '}' in element.tag:
        tag_name = element.tag.split('}', 1)[1]
    else:
        tag_name = element.tag
    result_text = ' '.join(text_parts)
    if tag_name in block_elements:
        if len(result_text) < short_content_threshold:
            return result_text + ' '
        else:
            return result_text + '\n\n'
    else:
        return result_text + ' '

def convert_xml_to_txt(xml_file_path, txt_file_path):
    try:
        tree = ET.parse(xml_file_path)
        root = tree.getroot()
        style = detect_xml_style(root)
        l.info(f"Detected XML style: {style}")
        block_elements = get_block_elements(style)
        short_content_threshold = 50
        structured_text = extract_text_with_structure(root, style, block_elements, short_content_threshold)
        with open(txt_file_path, 'w', encoding='utf-8') as txt_file:
            txt_file.write(structured_text)
        return True
    except ET.ParseError as e:
        l.error(f"Error: Could not parse file {xml_file_path}. It may not be valid XML. Error: {e}")
        return False
    except Exception as e:
        l.error(f"Unknown error processing file {xml_file_path}: {e}")
        return False

def batch_convert_xml_folder(input_folder, output_folder):
    if not os.path.exists(output_folder):
        os.makedirs(output_folder)
    xml_files = glob.glob(os.path.join(input_folder, "*.xml"))
    xml_count = len(xml_files)
    overwrite_count = 0
    for xml_file in xml_files:
        original_filename = os.path.splitext(os.path.basename(xml_file))[0]
        txt_file_path = os.path.join(output_folder, original_filename + ".txt")
        if os.path.exists(txt_file_path):
            overwrite_count += 1
        if convert_xml_to_txt(xml_file, txt_file_path):
            l.info(f"Converted: {original_filename}.xml -> {original_filename}.txt")
    return xml_count, overwrite_count

def main():
    parser = argparse.ArgumentParser()
    # In a notebook environment, you might need to handle args differently,
    # but for a script file, this is fine.
    # We can pass arguments via a list to parse_args() if not running from CLI.
    # For now, keeping as is to match the file structure.
    parser.add_argument('--pdf-dir', type=Path, default=PDF_DIR, help='Directory containing PDF files')
    parser.add_argument('--xml-dir', type=Path, default=XML_DIR, help='Directory containing XML files')
    parser.add_argument('--output-dir', type=Path, default=Path('/tmp/train_parse'), help='Directory to save text files')
    
    # When running in a notebook, you might call main with specific args
    # so we prevent parse_args() from looking at sys.argv
    args = parser.parse_args(args=[]) 

    # Process PDFs
    pdf_count = pdf_to_txt(args.output_dir)
    l.info(f"Found and processed {pdf_count} PDF files.")

    # Process XMLs
    xml_count, overwrite_count = batch_convert_xml_folder(args.xml_dir, str(args.output_dir))
    l.info(f"Found and processed {xml_count} XML files.")
    l.info(f"Overwrote {overwrite_count} text files from XML conversions.")

    # Print summary to terminal
    print(f"Processed {pdf_count} PDF files.")
    print(f"Processed {xml_count} XML files.")
    print(f"Overwrote {overwrite_count} text files from XML conversions.")

if __name__ == "__main__":
    main()

Overwriting /tmp/src/parse.py


In [4]:
%%writefile /tmp/src/check_parse.py
import polars as pl
from pathlib import Path
from helpers import *

l=get_logger()

def gt_dataset_id_normalization(name:str) -> pl.Expr:
    return (
        pl.when(is_doi_link(name))
        .then(pl.col(name).str.split(DOI_LINK).list.last())
        .otherwise(name)
        .str.to_lowercase()
    )

def main():
    if IS_KAGGLE_SUBMISSION:
        l.debug('skipping check_parse for submission')
        return
    df = (
        get_df('/tmp/train_parse')
        .with_columns(pl.col('text').str.replace_all('\s+', '').str.to_lowercase().alias('text'))
    )

    gt = (
        pl.read_csv(COMP_DIR/'train_labels.csv')
        .filter(pl.col('article_id').is_in(df['article_id']))
        .filter(pl.col('type')!='Missing')
        .with_columns(gt_dataset_id_normalization('dataset_id').alias('norm_id'))
    )

    l.info(f"pymupdf misses: {gt.join(df, on='article_id').with_columns(hit=pl.col('text').str.contains(pl.col('norm_id'), literal=True)).filter(~pl.col('hit')).height} dataset_ids")

if __name__=='__main__': main()

Overwriting /tmp/src/check_parse.py


In [5]:
%%writefile /tmp/src/getid.py
import re
import polars as pl
from typing import Optional, Tuple

from helpers import *

COMPILED_PATTERNS = {
    # 'ref_header_patterns': [re.compile(r'\b(R\s*E\s*F\s*E\s*R\s*E\s*N\s*C\s*E\s*S|BIBLIOGRAPHY|LITERATURE CITED|WORKS CITED|CITED WORKS|ACKNOWLEDGEMENTS)\b[:\s]*', re.IGNORECASE)],    
    'ref_header_patterns': [re.compile(r'\b(R\s*E\s*F\s*E\s*R\s*E\s*N\s*C\s*E\s*S|BIBLIOGRAPHY|LITERATURE CITED|WORKS CITED|CITED WORKS|ACKNOWLEDGEMENTS|REFERENCES AND NOTES)\b[:\s]*', re.IGNORECASE)],
    'citation_pattern': re.compile(r'^\s*(\[\d+\]|\(\d+\)|\d+\.|\d+\)|\d+(?=\s|$))\s*'),
    'first_citation_patterns': [
        re.compile(r'^\s*\[1\]\s*'),
        re.compile(r'^\s*\(1\)\s*'),
        re.compile(r'^\s*1\.\s*'),
        re.compile(r'^\s*1\)\s*'),
        re.compile(r'^\s*1(?=\s|$)'),
    ],
}

l = get_logger()

def find_last_reference_header(text: str, header_patterns: list[re.Pattern]) -> Optional[int]:
    last_match_idx = None
    for pattern in header_patterns:
        matches = list(pattern.finditer(text))
        if matches:
            last_match_idx = matches[-1].start()
    return last_match_idx

def find_last_first_citation(text: str) -> Optional[int]:
    lines = text.splitlines()
    last_match_line = None
    for line_num, line in enumerate(lines):
        line = line.strip()
        for pattern in COMPILED_PATTERNS['first_citation_patterns']:
            if pattern.match(line):
                next_lines = lines[line_num:line_num+3]
                if any(COMPILED_PATTERNS['citation_pattern'].match(l.strip()) for l in next_lines[1:]):
                    last_match_line = line_num
                break
    return last_match_line

def find_reference_start(text: str) -> Optional[int]:
    lines = text.splitlines()
    last_first_citation = find_last_first_citation(text)
    if last_first_citation is not None:
        return last_first_citation
    start_search_idx = int(len(lines) * 0.5)
    for i in range(start_search_idx, len(lines)):
        line = lines[i].strip()
        if COMPILED_PATTERNS['citation_pattern'].match(line):
            next_lines = lines[i:i+3]
            if sum(1 for l in next_lines if COMPILED_PATTERNS['citation_pattern'].match(l.strip())) >= 2:
                for j in range(i, max(-1, i-10), -1):
                    if not COMPILED_PATTERNS['citation_pattern'].match(lines[j].strip()):
                        return j + 1
                return max(0, i-10)
    return None

def split_text_and_references(text: str) -> Tuple[str, str]:
    header_idx = find_last_reference_header(text, COMPILED_PATTERNS['ref_header_patterns'])
    if header_idx is not None:
        header_idx2 = find_last_reference_header(text[:header_idx].strip(), COMPILED_PATTERNS['ref_header_patterns'])
        if header_idx2 is not None:
            header_idx3 = find_last_reference_header(text[:header_idx2].strip(), COMPILED_PATTERNS['ref_header_patterns'])
            if header_idx3 is not None:
                return text[:header_idx3].strip(), text[header_idx3:].strip()
            return text[:header_idx2].strip(), text[header_idx2:].strip()
        return text[:header_idx].strip(), text[header_idx:].strip()
    ref_start_line = find_reference_start(text)
    if ref_start_line is not None:
        lines = text.splitlines()
        body = '\n'.join(lines[:ref_start_line])
        refs = '\n'.join(lines[ref_start_line:])
        return body.strip(), refs.strip()
    return text.strip(), ''

def get_splits(df: pl.DataFrame) -> pl.DataFrame:
    bodies, refs = [], []
    for raw_text in df['text']:
        main, ref = split_text_and_references(raw_text)
        bodies.append(main)
        refs.append(ref)
    return df.with_columns(pl.Series('body', bodies), pl.Series('ref', refs))

def tidy_extraction(df) -> pl.DataFrame:
    bad_ids = [f'{DOI_LINK}{e}' for e in ['10.5061/dryad', '10.5281/zenodo', '10.6073/pasta']]
    
    # 学术论文DOI前缀黑名单 (从post_filter.py移植)
    PAPER_PREFIXES = [
        "10.1038","10.1007","10.1126","10.1016","10.1101","10.1021","10.1145","10.1177",
        "10.1093","10.1080","10.1111","10.1098","10.1103","10.1186","10.1371","10.7554",
        "10.1039","10.1002","10.3390","10.1073","10.1097","10.15252","10.1136","10.1091",
        "10.1523", "10.1152", "10.1128", "10.1155", "10.1242", "10.1182", "10.1012","10.1023",
        "10.1001","10.1006","10.1017","10.1029","10.1034","10.1037","10.1042","10.1044","10.1046",
        "10.1053","10.1056","10.1061","10.1063","10.5194","10.1029","10.5194","10.1175","10.2307",
        "10.3389","10.1590","10.1130","10.1088","10.1146","10.1890","10.1017","10.1086","10.3133",
        "10.1046","10.1109","10.1140","10.3354","10.1534","10.1023","10.6084","10.1158","10.1139",
        "10.1006","10.4319","10.1785","10.1099","10.1143","10.1089","10.1104","10.1074","10.3897",
        "10.1071","10.1056","10.1121","10.3201","10.3109","10.18637","10.1061","10.1364","10.1163",
        "10.1144","10.1159","10.1063","10.1161","10.1113","10.7717","10.1515","10.1001","10.11646",
        "10.1108","10.1115","10.6070","10.1617","10.1306","10.1645","10.14379","10.1899","10.4271",
        "10.1210","10.4161","10.21105","10.1183","10.14411","10.12688","10.1148","10.1105","10.3892",
        "10.18632","10.3945","10.1107","10.1659","10.1162","10.1586","10.3322","10.1641","10.2147",
        "10.1603","10.1067","10.1201","10.5441","10.48550","10.1200","10.5860","10.1078","10.3168",
        "10.2217","10.1127","10.2193","10.1164","10.5027","10.17161","10.2136","10.1142","10.7589",
        "10.1292","10.13155","10.1053","10.1554","10.3920","10.2337","10.5065","10.1037","10.2134",
        "10.1248","10.1044","10.17600","10.5479","10.5751","10.2110","10.3174","10.1212","10.17660",
        "10.1530","10.4067","10.1172","10.1094","10.1674","10.18194","10.1042","10.4103","10.1190",
        "10.2174","10.1117","10.3233","10.1577","10.2737","10.4172","10.2475","10.3732","10.15454",
        "10.1643","10.1214","10.1642","10.13039","10.2135","10.1084","10.4049","10.1124","10.1261",
        "10.5155","10.1118","10.1083","10.3324","10.7326","10.1055","10.1270","10.1213","10.3835",
        "10.1385","10.3171","10.1373","10.1637","10.4269","10.2478","10.1096","10.1137","10.1378",
        "10.4143","10.26197","10.1194","10.7150","10.13140","10.1246","10.2987","10.2144","10.4315",
        "10.1593","10.2202","10.1196","10.1110","10.1134","10.3748","10.21273","10.1503","10.1517","10.1215"
    ]
    
    # 创建学术论文DOI过滤函数
    def is_paper_prefix_func(dataset_id: str) -> bool:
        """检查是否为学术论文DOI前缀"""
        for prefix in PAPER_PREFIXES:
            if dataset_id.startswith(f"{DOI_LINK}{prefix}"):
                return True
        return False

    doi_df = (
        df.with_columns(pl.col('text').str.extract_all(r'10\s*\.\s*\d{4,9}\s*/\s*\S+').alias('match'))
        .explode('match')
        .drop_nulls('match')
        .with_columns(
            pl.col('match').str.replace_all(r'\s+', '')
                           .str.replace(r'[^A-Za-z0-9]+$', '')
                           .str.to_lowercase()
                           .alias('dataset_id')
        )
        .group_by('article_id', 'dataset_id')
        .agg('match')
        .with_columns((DOI_LINK + pl.col('dataset_id')).alias('dataset_id'))
        # 排除学术论文DOI前缀
        .filter(~pl.col('dataset_id').map_elements(is_paper_prefix_func, return_dtype=pl.Boolean))
    )

    REGEX_IDS = (
        r"(?i)\b(?:"
        r"CHEMBL\d+|"
        r"E-GEOD-\d+|E-PROT-\d+|E-MTAB-\d+|E-MEXP-\d+|EMPIAR-\d+|"
        r"ENSBTAG\d+|ENSOARG\d+|"
        r"EPI_ISL_\d{5,}|EPI\d{6,7}|"
        r"HPA\d+|CP\d{6}|IPR\d{6}|PF\d{5}|BX\d{6}|KX\d{6}|K0\d{4}|CAB\d{6}|"
        r"NC_\d{6}\.\d{1}|NM_\d{9}|"
        r"PRJNA\d+|PRJEB\d+|PRJDB\d+|PXD\d+|SAMN\d+|"
        r"GSE\d+|GSM\d+|GPL\d+|"
        r"PDB\s?[1-9][A-Z0-9]{3}|HMDB\d+|"
        r"dryad\.[^\s\"<>]+|pasta\/[^\s\"<>]+|"
        r"(?:SR[PRX]|STH|ERR|DRR|DRX|DRP|ERP|ERX)\d+|"
        r"CVCL_[A-Z0-9]{4}|"
        r"[1-5]\.(?:10|20|30|40|50|60|70|80|90)\.\d{2,4}\.\d{2,4}"
        r")"
    )

    acc_df = (
        df.with_columns(
            pl.col('text').str.extract_all(REGEX_IDS).alias('match')
        )
        .explode('match')
        .drop_nulls('match')
        .with_columns(
            pl.col('match').str.replace_all(r'\s+', '')
                           .str.replace(r'[^A-Za-z0-9]+$', '')
                           .str.replace(r'(?i)^PDB', '')
                           .alias('dataset_id')
        )
        .group_by('article_id', 'dataset_id')
        .agg('match')
        .with_columns(
            pl.when(pl.col('dataset_id').str.starts_with('dryad.'))
              .then(f'{DOI_LINK}10.5061/' + pl.col('dataset_id'))
              .otherwise('dataset_id')
              .alias('dataset_id')
        )
        .with_columns(
            pl.when(pl.col('dataset_id').str.starts_with('pasta/'))
              .then(f'{DOI_LINK}10.6073/' + pl.col('dataset_id'))
              .otherwise('dataset_id')
              .alias('dataset_id')
        )
    )

    df = pl.concat([doi_df, acc_df])
    
    # ======== 智能过滤策略 ========
    # 1. 基本过滤（保留所有可能的匹配）
    df = df.unique(['article_id', 'dataset_id'])
    
    # 2. 添加置信度评分而不是直接过滤
    def calculate_confidence(dataset_id):
        """为每个ID计算置信度分数"""
        score = 0.5  # 基础分数
        
        # 高置信度模式
        high_confidence_patterns = [
            r'10\.\d{4,9}/',  # DOI格式
            r'chebi\.org',    # 已知仓库
            r'ensembl\.org',
            r'GSE\d+',
            r'PRJNA\d+',
            r'CHEMBL\d+',
            r'PXD\d+',
        ]
        
        # 低置信度模式（可能误判）
        low_confidence_patterns = [
            r'^\d+$',  # 纯数字
            r'^\d+\.\d+$',  # 简单小数
            r'^[A-Z]{1,2}\d{1,3}$',  # 短字母数字组合
            r'^Figure\s+\d+',  # 图表引用
            r'^Table\s+\d+',  # 表格引用
        ]
        
        # 应用规则
        for pattern in high_confidence_patterns:
            if re.search(pattern, dataset_id, re.IGNORECASE):
                score += 0.3
        
        for pattern in low_confidence_patterns:
            if re.search(pattern, dataset_id, re.IGNORECASE):
                score -= 0.4
        
        # 确保分数在0-1之间
        return max(0.1, min(1.0, score))
    
    # 添加置信度列
    confidence_scores = [calculate_confidence(id) for id in df['dataset_id'].to_list()]
    df = df.with_columns(pl.Series('confidence', confidence_scores))
    
    # 3. 应用智能过滤（只过滤明显错误的匹配）
    df = (
        df
        # 排除明显错误的匹配（置信度极低）
        .filter(pl.col('confidence') > 0.2)
        # 弱化自身引用过滤
        .filter(
            ~pl.col('dataset_id').str.replace("https?://", "")
            .str.contains(pl.col('article_id').str.replace('_','/'))
        )
        # 更精确的坏ID过滤
        .filter(~pl.col('dataset_id').is_in(bad_ids))
    )
    
    # 保留原有的match去重
    df = df.with_columns(pl.col('match').list.unique())
    # ======== 智能过滤结束 ========
    return df

def get_context_window(text: str, substring: str, window: int = 100) -> str:
    idx = text.find(substring)
    if idx == -1:
        raise ValueError
    start = max(idx - window, 0)
    end = min(idx + len(substring) + window, len(text))
    return text[start:end]

def get_window_df(text_df, ids_df):
    df = ids_df.join(text_df, on='article_id')
    windows = []
    for text, match_ids in df.select('text', 'match').rows():
        windows.append(get_context_window(text, match_ids[0]))
    return df.with_columns(pl.Series('window', windows)).select('article_id', 'dataset_id', 'window')

def preprocess_text(text: str) -> str:
    """
    Preprocessing text for better ID extraction:
    1. Replace punctuation (commas, parentheses) with spaces to help ID extraction
    2. Join lines that were likely broken during text extraction
    """
    # Step 1: Replace punctuation with spaces
    cleaned_text = re.sub(r'[,()]', ' ', text)
    
    # Step 2: Join broken lines (original logic)
    lines = cleaned_text.split('\n')
    processed_lines = []
    i = 0
    while i < len(lines):
        line = lines[i]
        # Heuristic: If a line ends with a common DOI prefix part or a hyphen,
        # it might be a broken line.
        if (line.strip().endswith('/') or line.strip().endswith('-')) and i + 1 < len(lines):
            # Join with the next line
            processed_lines.append(line.strip() + lines[i+1].strip())
            i += 2 # Skip the next line as it has been merged
        else:
            processed_lines.append(line)
            i += 1
    return '\n'.join(processed_lines)

def main(input_dir: str, parquet_dir: str, output_dir: str) -> None:
    text_df = get_df(input_dir)

    # !!! ADD THIS PREPROCESSING STEP !!!
    original_texts = text_df['text'].to_list()
    preprocessed_texts = [preprocess_text(t) for t in original_texts]
    text_df = text_df.with_columns(pl.Series("text", preprocessed_texts))

    df = get_splits(text_df)
    df = tidy_extraction(df)
    df = get_window_df(text_df, df)
    df.write_parquet(parquet_dir)
    df = assume_type(df)
    df.select(['article_id', 'dataset_id', 'type']).with_row_index(name='row_id').write_csv(output_dir)
    
    if not IS_KAGGLE_SUBMISSION:
        print("*"*10)
        results = evaluate(df)
        for r in results: l.info(r)
        print("*"*10)
        results = evaluate(df, on=['article_id', 'dataset_id', 'type'])
        for r in results: l.info(r)

if __name__=='__main__': 
    # In a Kaggle environment, you would typically get these paths from environment variables
    # or predefined constants, but for this script structure, we keep the local-style setup.
    # The actual execution in the notebook will likely call main() with different paths.
    
    # For local execution:
    # input_dir = './temp/parse'
    # parquet_dir = './temp/extracted.parquet'
    # output_dir = './temp/submission.csv'
    # main(input_dir, parquet_dir, output_dir)
    
    # For Kaggle notebook execution (mimicking the original main function's behavior):
    main(
        input_dir='/tmp/train_parse',
        parquet_dir='/tmp/extracted.parquet',
        output_dir='/kaggle/working/submission.csv'
    )

# def main():
#     text_df = get_df('/tmp/train_parse')
#     df = get_splits(text_df)
#     df = tidy_extraction(df)
        
#     write_the_match(text_df,df)
    
#     df = get_window_df(text_df, df)
#     df.write_parquet('/tmp/extracted.parquet')
#     df = assume_type(df)
#     df.select(['article_id', 'dataset_id', 'type']).with_row_index(name='row_id').write_csv('/kaggle/working/submission.csv')
#     if not IS_KAGGLE_SUBMISSION:
#         results = evaluate(df)
#         for r in results: l.info(r)
#         results = evaluate(df, on=['article_id', 'dataset_id', 'type'])
#         for r in results: l.info(r)

# if __name__=='__main__': main()

Overwriting /tmp/src/getid.py


In [6]:
%%writefile /tmp/src/llm_validate.py
import polars as pl
import os

from helpers import *

l = get_logger()

SYS_PROMPT_CLASSIFY_DOI = """
1. Priority Rules (highest → lowest)
1.1 Always classify as A (Data) if:
DOI prefix matches a known data repository:

Dryad: 10.5061

Zenodo: 10.5281

Dl: 10.15468

ICPSR: 10.3886

USGS data: 10.5066

Mendeley Data: 10.17632

Dataverse: 10.7910/DVN

OpenNeuro: 10.18112/openneuro.

PANGAEA: 10.1594/PANGAEA.


2. Classify as B (Literature) if:
DOI prefix belongs to a publisher (e.g., 10.1038, 10.1007, 10.1126, 10.1016, 10.1101, 10.1021, 10.1145, 10.1177, 10.1093, 10.1080, 10.1111, etc.).

Context indicates a journal article, book, conference paper, preprint, protocol, or method paper, without any repository/data storage signal.

Mentions only “supplementary material” or “supplementary information” without a repository.

3. Ambiguous cases
No repository prefix and no clear context → default to B.


4. Output
Only output:

A → data repository / dataset

B → literature / non-data resource

Few-shot examples

“Raw images are stored on Figshare (DOI 10.6084/m9.figshare.1234567).” → A

“Sequence reads available under BioProject accession PRJNA765432.” → A

“As described in Nature Methods (DOI 10.1038/s41592-020-0793-2).” → B

“See Supplementary Data at Zenodo (10.5281/zenodo.987654).” → A

“Method details published in J. Proteome Res. DOI: 10.1021/acs.jproteome.0c00845.” → B

“Data uploaded to Dryad (10.5061/dryad.x1y2z3).” → A

“Referenced paper: DOI 10.1101/2020.01.01.123456 (bioRxiv preprint).” → B

“Metabolomics data in MetaboLights MTBLS1234.” → A

“The MRI scans are deposited at OpenNeuro (DOI 10.18112/openneuro.ds000001.v1.0.0).” → A

“Protein structure described in Science (DOI 10.1126/science.abc1234).” → B
""".strip()

def build_df():
    df = pl.read_parquet('/tmp/extracted.parquet')
    df.filter(~is_doi_link('dataset_id')).select('article_id', 'dataset_id').write_csv('/tmp/accid_sub.csv')
    return df.filter(is_doi_link('dataset_id'))

def build_prompt(tokenizer, df):
    prompts = []
    for doi, text in df.select('dataset_id', 'window').rows():
        messages = [{'role':'system','content': SYS_PROMPT_CLASSIFY_DOI}, {'role':'user', 'content': text}]
        prompts.append(tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False))
    return df.with_columns(pl.Series('prompt', prompts))

if __name__=='__main__':
    os.environ["VLLM_USE_V1"] = "0"
    import vllm
    from logits_processor_zoo.vllm import MultipleChoiceLogitsProcessor
    model_path = "/kaggle/input/qwen2.5/transformers/32b-instruct-awq/1"
    llm = vllm.LLM(model_path, quantization='awq', tensor_parallel_size=2, gpu_memory_utilization=0.9, trust_remote_code=True, dtype="half", enforce_eager=True, max_model_len=2048, disable_log_stats=True, disable_custom_all_reduce=True, enable_prefix_caching=True, task='generate')
    tokenizer = llm.get_tokenizer()
    df = build_df()
    df = build_prompt(tokenizer, df)
    prompts = df['prompt'].to_list()
    mclp = MultipleChoiceLogitsProcessor(tokenizer, choices=["A", "B"])
    outputs = llm.generate(prompts, vllm.SamplingParams(seed=777, temperature=0.2, skip_special_tokens=True, max_tokens=1, logits_processors=[mclp], logprobs=len(mclp.choices)), use_tqdm=True)
    logprobs = [{lp.decoded_token: lp.logprob for lp in list(lps)} for lps in [output.outputs[0].logprobs[0].values() for output in outputs]]
    choices = [max(d, key=d.get) for d in logprobs]
    types = {'A': True, 'B': False}
    choices = [types[c] for c in choices]
    df = df.with_columns(pl.Series('type', choices))
    df.filter(pl.col('type')).select('article_id', 'dataset_id').write_csv('/tmp/doi_sub.csv')
    df = pl.concat([pl.read_csv('/tmp/doi_sub.csv'), pl.read_csv('/tmp/accid_sub.csv')])
    df = assume_type(df)
    df.select(['article_id', 'dataset_id', 'type']).with_row_index(name='row_id').write_csv('/kaggle/working/submission.csv')
    if not IS_KAGGLE_SUBMISSION:
        results = evaluate(df)
        for r in results: l.info(r) 
        results = evaluate(df, on=['article_id', 'dataset_id', 'type'])
        for r in results: l.info(r)


    
    try:
        del llm, tokenizer
    except:
        pass
    
    import gc, torch
    gc.collect()
    torch.cuda.empty_cache()

Overwriting /tmp/src/llm_validate.py


In [7]:
%%writefile /tmp/src/post_filter.py
import polars as pl
from helpers import *

"""
Fourth essence: Post-filter to cut FP DOIs that look like literature.
- Read /kaggle/working/submission.csv (output of llm_validate.py)
- Join with /tmp/extracted.parquet to get context window
- Drop DOI rows that (1) start with typical publisher prefixes AND (2) have no data-ish words nearby
- Keep accessions untouched
"""

l = get_logger()

PAPER_PREFIXES = [
    "10.5061","10.5281","10.17632","10.1594","10.15468","10.17882","10.7937","10.7910","10.6073",
    "10.3886","10.3334","10.4121","10.5066","10.5067","10.18150","10.25377","10.25387","10.23642","10.24381","10.22033"
]

CONTEXT_RE = r"(?i)\b(data(?:set)?|repository|archive|deposited|available|supplementary|raw(?:\s+data)?|uploaded|hosted|stored|accession)\b"

def is_paper_prefix(col: str = "dataset_id") -> pl.Expr:
    expr = pl.lit(False)
    for p in PAPER_PREFIXES:
        expr = expr | pl.col(col).str.starts_with(f"{DOI_LINK}{p}")
    return expr

def main():
    sub = pl.read_csv("/kaggle/working/submission.csv")

    # Normalize columns: drop row_id if present so concat widths match
    if "row_id" in sub.columns:
        sub = sub.drop("row_id")

    # Context windows
    win = pl.read_parquet("/tmp/extracted.parquet").select("article_id", "dataset_id", "window")

    # DOI & ACC split
    doi_rows = sub.filter(is_doi_link("dataset_id")).join(win, on=["article_id", "dataset_id"], how="left")
    acc_rows = sub.filter(~is_doi_link("dataset_id"))

    keep_mask = (
        (~is_paper_prefix("dataset_id"))  # not a known paper prefix
        | doi_rows["window"].fill_null("").str.contains(CONTEXT_RE)
    )

    kept_doi = doi_rows.filter(keep_mask).select("article_id", "dataset_id", "type")
    final = pl.concat([kept_doi, acc_rows.select("article_id", "dataset_id", "type")])

    # Re-eval & save
    if not IS_KAGGLE_SUBMISSION:
        for r in evaluate(final): l.info(r)
        for r in evaluate(final, on=["article_id", "dataset_id", "type"]): l.info(r)

    final.with_row_index("row_id").write_csv("/kaggle/working/submission.csv")

if __name__ == "__main__":
    main()

Overwriting /tmp/src/post_filter.py


In [8]:
%%writefile /tmp/src/post_validate.py

from helpers import *
import polars as pl
import os


l = get_logger()


PROMPT_CLASSIFY_CITATION_TYPE = '''
# Role & Task
You are an expert data citation analyst. Your task is to classify a given citation from a scientific paper into one of two categories: **A** (Data) or **B** (Not Data). Base your decision strictly on the provided abstract and the context of the citation.

## Instructions
1.  **Read the provided abstract** to understand the research context.
2.  **Analyze the citation context** for key linguistic cues.
3.  **Classify the citation** as either **A** or **B** based on the definitions below.
4.  **Output only a single letter: A or B.** Do not output any other text, explanation, or formatting.

## Category Definitions

### **Category A: DATA**
The citation points to a dataset. This includes:
*   **Primary Data:** Raw or processed data that the current study's authors collected, generated, or created.
*   **Secondary Data:** Data that was originally produced by other researchers but is being *used as a dataset* in the current study.
*   **Key Phrases:** "data are available at", "we collected", "we measured", "data were obtained from", "dataset", "downloaded from", "deposited in", repository names (e.g., GenBank, Zenodo, Figshare, TCIA).

### **Category B: NOT DATA**
The citation points to a traditional scholarly publication or other non-data resource. This includes:
*   Journal articles, books, conference proceedings, preprints, protocols, methods papers.
*   **Key Phrases:** "as described in", "according to", "previous study", "et al.", "paper", "article", "methodology", "was used for analysis" (without indicating data access).
*   Citations that provide background context or methodological description but do not serve as the source of the data used in the analysis.

## Input Format
You will be provided with the following three pieces of information:
Paper Abstract: {abstract}
Citation: {dataset_id}
Citation Context: {context}

## Critical Thinking Guidelines
*   A DOI or URL can point to either data (A) or a paper (B). The context determines the classification.
*   If the citation is used to describe the *source* of the data for the current study's analysis, it is likely **A**.
*   If the citation is used to provide background, justify a method, or compare results, it is likely **B** (a reference to another paper).
*   When in doubt, rely on the linguistic cues in the "Citation Context".

## Examples for Pattern Recognition

**Example 1 (Classify as A):**
*   Context: "Three out of four cohorts used in this study can be found on The Cancer Imaging Archive (TCIA)24: Canadian benchmark dataset23: https://doi.org/10.7937/K9/TCIA.2017.8oje5q00."
*   **Reasoning:** The text states cohorts are "used in this study" and provides direct repository links. This is a clear case of citing external data for use.
*   **Output:** A

**Example 2 (Classify as B):**
*   Context: "data presented here are available at the SEANOE dataportal: https://doi.org/10.17882/94052 (ZooScan dataset Grandremy et al. 2023c)"
*   **Reasoning:** The phrase "data presented here" indicates this is the authors' own data being deposited, not a citation to an external source they are using. The "(Author et al. Year)" format is a classic literature citation style.
*   **Output:** B

**Example 3 (Classify as A):**
*   Context: "GBIF occurrence data: Vulpes vulpes: https://doi.org/10.15468/dl.wgtneb (28 May 2021)."
*   **Reasoning:** Explicitly names the data source (GBIF) and provides a direct access link/DOI for the specific dataset used.
*   **Output:** A

**Example 4 (Classify as B):**
*   Context: "North American soil NCBI SRA SRP035367 Smith & Peay [36] ITS2-Soil"
*   **Reasoning:** While it mentions a data repository ID (SRP035367), it couples it with a standard literature citation "[36]". The context suggests it is referencing the *paper* by Smith & Peay that describes the data, not directly citing the dataset itself for use.
*   **Output:** B

## Ready for Input
Begin your analysis. Remember: Output only **A** or **B**.
'''

def get_context_window(text: str, substring: str, window: int = 600) -> str:
    idx = text.find(substring)
    if idx == -1:
        return "no context", "no abstraction"
    start = max(idx - window, 0)
    end = min(idx + len(substring) + window, len(text))
    return text[start:end] , text[:1000]




def find_context_win(tokenizer,df):
    text_df = pl.read_parquet('/tmp/context_data.parquet')
    # print(text_df)
    df = df.join(text_df, on=["article_id","dataset_id"], how="inner")
    df = df.drop("type")
    print(df)

    prompts = []
    
    for article_id,dataset_id,text,match in df.select(["article_id","dataset_id","text",'match']).rows():

        context, abstract = get_context_window(text,match)
        user_content = f"""
        Paper Abstract: {abstract}
        
        Citation: {dataset_id}

        
        Citation Context: {context}
        """
        messages = [
            {"role": "system", "content": PROMPT_CLASSIFY_CITATION_TYPE},
            {"role": "user", "content": user_content.strip()}
        ]
        prompts.append(
            tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
        )
        
    return df.with_columns(pl.Series("prompt", prompts))

    

if __name__=="__main__":
    os.environ["VLLM_USE_V1"] = "0"
    MODEL_PATH = "/kaggle/input/qwen2.5/transformers/32b-instruct-awq/1"
    import vllm
    from logits_processor_zoo.vllm import MultipleChoiceLogitsProcessor

    llm = vllm.LLM(
        MODEL_PATH,
        quantization='awq',
        tensor_parallel_size=2,
        gpu_memory_utilization=0.9,
        trust_remote_code=True,
        dtype="half",
        enforce_eager=True,
        max_model_len=16384,
        disable_log_stats=True, 
        disable_custom_all_reduce=True,
        enable_prefix_caching=True,
        task='generate')

    tokenizer = llm.get_tokenizer()

    df=pl.read_csv("/kaggle/working/submission.csv")
    
    if "row_id" in df.columns:
        df = df.drop("row_id")

    # print(df)

    doi_df = df.filter(is_doi_link("dataset_id"))
    acc_df = df.filter(~is_doi_link("dataset_id"))

    # print(doi_df)

    df = find_context_win(tokenizer,doi_df)

    
    
    prompts = df['prompt'].to_list()
    mclp = MultipleChoiceLogitsProcessor(tokenizer, choices=["A", "B","C"])
    outputs = llm.generate(prompts, vllm.SamplingParams(seed=777, temperature=0.7, skip_special_tokens=True, max_tokens=1, logits_processors=[mclp], logprobs=len(mclp.choices)), use_tqdm=True)
    logprobs = [{lp.decoded_token: lp.logprob for lp in list(lps)} for lps in [output.outputs[0].logprobs[0].values() for output in outputs]]
    choices = [max(d, key=d.get) for d in logprobs]
    types = {'A': True, 'B': False}
    choices = [types[c] for c in choices]
    df = df.with_columns(pl.Series('type', choices))
    df.filter(pl.col('type')).select('article_id', 'dataset_id').write_csv('/tmp/doi_sub.csv')
    df = pl.concat([pl.read_csv('/tmp/doi_sub.csv'), pl.read_csv('/tmp/accid_sub.csv')])
    df = assume_type(df)
    df.select(['article_id', 'dataset_id', 'type']).with_row_index(name='row_id').write_csv('/kaggle/working/submission.csv')
    # print(df)
    if not IS_KAGGLE_SUBMISSION:
        results = evaluate(df)
        for r in results: l.info(r) 
        results = evaluate(df, on=['article_id', 'dataset_id', 'type'])
        for r in results: l.info(r)
    
    
    try:
        del llm, tokenizer
    except:
        pass
    
    import gc, torch
    gc.collect()
    torch.cuda.empty_cache()

Overwriting /tmp/src/post_validate.py


In [None]:
%%writefile /tmp/src/predict.py

from helpers import *
import polars as pl
import os


l = get_logger()


PROMPT_CLASSIFY_CITATION_TYPE = '''
# Role & Task
You are an expert data citation analyst. Your task is to classify a given citation from a scientific paper into one of two categories based on the context: **A (Primary Data)** or **B (Secondary Data)**.

## Instructions
1.  **Read the provided abstract** to understand the research context.
2.  **Analyze the citation context** for key linguistic cues.
3.  **Classify the citation** as either **A** or **B** based on the definitions below.
4.  **Output only a single letter: A or B.** Do not output any other text, explanation, or formatting.

## Category Definitions

### **Category A: PRIMARY DATA**
The data was generated, collected, or created by the **authors of the current study**. This is *their* data.
*   **Key Phrases:** "we collected", "we generated", "our data", "data are available at [URL/DOI]", "data have been deposited", "this study presents", "supplementary data".

### **Category B: SECONDARY DATA**
The data was produced by **other researchers** or external sources and is being reused or analyzed by the current study's authors.
*   **Key Phrases:** "data were obtained from", "publicly available data", "previously published data", "retrieved from", "downloaded from", "[Dataset Name] dataset", "database", citing a specific external source.

## Input Format
You will be provided with the following three pieces of information:
Paper Abstract: {abstract}
Citation: {dataset_id}
Citation Context: {context}


## Decision Framework
Answer these questions based on the **Citation Context**:

1.  **Who is the source of the data?**
    *   If the context implies the **authors themselves** are the source (e.g., "we," "our"), classify as **A**.
    *   If the context names an **external source** (e.g., a repository, another study, a database), classify as **B**.

2.  **What is the action being described?**
    *   **A (Primary)** actions: *depositing, making available, presenting* their own data.
    *   **B (Secondary)** actions: *using, obtaining, accessing, downloading, analyzing* existing data from elsewhere.

## Examples for Pattern Recognition

**Example 1 (Classify as B):**
*   Context: "Three out of four cohorts **used in this study** can be found on The Cancer Imaging Archive (TCIA)24: Canadian benchmark dataset23: https://doi.org/10.7937/K9/TCIA.2017.8oje5q00."
*   **Reasoning:** The authors are describing external datasets they **used** (a Secondary action). The source is TCIA, not themselves.
*   **Output:** B

**Example 2 (Classify as A):**
*   Context: "Additional research data **supporting this publication are available** at 10.25377/sussex.21184705."
*   **Reasoning:** The authors are stating the availability of data that **supports their own publication**. The source is implied to be themselves.
*   **Output:** A

**Example 3 (Classify as B):**
*   Context: "GBIF occurrence data: Vulpes vulpes: https://doi.org/10.15468/dl.wgtneb (28 May 2021)."
*   **Reasoning:** The data is explicitly sourced from an external repository (GBIF). The authors are referring to data they reused.
*   **Output:** B

**Example 4 (Classify as A):**
*   Context: "Data referring to Barbieux et al. (2017; https://doi.org/10.17882/49388) are freely available on SEANOE."
*   **Reasoning:** This is a tricky case. The citation format "(Author et al. Year)" suggests a literature reference. However, the phrase "Data referring to" and the direct data DOI indicate the authors are citing **their own previously published dataset** (from a 2017 paper) that is now available. This is their Primary data.
*   **Output:** A

## Ready for Input
Begin your analysis. Remember: Output only **A** or **B**.

'''

def get_context_window(text: str, substring: str, window: int = 600) -> str:
    idx = text.find(substring)
    if idx == -1:
        return "no context", "no abstraction"
    start = max(idx - window, 0)
    end = min(idx + len(substring) + window, len(text))
    return text[start:end] , text[:1000]




def find_context_win(tokenizer,df):
    text_df = pl.read_parquet('/tmp/context_data.parquet')
    # print(text_df)
    df = df.join(text_df, on=["article_id","dataset_id"], how="inner")
    df = df.drop("type")
    print(df)

    prompts = []
    
    for article_id,dataset_id,text,match in df.select(["article_id","dataset_id","text",'match']).rows():

        context, abstract = get_context_window(text,match)
        user_content = f"""
        Paper Abstract: {abstract}
        
        Citation: {dataset_id}

        
        Citation Context: {context}
        """
        messages = [
            {"role": "system", "content": PROMPT_CLASSIFY_CITATION_TYPE},
            {"role": "user", "content": user_content.strip()}
        ]
        prompts.append(
            tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
        )
        
    return df.with_columns(pl.Series("prompt", prompts))

    

if __name__=="__main__":
    os.environ["VLLM_USE_V1"] = "0"
    MODEL_PATH = "/kaggle/input/qwen2.5/transformers/32b-instruct-awq/1"
    import vllm
    from logits_processor_zoo.vllm import MultipleChoiceLogitsProcessor

    llm = vllm.LLM(
        MODEL_PATH,
        quantization='awq',
        tensor_parallel_size=2,
        gpu_memory_utilization=0.9,
        trust_remote_code=True,
        dtype="half",
        enforce_eager=True,
        max_model_len=16384,
        disable_log_stats=True, 
        disable_custom_all_reduce=True,
        enable_prefix_caching=True,
        task='generate')

    tokenizer = llm.get_tokenizer()

    df=pl.read_csv("/kaggle/working/submission.csv")
    
    if "row_id" in df.columns:
        df = df.drop("row_id")


    doi_df = df.filter(is_doi_link("dataset_id"))
    acc_df = df.filter(~is_doi_link("dataset_id"))



    df = find_context_win(tokenizer,doi_df)

    
    
    prompts = df['prompt'].to_list()
    mclp = MultipleChoiceLogitsProcessor(tokenizer, choices=["A", "B"])
    outputs = llm.generate(prompts, vllm.SamplingParams(seed=777, temperature=0.8, skip_special_tokens=True, max_tokens=1, logits_processors=[mclp], logprobs=len(mclp.choices)), use_tqdm=True)
    logprobs = [{lp.decoded_token: lp.logprob for lp in list(lps)} for lps in [output.outputs[0].logprobs[0].values() for output in outputs]]
    choices = [max(d, key=d.get) for d in logprobs]
    types = {'A':'Primary', 'B':'Secondary'}
    choices = [types[c] for c in choices]


    
    df = df.with_columns(pl.Series('type', choices))
    df.select('article_id', 'dataset_id','type').write_csv('/tmp/doi_sub.csv')

    acc_df = assume_type(acc_df)
    acc_df.select('article_id','dataset_id','type').write_csv("/tmp/accid_sub.csv")
    df = pl.concat([pl.read_csv('/tmp/doi_sub.csv'), pl.read_csv('/tmp/accid_sub.csv')])
    
    df.select(['article_id', 'dataset_id', 'type']).with_row_index(name='row_id').write_csv('/kaggle/working/submission.csv')
    # print(df)
    if not IS_KAGGLE_SUBMISSION:
        results = evaluate(df)
        for r in results: l.info(r) 
        results = evaluate(df, on=['article_id', 'dataset_id', 'type'])
        for r in results: l.info(r)
    
    
    try:
        del llm, tokenizer
    except:
        pass
    
    import gc, torch
    gc.collect()
    torch.cuda.empty_cache()

Overwriting /tmp/src/predict.py


In [10]:
%cd /tmp
! python src/parse.py /tmp/train_parse


/tmp
INFO 2025-09-08 08:36:08  [parse.py:133 - main()] Found and processed 524 PDF files.
INFO 2025-09-08 08:36:08  [parse.py:88 - convert_xml_to_txt()] Detected XML style: html
INFO 2025-09-08 08:36:08  [parse.py:114 - batch_convert_xml_folder()] Converted: 10.1590_1678-4685-gmb-2018-0055.xml -> 10.1590_1678-4685-gmb-2018-0055.txt
INFO 2025-09-08 08:36:08  [parse.py:88 - convert_xml_to_txt()] Detected XML style: html
INFO 2025-09-08 08:36:08  [parse.py:114 - batch_convert_xml_folder()] Converted: 10.1021_jacs.2c06519.xml -> 10.1021_jacs.2c06519.txt
INFO 2025-09-08 08:36:08  [parse.py:88 - convert_xml_to_txt()] Detected XML style: html
INFO 2025-09-08 08:36:08  [parse.py:114 - batch_convert_xml_folder()] Converted: 10.1107_s2056989015019891.xml -> 10.1107_s2056989015019891.txt
INFO 2025-09-08 08:36:08  [parse.py:88 - convert_xml_to_txt()] Detected XML style: html
INFO 2025-09-08 08:36:08  [parse.py:114 - batch_convert_xml_folder()] Converted: 10.1186_s12881-019-0773-3.xml -> 10.1186_s1

In [11]:
! python src/check_parse.py

INFO 2025-09-08 08:36:17  [check_parse.py:31 - main()] pymupdf misses: 36 dataset_ids


In [12]:
! python src/getid.py
# old
# INFO 2025-09-08 07:12:48  [getid.py:299 - main()] all - f1: 0.1256 [626/8620/93]
# INFO 2025-09-08 07:12:48  [getid.py:299 - main()] doi - f1: 0.0530 [233/8270/54]
# INFO 2025-09-08 07:12:48  [getid.py:299 - main()] acc - f1: 0.6689 [393/350/39]

# INFO 2025-09-08 07:12:48  [getid.py:301 - main()] all - f1: 0.0997 [497/8749/222]
# INFO 2025-09-08 07:12:48  [getid.py:301 - main()] doi - f1: 0.0362 [159/8344/128]
# INFO 2025-09-08 07:12:48  [getid.py:301 - main()] acc - f1: 0.5753 [338/405/94]
# add Codeadd Markdown

# new
# **********
# INFO 2025-09-08 07:22:31  [getid.py:313 - main()] all - f1: 0.6038 [615/703/104]
# INFO 2025-09-08 07:22:31  [getid.py:313 - main()] doi - f1: 0.5432 [236/346/51]
# INFO 2025-09-08 07:22:31  [getid.py:313 - main()] acc - f1: 0.6490 [379/357/53]
# **********
# INFO 2025-09-08 07:22:31  [getid.py:316 - main()] all - f1: 0.4742 [483/835/236]
# INFO 2025-09-08 07:22:31  [getid.py:316 - main()] doi - f1: 0.3659 [159/423/128]
# INFO 2025-09-08 07:22:31  [getid.py:316 - main()] acc - f1: 0.5548 [324/412/108]

**********
INFO 2025-09-08 08:36:25  [getid.py:314 - main()] all - f1: 0.5814 [648/862/71]
INFO 2025-09-08 08:36:25  [getid.py:314 - main()] doi - f1: 0.4871 [246/477/41]
INFO 2025-09-08 08:36:25  [getid.py:314 - main()] acc - f1: 0.6596 [402/385/30]
**********
INFO 2025-09-08 08:36:25  [getid.py:317 - main()] all - f1: 0.4612 [514/996/205]
INFO 2025-09-08 08:36:25  [getid.py:317 - main()] doi - f1: 0.3366 [170/553/117]
INFO 2025-09-08 08:36:25  [getid.py:317 - main()] acc - f1: 0.5644 [344/443/88]


In [13]:
! python src/llm_validate.py

INFO 09-08 08:36:34 [__init__.py:243] No platform detected, vLLM is running on UnspecifiedPlatform
Traceback (most recent call last):
  File "/tmp/src/llm_validate.py", line 90, in <module>
    llm = vllm.LLM(model_path, quantization='awq', tensor_parallel_size=2, gpu_memory_utilization=0.9, trust_remote_code=True, dtype="half", enforce_eager=True, max_model_len=2048, disable_log_stats=True, disable_custom_all_reduce=True, enable_prefix_caching=True, task='generate')
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/vllm/utils.py", line 1161, in inner
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/vllm/entrypoints/llm.py", line 247, in __init__
    

In [14]:
! python src/post_validate.py


INFO 09-08 08:36:47 [__init__.py:243] No platform detected, vLLM is running on UnspecifiedPlatform
Traceback (most recent call last):
  File "/tmp/src/post_validate.py", line 121, in <module>
    llm = vllm.LLM(
          ^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/vllm/utils.py", line 1161, in inner
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/vllm/entrypoints/llm.py", line 247, in __init__
    self.llm_engine = LLMEngine.from_engine_args(
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/vllm/engine/llm_engine.py", line 503, in from_engine_args
    vllm_config = engine_args.create_engine_config(usage_context)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/vllm/engine/arg_utils.py", line 1098, in create_engine_config
    device_config = DeviceConfig(device=self.device)
                    ^^^^^

In [15]:
! grep "f1:" /tmp/logs/project.log

INFO 2025-09-08 08:29:34  [getid.py:314 - main()] all - f1: 0.5814 [648/862/71]
INFO 2025-09-08 08:29:34  [getid.py:314 - main()] doi - f1: 0.4871 [246/477/41]
INFO 2025-09-08 08:29:34  [getid.py:314 - main()] acc - f1: 0.6596 [402/385/30]
INFO 2025-09-08 08:29:34  [getid.py:317 - main()] all - f1: 0.4612 [514/996/205]
INFO 2025-09-08 08:29:34  [getid.py:317 - main()] doi - f1: 0.3366 [170/553/117]
INFO 2025-09-08 08:29:34  [getid.py:317 - main()] acc - f1: 0.5644 [344/443/88]
INFO 2025-09-08 08:36:25  [getid.py:314 - main()] all - f1: 0.5814 [648/862/71]
INFO 2025-09-08 08:36:25  [getid.py:314 - main()] doi - f1: 0.4871 [246/477/41]
INFO 2025-09-08 08:36:25  [getid.py:314 - main()] acc - f1: 0.6596 [402/385/30]
INFO 2025-09-08 08:36:25  [getid.py:317 - main()] all - f1: 0.4612 [514/996/205]
INFO 2025-09-08 08:36:25  [getid.py:317 - main()] doi - f1: 0.3366 [170/553/117]
INFO 2025-09-08 08:36:25  [getid.py:317 - main()] acc - f1: 0.5644 [344/443/88]


In [None]:
! python src/predict.py