#### 1. Setup and Dependencies

First, ensure all necessary libraries are installed.

In [2]:
# !pip show transformers
!pip show numpy

# Name: transformers
# Version: 4.52.4

Name: numpy
Version: 2.3.0
Summary: Fundamental package for array computing in Python
Home-page: https://numpy.org
Author: Travis E. Oliphant et al.
Author-email: 
License: Copyright (c) 2005-2025, NumPy Developers.
 All rights reserved.

 Redistribution and use in source and binary forms, with or without
 modification, are permitted provided that the following conditions are
 met:

     * Redistributions of source code must retain the above copyright
        notice, this list of conditions and the following disclaimer.

     * Redistributions in binary form must reproduce the above
        copyright notice, this list of conditions and the following
        disclaimer in the documentation and/or other materials provided
        with the distribution.

     * Neither the name of the NumPy Developers nor the names of any
        contributors may be used to endorse or promote products derived
        from this software without specific prior written permission.

 THIS SOFTWARE IS PROVIDED

In [1]:
# 1.1. Install necessary libraries
# Use !pip install for notebook environment
# !pip install transformers trl accelerate bitsandbytes sentencepiece lxml PyMuPDF spacy peft
# !python -m spacy download en_core_web_sm # Download a small spaCy model

# 1.2. Import Libraries
import os
import re
import json
import pandas as pd
from dataclasses import dataclass, field, asdict
from typing import Set, List, Optional, Dict, Any

import fitz # PyMuPDF
from lxml import etree # For XML parsing
import spacy
import kagglehub

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.utils.quantization_config import BitsAndBytesConfig
from trl import SFTTrainer, SFTConfig
from datasets import Dataset, concatenate_datasets
from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training

from tqdm.auto import tqdm
import gc # For garbage collection

# For KaggleHub integration (assuming it's set up or models are downloaded)
# You might need to install kagglehub if you plan to use it directly for model download
# !pip install kagglehub

# 1.3. Configure CUDA for local GPU
if torch.cuda.is_available():
    print(f"CUDA is available! Using GPU: {torch.cuda.get_device_name(0)}")
    device = torch.device("cuda")
    torch.cuda.empty_cache() # Clear GPU memory
else:
    print("CUDA is not available. Using CPU.")
    device = torch.device("cpu")



CUDA is available! Using GPU: NVIDIA GeForce RTX 3050 Laptop GPU


In [2]:
# Import classes from local utility file
import mdc_data_processing_utils

# If mdc_data_processing_utils.py has been changed and saved.
# To load the changes without restarting the kernel:
import importlib
importlib.reload(mdc_data_processing_utils)

# Now, any calls to functions from mdc_data_processing_utils
# will use the newly reloaded code.
from mdc_data_processing_utils import (
    ArticleData,
    DatasetCitation,
    LlmTrainingData,
    SubmissionData,
    MdcFileTextExtractor,
)


In [3]:
# Define constants for file paths and model configurations
BASE_INPUT_DIR = './kaggle/input/make-data-count-finding-data-references'
BASE_OUTPUT_DIR = "./kaggle/working"

# Define directories for articles in train and test sets
TRAIN_DATA_DIR = os.path.join(BASE_INPUT_DIR, 'train')
TEST_DATA_DIR = os.path.join(BASE_INPUT_DIR, 'test')
TRAIN_LABELS_PATH = os.path.join(BASE_INPUT_DIR, 'train_labels.csv')

# Define the path to the few-shot examples CSV
FEW_SHOT_CSV_PATH = os.path.join(BASE_OUTPUT_DIR, "few_shot_examples.csv")

# Define the base model path
QWEN_BASE_MODEL_PATH = kagglehub.model_download("qwen-lm/qwen-3/transformers/0.6b")

# Output directory for the fine-tuned model and results
FINE_TUNED_MODEL_OUTPUT_DIR = os.path.join(BASE_OUTPUT_DIR, "results")
SAMPLE_SUBMISSION_PATH = os.path.join(BASE_OUTPUT_DIR, "submission.csv")

# Load spaCy model for sentence segmentation and potentially other NLP tasks
# python -m spacy download en_core_web_sm 
NLP_SPACY = spacy.load("en_core_web_sm")

FEW_SHOT_EXAMPLES = """
### Example 1 (Primary)
Article Abstract: We present a novel dataset of forest growth measurements collected over 5 years in the Amazon rainforest. This data was used to develop a new model of carbon sequestration.
Dataset ID: 10.5061/dryad.r6nq870
Data Citation Context: The raw data for this study, including all measurements and derived variables, has been deposited in the Dryad Digital Repository (dryad.2f62927). This newly generated dataset supports the findings of our research.
Classification: Primary

### Example 2 (Primary - Another Example)
Article Abstract: We developed a new computational model for protein folding and generated a large dataset of simulated protein structures.
Dataset ID: PDB12345
Data Citation Context: The simulated protein structures (PDB12345) are available in the Protein Data Bank and were generated as part of this study's novel computational approach.
Classification: Primary

### Example 3 (Primary - Ambiguous Case)
Article Abstract: This study utilized both newly collected field data and some previously published climate data to analyze ecosystem changes.
Dataset ID: 10.5281/zenodo.7074790
Data Citation Context: Data Availability Statement The data that support the findings of this study are openly available in zenodo at https://doi.org/10.5281/zenodo.7074790.(https://doi.org/10.5281/zenodo.7074790).
Classification: Primary

### Example 4 (Secondary)
Article Abstract: We re-analyzed publicly available gene expression data to identify new biomarkers for cancer.
Dataset ID: GSE12345
Data Citation Context: We downloaded the gene expression profiles from the GEO database (GSE12345), which were previously published by Smith et al. (2020). This existing dataset was re-analyzed for our current study.
Classification: Secondary

### Example 5 (Missing - Irrelevant Context)
Article Abstract: Implications for intraplate earthquake behavior and the geomorphic longevity of bedrock fault scarps in a low strain-rate cratonic region.
Dataset ID: 10.1080/08120090802546977
Data Citation Context: (2009) Constraints on the current rate of deformation and surface uplift of the Australian continent from a new seismic database and low-T thermochronological data. Australian Journal of Earth Sciences, 56(2), 99-110. https://doi.org/(https://doi.org/10.1080/08120090802546977).
Classification: Missing

### Example 6 (Missing - No context)
Article Abstract: Our study investigates the social dynamics of ant colonies.
Dataset ID: 10.1038/s41586-023-06000-0
Data Citation Context: 
Classification: Missing
"""

#### 3. Data Loading and Initial Preprocessing

This section will cover how to load the raw competition data (full text articles and labels) and begin structuring it.

#### Load Labeled Training Data

In [None]:
def load_file_paths(dataset_type_dir: str) -> pd.DataFrame: 
    pdf_path = os.path.join(dataset_type_dir, 'PDF')
    xml_path = os.path.join(dataset_type_dir, 'XML')
    dataset_type = os.path.basename(dataset_type_dir)
    pdf_files = [f for f in os.listdir(pdf_path) if f.endswith('.pdf')]
    xml_files = [f for f in os.listdir(xml_path) if f.endswith('.xml')]
    df_pdf = pd.DataFrame({
        'article_id': [f.replace('.pdf', '') for f in pdf_files],
        'pdf_file_path': [os.path.join(pdf_path, f) for f in pdf_files]
    })
    df_xml = pd.DataFrame({
        'article_id': [f.replace('.xml', '') for f in xml_files],
        'xml_file_path': [os.path.join(xml_path, f) for f in xml_files]
    })
    merge_df = pd.merge(df_pdf, df_xml, on='article_id', how='outer', suffixes=('_pdf', '_xml'), validate="one_to_many")
    merge_df['dataset_type'] = dataset_type
    return merge_df

# Load the labeled training data CSV file
print(f"Loading labeled training data from: {TRAIN_LABELS_PATH}")
train_labels_df = pd.read_csv(TRAIN_LABELS_PATH)
print(f"Training labels shape: {train_labels_df.shape}")

# Group training data by article_id to get all datasets for each article
# This creates a dictionary where keys are article_ids and values are lists of dataset dicts
grouped_training_data = {}
for article_id, group_df in train_labels_df.groupby('article_id'):
    grouped_training_data[article_id] = group_df[['dataset_id', 'type']].to_dict('records')

# Example usage of grouped_training_data
print(f"Example grouped training data for article_id '10.1002_2017jc013030': {grouped_training_data['10.1002_2017jc013030']}")

# Just for testing, always set to the TEST_DATA_DIR
base_file_dir = TRAIN_DATA_DIR

# Load file paths for base directory
file_paths_df = load_file_paths(base_file_dir)
file_paths_df['pdf_file_path'] = file_paths_df['pdf_file_path'].fillna('')
file_paths_df['xml_file_path'] = file_paths_df['xml_file_path'].fillna('')
file_paths_df = file_paths_df[file_paths_df['article_id'] != '10.20944_preprints202009.0353.v1']

# Merge the file paths with the grouped_training_data
file_paths_df['ground_truth_dataset_info'] = file_paths_df['article_id'].map(grouped_training_data)
file_paths_df['ground_truth_dataset_info'] = file_paths_df['ground_truth_dataset_info'].fillna('')

# Reduce the file paths DataFrame to only those with ground truth dataset info and get a sample
# This is to ensure we have a manageable dataset for training
file_paths_df = file_paths_df[file_paths_df['ground_truth_dataset_info'].astype(bool)]
file_paths_df = file_paths_df.reset_index(drop=True)
file_paths_df = file_paths_df.sample(frac=.5, random_state=42).reset_index(drop=True)  # Shuffle the DataFrame
print(f"Files paths shape: {file_paths_df.shape}")
display(file_paths_df.sample(3))

Loading labeled training data from: ./kaggle/input/make-data-count-finding-data-references\train_labels.csv
Training labels shape: (1028, 3)
Example grouped training data for article_id '10.1002_2017jc013030': [{'dataset_id': 'https://doi.org/10.17882/49388', 'type': 'Primary'}]
Files paths shape: (262, 5)


Unnamed: 0,article_id,pdf_file_path,xml_file_path,dataset_type,ground_truth_dataset_info
17,10.1107_s2052252514012081,./kaggle/input/make-data-count-finding-data-re...,./kaggle/input/make-data-count-finding-data-re...,train,"[{'dataset_id': 'Missing', 'type': 'Missing'}]"
158,10.3897_neobiota.82.87455,./kaggle/input/make-data-count-finding-data-re...,./kaggle/input/make-data-count-finding-data-re...,train,[{'dataset_id': 'https://doi.org/10.15468/dl.y...
240,10.12688_f1000research.13483.1,./kaggle/input/make-data-count-finding-data-re...,./kaggle/input/make-data-count-finding-data-re...,train,[{'dataset_id': 'https://doi.org/10.5256/f1000...


#### Define Training Data Extract Function

In [5]:
def extract_training_data_for_llm(file_paths_df: pd.DataFrame) -> list[dict[str, str]]:
    """
    Extracts article data for training set with ground truth.
    
    Args:
        file_paths_df (pd.DataFrame): DataFrame containing file paths and ground truth info.
        
    Returns:
        Dict[str, ArticleData]: Dictionary mapping article IDs to ArticleData objects.
    """
    training_data_for_llm: list[dict[str, str]] = [] # This will be a list of LlmTrainingData for the LLM training dataset
    for i, row in tqdm(file_paths_df.iterrows(), total=len(file_paths_df)):
        article_id = row['article_id']
        filepath = row['pdf_file_path'] if row['pdf_file_path'] else row['xml_file_path']
        ground_truth_list = row['ground_truth_dataset_info'] if 'ground_truth_dataset_info' in row else []

        file_extractor = MdcFileTextExtractor(article_id, filepath)
        article_data = file_extractor.extract_article_data_for_training(NLP_SPACY, ground_truth_list)
        training_data_for_llm.extend(article_data.get_data_for_llm())

    print(f"Loaded training data for {len(training_data_for_llm)} articles.")
    return training_data_for_llm


In [6]:
# For testing, let's extract training data for a specific article
sample_file_paths_df = file_paths_df.loc[file_paths_df['article_id'] == '10.1002_esp.5058']
sample_file_paths_df

Unnamed: 0,article_id,pdf_file_path,xml_file_path,dataset_type,ground_truth_dataset_info
154,10.1002_esp.5058,./kaggle/input/make-data-count-finding-data-re...,,train,[{'dataset_id': 'https://doi.org/10.5061/dryad...


#### 4. Advanced Preprocessing: Extracting Dataset Mentions and Context (Training)

Use regex to find the given dataset IDs from the training_labels and then use spaCy to extract surrounding sentences as context.


In [None]:
# This take 10+ minutes

# 4.3. Populate ArticleData with DatasetCitation objects and ground truth
training_data_for_llm = extract_training_data_for_llm(file_paths_df)
print(f"Prepared {len(training_data_for_llm)} training examples for the LLM.")

# Convert the list of LlmTrainingData to a DataFrame and save it
training_data_for_llm_df = pd.DataFrame(training_data_for_llm)
training_data_for_llm_df.to_csv(os.path.join(BASE_OUTPUT_DIR, "training_data_for_llm.csv"), index=False)


  0%|          | 0/262 [00:00<?, ?it/s]

Extracting md text from file: ./kaggle/input/make-data-count-finding-data-references\train\PDF\10.7717_peerj.12422.pdf
Extracting md text from file: ./kaggle/input/make-data-count-finding-data-references\train\PDF\10.1371_journal.pone.0198382.pdf
Extracting md text from file: ./kaggle/input/make-data-count-finding-data-references\train\PDF\10.1002_chem.202000235.pdf
Extracting md text from file: ./kaggle/input/make-data-count-finding-data-references\train\PDF\10.1111_2041-210x.12453.pdf
Extracting md text from file: ./kaggle/input/make-data-count-finding-data-references\train\PDF\10.1038_s41597-019-0101-y.pdf
Extracting md text from file: ./kaggle/input/make-data-count-finding-data-references\train\PDF\10.1186_s12974-020-01860-y.pdf
Extracting md text from file: ./kaggle/input/make-data-count-finding-data-references\train\PDF\10.7717_peerj.13193.pdf
Extracting md text from file: ./kaggle/input/make-data-count-finding-data-references\train\PDF\10.3390_s23177333.pdf
Extracting md text fr

In [13]:
def last_of_string(s: str, length: int = 400) -> str:
    return s[-length:]

# Load data from the csv file
training_data_for_llm_df = pd.read_csv(os.path.join(BASE_OUTPUT_DIR, "training_data_for_llm.csv"))
training_data_for_llm_df = training_data_for_llm_df.sample(290, random_state=42)
training_data_for_llm_df['citation_context'] = training_data_for_llm_df['citation_context'].apply(last_of_string)

# Convert to Hugging Face Dataset format
train_dataset = Dataset.from_pandas(training_data_for_llm_df)

# Split into train/validation
train_test_split = train_dataset.train_test_split(test_size=0.1)
train_dataset = train_test_split['train']
eval_dataset = train_test_split['test']
print(f"Training set size: {len(train_dataset)} examples")
print(f"Validation set size: {len(eval_dataset)} examples")

# Clean up
del training_data_for_llm_df
del train_test_split
gc.collect()


Training set size: 261 examples
Validation set size: 29 examples


952

In [14]:
# Concatenate the few-shot examples with the main training dataset
# This will add the few-shot examples as additional rows to the training data
few_shot_dataset = Dataset.from_csv(FEW_SHOT_CSV_PATH)
print(f"Loaded {len(few_shot_dataset)} few-shot examples.")
train_dataset = concatenate_datasets([train_dataset, few_shot_dataset])
print(f"New training set size: {len(train_dataset)} examples")

# Save off the datasets to CSV for later use
train_dataset.to_csv(os.path.join(BASE_OUTPUT_DIR, "train_dataset.csv"), index=False)
eval_dataset.to_csv(os.path.join(BASE_OUTPUT_DIR, "eval_dataset.csv"), index=False)

Loaded 16 few-shot examples.
New training set size: 277 examples


Creating CSV from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Creating CSV from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

19113

In [None]:
# # Load datasets from CSV
# train_dataset = Dataset.from_csv(os.path.join(BASE_OUTPUT_DIR, "train_dataset.csv"))
# eval_dataset = Dataset.from_csv(os.path.join(BASE_OUTPUT_DIR, "eval_dataset.csv"))

In [15]:
train_dataset

Dataset({
    features: ['article_id', 'article_doi', 'article_abstract', 'dataset_id', 'citation_context', 'label', '__index_level_0__'],
    num_rows: 277
})

In [16]:
eval_dataset

Dataset({
    features: ['article_id', 'article_doi', 'article_abstract', 'dataset_id', 'citation_context', 'label', '__index_level_0__'],
    num_rows: 29
})

#### 5. Model Selection and Configuration

We'll use a Qwen model.

In [17]:
# 5.1. Choose a Model from KaggleHub
# Example: Qwen/Qwen1.5-0.5B-Chat (or 1.8B-Chat if 0.5B is too small/performs poorly)
# You can find these on KaggleHub or Hugging Face Hub.
model_name = QWEN_BASE_MODEL_PATH

# 5.2. Load Tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token # Qwen uses EOS for padding

# 5.3. Load Model with Quantization (4-bit)
nf4_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.bfloat16 # Or torch.float16 if bfloat16 is not supported by your GPU
)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=nf4_config,
    torch_dtype=torch.bfloat16, # Match compute_dtype
    device_map="auto", # Automatically maps model to available devices
    trust_remote_code=True # Required for some models like Qwen
)

# Prepare model for k-bit training (LoRA compatible)
model.config.use_cache = False
model.config.pretraining_tp = 1
model = prepare_model_for_kbit_training(model)

print(f"Model {model_name} loaded with 4-bit quantization.")

Model C:\Users\jim\.cache\kagglehub\models\qwen-lm\qwen-3\transformers\0.6b\1 loaded with 4-bit quantization.


#### 6. Dataset Preparation for Training

Format the extracted data into instruction-tuning prompts using the ChatML format, which Qwen models are trained on.

In [19]:
# 6.1. Define the formatting function for ChatML (Corrected for trl 0.19.1)
def format_example(example):
    messages = [
        {"role": "system", "content": "You are an expert assistant for classifying research data citations. /no_think"},
        {"role": "user", "content": (
            f"""
Given the following 'Article Abstract' and a specific data citation ('Dataset ID' and 'Data Citation Context' combination), classify the data citation as either: 
'Primary' (if the data citation refers to raw or processed **data created/generated as part of the paper**, specifically for this study), 
'Secondary' (if the data citation refers to raw or processed **data derived/reused from existing records** or previously published data), or 
'Missing' (if the data citation refers to another **article/paper/journal**, a **figure, software, or other non-data entity**, or the 'Data Citation Context' is **empty or irrelevant**).\n\n"""
            f"If the data citation refers to raw or processed **data** but the distinction between 'Primary' and 'Secondary' is ambiguous, then default to 'Primary'.\n\n"
            # f"{FEW_SHOT_EXAMPLES.strip()}\n\n" # Add the examples here
            f"Now, classify the following:\n\n" # Add a clear separator            
            f"Article Abstract: {example['article_abstract']}\n" 
            f"Dataset ID: {example['dataset_id']}\n"
            f"Data Citation Context: {example['citation_context']}\n\n"
            f"Classification:"
        )}
    ]
    # The target output for the model is just "Primary", "Secondary, or "Missing"
    messages.append({"role": "assistant", "content": example['label']})
    
    # Apply chat template and return the string directly
    # <--- IMPORTANT CHANGE: Directly return the string, not a dictionary
    return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False, enable_thinking=False)

# Apply the formatting to the dataset
# IMPORTANT: When formatting_func returns a string directly, you typically don't
# need to call .map() on the dataset beforehand if SFTTrainer handles it internally.
# However, if you want to inspect the formatted text, you can still do this:
# formatted_train_dataset = train_dataset.map(format_example)
# But for SFTTrainer, you pass the original `train_dataset` and the `formatting_func`
# and `dataset_text_field` (which will be ignored if formatting_func is used to generate the text).

# Print an example to verify (you'll need to call format_example directly for this)
print("\nExample of formatted training data (string output):")
# You can't directly print from formatted_train_dataset if you don't map it first.
# Let's print by calling the function on a sample:
if len(train_dataset) > 0:
    sample_formatted_text = format_example(train_dataset[35])
    tokenized_input = tokenizer(sample_formatted_text, return_tensors="pt")
    prompt_length = tokenized_input.input_ids.shape[1]
    print(f"Length of the full prompt in tokens: {prompt_length}")    
    print(sample_formatted_text)
else:
    print("No training data to display example.")


Example of formatted training data (string output):
Length of the full prompt in tokens: 496
<|im_start|>system
You are an expert assistant for classifying research data citations. /no_think<|im_end|>
<|im_start|>user

Given the following 'Article Abstract' and a specific data citation ('Dataset ID' and 'Data Citation Context' combination), classify the data citation as either: 
'Primary' (if the data citation refers to raw or processed **data created/generated as part of the paper**, specifically for this study), 
'Secondary' (if the data citation refers to raw or processed **data derived/reused from existing records** or previously published data), or 
'Missing' (if the data citation refers to another **article/paper/journal**, a **figure, software, or other non-data entity**, or the 'Data Citation Context' is **empty or irrelevant**).

If the data citation refers to raw or processed **data** but the distinction between 'Primary' and 'Secondary' is ambiguous, then default to 'Primar

### 7. Train the Model

In [20]:
# ---------------------------------------------------------
# This version uses the evaluation dataset in the SFTTrainer
# ---------------------------------------------------------

import numpy as np
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
from transformers.trainer_utils import EvalPrediction # Import this for type hinting if desired

# Assuming your model outputs logits for 3 classes (Primary, Secondary, Missing)
# And your labels are integers (e.g., 0, 1, 2) corresponding to these classes.

def compute_classification_metrics(eval_pred: EvalPrediction):
    logits, labels = eval_pred.predictions, eval_pred.label_ids
    
    # For classification, you typically take the argmax of the logits to get predicted class IDs
    predictions = np.argmax(logits, axis=-1)
    
    # Calculate desired metrics
    accuracy = accuracy_score(labels, predictions)
    f1_macro = f1_score(labels, predictions, average='macro') # Macro F1 treats all classes equally
    f1_weighted = f1_score(labels, predictions, average='weighted') # Weighted F1 accounts for class imbalance
    
    # You might also want precision and recall per class, or just overall
    # precision_macro = precision_score(labels, predictions, average='macro')
    # recall_macro = recall_score(labels, predictions, average='macro')

    return {
        "accuracy": accuracy,
        "f1_macro": f1_macro,
        "f1_weighted": f1_weighted,
        # You can also include 'eval_loss' if you want, but Trainer computes it by default
        # "eval_loss": eval_pred.metrics.get("eval_loss", None) # Access it if Trainer provides it
    }

# 7.1. Configure LoRA
peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.1,
    r=64,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules="all-linear", # Adjust based on model architecture if needed
)

# 7.2. Configure Training Arguments (now using SFTConfig)
training_args = SFTConfig(
    output_dir=FINE_TUNED_MODEL_OUTPUT_DIR,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=16,
    learning_rate=2e-4,
    num_train_epochs=3,
    logging_steps=10,
    save_steps=25,
    optim="paged_adamw_8bit",
    fp16=True,
    bf16=False,
    max_grad_norm=0.3,
    warmup_ratio=0.03,
    lr_scheduler_type="constant",
    report_to="none",
    disable_tqdm=False,
    remove_unused_columns=False,
    label_names=['labels'],
    
    # SFTTrainer-specific parameters moved into SFTConfig
    max_seq_length=512,
    packing=False,
    dataset_text_field="text",
    gradient_checkpointing=True,
    gradient_checkpointing_kwargs={'use_reentrant':False},

    # --- NEW: Evaluation Parameters ---
    eval_strategy="steps", # Evaluate every 'eval_steps'. You can also use "epoch" for evaluation_strategy
    eval_steps=25,               # How often to run evaluation (e.g., every 25 steps)
    save_strategy="steps",       # How often to save checkpoints
    save_total_limit=1,          # Only keep the best model checkpoint
    load_best_model_at_end=True, # Load the model with the best validation metric at the end of training
    greater_is_better=False,     # For loss, lower is better
)

# 7.3. Initialize SFTTrainer
trainer = SFTTrainer(
    model=model,
    processing_class=tokenizer,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset, # <--- Pass the evaluation dataset here
    # compute_metrics=compute_classification_metrics,
    peft_config=peft_config,
    args=training_args,
    formatting_func=format_example
)


Applying formatting function to train dataset:   0%|          | 0/277 [00:00<?, ? examples/s]

Adding EOS to train dataset:   0%|          | 0/277 [00:00<?, ? examples/s]

Tokenizing train dataset:   0%|          | 0/277 [00:00<?, ? examples/s]

Truncating train dataset:   0%|          | 0/277 [00:00<?, ? examples/s]

Applying formatting function to eval dataset:   0%|          | 0/29 [00:00<?, ? examples/s]

Adding EOS to eval dataset:   0%|          | 0/29 [00:00<?, ? examples/s]

Tokenizing eval dataset:   0%|          | 0/29 [00:00<?, ? examples/s]

Truncating eval dataset:   0%|          | 0/29 [00:00<?, ? examples/s]

In [None]:

# 7.4. Start Training
print("\nStarting model training...")
trainer.train()
print("Training complete!")

# Save the fine-tuned model (LoRA adapters)
trainer.save_model(os.path.join(FINE_TUNED_MODEL_OUTPUT_DIR, "final_model"))
print(f"Fine-tuned model saved to {os.path.join(FINE_TUNED_MODEL_OUTPUT_DIR, 'final_model')}")


Starting model training...


Step,Training Loss,Validation Loss
25,1.2678,1.07803


In [None]:
# --- Explicit GPU Memory Cleanup ---
print("\nInitiating GPU memory cleanup...")

# 1. Explicitly delete large objects that consume GPU memory
#    This removes references, allowing Python's garbage collector to act.
# if 'trainer' in locals() and trainer is not None:
#     del trainer
# if 'model' in locals() and model is not None:
#     del model
# if 'tokenizer' in locals() and tokenizer is not None:
#     del tokenizer
# If you had other large tensors or datasets explicitly moved to GPU,
# you would delete them here too. For Hugging Face datasets, they are usually
# on CPU unless you manually call .to('cuda').

# 2. Force Python's garbage collection
#    This helps ensure that deleted objects are immediately cleaned up.
# gc.collect()

# 3. Clear PyTorch's CUDA memory cache
#    This tells PyTorch to release any cached memory back to the OS/driver.
# if torch.cuda.is_available():
#     torch.cuda.empty_cache()

print("GPU memory cleanup complete. Please check nvidia-smi to confirm.")


Initiating GPU memory cleanup...
GPU memory cleanup complete. Please check nvidia-smi to confirm.


#### 8. Inference and Evaluation

After training, load the best model (or the final one) and apply it to the test data.

In [29]:
# 8.1. Load the Trained Model (or merge LoRA adapters for full model)
# If you saved LoRA adapters, you'll need to load the base model and then the adapters.
# For inference, it's often easier to merge them.
# model = AutoModelForCausalLM.from_pretrained(
#     model_name,
#     quantization_config=nf4_config, # Use the same config as training
#     torch_dtype=torch.bfloat16,
#     device_map="auto",
#     trust_remote_code=True
# )
# model = PeftModel.from_pretrained(model, os.path.join(FINE_TUNED_MODEL_OUTPUT_DIR, "final_model"))
# model = model.merge_and_unload() # Merge LoRA adapters into the base model

# For simplicity, if you just want to test the last saved checkpoint:
# You can also load the model directly from the checkpoint if it's a full save
# model = AutoModelForCausalLM.from_pretrained(os.path.join(FINE_TUNED_MODEL_OUTPUT_DIR, "final_model"), device_map="auto")
# tokenizer = AutoTokenizer.from_pretrained(os.path.join(FINE_TUNED_MODEL_OUTPUT_DIR, "final_model"))

# If you want to load the base model and then the adapters for inference:
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=nf4_config,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    trust_remote_code=True
)
model = PeftModel.from_pretrained(model, os.path.join(FINE_TUNED_MODEL_OUTPUT_DIR, "final_model"))
model.eval() # Set to evaluation mode

print("Model loaded for inference.")


Model loaded for inference.


In [30]:
# For testing, always set to the TEST_DATA_DIR
base_file_dir = TEST_DATA_DIR

# Load file paths for base directory
test_file_paths_df = load_file_paths(base_file_dir)
test_file_paths_df['xml_file_path'] = test_file_paths_df['xml_file_path'].fillna('')

print(f"Files paths shape: {test_file_paths_df.shape}")
display(test_file_paths_df.sample(3))

Files paths shape: (30, 4)


Unnamed: 0,article_id,pdf_file_path,xml_file_path,dataset_type
27,10.1002_mp.14424,./kaggle/input/make-data-count-finding-data-re...,./kaggle/input/make-data-count-finding-data-re...,test
15,10.1002_ece3.6144,./kaggle/input/make-data-count-finding-data-re...,./kaggle/input/make-data-count-finding-data-re...,test
23,10.1002_ejoc.202000139,./kaggle/input/make-data-count-finding-data-re...,,test


In [None]:

def invoke_model_for_inference(tokenizer, article_data: ArticleData) -> list[SubmissionData]:
    submission_data_list = []
    article_id = article_data.article_id
    dataset_citations = article_data.dataset_citations
    if not dataset_citations:
        submission_data_list.append(SubmissionData(article_id, dataset_id="Missing", type="Missing"))
        return submission_data_list

    print(f"Found {len(dataset_citations)} citations for {article_id}")
    for dc in dataset_citations:
        # Create the prompt for inference
        messages = [
            {"role": "system", "content": "You are an expert assistant for classifying research data citations. /no_think"},
            {"role": "user", "content": (
                f"""
Given the following 'Article Abstract' and a specific data citation ('Dataset ID' and 'Data Citation Context' combination), classify the data citation as either: 
'Primary' (if the data citation refers to raw or processed **data created/generated as part of the paper**, specifically for this study), 
'Secondary' (if the data citation refers to raw or processed **data derived/reused from existing records** or previously published data), or 
'Missing' (if the data citation refers to another **article/paper/journal**, a **figure, software, or other non-data entity**, or the 'Data Citation Context' is **empty or irrelevant**).\n\n"""
                f"If the data citation refers to raw or processed **data** but the distinction between 'Primary' and 'Secondary' is ambiguous, then default to 'Primary'.\n\n"
                # f"{FEW_SHOT_EXAMPLES.strip()}\n\n" # Add the examples here
                f"Now, classify the following:\n\n" # Add a clear separator            
                f"Article Abstract: {article_data.abstract}\n"
                f"Dataset ID: {dc.dataset_id}\n"                
                f"Data Citation Context: {dc.citation_context}\n\n"
                f"Classification:"
            )}
        ]

        # --- CHANGE STARTS HERE ---
        # Tokenize and get both input_ids and attention_mask
        input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True).to(device)

        with torch.no_grad():
            output = model.generate(
                **inputs, # <--- Pass the entire dictionary (includes input_ids and attention_mask)
                max_new_tokens=10, # Expecting "Primary" or "Secondary"
                do_sample=True,    # <--- Enable sampling
                temperature=0.7,   # <--- Adjust temperature (0.7-0.9 is common)
                top_p=0.9,         # <--- Top-p sampling (consider tokens that sum to 90% probability)
                top_k=50,          # <--- Top-k sampling (consider only the top 50 most probable tokens)                
                pad_token_id=tokenizer.eos_token_id,
                eos_token_id=tokenizer.eos_token_id
            )
        # --- CHANGE ENDS HERE ---        

        generated_text = tokenizer.decode(output[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True).strip() # Use inputs['input_ids']
        # print(f"LLM Resp: {generated_text}")        
        
        # Post-process the generated text to get the classification
        predicted_type = "Missing"
        if "Primary" in generated_text:
            predicted_type = "Primary"
        elif "Secondary" in generated_text:
            predicted_type = "Secondary"
        
        submission_data_list.append(SubmissionData(article_id, dataset_id=dc.dataset_id, type=predicted_type, context=dc.citation_context))

    return submission_data_list

def process_test_articles(tokenizer, file_paths_df: pd.DataFrame) -> list[SubmissionData]:
    """
    Extracts article data for testing set without ground truth.
    
    Args:
        file_paths_df (pd.DataFrame): DataFrame containing file paths and ground truth info.
        
    Returns:
        Dict[str, ArticleData]: Dictionary mapping article IDs to ArticleData objects.
    """
    submission_data_list = []
    for i, row in tqdm(file_paths_df.iterrows(), total=len(file_paths_df)):
        article_id = row['article_id']
        filepath = row['pdf_file_path'] if row['pdf_file_path'] else row['xml_file_path']
        file_extractor = MdcFileTextExtractor(article_id, filepath)
        
        # Extract article data
        article_data = file_extractor.extract_article_data_for_inference(NLP_SPACY)

        # Invoke the model with the collected article_data
        submission_data_list.extend(invoke_model_for_inference(tokenizer, article_data))

    print(f"Processed testing data for {len(submission_data_list)} article and dataset_id combos.")
    return submission_data_list

In [41]:
sample_test_file_paths_df = test_file_paths_df.sample(2, random_state=42)
sample_test_file_paths_df = test_file_paths_df.loc[test_file_paths_df['article_id']=='10.1002_mp.14424']
sample_test_file_paths_df = test_file_paths_df.loc[test_file_paths_df['article_id']=='10.1002_cssc.202201821']
sample_test_file_paths_df = test_file_paths_df.loc[test_file_paths_df['article_id']=='10.1002_ecs2.1280']
# sample_test_file_paths_df = test_file_paths_df.loc[test_file_paths_df['article_id']=='10.1002_esp.5090']
sample_test_file_paths_df

Unnamed: 0,article_id,pdf_file_path,xml_file_path,dataset_type
20,10.1002_ecs2.1280,./kaggle/input/make-data-count-finding-data-re...,,test


In [43]:
sample_sub = process_test_articles(tokenizer, test_file_paths_df)
display(sample_sub)

  0%|          | 0/30 [00:00<?, ?it/s]

Extracting md text from file: ./kaggle/input/make-data-count-finding-data-references\test\PDF\10.1002_2017jc013030.pdf
Found 2 citations for 10.1002_2017jc013030
Extracting md text from file: ./kaggle/input/make-data-count-finding-data-references\test\PDF\10.1002_anie.201916483.pdf
Extracting md text from file: ./kaggle/input/make-data-count-finding-data-references\test\PDF\10.1002_anie.202005531.pdf
Extracting md text from file: ./kaggle/input/make-data-count-finding-data-references\test\PDF\10.1002_anie.202007717.pdf
Extracting md text from file: ./kaggle/input/make-data-count-finding-data-references\test\PDF\10.1002_chem.201902131.pdf
Extracting md text from file: ./kaggle/input/make-data-count-finding-data-references\test\PDF\10.1002_chem.201903120.pdf
Extracting md text from file: ./kaggle/input/make-data-count-finding-data-references\test\PDF\10.1002_chem.202000235.pdf
Extracting md text from file: ./kaggle/input/make-data-count-finding-data-references\test\PDF\10.1002_chem.20200

[SubmissionData(article_id='10.1002_2017jc013030', dataset_id='10.17882/49388', type='Secondary', context='Data referring to Organelli et al. (2016a; https://doi.org/10.17882/(https://doi.org/10.17882/47142) 47142) and Barbieux et al. (2017;(https://doi.org/10.17882/47142) https://doi.org/10.17882/49388) are(https://doi.org/10.17882/49388) freely available on SEANOE.'),
 SubmissionData(article_id='10.1002_2017jc013030', dataset_id='10.17882/47142', type='Primary', context='Approches Num�eriques); Pierre-Marie Poulain (National Institute of Oceanography and Experimental Geophysics, Italy; ArgoItaly); Sabrina Speich (Laboratoire de M�et�eorologie Dynamique, France; LEFEGMMC); Virginie Thierry (Ifremer, France; LEFE-GMMC); Pascal Conan (Observatoire Oc�eanologique de Banyuls sur mer, France; LEFE-GMMC); Laurent Coppola (Laboratoire d’Oc�eanographie de Villefranche, France; LEFE-GMMC); Anne Petrenko (Mediterranean Institute of Oceanography, France; LEFE-GMMC); and Jean-Baptiste Sall�ee (La

#### 9. Submission File Generation (Kaggle Specific)

Finally, format your predictions into the required `submission.csv` file.

In [48]:
def format_dataset_id(dataset_id: str) -> str:
    """
    Formats the dataset_id by removing any leading/trailing whitespace and ensuring it is a string.
    
    Args:
        dataset_id (str): The dataset identifier to format.
        
    Returns:
        str: The formatted dataset identifier.
    """
    if dataset_id and dataset_id.startswith("10.") and len(dataset_id) > 10:
        # If the dataset_id starts with "10." and is longer than 10 characters, it's likely a DOI
        dataset_id = "https://doi.org/" + dataset_id.lower().strip()
    return dataset_id

def prepare_for_submission(submission_list: list[SubmissionData]) -> pd.DataFrame:
    """
    Prepares the submission_list for submission by ensuring the correct columns and formatting.
    
    Args:
        expanded_df (pd.DataFrame): The DataFrame containing expanded dataset information.
        
    Returns:
        pd.DataFrame: A DataFrame ready for submission with 'article_id', 'dataset_id', and 'type' columns.
    """
    submission_df = pd.DataFrame(sample_sub)
    # Ensure the DataFrame has the correct columns
    submission_df = submission_df[['article_id', 'dataset_id', 'type']].copy()

    # Format dataset_id
    submission_df['dataset_id'] = submission_df['dataset_id'].apply(format_dataset_id)  

    # Remove rows where type is 'Missing' and reset index
    submission_df = submission_df[submission_df['type'] != 'Missing'].reset_index(drop=True)
    submission_df['row_id'] = range(len(submission_df))

    # Reorder columns to match the submission format
    submission_df = submission_df[['row_id', 'article_id', 'dataset_id', 'type']]
    
    return submission_df


In [None]:
# 9.1. Create Submission DataFrame

submission_df = prepare_for_submission(sample_sub)
submission_df.to_csv("submission_df.csv", index=False)
print("Submission file 'submission_df.csv' created successfully!")


Submission file 'submission_df.csv' created successfully!


In [50]:
def f1_score(tp, fp, fn):
    return 2 * tp / (2 * tp + fp + fn) if (2 * tp + fp + fn) != 0 else 0.0
    
    
# if not os.getenv('KAGGLE_IS_COMPETITION_RERUN'):
pred_df = submission_df.copy()
label_df = pd.read_csv("./kaggle/input/make-data-count-finding-data-references/sample_submission.csv")
label_df = label_df[label_df['type'] != 'Missing'].reset_index(drop=True)

hits_df = label_df.merge(pred_df, on=["article_id", "dataset_id", "type"])

tp = hits_df.shape[0]
fp = pred_df.shape[0] - tp
fn = label_df.shape[0] - tp


print("TP:", tp)
print("FP:", fp)
print("FN:", fn)
print("F1 Score:", round(f1_score(tp, fp, fn), 3))

TP: 4
FP: 26
FN: 10
F1 Score: 0.182
