In [1]:
# --- 0. Environment Setup & Offline Preparation ---

# Standard Imports
import os
import glob
import re
import pandas as pd
import xml.etree.ElementTree as ET
import collections # For deque in parenthesis removal
import fitz # PyMuPDF for PDF processing
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.training_args import TrainingArguments
from trl import SFTTrainer
import torch
from datasets import Dataset # Hugging Face datasets library
import kagglehub

# Set device for PyTorch
device = "cuda" if torch and torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")


Using device: cpu


In [2]:
# Define constants for file paths and model configurations
BASE_INPUT_DIR = './kaggle/input/make-data-count-finding-data-references'
ARTICLE_TRAIN_DIR = os.path.join(BASE_INPUT_DIR, 'train')
ARTICLE_TEST_DIR = os.path.join(BASE_INPUT_DIR, 'test')

# Define directories for articles in train and test sets
LABELED_TRAINING_DATA_CSV_PATH = os.path.join(BASE_INPUT_DIR, 'train_labels.csv')

# Output directory for the fine-tuned model and results
BASE_OUTPUT_DIR = "./kaggle/working"
FINE_TUNED_MODEL_OUTPUT_DIR = os.path.join(BASE_OUTPUT_DIR, "qwen_finetuned_dataset_classifier")
FINAL_RESULTS_CSV_PATH = os.path.join(BASE_OUTPUT_DIR, "article_dataset_classification.csv")

inference_model = None
inference_tokenizer = None


In [3]:
# --- Data Loading ---
def load_file_paths(dataset_type_dir: str) -> pd.DataFrame: 
    pdf_path = os.path.join(dataset_type_dir, 'PDF')
    xml_path = os.path.join(dataset_type_dir, 'XML')
    dataset_type = os.path.basename(dataset_type_dir)
    pdf_files = [f for f in os.listdir(pdf_path) if f.endswith('.pdf')]
    xml_files = [f for f in os.listdir(xml_path) if f.endswith('.xml')]
    df_pdf = pd.DataFrame({
        'article_id': [f.replace('.pdf', '') for f in pdf_files],
        'pdf_file_path': [os.path.join(pdf_path, f) for f in pdf_files]
    })
    df_xml = pd.DataFrame({
        'article_id': [f.replace('.xml', '') for f in xml_files],
        'xml_file_path': [os.path.join(xml_path, f) for f in xml_files]
    })
    merge_df = pd.merge(df_pdf, df_xml, on='article_id', how='outer', suffixes=('_pdf', '_xml'), validate="one_to_many")
    merge_df['dataset_type'] = dataset_type
    return merge_df

def read_pdf_text(pdf_path: str) -> str:
    """Extracts all text from a PDF file using PyMuPDF."""
    text = ""
    if not fitz:
        return text  # Return empty string if fitz is not available
    try:
        with fitz.open(pdf_path) as doc:
            for page in doc:
                text += page.get_textpage().extractTEXT().replace('\u200b', '').strip()
    except Exception as e:
        print(f"Error reading PDF {pdf_path}: {e}")
        
    return text

def read_xml_text(xml_file_path: str) -> str:
    """Reads and concatenates all text content from an XML file."""
    all_text_parts = []
    try:
        tree = ET.parse(xml_file_path)
        root = tree.getroot()
        for element in root.iter():
            if element.text:
                cleaned_text = element.text.strip()
                if cleaned_text:
                    all_text_parts.append(cleaned_text)
            if element.tail:
                cleaned_tail = element.tail.strip()
                if cleaned_tail:
                    all_text_parts.append(cleaned_tail)
        return " ".join(all_text_parts) if all_text_parts else ""
    except Exception as e:
        print(f"Error reading XML {xml_file_path}: {e}")
        return ""

def load_article_text(filepath: str) -> str:
    """
    Loads text content from a single article file (PDF or XML).
    Returns the text content of the given file.
    """
    article_id = os.path.splitext(os.path.basename(filepath))[0]
    text_content = ""

    if filepath.endswith(".pdf"):
        text_content = read_pdf_text(filepath)
    elif filepath.endswith(".xml"):
        text_content = read_xml_text(filepath)

    return text_content


In [4]:
# Load the labeled training data CSV file
print(f"Loading labeled training data from: {LABELED_TRAINING_DATA_CSV_PATH}")
train_labels_df = pd.read_csv(LABELED_TRAINING_DATA_CSV_PATH)
train_labels_df['has_dataset'] = train_labels_df['dataset_id'] != 'Missing'

print(f"Training labels shape: {train_labels_df.shape}")
display(train_labels_df.head())

Loading labeled training data from: ./kaggle/input/make-data-count-finding-data-references\train_labels.csv
Training labels shape: (1028, 4)


Unnamed: 0,article_id,dataset_id,type,has_dataset
0,10.1002_2017jc013030,https://doi.org/10.17882/49388,Primary,True
1,10.1002_anie.201916483,Missing,Missing,False
2,10.1002_anie.202005531,Missing,Missing,False
3,10.1002_anie.202007717,Missing,Missing,False
4,10.1002_chem.201902131,Missing,Missing,False


In [5]:
# Load file paths for training and testing datasets
train_file_paths_df = load_file_paths(ARTICLE_TRAIN_DIR)
test_file_paths_df = load_file_paths(ARTICLE_TEST_DIR)

print(f"Train files paths shape: {train_file_paths_df.shape}")
display(train_file_paths_df.sample(3))
print(f"Test files paths shape: {test_file_paths_df.shape}")
display(test_file_paths_df.sample(3))

Train files paths shape: (524, 4)


Unnamed: 0,article_id,pdf_file_path,xml_file_path,dataset_type
272,10.1186_s12885-018-4314-9,./kaggle/input/make-data-count-finding-data-re...,./kaggle/input/make-data-count-finding-data-re...,train
440,10.1590_1981-52712015v43n3rb20180090,./kaggle/input/make-data-count-finding-data-re...,,train
295,10.1186_s12935-020-01373-x,./kaggle/input/make-data-count-finding-data-re...,./kaggle/input/make-data-count-finding-data-re...,train


Test files paths shape: (30, 4)


Unnamed: 0,article_id,pdf_file_path,xml_file_path,dataset_type
21,10.1002_ecs2.4619,./kaggle/input/make-data-count-finding-data-re...,./kaggle/input/make-data-count-finding-data-re...,test
28,10.1002_nafm.10870,./kaggle/input/make-data-count-finding-data-re...,,test
25,10.1002_esp.5058,./kaggle/input/make-data-count-finding-data-re...,,test


In [6]:
# Global variables for LLM components (loaded once)
llm_tokenizer = None
llm_model = None

# --- Load the Qwen Model and Tokenizer---
model_name = kagglehub.model_download("qwen-lm/qwen-3/transformers/0.6b")
print(f"Loading Qwen model and tokenizer from: {model_name}")
llm_tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
llm_model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype="auto",
    device_map="auto",
    trust_remote_code=True
)
llm_model.eval()  # Set the model to evaluation mode

# prepare the model input
prompt = "Give me a short introduction to large language model."
messages = [
    {"role": "user", "content": prompt}
]
text = llm_tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True,
    enable_thinking=True # Switches between thinking and non-thinking modes. Default is True.
)
model_inputs = llm_tokenizer([text], return_tensors="pt").to(llm_model.device)

# conduct text completion
generated_ids = llm_model.generate(
    **model_inputs,
    max_new_tokens=32768
)
output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist() 

# parsing thinking content
try:
    # rindex finding 151668 ()
    index = len(output_ids) - output_ids[::-1].index(151668)
except ValueError:
    index = 0

thinking_content = llm_tokenizer.decode(output_ids[:index], skip_special_tokens=True).strip("\n")
content = llm_tokenizer.decode(output_ids[index:], skip_special_tokens=True).strip("\n")

print("thinking content:", thinking_content)
print("content:", content)

Loading Qwen model and tokenizer from: C:\Users\jim\.cache\kagglehub\models\qwen-lm\qwen-3\transformers\0.6b\1
thinking content: <think>
Okay, the user wants a short introduction to a large language model. Let me start by recalling what I know about LLMs. They are large language models, right? So I should mention their size and capabilities. Maybe start with "Large language models" as the main term. Then explain that they're trained on vast amounts of text, so they can understand and generate a lot of text.

I should highlight their ability to understand and generate human-like text. Maybe mention specific examples, like answering questions or writing creative content. Also, it's important to note that they can work with various tasks, not just text. Oh, and maybe touch on their applications in fields like customer service, content creation, etc.

Wait, the user might be a student or someone new to the topic. They need a concise yet informative introduction. I should keep it short, may

In [13]:
llm_tokenizer.model_max_length

131072

In [37]:
class QwenChatbotSimple:
    """ A simple chatbot class using the Qwen model for generating responses.
    It maintains a history of messages to provide context for the conversation.
    """

    def __init__(self, model_name = kagglehub.model_download("qwen-lm/qwen-3/transformers/0.6b")):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForCausalLM.from_pretrained(model_name)
        self.history = []

    def generate_response(self, user_input):
        messages = self.history + [{"role": "user", "content": user_input}]

        text = self.tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )

        inputs = self.tokenizer(text, return_tensors="pt")
        response_ids = self.model.generate(**inputs, max_new_tokens=32768)[0][len(inputs.input_ids[0]):].tolist()
        response = self.tokenizer.decode(response_ids, skip_special_tokens=True)

        # Update history
        self.history.append({"role": "user", "content": user_input})
        self.history.append({"role": "assistant", "content": response})

        return response

In [7]:
# --- QwenChatbot Class ---
class QwenChatbot:
    def __init__(self, model_name = kagglehub.model_download("qwen-lm/qwen-3/transformers/0.6b")):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
        self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto", trust_remote_code=True)
        self.model.eval() # Set the model to evaluation mode here as well.
        self.history = []

    def generate_response(self, user_input, enable_thinking=True, max_new_tokens=512):  # Reduce max_new_tokens
        messages = self.history + [{"role": "user", "content": user_input}]
        text = self.tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True,
            enable_thinking=enable_thinking
        )

        inputs = self.tokenizer(text, return_tensors="pt").to(self.model.device)
        with torch.no_grad(): # Disable gradient calculation during inference
            generated_ids = self.model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,  # Use the reduced value
                temperature=0.6 if enable_thinking else 0.7,
                top_p=0.95 if enable_thinking else 0.8,
                top_k=20,
                min_p=0
            )

        output_ids = generated_ids[0][len(inputs.input_ids[0]):].tolist()

        #parsing thinking content
        try:
            # rindex finding 151668 ()
            index = len(output_ids) - output_ids[::-1].index(151668)
        except ValueError:
            index = 0

        thinking_content = self.tokenizer.decode(output_ids[:index], skip_special_tokens=True).strip("\n")
        content = self.tokenizer.decode(output_ids[index:], skip_special_tokens=True).strip("\n")

        response = content

        # Update history
        self.history.append({"role": "user", "content": user_input})
        self.history.append({"role": "assistant", "content": response})

        return response, thinking_content
    
    def clear_history(self):
        self.history = []


In [8]:
chatbot = QwenChatbot()

In [9]:
# --- Example Usage of QwenChatbot ---

# First input (without /think or /no_think tags, thinking mode is enabled by default)
user_input_1 = "How many r's in strawberries?"
print(f"User: {user_input_1}")
response_1 = chatbot.generate_response(user_input_1)
print(f"Bot: {response_1}")
print("----------------------")

User: How many r's in strawberries?
Bot: ('How many r\'s are in the word "strawberries"? Let\'s break it down:\n\n- S, T, R, A, W, B, E, R, R, I, N, G, S.  \n\nThe letters \'r\' appear at positions 3 and', '<think>\nOkay, so the user is asking, "How many r\'s in strawberries?" Let me think about this. First, I need to check how many times the letter \'r\' appears in the word "strawberries". Let me write it out: S-T-R-A-W-B-E-R-R-I-N-G-S. Now, I\'ll go through each letter one by one.\n\nStarting with the first letter, S. No \'r\'s here. Then T. Still no. R! That\'s the first \'r\'. Next, A. Then W, B, E, R, R, I, N, G. So here, there are two more \'r\'s. Let me count again: positions 3 and 6. So total of 3 \'r\'s. Wait, is that correct? Let me double-check. The word is S-T-R-A-W-B-E-R-R-I-N-G-S. Yes, the letters are S-T-R-A-W-B-E-R-R-I-N-G-S. So the \'r\'s are at the 3rd and 6th positions. So that\'s two \'r\'s. Wait, but I thought maybe three? Let me count again. S (no), T (no), R (yes

In [12]:
response_1

('How many r\'s are in the word "strawberries"? Let\'s break it down:\n\n- S, T, R, A, W, B, E, R, R, I, N, G, S.  \n\nThe letters \'r\' appear at positions 3 and',
 '<think>\nOkay, so the user is asking, "How many r\'s in strawberries?" Let me think about this. First, I need to check how many times the letter \'r\' appears in the word "strawberries". Let me write it out: S-T-R-A-W-B-E-R-R-I-N-G-S. Now, I\'ll go through each letter one by one.\n\nStarting with the first letter, S. No \'r\'s here. Then T. Still no. R! That\'s the first \'r\'. Next, A. Then W, B, E, R, R, I, N, G. So here, there are two more \'r\'s. Let me count again: positions 3 and 6. So total of 3 \'r\'s. Wait, is that correct? Let me double-check. The word is S-T-R-A-W-B-E-R-R-I-N-G-S. Yes, the letters are S-T-R-A-W-B-E-R-R-I-N-G-S. So the \'r\'s are at the 3rd and 6th positions. So that\'s two \'r\'s. Wait, but I thought maybe three? Let me count again. S (no), T (no), R (yes, first), A (no), W (no), B (no), E (no)

In [None]:
reason_prompt =  '''
You are an advanced AI reasoning assistant tasked with delivering a comprehensive analysis of a specific problem or question.  Your goal is to outline your reasoning process in a structured and transparent manner, with each step reflecting a thorough examination of the issue at hand, culminating in a well-reasoned conclusion.

### Key Instructions:
1.  Conduct **at least 5 distinct reasoning steps**, each building on the previous one.
2.  **Acknowledge the limitations** inherent to AI, specifically what you can accurately assess and what you may struggle with.
3.  **Adopt multiple reasoning frameworks** to resolve the problem or derive conclusions, such as:
- **Deductive reasoning** (drawing specific conclusions from general principles)
- **Inductive reasoning** (deriving broader generalizations from specific observations)
- **Abductive reasoning** (choosing the best possible explanation for the given evidence)
- **Analogical reasoning** (solving problems through comparisons and analogies)
4.  **Critically analyze your reasoning** to identify potential flaws, biases, or gaps in logic.
5.  When reviewing, apply a **fundamentally different perspective or approach** to enhance your analysis.
6.  **Employ at least 2 distinct reasoning methods** to derive or verify the accuracy of your conclusions.
7.  **Incorporate relevant domain knowledge** and **best practices** where applicable, ensuring your reasoning aligns with established standards.
8.  **Quantify certainty levels** for each step and your final conclusion, where applicable.
9.  Consider potential **edge cases or exceptions** that could impact the outcome of your reasoning.
10.  Provide **clear justifications** for dismissing alternative hypotheses or solutions that arise during your analysis.
'''

task_prompt = '''
You are given a piece of academic text. Your task is to identify the single DOI citation string, if present.
Then normalize it into its full URL format: https://doi.org/...

Each object (paper and dataset) has a unique, persistent identifier to represent it. In this competition there will be two types:

the definition of dataset_id is as follows:
the dataset identifier and citation type in the paper.
DOIs are used for all papers and some datasets. They take the following form: https://doi.org/[prefix]/[suffix]. Examples:
https://doi.org/10.1371/journal.pone.0303785
https://doi.org/10.5061/dryad.r6nq870
Accession IDs are used for some datasets. They vary in form by individual data repository where the data live. Examples:
"GSE12345" (Gene Expression Omnibus dataset)
“PDB 1Y2T” (Protein Data Bank dataset)
"E-MEXP-568" (ArrayExpress dataset)

the definition of type is as follows:
the type citation type, Primary - raw or processed data generated as part of this paper, specifically for this study
Secondary - raw or processed data derived or reused from existing records or published data
'''

example_prompt = '''
use markdown json style to get the result as the following examples:

```json
[
    {
        "dataset_id": "https://doi.org/10.1371/journal.pone.0303785",
        "type": "Primary"
    },
    {
        "dataset_id": "https://doi.org/10.1371/journal.pone.0303785",
        "type": "Secondary"
    },
    {
        "dataset_id": "GSE12345",
        "type": "Secondary"
    },
    {
        "dataset_id": "Missing",
        "type": "Missing, "
    }
    ...
]
```
'''

In [None]:

# --- 2. Data Preparation for LLM Training (Revised for Combined Task) ---

def load_base_llm_for_training():
    """Loads the base Qwen model and tokenizer for fine-tuning."""
    global llm_tokenizer, llm_model
    if not AutoModelForCausalLM or not QWEN_BASE_MODEL_PATH:
        print("LLM components not available or base model path not set. Skipping LLM loading.")
        return False
    try:
        print(f"Loading Qwen tokenizer from: {QWEN_BASE_MODEL_PATH}")
        llm_tokenizer = AutoTokenizer.from_pretrained(QWEN_BASE_MODEL_PATH, trust_remote_code=True)
        if llm_tokenizer.pad_token is None:
            llm_tokenizer.pad_token = llm_tokenizer.eos_token
            print("Set tokenizer.pad_token to tokenizer.eos_token")

        print(f"Loading Qwen model from: {QWEN_BASE_MODEL_PATH}")
        llm_model = AutoModelForCausalLM.from_pretrained(
            QWEN_BASE_MODEL_PATH,
            torch_dtype=torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float32,
            device_map="auto", # Automatically uses GPU if available
            trust_remote_code=True,
            # load_in_8bit=True if bnb else False # Uncomment if bitsandbytes is used
        )
        print(f"Base LLM loaded successfully on {llm_model.device}.")
        return True
    except Exception as e:
        print(f"Error loading base LLM for training: {e}")
        llm_tokenizer, llm_model = None, None # Reset to None on failure
        return False

def prepare_training_data_for_llm(
    training_df: pd.DataFrame,
    all_article_texts: dict[str, str],
    tokenizer_max_length: int
) -> Dataset:
    """
    Prepares training data for LLM fine-tuning, aggregating dataset IDs and classifications
    per article and formatting into ChatML JSON output.
    """
    formatted_examples = []

    # Group training data by article_id to get all datasets for each article
    # This creates a dictionary where keys are article_ids and values are lists of dataset dicts
    grouped_training_data = training_df.groupby('article_id').apply(
        lambda x: [{"dataset_id": row['dataset_id'], "classification": row['label']} for _, row in x]
    ).to_dict()

    # Get all article IDs for which we have text content
    all_article_ids_with_text = set(all_article_texts.keys())
    
    # Iterate through all articles for which we have text (these are our potential training examples)
    for article_id in all_article_ids_with_text:
        article_text = all_article_texts.get(article_id, "")
        if not article_text:
            print(f"Warning: Article text for {article_id} not found. Skipping training example.")
            continue

        # Truncate article text to fit within the model's context window
        # Reserve tokens for the prompt and the expected JSON response.
        # A typical Qwen 1.5 model has 32768 max_seq_length.
        # 512 tokens for prompt/response is a safe estimate.
        truncated_article_text = article_text[:tokenizer_max_length - 512] 

        # Determine the ground truth output for this article
        if article_id in grouped_training_data:
            # Article has datasets, format them as JSON
            ground_truth_datasets = grouped_training_data[article_id]
            assistant_response_json = json.dumps(ground_truth_datasets, ensure_ascii=False)
        else:
            # Article has no datasets in training data, so the model should output an empty list.
            # This explicitly trains the model to output '[]' for "Missing" cases.
            assistant_response_json = "[]"
            # print(f"Info: Article {article_id} has no datasets in training data. Training to output '[]'.")

        # Construct the user message for the LLM
        user_message = f"""
Article Text:
{truncated_article_text}

Task: Identify all datasets or databases used in this research article and classify each as "Primary" (if created by the authors for this research) or "Secondary" (if an existing dataset used in this research).

Output Format: Provide a JSON list of objects. Each object should have "dataset_id" and "classification" keys. If no datasets are identified, return an empty JSON list: [].
"""
        # Construct the full ChatML formatted string for SFTTrainer
        # The trainer will use this entire string as the 'text' field.
        chatml_formatted_string = f"<|im_start|>system\nYou are an expert research assistant. Your task is to extract and classify datasets from scientific articles.<|im_end|>\n<|im_start|>user\n{user_message.strip()}<|im_end|>\n<|im_start|>assistant\n{assistant_response_json}<|im_end|>"
        
        formatted_examples.append({"text": chatml_formatted_string})

    if not formatted_examples:
        raise ValueError("No training examples could be prepared. Check your data and article texts.")

    return Dataset.from_list(formatted_examples)

# --- 3. LLM Model Training (Fine-tuning) ---

# Attempt to load tokenizer and model if not already loaded (e.g., if previous training failed or was skipped)
if llm_model is None:
    load_base_llm_for_training()

if llm_model and not training_df.empty and Dataset: # Ensure Dataset is imported
    print("\n--- Preparing data for Fine-tuning (Combined Task) ---")
    # Use the model's max_length for context, or a reasonable default if tokenizer isn't loaded
    max_len = llm_tokenizer.model_max_length if llm_tokenizer else 4096 
    train_dataset = prepare_training_data_for_llm(training_df, all_article_texts, max_len)
    
    print(f"Prepared {len(train_dataset)} examples for fine-tuning.")
    print("Example formatted training instance (first 500 chars):")
    print(train_dataset[0]['text'][:500])

    print("\n--- Starting Fine-tuning (Combined Task) ---")
    try:
        training_args = TrainingArguments(
            output_dir=f"{FINE_TUNED_MODEL_OUTPUT_DIR}/checkpoints",
            num_train_epochs=1,  # Start with 1 epoch, adjust as needed
            per_device_train_batch_size=1, # Adjust based on VRAM
            gradient_accumulation_steps=4, # Effective batch size = 1 * 4 = 4
            learning_rate=2e-5,
            logging_steps=10,
            save_steps=50, # Save checkpoints periodically
            fp16=torch.cuda.is_available() and not torch.cuda.is_bf16_supported(),
            bf16=torch.cuda.is_available() and torch.cuda.is_bf16_supported(),
            optim="paged_adamw_8bit", # Good for memory efficiency if bitsandbytes is installed
            # report_to="none", # Disable logging to external services
            # max_steps=100, # For quick testing
        )

        trainer = SFTTrainer(
            model=llm_model,
            tokenizer=llm_tokenizer,
            train_dataset=train_dataset,
            dataset_text_field="text", # This field contains the full ChatML string
            args=training_args,
            max_seq_length=max_len, # Use the model's full context length
            packing=False, # Set to True if your inputs are much shorter than max_seq_length
        )

        trainer.train()
        print("Fine-tuning completed.")

        print(f"Saving fine-tuned model to: {FINE_TUNED_MODEL_OUTPUT_DIR}")
        trainer.save_model(FINE_TUNED_MODEL_OUTPUT_DIR)
        print("Model and tokenizer saved.")

    except Exception as e:
        print(f"An error occurred during fine-tuning: {e}")
        import traceback
        traceback.print_exc()
        llm_model = None # Mark model as failed to load/train
else:
    print("Skipping LLM fine-tuning due to missing training data or LLM components.")


# --- 4. LLM-based Extraction & Classification (Inference) ---

# Load the fine-tuned model for inference (if training was successful)
# If training was skipped or failed, this will attempt to load from the base path or fail.
if inference_model is None: # Only load if not already loaded
    if AutoModelForCausalLM: # Check if transformers is available
        if os.path.exists(FINE_TUNED_MODEL_OUTPUT_DIR) and os.path.isdir(FINE_TUNED_MODEL_OUTPUT_DIR):
            MODEL_TO_LOAD = FINE_TUNED_MODEL_OUTPUT_DIR
            print(f"Loading fine-tuned model for inference from: {MODEL_TO_LOAD}")
        else:
            MODEL_TO_LOAD = QWEN_BASE_MODEL_PATH
            print(f"Fine-tuned model not found. Loading base model for inference from: {MODEL_TO_LOAD}")

        try:
            inference_tokenizer = AutoTokenizer.from_pretrained(MODEL_TO_LOAD, trust_remote_code=True)
            if inference_tokenizer.pad_token is None:
                inference_tokenizer.pad_token = inference_tokenizer.eos_token
            inference_model = AutoModelForCausalLM.from_pretrained(
                MODEL_TO_LOAD,
                torch_dtype=torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float32,
                device_map="auto",
                trust_remote_code=True
            ).eval() # Set to evaluation mode
            print(f"Inference LLM loaded successfully on {inference_model.device}.")
        except Exception as e:
            print(f"Error loading inference LLM from {MODEL_TO_LOAD}: {e}")
            inference_model, inference_tokenizer = None, None
    else:
        print("Transformers library not available. Cannot load LLM for inference.")


def extract_and_classify_with_llm(article_text: str) -> list[dict]:
    """
    Uses the loaded LLM to extract dataset IDs and classify them.
    Returns a list of dictionaries like [{"dataset_id": "...", "classification": "..."}].
    Returns an empty list if LLM is unavailable or parsing fails.
    """
    if not inference_model or not inference_tokenizer:
        print("  LLM unavailable for extraction/classification.")
        return [] # Return empty list if LLM is not loaded

    # Truncate article text for inference if it exceeds model's context window
    # Use the same max_length as during training for consistency
    max_inference_context_length = inference_tokenizer.model_max_length - 256 # Reserve tokens for prompt and response
    truncated_article_text = article_text[:max_inference_context_length]

    user_message = f"""
Article Text:
{truncated_article_text}

Task: Identify all datasets or databases used in this research article and classify each as "Primary" (if created by the authors for this research) or "Secondary" (if an existing dataset used in this research).

Output Format: Provide a JSON list of objects. Each object should have "dataset_id" and "classification" keys. If no datasets are identified, return an empty JSON list: [].
"""
    messages = [
        {"role": "system", "content": "You are an expert research assistant. Your task is to extract and classify datasets from scientific articles."},
        {"role": "user", "content": user_message.strip()}
    ]
    
    input_ids = inference_tokenizer.apply_chat_template(
        messages,
        tokenize=True,
        add_generation_prompt=True,
        return_tensors="pt"
    ).to(inference_model.device)

    try:
        with torch.no_grad():
            outputs = inference_model.generate(
                input_ids,
                max_new_tokens=512, # Allow more tokens for multiple dataset outputs
                pad_token_id=inference_tokenizer.eos_token_id,
                eos_token_id=inference_tokenizer.convert_tokens_to_ids("<|im_end|>")
            )
        
        response_text = inference_tokenizer.decode(
            outputs[0][input_ids.shape[1]:],
            skip_special_tokens=False # Keep special tokens to remove <|im_end|> explicitly
        ).strip()
        response_text = response_text.replace("<|im_end|>", "").strip()
        
        print(f"  LLM raw response: '{response_text}'")

        # Attempt to parse the JSON output
        try:
            parsed_data = json.loads(response_text)
            if isinstance(parsed_data, list):
                # Validate structure: each item should be a dict with 'dataset_id' and 'classification'
                valid_datasets = []
                for item in parsed_data:
                    if isinstance(item, dict) and 'dataset_id' in item and 'classification' in item:
                        # Basic validation for classification label
                        if item['classification'] in ["Primary", "Secondary"]:
                            valid_datasets.append(item)
                        else:
                            print(f"  Warning: Invalid classification '{item['classification']}' for dataset '{item.get('dataset_id', 'N/A')}'. Skipping.")
                    else:
                        print(f"  Warning: Malformed JSON object: {item}. Skipping.")
                return valid_datasets
            else:
                print(f"  Warning: LLM did not return a JSON list: {response_text}")
                return []
        except json.JSONDecodeError as jde:
            print(f"  Error decoding JSON from LLM response: {jde}. Raw response: '{response_text}'")
            return []

    except Exception as e:
        print(f"  Error during LLM generation: {e}")
        return []

# --- Main Processing Loop for all articles (Revised) ---
print("\n--- Starting Article Processing and Classification (LLM-driven) ---")
final_results = []

for article_id, article_text in all_article_texts.items():
    print(f"\nProcessing article: {article_id}")
    
    # LLM directly extracts and classifies
    identified_datasets = extract_and_classify_with_llm(article_text)
    
    if not identified_datasets:
        # If LLM returns an empty list, classify the article as "Missing"
        print(f"  LLM identified no datasets for {article_id}. Classifying as 'Missing'.")
        final_results.append({
            "article_id": article_id,
            "dataset_id": "N/A", # Indicate no specific dataset ID
            "classification_label": "Missing"
        })
    else:
        print(f"  LLM identified {len(identified_datasets)} dataset(s) for {article_id}.")
        for item in identified_datasets:
            final_results.append({
                "article_id": article_id,
                "dataset_id": item.get("dataset_id", "Unknown"), # Use .get() for safety
                "classification_label": item.get("classification", "Uncertain_LLM")
            })


# --- 5. Results & Output ---

print("\n--- Final Results ---")
if final_results:
    results_df = pd.DataFrame(final_results)
    print(results_df.head(10)) # Print first 10 rows
    
    # Save to CSV
    results_df.to_csv(FINAL_RESULTS_CSV_PATH, index=False)
    print(f"\nResults saved to: {FINAL_RESULTS_CSV_PATH}")
else:
    print("No results generated.")

print("\nProcessing complete, Jim!")