# Fine-tune Gemma3 270M for MR Abstract Data Extraction

This notebook fine-tunes a Gemma3 270M model using Unsloth to replicate the OpenAI-based extraction logic for Mendelian randomization abstracts.

## Setup and Installation

In [None]:
# Install required packages
!pip install unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git
!pip install --no-deps "xformers<0.0.27" trl peft accelerate bitsandbytes

In [None]:
import json
import logging
import os
import sys
from pathlib import Path
from typing import Dict, List, Tuple, Any
import numpy as np
import torch
import pandas as pd
from datasets import Dataset as HFDataset

# Add local packages to path
sys.path.insert(0, "../../src/local_funcs/src")
sys.path.insert(0, "../../src/yiutils/src")

from local_funcs import prompt_funcs, schema_funcs
from yiutils.failsafe import safe_json_loads

# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

## Configuration

In [None]:
# Fine-tuning configuration
CONFIG = {
    "model_name": "unsloth/Gemma-2-2b",  # Will use 270M when available
    "max_seq_length": 2048,
    "dtype": "bfloat16",
    "load_in_4bit": True,
    
    # LoRA config
    "r": 16,
    "target_modules": ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    "lora_alpha": 16,
    "lora_dropout": 0.0,
    "bias": "none",
    "use_gradient_checkpointing": "unsloth",
    "random_state": 3407,
    
    # Training config
    "per_device_train_batch_size": 2,
    "gradient_accumulation_steps": 4,
    "warmup_steps": 5,
    "max_steps": 60,
    "learning_rate": 2e-4,
    "fp16": False,
    "bf16": True,
    "logging_steps": 1,
    "optim": "adamw_8bit",
    "weight_decay": 0.01,
    "lr_scheduler_type": "linear",
    "seed": 3407,
}

print(f"Configuration: {CONFIG}")

## Data Preparation

Load existing OpenAI extraction results to create training data.

In [None]:
def load_training_data(data_dir: Path) -> Tuple[List[str], List[str], List[str]]:
    """Load training data from existing OpenAI extraction results."""
    logger.info(f"Loading training data from {data_dir}")
    
    abstracts = []
    metadata_extractions = []
    results_extractions = []
    
    # Look for processed results files
    for result_file in data_dir.glob("**/*results*.json"):
        try:
            with open(result_file, 'r') as f:
                data = json.load(f)
            
            # Extract abstracts and extractions from OpenAI results
            if isinstance(data, list):
                for item in data:
                    if 'abstract' in item and 'extracted_metadata' in item and 'extracted_results' in item:
                        abstracts.append(item['abstract'])
                        metadata_extractions.append(json.dumps(item['extracted_metadata']))
                        results_extractions.append(json.dumps(item['extracted_results']))
            
        except Exception as e:
            logger.warning(f"Failed to load {result_file}: {e}")
            continue
    
    logger.info(f"Loaded {len(abstracts)} training examples")
    return abstracts, metadata_extractions, results_extractions


# Load data
data_dir = Path("../../data/intermediate")
abstracts, metadata_extractions, results_extractions = load_training_data(data_dir)

print(f"Found {len(abstracts)} training examples")
if len(abstracts) > 0:
    print(f"Example abstract: {abstracts[0][:200]}...")
    print(f"Example metadata: {metadata_extractions[0][:200]}...")
else:
    print("No training data found. Please run OpenAI extraction first.")

In [None]:
def prepare_training_examples(abstracts: List[str], metadata_extractions: List[str], 
                            results_extractions: List[str]) -> List[Dict[str, str]]:
    """Prepare training examples in instruction format."""
    examples = []
    
    for abstract, metadata, results in zip(abstracts, metadata_extractions, results_extractions):
        # Metadata extraction example
        metadata_instruction = (
            "Extract metadata from the following Mendelian randomization abstract. "
            "Return valid JSON with fields: title, authors, journal, year, pmid, doi, abstract.\n\n"
            f"Abstract: {abstract}\n\n"
            "Metadata JSON:"
        )
        
        examples.append({
            "instruction": metadata_instruction,
            "output": metadata,
        })
        
        # Results extraction example
        results_instruction = (
            "Extract MR results from the following Mendelian randomization abstract. "
            "Return valid JSON with fields: exposures, outcomes, mr_methods, effect_estimates, "
            "pvalues, confidence_intervals, sample_sizes, populations, study_design.\n\n"
            f"Abstract: {abstract}\n\n"
            "Results JSON:"
        )
        
        examples.append({
            "instruction": results_instruction,
            "output": results,
        })
    
    return examples


def format_prompts(examples: List[Dict[str, str]]) -> List[str]:
    """Format examples as conversation prompts."""
    formatted = []
    
    for example in examples:
        # Use Gemma format
        prompt = f"<bos><start_of_turn>user\n{example['instruction']}<end_of_turn>\n<start_of_turn>model\n{example['output']}<eos>"
        formatted.append(prompt)
    
    return formatted


# Prepare training examples
if len(abstracts) > 0:
    examples = prepare_training_examples(abstracts, metadata_extractions, results_extractions)
    formatted_prompts = format_prompts(examples)
    
    # Create HuggingFace dataset
    train_dataset = HFDataset.from_dict({"text": formatted_prompts})
    
    print(f"Prepared {len(formatted_prompts)} training examples")
    print(f"Example formatted prompt: {formatted_prompts[0][:300]}...")
else:
    print("Skipping training data preparation - no data found")

## Model Loading

Load Gemma3 270M with Unsloth for efficient fine-tuning.

In [None]:
from unsloth import FastLanguageModel
import torch

def load_unsloth_model(config: dict):
    """Load Gemma3 270M model with Unsloth."""
    model, tokenizer = FastLanguageModel.from_pretrained(
        model_name=config["model_name"],
        max_seq_length=config["max_seq_length"],
        dtype=getattr(torch, config["dtype"]) if hasattr(torch, config["dtype"]) else None,
        load_in_4bit=config["load_in_4bit"],
    )
    
    model = FastLanguageModel.get_peft_model(
        model,
        r=config["r"],
        target_modules=config["target_modules"],
        lora_alpha=config["lora_alpha"],
        lora_dropout=config["lora_dropout"],
        bias=config["bias"],
        use_gradient_checkpointing=config["use_gradient_checkpointing"],
        random_state=config["random_state"],
        use_rslora=False,
        loftq_config=None,
    )
    
    return model, tokenizer


# Load model
logger.info("Loading Unsloth model...")
model, tokenizer = load_unsloth_model(CONFIG)
print("Model loaded successfully!")
print(f"Model: {model.__class__.__name__}")
print(f"Tokenizer: {tokenizer.__class__.__name__}")

## Training Setup

In [None]:
from transformers import TrainingArguments
from trl import SFTTrainer

def create_trainer(model, tokenizer, train_dataset, config: dict):
    """Create trainer for fine-tuning."""
    training_args = TrainingArguments(
        per_device_train_batch_size=config["per_device_train_batch_size"],
        gradient_accumulation_steps=config["gradient_accumulation_steps"],
        warmup_steps=config["warmup_steps"],
        max_steps=config["max_steps"],
        learning_rate=config["learning_rate"],
        fp16=config["fp16"],
        bf16=config["bf16"],
        logging_steps=config["logging_steps"],
        optim=config["optim"],
        weight_decay=config["weight_decay"],
        lr_scheduler_type=config["lr_scheduler_type"],
        seed=config["seed"],
        output_dir="./gemma3-270m-mr-extraction",
    )
    
    trainer = SFTTrainer(
        model=model,
        tokenizer=tokenizer,
        train_dataset=train_dataset,
        dataset_text_field="text",
        max_seq_length=config["max_seq_length"],
        dataset_num_proc=2,
        packing=False,
        args=training_args,
    )
    
    return trainer


# Create trainer
if 'train_dataset' in locals():
    trainer = create_trainer(model, tokenizer, train_dataset, CONFIG)
    print("Trainer created successfully!")
else:
    print("Skipping trainer creation - no training dataset")

## Training

Start the fine-tuning process.

In [None]:
# Train model
if 'trainer' in locals():
    logger.info("Starting training...")
    trainer.train()
    print("Training completed!")
else:
    print("Skipping training - no trainer available")

## Evaluation

Test the fine-tuned model on sample abstracts.

In [None]:
def evaluate_model(model, tokenizer, test_abstracts: List[str]):
    """Evaluate the fine-tuned model on test examples."""
    logger.info("Evaluating model...")
    
    FastLanguageModel.for_inference(model)
    
    results = []
    
    for i, abstract in enumerate(test_abstracts[:3]):  # Test on first 3 examples
        print(f"\nEvaluating example {i+1}/3...")
        
        # Test metadata extraction
        metadata_instruction = (
            "Extract metadata from the following Mendelian randomization abstract. "
            "Return valid JSON with fields: title, authors, journal, year, pmid, doi, abstract.\n\n"
            f"Abstract: {abstract}\n\n"
            "Metadata JSON:"
        )
        
        inputs = tokenizer(
            [f"<bos><start_of_turn>user\n{metadata_instruction}<end_of_turn>\n<start_of_turn>model\n"],
            return_tensors="pt"
        ).to("cuda")
        
        outputs = model.generate(
            **inputs,
            max_new_tokens=512,
            use_cache=True,
            temperature=0.1,
            do_sample=True,
        )
        
        response = tokenizer.decode(outputs[0], skip_special_tokens=True)
        metadata_extracted = response.split("<start_of_turn>model\n")[-1]
        
        print(f"Metadata extraction: {metadata_extracted[:200]}...")
        
        # Test results extraction
        results_instruction = (
            "Extract MR results from the following Mendelian randomization abstract. "
            "Return valid JSON with fields: exposures, outcomes, mr_methods, effect_estimates, "
            "pvalues, confidence_intervals, sample_sizes, populations, study_design.\n\n"
            f"Abstract: {abstract}\n\n"
            "Results JSON:"
        )
        
        inputs = tokenizer(
            [f"<bos><start_of_turn>user\n{results_instruction}<end_of_turn>\n<start_of_turn>model\n"],
            return_tensors="pt"
        ).to("cuda")
        
        outputs = model.generate(
            **inputs,
            max_new_tokens=512,
            use_cache=True,
            temperature=0.1,
            do_sample=True,
        )
        
        response = tokenizer.decode(outputs[0], skip_special_tokens=True)
        results_extracted = response.split("<start_of_turn>model\n")[-1]
        
        print(f"Results extraction: {results_extracted[:200]}...")
        
        results.append({
            "abstract": abstract[:200] + "...",
            "metadata_extraction": metadata_extracted,
            "results_extraction": results_extracted,
        })
    
    return results


# Evaluate model
if 'abstracts' in locals() and len(abstracts) > 0:
    test_abstracts = abstracts[-5:]  # Use last 5 as test
    evaluation_results = evaluate_model(model, tokenizer, test_abstracts)
    
    # Save results
    output_dir = Path("./gemma3-270m-mr-extraction")
    output_dir.mkdir(parents=True, exist_ok=True)
    with open(output_dir / "evaluation_results.json", "w") as f:
        json.dump(evaluation_results, f, indent=2)
    
    print(f"\nEvaluation results saved to: {output_dir / 'evaluation_results.json'}")
else:
    print("Skipping evaluation - no test data available")

## Model Export

Save the fine-tuned model for future use.

In [None]:
def save_model(model, tokenizer, output_dir: Path):
    """Save the fine-tuned model."""
    logger.info(f"Saving model to {output_dir}")
    
    model.save_pretrained(output_dir)
    tokenizer.save_pretrained(output_dir)
    
    # Also save in GGUF format for inference
    try:
        model.save_pretrained_gguf(output_dir / "gguf", tokenizer, quantization_method="q4_k_m")
        logger.info(f"Saved GGUF model to {output_dir / 'gguf'}")
    except Exception as e:
        logger.warning(f"Failed to save GGUF model: {e}")


# Save model
output_dir = Path("./gemma3-270m-mr-extraction")
save_model(model, tokenizer, output_dir)
print(f"Model saved to: {output_dir}")

## Inference Test

Test the saved model with a sample abstract.

In [None]:
# Test inference with a sample abstract
sample_abstract = """
Background: The relationship between body mass index (BMI) and cardiovascular disease 
has been extensively studied, but the causal nature remains unclear. We used Mendelian 
randomization to investigate the causal effect of BMI on coronary artery disease.

Methods: We used genetic variants associated with BMI as instrumental variables in a 
two-sample Mendelian randomization analysis. GWAS summary statistics were obtained 
from the GIANT consortium (n=339,224) for BMI and CARDIoGRAMplusC4D (n=184,305) 
for coronary artery disease.

Results: Higher BMI was causally associated with increased risk of coronary artery 
disease (OR = 1.27, 95% CI: 1.15-1.40, P = 2.3e-6). The effect remained significant 
after adjustment for potential confounders.

Conclusions: This study provides evidence for a causal relationship between higher 
BMI and increased coronary artery disease risk.
"""

print("Testing inference on sample abstract...")
print(f"Sample abstract: {sample_abstract[:200]}...")

# Test metadata extraction
metadata_instruction = (
    "Extract metadata from the following Mendelian randomization abstract. "
    "Return valid JSON with fields: title, authors, journal, year, pmid, doi, abstract.\n\n"
    f"Abstract: {sample_abstract}\n\n"
    "Metadata JSON:"
)

FastLanguageModel.for_inference(model)

inputs = tokenizer(
    [f"<bos><start_of_turn>user\n{metadata_instruction}<end_of_turn>\n<start_of_turn>model\n"],
    return_tensors="pt"
).to("cuda")

outputs = model.generate(
    **inputs,
    max_new_tokens=512,
    use_cache=True,
    temperature=0.1,
    do_sample=True,
)

response = tokenizer.decode(outputs[0], skip_special_tokens=True)
metadata_result = response.split("<start_of_turn>model\n")[-1]

print(f"\nMetadata extraction result:\n{metadata_result}")

print("\n" + "="*50)
print("Fine-tuning completed successfully!")
print(f"Model saved to: {output_dir}")
print("Use the saved model for inference on new MR abstracts.")