In [2]:
# First, verify we're running in Colab
try:
    import google.colab
    IN_COLAB = True
except:
    IN_COLAB = False

if not IN_COLAB:
    print("WARNING: Not running in Google Colab!")
else:
    print("Running in Google Colab")

    # Install nvidia-smi if needed
    !apt-get update -qq && apt-get install -qq nvidia-utils-470 > /dev/null

    # Check CUDA availability
    import torch
    print("\nPyTorch CUDA Settings:")
    print("CUDA Available:", torch.cuda.is_available())
    if torch.cuda.is_available():
        print("Current CUDA Device:", torch.cuda.current_device())
        print("CUDA Device Name:", torch.cuda.get_device_name())
        print("CUDA Device Count:", torch.cuda.device_count())
        print("CUDA Version:", torch.version.cuda)

    print("\nGPU Information:")
    !nvidia-smi

Running in Google Colab
W: Skipping acquire of configured file 'main/source/Sources' as repository 'https://r2u.stat.illinois.edu/ubuntu jammy InRelease' does not seem to provide it (sources.list entry misspelt?)

PyTorch CUDA Settings:
CUDA Available: True
Current CUDA Device: 0
CUDA Device Name: NVIDIA A100-SXM4-80GB
CUDA Device Count: 1
CUDA Version: 12.6

GPU Information:
Wed Oct 15 21:44:20 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A100-SXM4-80GB         

In [3]:
# Configure GPU settings for Colab
if IN_COLAB and torch.cuda.is_available():
    # Clear any existing allocations
    torch.cuda.empty_cache()

    # Set device
    device = torch.device("cuda")

    # Print initial memory state
    print("Initial GPU Memory State:")
    print(f"Total GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
    print(f"Allocated: {torch.cuda.memory_allocated() / 1024**3:.1f} GB")
    print(f"Cached: {torch.cuda.memory_reserved() / 1024**3:.1f} GB")

    # Optional: Set memory growth
    if hasattr(torch.cuda, 'set_per_process_memory_fraction'):
        torch.cuda.set_per_process_memory_fraction(0.9)  # Use up to 90% of GPU memory

    print("\nGPU is ready for training!")
else:
    print("WARNING: GPU not available! Please make sure to:")
    print("1. Runtime → Change runtime type")
    print("2. Set 'Hardware accelerator' to 'GPU'")
    print("3. Set 'GPU type' to 'A100' (if available)")
    print("4. Runtime → Restart runtime")

Initial GPU Memory State:
Total GPU Memory: 79.3 GB
Allocated: 0.0 GB
Cached: 0.0 GB

GPU is ready for training!


In [5]:
!pip install -q --upgrade pip
!pip install -q torch>=2.2.0 \
    transformers>=4.37.0 \
    datasets \
    accelerate>=0.27.0 \
    peft>=0.7.0 \
    bitsandbytes>=0.41.1 \
    trl>=0.7.4 \
    evaluate \
    wandb \
    scikit-learn \
    pandas \
    numpy \
    tqdm \
    sentencepiece \
    jsonlines

# Import and configure warnings
import warnings
warnings.filterwarnings('ignore')

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.8 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m86.4 MB/s[0m eta [36m0:00:00[0m
[?25h

## Setup Environment

In [8]:
import os
import psutil
import GPUtil
from IPython.display import HTML, display
import ipywidgets as widgets

def get_gpu_memory():
    """Get GPU memory usage."""
    try:
        gpu = GPUtil.getGPUs()[0]
        return f"GPU Memory: {gpu.memoryUsed:.0f}MB / {gpu.memoryTotal:.0f}MB"
    except:
        return "No GPU found"

def get_ram_usage():
    """Get RAM usage."""
    process = psutil.Process(os.getpid())
    return f"RAM Usage: {process.memory_info().rss / 1024 / 1024:.0f}MB"

def create_progress_bar():
    """Create a progress bar widget."""
    return widgets.FloatProgress(
        value=0,
        min=0,
        max=100,
        description='Progress:',
        bar_style='info',
        orientation='horizontal'
    )

# Create output widgets for monitoring
memory_widget = widgets.HTML(value="Memory Usage: Initializing...")
gpu_widget = widgets.HTML(value="GPU: Initializing...")
display(memory_widget, gpu_widget)

# Update monitoring info
def update_monitoring():
    memory_widget.value = get_ram_usage()
    gpu_widget.value = get_gpu_memory()

# Create checkpoint directory
!mkdir -p checkpoints

HTML(value='Memory Usage: Initializing...')

HTML(value='GPU: Initializing...')

## Dataset Classes

In [9]:
import torch
from torch.utils.data import Dataset
from typing import Dict, List, Optional, Any
import json

class PubMedQADataset(Dataset):
    def __init__(self, data: Dict[str, Dict], tokenizer):
        # Convert dict to list of tuples (id, data)
        self.data = list(data.items())
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        _, item = self.data[idx]

        # Format: Question + Contexts + Final Decision + Long Answer
        text = f"Question: {item['QUESTION']}\n\n"
        text += "Context:\n"
        for ctx in item['CONTEXTS']:
            text += f"{ctx}\n"
        text += f"\nAnswer: {item['final_decision']}\n"
        text += f"Explanation: {item['LONG_ANSWER']}"

        encoded = self.tokenizer(
            text,
            max_length=MAX_SEQ_LEN,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )
        return {
            "input_ids": encoded["input_ids"].squeeze(),
            "attention_mask": encoded["attention_mask"].squeeze(),
            "labels": encoded["input_ids"].squeeze()
        }

class MedMCQADataset(Dataset):
    def __init__(self, data: Dict[str, Dict], tokenizer):
        self.data = list(data.values())  # Convert dict to list of items
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        # MedMCQA format
        text = f"Question: {item['question']}\nExplanation: {item['exp']}"

        encoded = self.tokenizer(
            text,
            max_length=MAX_SEQ_LEN,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )
        return {
            "input_ids": encoded["input_ids"].squeeze(),
            "attention_mask": encoded["attention_mask"].squeeze(),
            "labels": encoded["input_ids"].squeeze()
        }

class MedQADataset(Dataset):
    def __init__(self, data: Dict[str, Dict], tokenizer):
        # Convert dict to list of tuples (id, data)
        self.data = list(data.items())
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        _, item = self.data[idx]

        # Format: Question + Options + Answer
        text = f"Question: {item['question']}\n\n"
        text += "Options:\n"
        for opt_key, opt_value in item['options'].items():
            text += f"{opt_key}) {opt_value}\n"
        text += f"\nAnswer: {item['options'][item['answer_idx']]}"

        encoded = self.tokenizer(
            text,
            max_length=MAX_SEQ_LEN,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )
        return {
            "input_ids": encoded["input_ids"].squeeze(),
            "attention_mask": encoded["attention_mask"].squeeze(),
            "labels": encoded["input_ids"].squeeze()
        }

## Training Configuration

In [10]:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments
from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model, PeftModel
import torch

# Constants optimized for A100
MAX_SAMPLES_PER_STAGE = 50
MAX_SEQ_LEN = 512
EPOCHS = 2
LEARNING_RATE = 5e-5
MICRO_BATCH_SIZE = 4  # Increased for A100
GRADIENT_ACCUMULATION_STEPS = 2  # Adjusted for larger batch size
WARMUP_RATIO = 0.05
SEED = 7
MODEL_NAME = "meta-llama/Meta-Llama-3.1-8B"

# LoRA Configuration
LORA_CONFIG = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=[
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj"
    ],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

# Training stages configuration
STAGES = [
    {"name": "pubmedqa", "data_path": "/content/step1.json"},
    {"name": "medmcqa", "data_path": "/content/step1_medmcqa.json"},
    {"name": "medqa", "data_path": "/content/step1_medqa.json"}
]

def setup_model_and_tokenizer():
    """Initialize model with QLoRA configuration optimized for A100."""
    # Quantization config optimized for A100
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.float16,  # Changed to float16 for A100
        bnb_4bit_use_double_quant=True,
    )

    print("Loading model...")
    # Load model with quantization
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        quantization_config=bnb_config,
        device_map="auto",
        trust_remote_code=True,
    )
    model = prepare_model_for_kbit_training(model)

    print("Loading tokenizer...")
    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    if not tokenizer.pad_token:
        tokenizer.pad_token = tokenizer.eos_token

    return model, tokenizer

def load_dataset(stage: Dict[str, str], tokenizer, max_samples: int) -> Dataset:
    file_path = stage["data_path"]
    stage_name = stage["name"]

    print(f"Loading {stage_name} dataset...")
    # Load raw data
    with open(file_path, 'r') as f:
        data = json.load(f)

    # Create appropriate dataset based on stage
    if stage_name == "pubmedqa":
        return PubMedQADataset(data, tokenizer)
    elif stage_name == "medmcqa":
        return MedMCQADataset(data, tokenizer)
    elif stage_name == "medqa":
        return MedQADataset(data, tokenizer)
    else:
        raise ValueError(f"Unknown stage: {stage_name}")

## Training Pipeline

In [12]:
from transformers import Trainer, TrainerCallback
import logging
from tqdm.notebook import tqdm

class MonitorCallback(TrainerCallback):
    """Custom callback for monitoring training progress."""
    def __init__(self, progress_bar):
        self.progress_bar = progress_bar

    def on_step_end(self, args, state, control, **kwargs):
        if state.global_step % 10 == 0:
            # Update progress bar
            progress = (state.global_step / state.max_steps) * 100
            self.progress_bar.value = progress
            # Update monitoring
            update_monitoring()

def train_stage(
    stage: Dict[str, str],
    model: AutoModelForCausalLM,
    tokenizer: AutoTokenizer,
    prev_adapter_path: Optional[str] = None
) -> str:
    """Train a single stage and return the path to saved adapter."""
    stage_name = stage["name"]
    output_dir = f"checkpoints/{stage_name}"

    print(f"\nStarting {stage_name} stage...")

    # Load dataset for this stage
    train_dataset = load_dataset(stage, tokenizer, MAX_SAMPLES_PER_STAGE)

    # Apply LoRA if not already applied
    if not isinstance(model, PeftModel):
        print("Applying LoRA...")
        model = get_peft_model(model, LORA_CONFIG)
    elif prev_adapter_path:
        print(f"Loading adapter from {prev_adapter_path}...")
        model.load_adapter(prev_adapter_path, adapter_name="default")

    # Create progress bar
    progress_bar = create_progress_bar()
    display(progress_bar)

    # Training arguments optimized for A100
    training_args = TrainingArguments(
        output_dir=output_dir,
        num_train_epochs=EPOCHS,
        per_device_train_batch_size=MICRO_BATCH_SIZE,
        gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
        learning_rate=LEARNING_RATE,
        warmup_ratio=WARMUP_RATIO,
        logging_steps=1,
        save_strategy="epoch",
        fp16=True,  # Enable mixed precision training
        gradient_checkpointing=True,  # Enable gradient checkpointing
        seed=SEED,
        report_to="none",  # Disable wandb logging
    )

    # Initialize trainer with monitoring
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        callbacks=[MonitorCallback(progress_bar)]
    )

    print(f"Training {stage_name}...")
    trainer.train()

    print(f"Saving adapter to {output_dir}...")
    model.save_pretrained(output_dir)

    return output_dir

def main():
    """Run sequential fine-tuning pipeline."""
    logging.basicConfig(level=logging.INFO)

    print("Setting up model and tokenizer...")
    model, tokenizer = setup_model_and_tokenizer()
    prev_adapter_path = None

    # Train each stage sequentially
    for stage in STAGES:
        adapter_path = train_stage(stage, model, tokenizer, prev_adapter_path)
        prev_adapter_path = adapter_path
        print(f"Completed {stage['name']} stage. Adapter saved to: {adapter_path}\n")

In [14]:
from huggingface_hub import login
login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [15]:
if __name__ == "__main__":
    main()

Setting up model and tokenizer...
Loading model...


config.json:   0%|          | 0.00/826 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.92G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/1.17G [00:00<?, ?B/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/185 [00:00<?, ?B/s]

Loading tokenizer...


tokenizer_config.json:   0%|          | 0.00/50.5k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/73.0 [00:00<?, ?B/s]


Starting pubmedqa stage...
Loading pubmedqa dataset...
Applying LoRA...


FloatProgress(value=0.0, bar_style='info', description='Progress:')

`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.


Training pubmedqa...


Step,Training Loss
1,1.7194
2,1.9597
3,1.6651
4,1.4808
5,1.4275
6,1.585
7,1.2945
8,1.3692
9,1.5666
10,1.4232


Saving adapter to checkpoints/pubmedqa...
Completed pubmedqa stage. Adapter saved to: checkpoints/pubmedqa


Starting medmcqa stage...
Loading medmcqa dataset...
Applying LoRA...


FloatProgress(value=0.0, bar_style='info', description='Progress:')

Training medmcqa...


Step,Training Loss
1,3.2839
2,3.4755
3,3.5223
4,4.3192
5,2.3706
6,2.1449
7,1.748
8,1.3489
9,1.4447
10,1.2101


Saving adapter to checkpoints/medmcqa...
Completed medmcqa stage. Adapter saved to: checkpoints/medmcqa


Starting medqa stage...
Loading medqa dataset...
Applying LoRA...


FloatProgress(value=0.0, bar_style='info', description='Progress:')

Training medqa...


Step,Training Loss
1,2.1215
2,3.7756
3,2.3192
4,2.5197
5,2.6633
6,1.4827
7,0.9635
8,1.1718
9,0.9929
10,0.8447


Step,Training Loss
1,2.1215
2,3.7756
3,2.3192
4,2.5197
5,2.6633
6,1.4827
7,0.9635
8,1.1718
9,0.9929
10,0.8447


Saving adapter to checkpoints/medqa...
Completed medqa stage. Adapter saved to: checkpoints/medqa



## Model Evaluation

In [17]:
from transformers import GenerationConfig
import random
import pandas as pd
from IPython.display import display, HTML

def format_qa_pair(question, predicted, actual):
    """Format Q&A pair for display"""
    return f"""
    <div style="margin-bottom: 20px; padding: 10px; border: 1px solid #ddd; border-radius: 5px;">
        <p><strong>Question:</strong> {question}</p>
        <p><strong>Predicted:</strong> {predicted}</p>
        <p><strong>Actual:</strong> {actual}</p>
    </div>
    """

def evaluate_model(model_path: str, dataset_name: str, tokenizer, n_samples=5):
    """Evaluate model on a specific dataset"""
    print(f"\nEvaluating {dataset_name}...")

    # Load the model with adapter
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        device_map="auto",
        trust_remote_code=True,
    )
    model = PeftModel.from_pretrained(model, model_path)
    model.eval()

    # Load dataset
    stage = next(s for s in STAGES if s["name"] == dataset_name)
    with open(stage["data_path"], 'r') as f:
        data = json.load(f)

    # Sample random examples
    if isinstance(data, dict):
        samples = random.sample(list(data.items()), min(n_samples, len(data)))
    else:
        samples = random.sample(data, min(n_samples, len(data)))

    generation_config = GenerationConfig(
        max_length=512,
        temperature=0.7,
        top_p=0.9,
        do_sample=True,
    )

    results = []
    for idx, sample in enumerate(samples):
        if isinstance(sample, tuple):
            _, sample = sample

        # Format input based on dataset
        if dataset_name == "pubmedqa":
            question = sample["QUESTION"]
            context = "\n".join(sample["CONTEXTS"])
            actual = f"{sample['final_decision']}\nExplanation: {sample['LONG_ANSWER']}"
            prompt = f"Question: {question}\nContext: {context}\nAnswer:"
        elif dataset_name == "medmcqa":
            question = sample["question"]
            actual = sample["exp"]
            prompt = f"Question: {question}\nExplanation:"
        else:  # medqa
            question = sample["question"]
            options = "\n".join([f"{k}) {v}" for k, v in sample["options"].items()])
            actual = sample["options"][sample["answer_idx"]]
            prompt = f"Question: {question}\nOptions:\n{options}\nAnswer:"

        # Generate prediction
        inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=256)
        inputs = {k: v.to(model.device) for k, v in inputs.items()}

        outputs = model.generate(
            **inputs,
            generation_config=generation_config,
        )
        predicted = tokenizer.decode(outputs[0], skip_special_tokens=True)

        # Remove the prompt from prediction
        predicted = predicted[len(prompt):].strip()

        results.append({
            "question": question,
            "predicted": predicted,
            "actual": actual
        })

    # Display results
    html_output = "<div style='max-width: 800px;'>"
    for result in results:
        html_output += format_qa_pair(
            result["question"],
            result["predicted"],
            result["actual"]
        )
    html_output += "</div>"

    display(HTML(html_output))

    return results

# Load tokenizer first
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if not tokenizer.pad_token:
    tokenizer.pad_token = tokenizer.eos_token

# Evaluate each stage
stages_to_evaluate = [
    ("pubmedqa", "checkpoints/pubmedqa/checkpoint-14"),  # Using checkpoint-14 based on your folder structure
    ("medmcqa", "checkpoints/medmcqa/checkpoint-14"),
    ("medqa", "checkpoints/medqa/checkpoint-14")
]

print("\nStarting evaluation...")
for dataset_name, model_path in stages_to_evaluate:
    try:
        print(f"\nEvaluating {dataset_name} using {model_path}")
        results = evaluate_model(model_path, dataset_name, tokenizer)
    except Exception as e:
        print(f"Error evaluating {dataset_name}: {str(e)}")
        import traceback
        traceback.print_exc()

Loading tokenizer...

Starting evaluation...

Evaluating pubmedqa using checkpoints/pubmedqa/checkpoint-14

Evaluating pubmedqa...


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.



Evaluating medmcqa using checkpoints/medmcqa/checkpoint-14

Evaluating medmcqa...


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.



Evaluating medqa using checkpoints/medqa/checkpoint-14

Evaluating medqa...


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


In [18]:
!pip install -q sentence-transformers rouge-score nltk

  Preparing metadata (setup.py) ... [?25l[?25hdone
[33m  DEPRECATION: Building 'rouge-score' using the legacy setup.py bdist_wheel mechanism, which will be removed in a future version. pip 25.3 will enforce this behaviour change. A possible replacement is to use the standardized build interface by setting the `--use-pep517` option, (possibly combined with `--no-build-isolation`), or adding a `pyproject.toml` file to the source tree of 'rouge-score'. Discussion can be found at https://github.com/pypa/pip/issues/6334[0m[33m
[0m  Building wheel for rouge-score (setup.py) ... [?25l[?25hdone


In [21]:
from transformers import GenerationConfig
import random
import pandas as pd
import numpy as np
from IPython.display import display, HTML
import torch
from sentence_transformers import SentenceTransformer
from rouge_score import rouge_scorer
import nltk
from nltk.translate.bleu_score import sentence_bleu
from nltk.tokenize import word_tokenize
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

# Download required NLTK data
nltk.download('punkt', quiet=True)
nltk.download('punkt_tab', quiet=True) # Add this line to download punkt_tab

# Initialize sentence transformer for semantic similarity
semantic_model = SentenceTransformer('pritamdeka/S-PubMedBert-MS-MARCO')
rouge = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)

def safe_divide(a, b):
    """Safely divide two numbers, return 0 if denominator is 0."""
    try:
        return a / b if b != 0 else 0
    except:
        return 0

def calculate_metrics(predicted, actual):
    """Calculate various metrics between predicted and actual text."""
    metrics = {}

    try:
        # Basic text statistics
        metrics['pred_length'] = len(predicted.split())
        metrics['actual_length'] = len(actual.split())
        metrics['length_ratio'] = safe_divide(metrics['pred_length'], metrics['actual_length'])
    except Exception as e:
        print(f"Error calculating text statistics: {str(e)}")
        metrics['pred_length'] = 0
        metrics['actual_length'] = 0
        metrics['length_ratio'] = 0.0

    # Semantic Similarity
    try:
        pred_emb = semantic_model.encode(predicted, convert_to_tensor=True)
        actual_emb = semantic_model.encode(actual, convert_to_tensor=True)
        similarity = torch.nn.functional.cosine_similarity(pred_emb.unsqueeze(0), actual_emb.unsqueeze(0)).item()
        metrics['semantic_similarity'] = similarity
    except Exception as e:
        print(f"Error calculating semantic similarity: {str(e)}")
        metrics['semantic_similarity'] = 0.0

    # Word overlap metrics
    try:
        pred_words = set(word_tokenize(predicted.lower()))
        actual_words = set(word_tokenize(actual.lower()))

        intersection = len(pred_words & actual_words)
        union = len(pred_words | actual_words)

        metrics['word_overlap'] = safe_divide(intersection, union)
        metrics['word_coverage'] = safe_divide(intersection, len(actual_words))
    except Exception as e:
        print(f"Error calculating word overlap: {str(e)}")
        metrics['word_overlap'] = 0.0
        metrics['word_coverage'] = 0.0

    # ROUGE Scores
    try:
        rouge_scores = rouge.score(predicted, actual)
        metrics['rouge1_f1'] = rouge_scores['rouge1'].fmeasure
        metrics['rouge2_f1'] = rouge_scores['rouge2'].fmeasure
        metrics['rougeL_f1'] = rouge_scores['rougeL'].fmeasure

        # Add precision and recall for ROUGE-1
        metrics['rouge1_precision'] = rouge_scores['rouge1'].precision
        metrics['rouge1_recall'] = rouge_scores['rouge1'].recall
    except Exception as e:
        print(f"Error calculating ROUGE scores: {str(e)}")
        metrics['rouge1_f1'] = 0.0
        metrics['rouge2_f1'] = 0.0
        metrics['rougeL_f1'] = 0.0
        metrics['rouge1_precision'] = 0.0
        metrics['rouge1_recall'] = 0.0

    # BLEU Score
    try:
        pred_tokens = word_tokenize(predicted.lower())
        actual_tokens = word_tokenize(actual.lower())
        bleu = sentence_bleu([actual_tokens], pred_tokens, weights=(0.25, 0.25, 0.25, 0.25))
        metrics['bleu'] = bleu
    except Exception as e:
        print(f"Error calculating BLEU score: {str(e)}")
        metrics['bleu'] = 0.0

    # For Yes/No/Maybe questions (PubMedQA)
    pred_lower = predicted.lower()
    actual_lower = actual.lower()

    # Extract first word as answer
    pred_answer = next((word for word in ['yes', 'no', 'maybe'] if word in pred_lower.split()), 'unknown')
    actual_answer = next((word for word in ['yes', 'no', 'maybe'] if word in actual_lower.split()), 'unknown')

    metrics['exact_match'] = 1.0 if pred_answer == actual_answer else 0.0

    return metrics

def format_qa_pair(question, predicted, actual, metrics):
    """Format Q&A pair and metrics for display"""
    metrics_html = "".join([
        f"<tr><td>{k}</td><td>{v:.3f}</td></tr>"
        for k, v in metrics.items()
    ])

    return f"""
    <div style="margin-bottom: 20px; padding: 10px; border: 1px solid #ddd; border-radius: 5px;">
        <p><strong>Question:</strong> {question}</p>
        <p><strong>Predicted:</strong> {predicted}</p>
        <p><strong>Actual:</strong> {actual}</p>
        <div style="margin-top: 10px;">
            <strong>Metrics:</strong>
            <table style="margin-left: 20px;">
                {metrics_html}
            </table>
        </div>
    </div>
    """

def evaluate_model(model_path: str, dataset_name: str, tokenizer, n_samples=5):
    """Evaluate model on a specific dataset"""
    print(f"\nEvaluating {dataset_name}...")

    # Load the model with adapter
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        device_map="auto",
        trust_remote_code=True,
    )
    model = PeftModel.from_pretrained(model, model_path)
    model.eval()

    # Load dataset
    stage = next(s for s in STAGES if s["name"] == dataset_name)
    with open(stage["data_path"], 'r') as f:
        data = json.load(f)

    # Sample random examples
    if isinstance(data, dict):
        samples = random.sample(list(data.items()), min(n_samples, len(data)))
    else:
        samples = random.sample(data, min(n_samples, len(data)))

    generation_config = GenerationConfig(
        max_length=512,
        temperature=0.7,
        top_p=0.9,
        do_sample=True,
    )

    results = []
    for idx, sample in enumerate(samples):
        if isinstance(sample, tuple):
            _, sample = sample

        # Format input based on dataset
        if dataset_name == "pubmedqa":
            question = sample["QUESTION"]
            context = "\n".join(sample["CONTEXTS"])
            actual = f"{sample['final_decision']}\nExplanation: {sample['LONG_ANSWER']}"
            prompt = f"Question: {question}\nContext: {context}\nAnswer:"
        elif dataset_name == "medmcqa":
            question = sample["question"]
            actual = sample["exp"]
            prompt = f"Question: {question}\nExplanation:"
        else:  # medqa
            question = sample["question"]
            options = "\n".join([f"{k}) {v}" for k, v in sample["options"].items()])
            actual = sample["options"][sample["answer_idx"]]
            prompt = f"Question: {question}\nOptions:\n{options}\nAnswer:"

        # Generate prediction
        inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=256)
        inputs = {k: v.to(model.device) for k, v in inputs.items()}

        outputs = model.generate(
            **inputs,
            generation_config=generation_config,
        )
        predicted = tokenizer.decode(outputs[0], skip_special_tokens=True)

        # Remove the prompt from prediction
        predicted = predicted[len(prompt):].strip()

        # Calculate metrics
        metrics = calculate_metrics(predicted, actual)

        results.append({
            "question": question,
            "predicted": predicted,
            "actual": actual,
            "metrics": metrics
        })

    # Calculate average metrics
    avg_metrics = {
        metric: np.mean([r["metrics"][metric] for r in results])
        for metric in results[0]["metrics"].keys()
    }

    # Display results
    html_output = "<div style='max-width: 800px;'>"
    html_output += "<h3>Individual Results:</h3>"
    for result in results:
        html_output += format_qa_pair(
            result["question"],
            result["predicted"],
            result["actual"],
            result["metrics"]
        )

    # Add average metrics
    html_output += "<h3>Average Metrics:</h3>"
    html_output += "<table style='margin-left: 20px; margin-bottom: 20px;'>"
    for metric, value in avg_metrics.items():
        html_output += f"<tr><td><b>{metric}</b></td><td>{value:.3f}</td></tr>"
    html_output += "</table>"
    html_output += "</div>"

    display(HTML(html_output))

    return results

# Load tokenizer first
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if not tokenizer.pad_token:
    tokenizer.pad_token = tokenizer.eos_token

# Evaluate each stage
stages_to_evaluate = [
    ("pubmedqa", "checkpoints/pubmedqa/checkpoint-14"),  # Using checkpoint-14 based on your folder structure
    ("medmcqa", "checkpoints/medmcqa/checkpoint-14"),
    ("medqa", "checkpoints/medqa/checkpoint-14")
]

print("\nStarting evaluation...")
for dataset_name, model_path in stages_to_evaluate:
    try:
        print(f"\nEvaluating {dataset_name} using {model_path}")
        results = evaluate_model(model_path, dataset_name, tokenizer)
    except Exception as e:
        print(f"Error evaluating {dataset_name}: {str(e)}")
        import traceback
        traceback.print_exc()

Loading tokenizer...

Starting evaluation...

Evaluating pubmedqa using checkpoints/pubmedqa/checkpoint-14

Evaluating pubmedqa...


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Error calculating word overlap: 
**********************************************************************
  Resource [93mpunkt_tab[0m not found.
  Please use the NLTK Downloader to obtain the resource:

  [31m>>> import nltk
  >>> nltk.download('punkt_tab')
  [0m
  For more information see: https://www.nltk.org/data.html

  Attempted to load [93mtokenizers/punkt_tab/english/[0m

  Searched in:
    - '/root/nltk_data'
    - '/usr/nltk_data'
    - '/usr/share/nltk_data'
    - '/usr/lib/nltk_data'
    - '/usr/share/nltk_data'
    - '/usr/local/share/nltk_data'
    - '/usr/lib/nltk_data'
    - '/usr/local/lib/nltk_data'
**********************************************************************

Error calculating BLEU score: 
**********************************************************************
  Resource [93mpunkt_tab[0m not found.
  Please use the NLTK Downloader to obtain the resource:

  [31m>>> import nltk
  >>> nltk.download('punkt_tab')
  [0m
  For more information see: https://

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Error calculating word overlap: 
**********************************************************************
  Resource [93mpunkt_tab[0m not found.
  Please use the NLTK Downloader to obtain the resource:

  [31m>>> import nltk
  >>> nltk.download('punkt_tab')
  [0m
  For more information see: https://www.nltk.org/data.html

  Attempted to load [93mtokenizers/punkt_tab/english/[0m

  Searched in:
    - '/root/nltk_data'
    - '/usr/nltk_data'
    - '/usr/share/nltk_data'
    - '/usr/lib/nltk_data'
    - '/usr/share/nltk_data'
    - '/usr/local/share/nltk_data'
    - '/usr/lib/nltk_data'
    - '/usr/local/lib/nltk_data'
**********************************************************************

Error calculating BLEU score: 
**********************************************************************
  Resource [93mpunkt_tab[0m not found.
  Please use the NLTK Downloader to obtain the resource:

  [31m>>> import nltk
  >>> nltk.download('punkt_tab')
  [0m
  For more information see: https://

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Error calculating word overlap: 
**********************************************************************
  Resource [93mpunkt_tab[0m not found.
  Please use the NLTK Downloader to obtain the resource:

  [31m>>> import nltk
  >>> nltk.download('punkt_tab')
  [0m
  For more information see: https://www.nltk.org/data.html

  Attempted to load [93mtokenizers/punkt_tab/english/[0m

  Searched in:
    - '/root/nltk_data'
    - '/usr/nltk_data'
    - '/usr/share/nltk_data'
    - '/usr/lib/nltk_data'
    - '/usr/share/nltk_data'
    - '/usr/local/share/nltk_data'
    - '/usr/lib/nltk_data'
    - '/usr/local/lib/nltk_data'
**********************************************************************

Error calculating BLEU score: 
**********************************************************************
  Resource [93mpunkt_tab[0m not found.
  Please use the NLTK Downloader to obtain the resource:

  [31m>>> import nltk
  >>> nltk.download('punkt_tab')
  [0m
  For more information see: https://

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Error calculating word overlap: 
**********************************************************************
  Resource [93mpunkt_tab[0m not found.
  Please use the NLTK Downloader to obtain the resource:

  [31m>>> import nltk
  >>> nltk.download('punkt_tab')
  [0m
  For more information see: https://www.nltk.org/data.html

  Attempted to load [93mtokenizers/punkt_tab/english/[0m

  Searched in:
    - '/root/nltk_data'
    - '/usr/nltk_data'
    - '/usr/share/nltk_data'
    - '/usr/lib/nltk_data'
    - '/usr/share/nltk_data'
    - '/usr/local/share/nltk_data'
    - '/usr/lib/nltk_data'
    - '/usr/local/lib/nltk_data'
**********************************************************************

Error calculating BLEU score: 
**********************************************************************
  Resource [93mpunkt_tab[0m not found.
  Please use the NLTK Downloader to obtain the resource:

  [31m>>> import nltk
  >>> nltk.download('punkt_tab')
  [0m
  For more information see: https://

0,1
pred_length,35.0
actual_length,39.0
length_ratio,0.897
semantic_similarity,0.921
word_overlap,0.0
word_coverage,0.0
rouge1_f1,0.162
rouge2_f1,0.028
rougeL_f1,0.054
rouge1_precision,0.154

0,1
pred_length,0.0
actual_length,44.0
length_ratio,0.0
semantic_similarity,0.884
word_overlap,0.0
word_coverage,0.0
rouge1_f1,0.0
rouge2_f1,0.0
rougeL_f1,0.0
rouge1_precision,0.0

0,1
pred_length,0.0
actual_length,63.0
length_ratio,0.0
semantic_similarity,0.884
word_overlap,0.0
word_coverage,0.0
rouge1_f1,0.0
rouge2_f1,0.0
rougeL_f1,0.0
rouge1_precision,0.0

0,1
pred_length,0.0
actual_length,65.0
length_ratio,0.0
semantic_similarity,0.88
word_overlap,0.0
word_coverage,0.0
rouge1_f1,0.0
rouge2_f1,0.0
rougeL_f1,0.0
rouge1_precision,0.0

0,1
pred_length,0.0
actual_length,59.0
length_ratio,0.0
semantic_similarity,0.869
word_overlap,0.0
word_coverage,0.0
rouge1_f1,0.0
rouge2_f1,0.0
rougeL_f1,0.0
rouge1_precision,0.0

0,1
pred_length,7.0
actual_length,54.0
length_ratio,0.179
semantic_similarity,0.888
word_overlap,0.0
word_coverage,0.0
rouge1_f1,0.032
rouge2_f1,0.006
rougeL_f1,0.011
rouge1_precision,0.031



Evaluating medmcqa using checkpoints/medmcqa/checkpoint-14

Evaluating medmcqa...


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Error calculating word overlap: 
**********************************************************************
  Resource [93mpunkt_tab[0m not found.
  Please use the NLTK Downloader to obtain the resource:

  [31m>>> import nltk
  >>> nltk.download('punkt_tab')
  [0m
  For more information see: https://www.nltk.org/data.html

  Attempted to load [93mtokenizers/punkt_tab/english/[0m

  Searched in:
    - '/root/nltk_data'
    - '/usr/nltk_data'
    - '/usr/share/nltk_data'
    - '/usr/lib/nltk_data'
    - '/usr/share/nltk_data'
    - '/usr/local/share/nltk_data'
    - '/usr/lib/nltk_data'
    - '/usr/local/lib/nltk_data'
**********************************************************************

Error calculating BLEU score: 
**********************************************************************
  Resource [93mpunkt_tab[0m not found.
  Please use the NLTK Downloader to obtain the resource:

  [31m>>> import nltk
  >>> nltk.download('punkt_tab')
  [0m
  For more information see: https://

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Error calculating word overlap: 
**********************************************************************
  Resource [93mpunkt_tab[0m not found.
  Please use the NLTK Downloader to obtain the resource:

  [31m>>> import nltk
  >>> nltk.download('punkt_tab')
  [0m
  For more information see: https://www.nltk.org/data.html

  Attempted to load [93mtokenizers/punkt_tab/english/[0m

  Searched in:
    - '/root/nltk_data'
    - '/usr/nltk_data'
    - '/usr/share/nltk_data'
    - '/usr/lib/nltk_data'
    - '/usr/share/nltk_data'
    - '/usr/local/share/nltk_data'
    - '/usr/lib/nltk_data'
    - '/usr/local/lib/nltk_data'
**********************************************************************

Error calculating BLEU score: 
**********************************************************************
  Resource [93mpunkt_tab[0m not found.
  Please use the NLTK Downloader to obtain the resource:

  [31m>>> import nltk
  >>> nltk.download('punkt_tab')
  [0m
  For more information see: https://

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Error calculating word overlap: 
**********************************************************************
  Resource [93mpunkt_tab[0m not found.
  Please use the NLTK Downloader to obtain the resource:

  [31m>>> import nltk
  >>> nltk.download('punkt_tab')
  [0m
  For more information see: https://www.nltk.org/data.html

  Attempted to load [93mtokenizers/punkt_tab/english/[0m

  Searched in:
    - '/root/nltk_data'
    - '/usr/nltk_data'
    - '/usr/share/nltk_data'
    - '/usr/lib/nltk_data'
    - '/usr/share/nltk_data'
    - '/usr/local/share/nltk_data'
    - '/usr/lib/nltk_data'
    - '/usr/local/lib/nltk_data'
**********************************************************************

Error calculating BLEU score: 
**********************************************************************
  Resource [93mpunkt_tab[0m not found.
  Please use the NLTK Downloader to obtain the resource:

  [31m>>> import nltk
  >>> nltk.download('punkt_tab')
  [0m
  For more information see: https://

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Error calculating word overlap: 
**********************************************************************
  Resource [93mpunkt_tab[0m not found.
  Please use the NLTK Downloader to obtain the resource:

  [31m>>> import nltk
  >>> nltk.download('punkt_tab')
  [0m
  For more information see: https://www.nltk.org/data.html

  Attempted to load [93mtokenizers/punkt_tab/english/[0m

  Searched in:
    - '/root/nltk_data'
    - '/usr/nltk_data'
    - '/usr/share/nltk_data'
    - '/usr/lib/nltk_data'
    - '/usr/share/nltk_data'
    - '/usr/local/share/nltk_data'
    - '/usr/lib/nltk_data'
    - '/usr/local/lib/nltk_data'
**********************************************************************

Error calculating BLEU score: 
**********************************************************************
  Resource [93mpunkt_tab[0m not found.
  Please use the NLTK Downloader to obtain the resource:

  [31m>>> import nltk
  >>> nltk.download('punkt_tab')
  [0m
  For more information see: https://

0,1
pred_length,35.0
actual_length,212.0
length_ratio,0.165
semantic_similarity,0.971
word_overlap,0.0
word_coverage,0.0
rouge1_f1,0.262
rouge2_f1,0.096
rougeL_f1,0.127
rouge1_precision,0.153

0,1
pred_length,12.0
actual_length,202.0
length_ratio,0.059
semantic_similarity,0.913
word_overlap,0.0
word_coverage,0.0
rouge1_f1,0.061
rouge2_f1,0.009
rougeL_f1,0.043
rouge1_precision,0.032

0,1
pred_length,95.0
actual_length,350.0
length_ratio,0.271
semantic_similarity,0.958
word_overlap,0.0
word_coverage,0.0
rouge1_f1,0.284
rouge2_f1,0.069
rougeL_f1,0.155
rouge1_precision,0.182

0,1
pred_length,76.0
actual_length,116.0
length_ratio,0.655
semantic_similarity,0.966
word_overlap,0.0
word_coverage,0.0
rouge1_f1,0.373
rouge2_f1,0.126
rougeL_f1,0.176
rouge1_precision,0.308

0,1
pred_length,86.0
actual_length,82.0
length_ratio,1.049
semantic_similarity,0.97
word_overlap,0.0
word_coverage,0.0
rouge1_f1,0.376
rouge2_f1,0.083
rougeL_f1,0.176
rouge1_precision,0.39

0,1
pred_length,60.8
actual_length,192.4
length_ratio,0.44
semantic_similarity,0.956
word_overlap,0.0
word_coverage,0.0
rouge1_f1,0.271
rouge2_f1,0.077
rougeL_f1,0.136
rouge1_precision,0.213



Evaluating medqa using checkpoints/medqa/checkpoint-14

Evaluating medqa...


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Error calculating word overlap: 
**********************************************************************
  Resource [93mpunkt_tab[0m not found.
  Please use the NLTK Downloader to obtain the resource:

  [31m>>> import nltk
  >>> nltk.download('punkt_tab')
  [0m
  For more information see: https://www.nltk.org/data.html

  Attempted to load [93mtokenizers/punkt_tab/english/[0m

  Searched in:
    - '/root/nltk_data'
    - '/usr/nltk_data'
    - '/usr/share/nltk_data'
    - '/usr/lib/nltk_data'
    - '/usr/share/nltk_data'
    - '/usr/local/share/nltk_data'
    - '/usr/lib/nltk_data'
    - '/usr/local/lib/nltk_data'
**********************************************************************

Error calculating BLEU score: 
**********************************************************************
  Resource [93mpunkt_tab[0m not found.
  Please use the NLTK Downloader to obtain the resource:

  [31m>>> import nltk
  >>> nltk.download('punkt_tab')
  [0m
  For more information see: https://

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Error calculating word overlap: 
**********************************************************************
  Resource [93mpunkt_tab[0m not found.
  Please use the NLTK Downloader to obtain the resource:

  [31m>>> import nltk
  >>> nltk.download('punkt_tab')
  [0m
  For more information see: https://www.nltk.org/data.html

  Attempted to load [93mtokenizers/punkt_tab/english/[0m

  Searched in:
    - '/root/nltk_data'
    - '/usr/nltk_data'
    - '/usr/share/nltk_data'
    - '/usr/lib/nltk_data'
    - '/usr/share/nltk_data'
    - '/usr/local/share/nltk_data'
    - '/usr/lib/nltk_data'
    - '/usr/local/lib/nltk_data'
**********************************************************************

Error calculating BLEU score: 
**********************************************************************
  Resource [93mpunkt_tab[0m not found.
  Please use the NLTK Downloader to obtain the resource:

  [31m>>> import nltk
  >>> nltk.download('punkt_tab')
  [0m
  For more information see: https://

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Error calculating word overlap: 
**********************************************************************
  Resource [93mpunkt_tab[0m not found.
  Please use the NLTK Downloader to obtain the resource:

  [31m>>> import nltk
  >>> nltk.download('punkt_tab')
  [0m
  For more information see: https://www.nltk.org/data.html

  Attempted to load [93mtokenizers/punkt_tab/english/[0m

  Searched in:
    - '/root/nltk_data'
    - '/usr/nltk_data'
    - '/usr/share/nltk_data'
    - '/usr/lib/nltk_data'
    - '/usr/share/nltk_data'
    - '/usr/local/share/nltk_data'
    - '/usr/lib/nltk_data'
    - '/usr/local/lib/nltk_data'
**********************************************************************

Error calculating BLEU score: 
**********************************************************************
  Resource [93mpunkt_tab[0m not found.
  Please use the NLTK Downloader to obtain the resource:

  [31m>>> import nltk
  >>> nltk.download('punkt_tab')
  [0m
  For more information see: https://

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Error calculating word overlap: 
**********************************************************************
  Resource [93mpunkt_tab[0m not found.
  Please use the NLTK Downloader to obtain the resource:

  [31m>>> import nltk
  >>> nltk.download('punkt_tab')
  [0m
  For more information see: https://www.nltk.org/data.html

  Attempted to load [93mtokenizers/punkt_tab/english/[0m

  Searched in:
    - '/root/nltk_data'
    - '/usr/nltk_data'
    - '/usr/share/nltk_data'
    - '/usr/lib/nltk_data'
    - '/usr/share/nltk_data'
    - '/usr/local/share/nltk_data'
    - '/usr/lib/nltk_data'
    - '/usr/local/lib/nltk_data'
**********************************************************************

Error calculating BLEU score: 
**********************************************************************
  Resource [93mpunkt_tab[0m not found.
  Please use the NLTK Downloader to obtain the resource:

  [31m>>> import nltk
  >>> nltk.download('punkt_tab')
  [0m
  For more information see: https://

0,1
pred_length,0.0
actual_length,16.0
length_ratio,0.0
semantic_similarity,0.881
word_overlap,0.0
word_coverage,0.0
rouge1_f1,0.0
rouge2_f1,0.0
rougeL_f1,0.0
rouge1_precision,0.0

0,1
pred_length,0.0
actual_length,20.0
length_ratio,0.0
semantic_similarity,0.884
word_overlap,0.0
word_coverage,0.0
rouge1_f1,0.0
rouge2_f1,0.0
rougeL_f1,0.0
rouge1_precision,0.0

0,1
pred_length,0.0
actual_length,10.0
length_ratio,0.0
semantic_similarity,0.869
word_overlap,0.0
word_coverage,0.0
rouge1_f1,0.0
rouge2_f1,0.0
rougeL_f1,0.0
rouge1_precision,0.0

0,1
pred_length,0.0
actual_length,14.0
length_ratio,0.0
semantic_similarity,0.868
word_overlap,0.0
word_coverage,0.0
rouge1_f1,0.0
rouge2_f1,0.0
rougeL_f1,0.0
rouge1_precision,0.0

0,1
pred_length,0.0
actual_length,24.0
length_ratio,0.0
semantic_similarity,0.885
word_overlap,0.0
word_coverage,0.0
rouge1_f1,0.0
rouge2_f1,0.0
rougeL_f1,0.0
rouge1_precision,0.0

0,1
pred_length,0.0
actual_length,16.8
length_ratio,0.0
semantic_similarity,0.877
word_overlap,0.0
word_coverage,0.0
rouge1_f1,0.0
rouge2_f1,0.0
rougeL_f1,0.0
rouge1_precision,0.0
