#### 1. Setup and Dependencies

First, we'll ensure all necessary libraries are installed. Given your previous work with `lxml`, `PyMuPDF`, and `spaCy`, we'll include those for text extraction and potentially more advanced NLP preprocessing.

In [1]:
# 1.1. Install necessary libraries
# Use !pip install for notebook environment
# !pip install transformers trl accelerate bitsandbytes sentencepiece lxml PyMuPDF spacy peft
# !python -m spacy download en_core_web_sm # Download a small spaCy model

# 1.2. Import Libraries
import os
import re
import json
import pandas as pd
from dataclasses import dataclass, field, asdict
from typing import Set, List, Optional, Dict, Any

import fitz # PyMuPDF
from lxml import etree # For XML parsing
import spacy
import kagglehub

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.utils.quantization_config import BitsAndBytesConfig
from trl import SFTTrainer
from datasets import Dataset #, load_metric
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from tqdm.auto import tqdm

# For KaggleHub integration (assuming it's set up or models are downloaded)
# You might need to install kagglehub if you plan to use it directly for model download
# !pip install kagglehub

# 1.3. Configure CUDA for local GPU
if torch.cuda.is_available():
    print(f"CUDA is available! Using GPU: {torch.cuda.get_device_name(0)}")
    device = torch.device("cuda")
    torch.cuda.empty_cache() # Clear GPU memory
else:
    print("CUDA is not available. Using CPU.")
    device = torch.device("cpu")



CUDA is available! Using GPU: NVIDIA GeForce RTX 3050 Laptop GPU


In [2]:
# Define constants for file paths and model configurations
BASE_INPUT_DIR = './kaggle/input/make-data-count-finding-data-references'
BASE_OUTPUT_DIR = "./kaggle/working"

# Define directories for articles in train and test sets
TRAIN_DATA_DIR = os.path.join(BASE_INPUT_DIR, 'train')
TEST_DATA_DIR = os.path.join(BASE_INPUT_DIR, 'test')
TRAIN_LABELS_PATH = os.path.join(BASE_INPUT_DIR, 'train_labels.csv')

# Define the base model path
QWEN_BASE_MODEL_PATH = kagglehub.model_download("qwen-lm/qwen-3/transformers/0.6b")

# Output directory for the fine-tuned model and results
FINE_TUNED_MODEL_OUTPUT_DIR = os.path.join(BASE_OUTPUT_DIR, "qwen_finetuned_dataset_classifier")
SAMPLE_SUBMISSION_PATH = os.path.join(BASE_OUTPUT_DIR, "submission.csv")

# Load spaCy model for sentence segmentation and potentially other NLP tasks
# python -m spacy download en_core_web_sm 
NLP_SPACY = spacy.load("en_core_web_sm")

In [3]:
# Information Extraction (IE) - Dataset Identification ---
NON_STD_UNICODE_DASHES = re.compile(r'[\u2010\u2011\u2012\u2013\u2014]')
NON_STD_UNICODE_TICKS = re.compile(r'[\u201c\u201d]')
def clean_text(text: str) -> str:
    """
    Clean the input text by removing non-standard unicode dashes and extra whitespace.
    
    Args:
        text (str): The text to clean.
        
    Returns:
        str: The cleaned text.
    """
    if not text:
        return ""
    # Replace all non-standard unicode dashes with '-'
    text = text.replace('\u200b', '').replace('-\n', '-').replace('_\n', '_').replace('/\n', '/').replace('dryad.\n', 'dryad.').replace('doi.\norg', 'doi.org')
    text = NON_STD_UNICODE_DASHES.sub('-', text)
    text = NON_STD_UNICODE_TICKS.sub("'", text)
    # Remove extra whitespace
    return re.sub(r'\s+', ' ', text).strip()

# Information Extraction (IE) - Dataset Identification
# Regex patterns for common dataset identifiers
#DOI_PATTERN = r'\b10\.\d{4,5}\/[-._\/:A-Za-z0-9]+'
DOI_PATTERN = r"(?:doi:|https?://(?:dx\.)?doi\.org/)(10\.\d{4,9}/[-._;()/:A-Z0-9]+)"
EPI_PATTERN = r'\bEPI[-_A-Z0-9]{2,}'
SAM_PATTERN = r'\bSAMN[0-9]{2,}'          # SAMN07159041
IPR_PATTERN = r'\bIPR[0-9]{2,}'
CHE_PATTERN = r'\bCHEMBL[0-9]{2,}'
PRJ_PATTERN = r'\bPRJ[A-Z0-9]{2,}'
E_G_PATTERN = r'\bE-[A-Z]{4}-[0-9]{2,}'   # E-GEOD-19722 or E-PROT-100
ENS_PATTERN = r'\bENS[A-Z]{4}[0-9]{2,}'
CVC_PATTERN = r'\bCVCL_[A-Z0-9]{2,}'
EMP_PATTERN = r'\bEMPIAR-[0-9]{2,}'
PXD_PATTERN = r'\bPXD[0-9]{2,}'
HPA_PATTERN = r'\bHPA[0-9]{2,}'
SRR_PATTERN = r'\bSRR[0-9]{2,}'
GSE_PATTERN = r'\b(GSE|GSM|GDS|GPL)\d{4,6}\b' # Example for GEO accession numbers (e.g., GSE12345, GSM12345)
GNB_PATTERN = r'\b[A-Z]{1,2}\d{5,6}\b' # GenBank accession numbers (e.g., AB123456, AF000001)
CAB_PATTERN = r'\bCAB[0-9]{2,}'
PDB_PATTERN = r"\bpdb\s*\d[A-Za-z0-9]{3}" # Example: pdb 5yfp

# Combine all patterns into a list
DATASET_ID_PATTERNS = [
    DOI_PATTERN,
    EPI_PATTERN,
    SAM_PATTERN,
    IPR_PATTERN,
    CHE_PATTERN,
    PRJ_PATTERN,
    E_G_PATTERN,
    ENS_PATTERN,
    CVC_PATTERN,
    EMP_PATTERN,
    PXD_PATTERN,
    HPA_PATTERN,
    SRR_PATTERN,
    GSE_PATTERN,
    GNB_PATTERN,
    CAB_PATTERN,
    PDB_PATTERN
]

# Compile all patterns for efficiency
COMPILED_DATASET_ID_REGEXES = [re.compile(p) for p in DATASET_ID_PATTERNS]

# Data related keywords to look for in the text
# These keywords help to ensure that the text is relevant to datasets
DATA_RELATED_KEYWORDS = ['data release', 'data associated', 'data availability', 'data access', 'download', 'program data', 'the data', 'dataset', 'database', 'repository', 'data source', 'data access', 'archive', 'arch.', 'digital']


#### 2. Data Classes


In [4]:
# 2.1. DatasetCitation Class
@dataclass
class DatasetCitation:
    dataset_ids: Set[str] = field(default_factory=set)  # Set to store unique dataset IDs
    citation_context: str = ""
    citation_type: Optional[str] = None # "Primary" or "Secondary" - for ground truth during training

    def add_dataset_id(self, dataset_id: str):
        self.dataset_ids.add(dataset_id)

    def set_citation_context(self, context: str):
        """Sets the citation context, cleaning it."""
        if context:
            # Replace newlines with spaces, remove brackets, and normalize whitespace
            context = context.replace('\n', ' ').replace('[', '').replace(']', '')
            context = re.sub(r'\s+', ' ', context.strip())
            self.citation_context = context 

    def has_dataset(self) -> bool:
        """Returns True if there are both dataset IDs and citation context."""
        return bool(self.dataset_ids and self.citation_context.strip())

    def to_dict(self):
        d = asdict(self)
        d["dataset_ids"] = list(self.dataset_ids)
        return d

# 2.2. ArticleData Class
@dataclass
class ArticleData:
    article_id: str = ""
    article_doi: str = ""
    title: str = ""
    author: str = ""
    abstract: str = ""
    dataset_citations: List[DatasetCitation] = field(default_factory=list)
    full_text: str = ""

    def __post_init__(self):
        # Custom initialization
        if self.article_id and not self.article_doi:
            # If article_id is provided but not article_doi, set article_doi
            self.article_doi = self.article_id.replace("_", "/").lower()

    def add_dataset_citation(self, dataset_citation: DatasetCitation):
        """Adds a DatasetCitation object to the article."""
        if dataset_citation.has_dataset():
            self.dataset_citations.append(dataset_citation)
        
    def to_dict(self):
        d = asdict(self)
        # Convert list of DatasetCitation objects to their dict representation
        d["dataset_citations"] = [dc.to_dict() for dc in self.dataset_citations]
        # Remove full_text from the dictionary if it exists
        if "full_text" in d:
            del d["full_text"]
        return d

    def to_json(self):
        return json.dumps(self.to_dict(), separators=(',', ':'))

    def has_data(self) -> bool:
        """Returns True if there are any dataset citations."""
        return bool(self.dataset_citations)
    
@dataclass
class LlmTrainingData:
    article_id: str = ""
    article_doi: str = ""
    article_title: str = ""
    article_abstract: str = ""
    citation_context: str = ""
    dataset_id: str = ""
    label: str = ""

    def to_dict(self):
        return asdict(self)
    
    def to_json(self):
        return json.dumps(self.to_dict(), separators=(',', ':'))

    

#### 3. Data Loading and Initial Preprocessing

This section will cover how to load the raw competition data (full text articles and labels) and begin structuring it.

#### Load Labeled Training Data

In [5]:
def load_file_paths(dataset_type_dir: str) -> pd.DataFrame: 
    pdf_path = os.path.join(dataset_type_dir, 'PDF')
    xml_path = os.path.join(dataset_type_dir, 'XML')
    dataset_type = os.path.basename(dataset_type_dir)
    pdf_files = [f for f in os.listdir(pdf_path) if f.endswith('.pdf')]
    xml_files = [f for f in os.listdir(xml_path) if f.endswith('.xml')]
    df_pdf = pd.DataFrame({
        'article_id': [f.replace('.pdf', '') for f in pdf_files],
        'pdf_file_path': [os.path.join(pdf_path, f) for f in pdf_files]
    })
    df_xml = pd.DataFrame({
        'article_id': [f.replace('.xml', '') for f in xml_files],
        'xml_file_path': [os.path.join(xml_path, f) for f in xml_files]
    })
    merge_df = pd.merge(df_pdf, df_xml, on='article_id', how='outer', suffixes=('_pdf', '_xml'), validate="one_to_many")
    merge_df['dataset_type'] = dataset_type
    return merge_df

# Load the labeled training data CSV file
print(f"Loading labeled training data from: {TRAIN_LABELS_PATH}")
train_labels_df = pd.read_csv(TRAIN_LABELS_PATH)
print(f"Training labels shape: {train_labels_df.shape}")

# Group training data by article_id to get all datasets for each article
# This creates a dictionary where keys are article_ids and values are lists of dataset dicts
grouped_training_data = {}
for article_id, group_df in train_labels_df.groupby('article_id'):
    grouped_training_data[article_id] = group_df[['dataset_id', 'type']].to_dict('records')

# Example usage of grouped_training_data
print(f"Example grouped training data for article_id '10.1002_2017jc013030': {grouped_training_data['10.1002_2017jc013030']}")

# Set the base file dir for the articles to be processed
base_file_dir = TEST_DATA_DIR \
    if os.getenv('KAGGLE_IS_COMPETITION_RERUN') \
    else TRAIN_DATA_DIR

# Just for testing, always set to the TEST_DATA_DIR
base_file_dir = TEST_DATA_DIR

# Load file paths for base directory
file_paths_df = load_file_paths(base_file_dir)
file_paths_df['xml_file_path'] = file_paths_df['xml_file_path'].fillna('')

# Merge the file paths with the grouped_training_data
file_paths_df['ground_truth_dataset_info'] = file_paths_df['article_id'].map(grouped_training_data)
file_paths_df['ground_truth_dataset_info'] = file_paths_df['ground_truth_dataset_info'].fillna('')

print(f"Files paths shape: {file_paths_df.shape}")
display(file_paths_df.sample(3))

Loading labeled training data from: ./kaggle/input/make-data-count-finding-data-references\train_labels.csv
Training labels shape: (1028, 3)
Example grouped training data for article_id '10.1002_2017jc013030': [{'dataset_id': 'https://doi.org/10.17882/49388', 'type': 'Primary'}]
Files paths shape: (30, 5)


Unnamed: 0,article_id,pdf_file_path,xml_file_path,dataset_type,ground_truth_dataset_info
1,10.1002_anie.201916483,./kaggle/input/make-data-count-finding-data-re...,./kaggle/input/make-data-count-finding-data-re...,test,"[{'dataset_id': 'Missing', 'type': 'Missing'}]"
28,10.1002_nafm.10870,./kaggle/input/make-data-count-finding-data-re...,,test,[{'dataset_id': 'https://doi.org/10.5066/p9gtu...
16,10.1002_ece3.6303,./kaggle/input/make-data-count-finding-data-re...,./kaggle/input/make-data-count-finding-data-re...,test,[{'dataset_id': 'https://doi.org/10.5061/dryad...


#### Define File Extract Functions

In [6]:
# 3.1. Helper function to extract text from various file types
def extract_text_from_file(filepath: str) -> str:
    """Extracts text from XML, PDF, or TXT files."""
    if not os.path.exists(filepath):
        print(f"File not found: {filepath}")
        return ""
    
    print(f"Extracting text from file: {filepath}")
    if filepath.endswith(".xml"):
        parser = etree.XMLParser(resolve_entities=False, no_network=True)
        try:
            tree = etree.parse(filepath, parser)
            # A common way to get all text from an XML scientific article
            # This might need adjustment based on the specific XML schema
            return clean_text(" ".join(tree.xpath("//text()")).strip())
        except Exception as e:
            print(f"Error parsing XML {filepath}: {e}")
            return ""
    elif filepath.endswith(".pdf"):
        try:
            doc = fitz.open(filepath)
            text = ""
            for page in doc:
                text += page.get_textpage().extractTEXT()+"\n"
            return clean_text(text.strip())
        except Exception as e:
            print(f"Error parsing PDF {filepath}: {e}")
            return ""
    elif filepath.endswith(".txt"):
        with open(filepath, 'r', encoding='utf-8') as f:
            return f.read().strip()
    return ""

def extract_first_few_sentences(text: str, num_sentences: int = 5) -> str:
    """
    Extracts the first few sentences from the text.
    
    Args:
        text (str): The input text.
        num_sentences (int): The number of sentences to extract.
        
    Returns:
        str: The first few sentences from the text.
    """
    if not text:
        return ""
    
    doc = NLP_SPACY(text)
    sentences = list(doc.sents)
    
    # Join the first few sentences
    return " ".join([sent.text for sent in sentences[:num_sentences]]).strip()

def extract_article_data_from_text(full_text: str, article_id: str) -> ArticleData:
    """
    Extracts article data from the full text.
    
    Args:
        full_text (str): The full text of the article.
        article_id (str): The ID of the article.
        
    Returns:
        ArticleData: An instance of ArticleData with extracted information.
    """
    title = ""
    author = ""

    # Placeholder for extracting title, author, abstract
    # This is highly dependent on the structure of your full text files.
    # For now, we'll use simple regex or assume they are at the beginning.
    # title_match = re.search(r"Title:\s*(.*)", full_text, re.IGNORECASE)
    # title = title_match.group(1).strip() if title_match else "Unknown Title"

    # author_match = re.search(r"Author(?:s)?:\s*(.*)", full_text, re.IGNORECASE)
    # author = author_match.group(1).strip() if author_match else "Unknown Author"

    abstract_match = re.search(r"Abstract\s*(.*?)(?=\n\n|\Z)", full_text, re.IGNORECASE | re.DOTALL)
    abstract = abstract_match.group(1).strip() if abstract_match else "No Abstract"
    abstract = extract_first_few_sentences(abstract[:400], num_sentences=3)  # Extract first few sentences for the abstract

    return ArticleData(
        article_id=article_id,
        title=title,
        author=author,
        abstract=abstract
    )

# 4.2. Function to extract context around an ID
def extract_context_around_id(sentences, dataset_id: str, window_size_sentences: int = 3) -> str:
    """
    Extracts a window of sentences around a given dataset ID in the text.
    Uses spaCy for sentence segmentation.
    """
    if not sentences or not dataset_id or dataset_id == "Missing":
        return ""
        
    # Find all occurrences of the dataset_id (case-insensitive)
    matches = [(i, sent) for i, sent in enumerate(sentences) if dataset_id.lower() in sent.lower()]
    if not matches:
        return ""

    # For simplicity, take the context around the first match.
    # You might want to refine this to capture all relevant contexts or the most prominent one.
    first_match_idx = matches[0][0]
    
    start_idx = max(0, first_match_idx - window_size_sentences)
    end_idx = min(len(sentences), first_match_idx + 1)
    
    context_sentences = sentences[start_idx:end_idx]
    return " ".join(context_sentences)


def extract_training_data_for_llm(file_paths_df: pd.DataFrame) -> list[dict[str, str]]:
    """
    Extracts article data for training set with ground truth.
    
    Args:
        file_paths_df (pd.DataFrame): DataFrame containing file paths and ground truth info.
        
    Returns:
        Dict[str, ArticleData]: Dictionary mapping article IDs to ArticleData objects.
    """
    training_data_for_llm: list[dict[str, str]] = [] # This will be a list of LlmTrainingData for the LLM training dataset
    for i, row in tqdm(file_paths_df.iterrows(), total=len(file_paths_df)):
        article_id = row['article_id']
        filepath = row['pdf_file_path'] if row['pdf_file_path'] else row['xml_file_path']
        ground_truth_list = row['ground_truth_dataset_info'] if 'ground_truth_dataset_info' in row else []
        
        full_text = extract_text_from_file(filepath)
        article_data = extract_article_data_from_text(full_text, article_id)

        doc = NLP_SPACY(full_text)
        sentences = [sent.text for sent in doc.sents]

        if not ground_truth_list:
            print(f"No ground truth data found for article_id: {article_id}. Skipping this article.")
            continue
        for gt in ground_truth_list:
            dataset_id = gt['dataset_id'].replace("https://doi.org/", "").replace("doi:", "").strip()
            citation_type = gt.get('type', 'Primary')
            if dataset_id:
                # Convert to dict for LLM training data
                training_data_for_llm.append(
                    {
                        "article_id": article_data.article_id,
                        "article_doi": article_data.article_doi,
                        "article_abstract": article_data.abstract,
                        "citation_context": extract_context_around_id(sentences, dataset_id),
                        "dataset_id": dataset_id,
                        "label": citation_type
                    }
                )

    print(f"Loaded training data for {len(training_data_for_llm)} articles.")
    return training_data_for_llm


In [7]:
# For testing, let's extract training data for a specific article
sample_file_paths_df = file_paths_df.loc[file_paths_df['article_id'] == '10.1002_esp.5058']
sample_file_paths_df

Unnamed: 0,article_id,pdf_file_path,xml_file_path,dataset_type,ground_truth_dataset_info
25,10.1002_esp.5058,./kaggle/input/make-data-count-finding-data-re...,,test,[{'dataset_id': 'https://doi.org/10.5061/dryad...


#### 4. Advanced Preprocessing: Extracting Dataset Mentions and Context (Training)

Use regex to find the given dataset IDs from the training_labels and then use spaCy to extract surrounding sentences as context.


In [None]:
# 4.3. Populate ArticleData with DatasetCitation objects and ground truth
training_data_for_llm = extract_training_data_for_llm(file_paths_df)
print(f"Prepared {len(training_data_for_llm)} training examples for the LLM.")

# Convert the list of LlmTrainingData to a DataFrame and save it
training_data_for_llm_df = pd.DataFrame(training_data_for_llm)
training_data_for_llm_df.to_csv(os.path.join(BASE_OUTPUT_DIR, "training_data_for_llm.csv"), index=False)

# Convert to Hugging Face Dataset format
train_dataset = Dataset.from_list(training_data_for_llm)
train_dataset = train_dataset.shuffle(seed=42) # Shuffle for good measure

# Split into train/validation
train_test_split = train_dataset.train_test_split(test_size=0.1)
train_dataset = train_test_split['train']
eval_dataset = train_test_split['test']
print(f"Training set size: {len(train_dataset)} examples")
print(f"Validation set size: {len(eval_dataset)} examples")

  0%|          | 0/30 [00:00<?, ?it/s]

Extracting text from file: ./kaggle/input/make-data-count-finding-data-references\test\PDF\10.1002_2017jc013030.pdf
Extracting text from file: ./kaggle/input/make-data-count-finding-data-references\test\PDF\10.1002_anie.201916483.pdf
Extracting text from file: ./kaggle/input/make-data-count-finding-data-references\test\PDF\10.1002_anie.202005531.pdf
Extracting text from file: ./kaggle/input/make-data-count-finding-data-references\test\PDF\10.1002_anie.202007717.pdf
Extracting text from file: ./kaggle/input/make-data-count-finding-data-references\test\PDF\10.1002_chem.201902131.pdf
Extracting text from file: ./kaggle/input/make-data-count-finding-data-references\test\PDF\10.1002_chem.201903120.pdf
Extracting text from file: ./kaggle/input/make-data-count-finding-data-references\test\PDF\10.1002_chem.202000235.pdf
Extracting text from file: ./kaggle/input/make-data-count-finding-data-references\test\PDF\10.1002_chem.202001412.pdf
Extracting text from file: ./kaggle/input/make-data-count-

#### 5. Model Selection and Configuration

We'll use a Qwen model.

In [9]:
# 5.1. Choose a Model from KaggleHub
# Example: Qwen/Qwen1.5-0.5B-Chat (or 1.8B-Chat if 0.5B is too small/performs poorly)
# You can find these on KaggleHub or Hugging Face Hub.
model_name = QWEN_BASE_MODEL_PATH

# 5.2. Load Tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token # Qwen uses EOS for padding

# 5.3. Load Model with Quantization (4-bit)
nf4_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.bfloat16 # Or torch.float16 if bfloat16 is not supported by your GPU
)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=nf4_config,
    torch_dtype=torch.bfloat16, # Match compute_dtype
    device_map="auto", # Automatically maps model to available devices
    trust_remote_code=True # Required for some models like Qwen
)

# Prepare model for k-bit training (LoRA compatible)
model.config.use_cache = False
model.config.pretraining_tp = 1
model = prepare_model_for_kbit_training(model)

print(f"Model {model_name} loaded with 4-bit quantization.")

Model C:\Users\jim\.cache\kagglehub\models\qwen-lm\qwen-3\transformers\0.6b\1 loaded with 4-bit quantization.


#### 6. Dataset Preparation for Training

Format the extracted data into instruction-tuning prompts using the ChatML format, which Qwen models are trained on.

In [10]:
# 6.1. Define the formatting function for ChatML (Corrected for trl 0.19.1)
def format_example(example):
    messages = [
        {"role": "system", "content": "You are an expert assistant for classifying research data citations. /no_think"},
        {"role": "user", "content": (
            f"Given the following article context and a specific data citation, classify if the data was generated as 'Primary' (newly generated for this study), 'Secondary' (reused from existing records), or 'Missing' (no data citation context given).\n\n"
            f"Article DOI: {example['article_doi']}\n"
            f"Article Abstract: {example['article_abstract']}\n" 
            f"Data Citation Context: {example['citation_context']}\n"
            f"Dataset ID: {example['dataset_id']}\n\n"
            f"Classification:"
        )}
    ]
    # The target output for the model is just "Primary" or "Secondary"
    messages.append({"role": "assistant", "content": example['label']})
    
    # Apply chat template and return the string directly
    # <--- IMPORTANT CHANGE: Directly return the string, not a dictionary
    return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False, enable_thinking=False)

# Apply the formatting to the dataset
# IMPORTANT: When formatting_func returns a string directly, you typically don't
# need to call .map() on the dataset beforehand if SFTTrainer handles it internally.
# However, if you want to inspect the formatted text, you can still do this:
# formatted_train_dataset = train_dataset.map(format_example)
# But for SFTTrainer, you pass the original `train_dataset` and the `formatting_func`
# and `dataset_text_field` (which will be ignored if formatting_func is used to generate the text).

# Print an example to verify (you'll need to call format_example directly for this)
print("\nExample of formatted training data (string output):")
# You can't directly print from formatted_train_dataset if you don't map it first.
# Let's print by calling the function on a sample:
if len(train_dataset) > 0:
    sample_formatted_text = format_example(train_dataset[0])
    print(sample_formatted_text)
else:
    print("No training data to display example.")


Example of formatted training data (string output):
<|im_start|>system
You are an expert assistant for classifying research data citations. /no_think<|im_end|>
<|im_start|>user
Given the following article context and a specific data citation, classify if the data was generated as 'Primary' (newly generated for this study), 'Secondary' (reused from existing records), or 'Missing' (no data citation context given).

Article DOI: 10.1002/esp.5090
Article Abstract: The 20 May 2016 MW 6.1 Petermann earthquake in central Australia generated a 21 km surface rupture with 0.1 to 1 m vertical displacements across a low-relief landscape. No paleo-scarps or potentially analogous topographic features are evident in pre-earthquake Worldview-1 and Worldview-2 satellite data. Two excavations across the surface rupture expose near-surface fault geometry and mixed aeolian
Data Citation Context: Erosion has been locally enhanced by bed-rock shattering, rock fragment displacement (Figure 2), and rockfalls

In [None]:
# ---------------------------------------------------------
# This version uses the evaluation dataset directly in the SFTTrainer
# ---------------------------------------------------------

# First, import SFTConfig from trl
from trl import SFTConfig

# 7.1. Configure LoRA
peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.1,
    r=64,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules="all-linear", # Adjust based on model architecture if needed
)

# 7.2. Configure Training Arguments (now using SFTConfig)
output_dir = "./results"
training_args = SFTConfig(
    output_dir=output_dir,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=16,
    learning_rate=2e-4,
    num_train_epochs=3,
    logging_steps=10,
    save_steps=500,
    optim="paged_adamw_8bit",
    fp16=True,
    bf16=False,
    max_grad_norm=0.3,
    warmup_ratio=0.03,
    lr_scheduler_type="constant",
    report_to="none",
    disable_tqdm=False,
    remove_unused_columns=False,
    label_names=[],
    
    # SFTTrainer-specific parameters moved into SFTConfig
    max_seq_length=256,
    packing=False,
    dataset_text_field="text",
    gradient_checkpointing=True,
    gradient_checkpointing_kwargs={'use_reentrant':False},

    # --- NEW: Evaluation Parameters ---
    evaluation_strategy="steps", # Evaluate every 'eval_steps'
    eval_steps=500,              # How often to run evaluation (e.g., every 500 steps)
                                 # You can also use "epoch" for evaluation_strategy
    save_strategy="steps",       # How often to save checkpoints
    save_total_limit=1,          # Only keep the best model checkpoint
    load_best_model_at_end=True, # Load the model with the best validation metric at the end of training
    metric_for_best_model="eval_loss", # Metric to monitor for best model (default for CLM)
    greater_is_better=False,     # For loss, lower is better
)

# 7.3. Initialize SFTTrainer
trainer = SFTTrainer(
    model=model,
    processing_class=tokenizer,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset, # <--- Pass the evaluation dataset here
    peft_config=peft_config,
    args=training_args,
    formatting_func=format_example
)


In [None]:
# ---------------------------------------------------------
# This version works but does NOT use the evaluation dataset
# ---------------------------------------------------------

# # First, import SFTConfig from trl
# from trl import SFTConfig

# # 7.1. Configure LoRA
# peft_config = LoraConfig(
#     lora_alpha=16,
#     lora_dropout=0.1,
#     r=64,
#     bias="none",
#     task_type="CAUSAL_LM",
#     target_modules="all-linear", # Adjust based on model architecture if needed
# )

# # 7.2. Configure Training Arguments (now using SFTConfig)
# # SFTConfig combines TrainingArguments with SFTTrainer-specific parameters
# output_dir = "./results"
# training_args = SFTConfig( # <--- IMPORTANT CHANGE: Use SFTConfig instead of TrainingArguments
#     output_dir=output_dir,
#     per_device_train_batch_size=1, # Adjust based on your GPU memory
#     gradient_accumulation_steps=16,
#     learning_rate=2e-4,
#     num_train_epochs=3,
#     logging_steps=10,
#     save_steps=500,
#     optim="paged_adamw_8bit",
#     fp16=True,  # <--- CHANGED: Try fp16 for broader compatibility and memory
#     bf16=False, # <--- CHANGED: Disable bf16 if fp16 is used
#     max_grad_norm=0.3,
#     warmup_ratio=0.03,
#     lr_scheduler_type="constant",
#     report_to="none",
#     disable_tqdm=False,
#     remove_unused_columns=False, # Keep columns for formatting
#     label_names=[], # Explicitly tell Trainer not to look for label columns in the dataset
#     # Additional SFT-specific parameters    
#     max_seq_length=512, # Max input sequence length (adjust based on context size)
#     packing=False, # Set to True for more efficient training if your data is short
#     dataset_text_field="text", # The name of the column in your dataset containing the text
#     # <--- NEW: Enable gradient checkpointing
#     gradient_checkpointing=True,
#     # This line is important for gradient checkpointing with PeftModel
#     # It tells the model to use the Peft (LoRA) layers for checkpointing
#     gradient_checkpointing_kwargs={'use_reentrant':False} # Recommended for newer PyTorch/Accelerate
# )

# # 7.3. Initialize SFTTrainer (Corrected for trl 0.19.1)
# trainer = SFTTrainer(
#     model=model,
#     processing_class=tokenizer, 
#     train_dataset=train_dataset,
#     peft_config=peft_config,
#     args=training_args, # This is now an SFTConfig object
#     formatting_func=format_example # This remains a direct argument to SFTTrainer
# )


Applying formatting function to train dataset:   0%|          | 0/27 [00:00<?, ? examples/s]

Adding EOS to train dataset:   0%|          | 0/27 [00:00<?, ? examples/s]

Tokenizing train dataset:   0%|          | 0/27 [00:00<?, ? examples/s]

Truncating train dataset:   0%|          | 0/27 [00:00<?, ? examples/s]

In [12]:

# 7.4. Start Training
print("\nStarting model training...")
trainer.train()
print("Training complete!")

# Save the fine-tuned model (LoRA adapters)
trainer.save_model(os.path.join(output_dir, "final_model"))
print(f"Fine-tuned model saved to {os.path.join(output_dir, 'final_model')}")


Starting model training...


Step,Training Loss


Training complete!
Fine-tuned model saved to ./results\final_model


In [13]:
import gc # Import the garbage collection module

# --- Explicit GPU Memory Cleanup ---
print("\nInitiating GPU memory cleanup...")

# 1. Explicitly delete large objects that consume GPU memory
#    This removes references, allowing Python's garbage collector to act.
if 'trainer' in locals() and trainer is not None:
    del trainer
if 'model' in locals() and model is not None:
    del model
if 'tokenizer' in locals() and tokenizer is not None:
    del tokenizer
# If you had other large tensors or datasets explicitly moved to GPU,
# you would delete them here too. For Hugging Face datasets, they are usually
# on CPU unless you manually call .to('cuda').

# 2. Force Python's garbage collection
#    This helps ensure that deleted objects are immediately cleaned up.
gc.collect()

# 3. Clear PyTorch's CUDA memory cache
#    This tells PyTorch to release any cached memory back to the OS/driver.
if torch.cuda.is_available():
    torch.cuda.empty_cache()

print("GPU memory cleanup complete. Please check nvidia-smi to confirm.")


Initiating GPU memory cleanup...
GPU memory cleanup complete. Please check nvidia-smi to confirm.


#### 8. Inference and Evaluation

After training, load the best model (or the final one) and apply it to the test data.

In [None]:
# 8.1. Load the Trained Model (or merge LoRA adapters for full model)
# If you saved LoRA adapters, you'll need to load the base model and then the adapters.
# For inference, it's often easier to merge them.
# model = AutoModelForCausalLM.from_pretrained(
#     model_name,
#     quantization_config=nf4_config, # Use the same config as training
#     torch_dtype=torch.bfloat16,
#     device_map="auto",
#     trust_remote_code=True
# )
# model = PeftModel.from_pretrained(model, os.path.join(output_dir, "final_model"))
# model = model.merge_and_unload() # Merge LoRA adapters into the base model

# For simplicity, if you just want to test the last saved checkpoint:
# You can also load the model directly from the checkpoint if it's a full save
# model = AutoModelForCausalLM.from_pretrained(os.path.join(output_dir, "final_model"), device_map="auto")
# tokenizer = AutoTokenizer.from_pretrained(os.path.join(output_dir, "final_model"))

# If you want to load the base model and then the adapters for inference:
from peft import PeftModel
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=nf4_config,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    trust_remote_code=True
)
model = PeftModel.from_pretrained(model, os.path.join(output_dir, "final_model"))
model.eval() # Set to evaluation mode

print("Model loaded for inference.")


In [None]:

# 8.2. Preprocess Test Data (similar to training data)
test_articles_data: Dict[str, ArticleData] = {}
# Assuming test data structure is similar to train data (full text files)
for article_file in os.listdir(TEST_DATA_DIR):
    article_id = os.path.splitext(article_file)[0]
    filepath = os.path.join(TEST_DATA_DIR, article_file)
    
    full_text = extract_text_from_file(filepath)
    
    # Extract title, author, abstract (same as training)
    title_match = re.search(r"Title:\s*(.*)", full_text, re.IGNORECASE)
    title = title_match.group(1).strip() if title_match else "Unknown Title"
    author_match = re.search(r"Author(?:s)?:\s*(.*)", full_text, re.IGNORECASE)
    author = author_match.group(1).strip() if author_match else "Unknown Author"
    abstract_match = re.search(r"Abstract\s*(.*?)(?=\n\n|\Z)", full_text, re.IGNORECASE | re.DOTALL)
    abstract = abstract_match.group(1).strip() if abstract_match else "No Abstract"

    article_data = ArticleData(
        article_id=article_id,
        title=title,
        author=author,
        abstract=abstract
    )
    # For test data, we need to find *all* potential dataset IDs, not just ground truth
    # This is the "finding datasets" part of your goal.
    
    # Use regex to find all potential dataset IDs in the full text
    found_dataset_mentions = []
    for pattern in ALL_ID_PATTERNS:
        for match in re.finditer(pattern, full_text, re.IGNORECASE):
            dataset_id = match.group(1) if pattern == DOI_PATTERN else match.group(0)
            span_text = match.group(0) # The full matched text
            
            context = extract_context_around_id(full_text, span_text, window_size_sentences=3)
            
            if context:
                dc = DatasetCitation()
                dc.add_dataset_id(dataset_id)
                dc.set_citation_context(context)
                found_dataset_mentions.append(dc)
                
    article_data.dataset_citations = found_dataset_mentions # Assign found mentions
    test_articles_data[article_id] = article_data

print(f"Prepared {len(test_articles_data)} test articles for inference.")

# 8.3. Generate Predictions
predictions = []
true_labels = [] # Only if you have a test_labels.json for evaluation

for article_id, article_data in test_articles_data.items():
    for dc in article_data.dataset_citations:
        # Create the prompt for inference
        messages = [
            {"role": "system", "content": "You are an expert assistant for classifying research data citations."},
            {"role": "user", "content": (
                f"Given the following article context and a specific data citation, classify if the data was generated as 'Primary' (newly generated for this study) or 'Secondary' (reused from existing records).\n\n"
                f"Article Title: {article_data.title}\n"
                f"Article Abstract: {article_data.abstract}\n"
                f"Data Citation Context: {dc.citation_context}\n"
                f"Dataset ID: {list(dc.dataset_ids)[0]}\n\n" # Assuming one ID per citation
                f"Classification:"
            )}
        ]
        
        input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        input_ids = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True).input_ids.to(device)

        with torch.no_grad():
            output = model.generate(
                input_ids,
                max_new_tokens=10, # Expecting "Primary" or "Secondary"
                do_sample=False, # Use greedy decoding as per your preference
                pad_token_id=tokenizer.eos_token_id,
                eos_token_id=tokenizer.eos_token_id
            )
        
        generated_text = tokenizer.decode(output[0][input_ids.shape[1]:], skip_special_tokens=True).strip()
        
        # Post-process the generated text to get the classification
        predicted_type = "Unknown"
        if "Primary" in generated_text:
            predicted_type = "Primary"
        elif "Secondary" in generated_text:
            predicted_type = "Secondary"
        
        predictions.append({
            "article_id": article_id,
            "dataset_id": list(dc.dataset_ids)[0],
            "predicted_type": predicted_type
        })

        # If you have test labels, you can collect true_labels here for evaluation
        # For Kaggle, you'll typically submit predictions without knowing test labels.

print(f"Generated {len(predictions)} predictions.")

# 8.4. Evaluation (if test labels are available)
# If you have a separate test_labels.json for local evaluation:
# test_labels = load_labels(TEST_LABELS_PATH) # Load test labels
#
# # Match predictions to true labels and calculate metrics
# # This part requires careful matching of dataset_id within article_id
# # and might involve fuzzy matching for context if exact span isn't available.
# # For simplicity, assuming exact match on article_id and dataset_id.
#
# y_true = []
# y_pred = []
#
# for pred_entry in predictions:
#     article_id = pred_entry["article_id"]
#     dataset_id = pred_entry["dataset_id"]
#     predicted_type = pred_entry["predicted_type"]
#
#     # Find the true label for this specific dataset_id in this article
#     found_true_label = False
#     if article_id in test_labels:
#         for gt_info in test_labels[article_id]:
#             if gt_info["dataset_id"] == dataset_id: # Exact match on ID
#                 y_true.append(gt_info["citation_type"])
#                 y_pred.append(predicted_type)
#                 found_true_label = True
#                 break
#     if not found_true_label:
#         # Handle cases where a predicted ID might not be in ground truth
#         # or where the ID extraction was imperfect.
#         # For competition, this means your ID extraction needs to be precise.
#         pass
#
# if y_true and y_pred:
#     from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score
#     # Map "Primary" to 1, "Secondary" to 0 for sklearn metrics
#     label_map = {"Primary": 1, "Secondary": 0}
#     y_true_mapped = [label_map.get(l, -1) for l in y_true]
#     y_pred_mapped = [label_map.get(l, -1) for l in y_pred]
#
#     # Filter out -1 if there were unknown labels
#     valid_indices = [i for i, val in enumerate(y_true_mapped) if val != -1 and y_pred_mapped[i] != -1]
#     y_true_mapped = [y_true_mapped[i] for i in valid_indices]
#     y_pred_mapped = [y_pred_mapped[i] for i in valid_indices]
#
#     if y_true_mapped:
#         print("\nEvaluation Results:")
#         print(f"Accuracy: {accuracy_score(y_true_mapped, y_pred_mapped):.4f}")
#         print(f"F1 Score (weighted): {f1_score(y_true_mapped, y_pred_mapped, average='weighted'):.4f}")
#         print(f"Precision (weighted): {precision_score(y_true_mapped, y_pred_mapped, average='weighted'):.4f}")
#         print(f"Recall (weighted): {recall_score(y_true_mapped, y_pred_mapped, average='weighted'):.4f}")
#     else:
#         print("No matching true labels found for evaluation.")
# else:
#     print("Not enough data to perform evaluation.")

#### 9. Submission File Generation (Kaggle Specific)

Finally, format your predictions into the required `submission.csv` file.

In [None]:
# 9.1. Create Submission DataFrame

submission_df = pd.DataFrame(predictions)
# Rename columns to match Kaggle's expected format (e.g., 'id', 'class_label')
# This will depend on the exact submission format specified by Kaggle.
# Example:
# submission_df = submission_df.rename(columns={"article_id": "Id", "dataset_id": "DatasetId", "predicted_type": "Type"})
# submission_df["Id"] = submission_df["Id"] + "_" + submission_df["DatasetId"] # If Id is a combination

# Assuming the submission format is a list of dictionaries with 'article_id', 'dataset_id', 'citation_type'
# You might need to adjust this based on the exact competition requirements.
# For example, if it expects a single ID column like "article_id_dataset_id"
final_submission_data = []
for pred in predictions:
    final_submission_data.append({
        "Id": f"{pred['article_id']}_{pred['dataset_id']}", # Example: combine IDs
        "Type": pred['predicted_type']
    })

final_submission_df = pd.DataFrame(final_submission_data)
final_submission_df.to_csv("submission.csv", index=False)

print("Submission file 'submission.csv' created successfully!")