# Eva-4B Inference Notebook
Inference with FutureMa/Eva-4B (Qwen3 4B Instruct) on the QEvasion test dataset.

**Classification Labels:**
- `direct`: answers the core question with specific information
- `intermediate`: provides related information but sidesteps the core question
- `fully_evasive`: does not address the question (refusal, redirection, non-response)

## Installation

In [None]:
%pip install transformers accelerate datasets scikit-learn pandas tqdm -q

## Imports & Configuration

In [None]:
import pandas as pd
import torch
import warnings
from tqdm.notebook import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from sklearn.metrics import (
    accuracy_score,
    precision_recall_fscore_support,
    f1_score,
    classification_report,
    confusion_matrix
)

# Suppress attention mask warning
warnings.filterwarnings("ignore", message=".*attention_mask.*")

# Check available GPUs
print(f"Available GPUs: {torch.cuda.device_count()}")
for i in range(torch.cuda.device_count()):
    props = torch.cuda.get_device_properties(i)
    print(f"  GPU {i}: {props.name} ({props.total_memory / 1024**3:.1f} GB)")

In [None]:
# === CONFIGURATION ===
MODEL_NAME = "FutureMa/Eva-4B"  # Qwen3 4B Instruct model
HF_DATASET = "ailsntua/QEvasion"  # Test dataset
OUTPUT_PATH = "eva_4b_predictions.csv"

# Generation parameters
MAX_NEW_TOKENS = 512
TEMPERATURE = 0.1
MAX_INPUT_LENGTH = 3500

# Set to a number for testing, None for full run
LIMIT = None

## Prompt Template & Label Extraction

In [None]:
# System prompt for classification
SYSTEM_PROMPT = "You are an expert political discourse analyst. Classify response clarity."

# User prompt template with new classification labels
USER_TEMPLATE = """Classify the response to the question as one of:
1. direct - answers the core question with specific information
2. intermediate - provides related information but sidesteps the core question
3. fully_evasive - does not address the question (refusal, redirection, non-response)

### Interview Question ###
{interview_question}

### Interview Answer ###
{interview_answer}

### Specific Question ###
{question}

Classification:"""


def extract_label(response: str) -> str:
    """Extract classification label from model response."""
    response_lower = response.lower().strip()
    
    # Check for explicit label mentions (prioritize more specific matches)
    if "fully_evasive" in response_lower or "fully evasive" in response_lower:
        return "fully_evasive"
    elif "intermediate" in response_lower:
        return "intermediate"
    elif "direct" in response_lower:
        return "direct"
    
    # Fallback: check first word/line for label
    first_word = response_lower.split()[0] if response_lower.split() else ""
    if first_word in ["direct", "intermediate", "fully_evasive"]:
        return first_word
    
    return "PARSE_ERROR"


def format_messages(row) -> list:
    """Format a dataset row into conversation messages."""
    return [
        {"role": "system", "content": SYSTEM_PROMPT},
        {"role": "user", "content": USER_TEMPLATE.format(
            interview_question=str(row["interview_question"]),
            interview_answer=str(row["interview_answer"]),
            question=str(row["question"])
        )}
    ]


# Map ground truth labels to new classification scheme
LABEL_MAPPING = {
    "Clear Reply": "direct",
    "Ambivalent": "intermediate",
    "Clear Non-Reply": "fully_evasive",
}

print("Prompt template and helper functions defined")

## Load Model

In [None]:
print("=" * 60)
print(f"Loading model: {MODEL_NAME}")
print("Loading in full precision (float16)...")
print("=" * 60)

# Load tokenizer
print("\nLoading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# Load model in full precision (float16 for GPU efficiency)
print("Loading model...")
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    device_map="auto",
    torch_dtype=torch.float16,
    trust_remote_code=True,
)

print("\nModel loaded successfully!")
print(f"Model device map: {model.hf_device_map if hasattr(model, 'hf_device_map') else 'auto'}")

## Load Test Dataset

In [None]:
# Load test dataset from HuggingFace
print(f"Loading test dataset from: {HF_DATASET}")
test_dataset = load_dataset(HF_DATASET, split="test")

print(f"Loaded {len(test_dataset)} test examples")
print(f"Columns: {test_dataset.column_names}")

# Convert to DataFrame for easier handling
test_df = test_dataset.to_pandas()

if LIMIT:
    print(f"Limiting to first {LIMIT} examples")
    test_df = test_df.head(LIMIT)

# Preview sample
print("\nSample row:")
print(test_df.iloc[0][["question", "clarity_label"]].to_dict())

## Run Inference

In [None]:
results = []
total = len(test_df)

print("\n" + "=" * 60)
print(f"Starting inference on {total} examples")
print("=" * 60 + "\n")

for idx, row in tqdm(test_df.iterrows(), total=total, desc="Evaluating"):
    messages = format_messages(row)
    
    # Apply chat template
    chat_output = tokenizer.apply_chat_template(
        messages,
        tokenize=True,
        add_generation_prompt=True,
        return_tensors="pt",
    )
    
    # Handle both tensor and BatchEncoding returns
    if isinstance(chat_output, torch.Tensor):
        inputs = chat_output
    else:
        inputs = chat_output["input_ids"]
    
    input_len = inputs.shape[-1]
    
    # Skip if input is too long
    if input_len > MAX_INPUT_LENGTH:
        tqdm.write(f"[{idx + 1}/{total}] SKIPPED (input too long: {input_len} tokens)")
        results.append({
            "index": idx,
            "question": row["question"],
            "prediction": "SKIPPED",
            "ground_truth": LABEL_MAPPING.get(row.get("clarity_label", ""), ""),
            "raw_output": f"Input too long: {input_len} tokens"
        })
        continue
    
    try:
        inputs = inputs.to(model.device)
        attention_mask = torch.ones_like(inputs)
        
        with torch.no_grad():
            outputs = model.generate(
                input_ids=inputs,
                attention_mask=attention_mask,
                max_new_tokens=MAX_NEW_TOKENS,
                temperature=TEMPERATURE,
                do_sample=TEMPERATURE > 0,
                use_cache=True,
                pad_token_id=tokenizer.eos_token_id,
                eos_token_id=tokenizer.eos_token_id,
            )
        
        # Decode only the generated part
        full_response = tokenizer.decode(
            outputs[0, input_len:],
            skip_special_tokens=True
        ).strip()
        
        # Extract prediction label
        prediction = extract_label(full_response)
        
        # Map ground truth to new labels
        ground_truth = LABEL_MAPPING.get(row.get("clarity_label", ""), "")
        
        results.append({
            "index": idx,
            "question": row["question"],
            "prediction": prediction,
            "ground_truth": ground_truth,
            "raw_output": full_response,
        })
        
    except Exception as e:
        tqdm.write(f"[{idx + 1}/{total}] ERROR: {e}")
        results.append({
            "index": idx,
            "question": row["question"],
            "prediction": "ERROR",
            "ground_truth": LABEL_MAPPING.get(row.get("clarity_label", ""), ""),
            "raw_output": str(e),
        })
    
    # Clear cache periodically
    if (idx + 1) % 25 == 0:
        torch.cuda.empty_cache()

print("\nInference complete!")

## Save Results

In [None]:
# Save results to CSV
results_df = pd.DataFrame(results)
results_df.to_csv(OUTPUT_PATH, index=False)

# Summary statistics
parse_errors = sum(1 for r in results if r['prediction'] == 'PARSE_ERROR')
skipped = sum(1 for r in results if r['prediction'] == 'SKIPPED')
errors = sum(1 for r in results if r['prediction'] == 'ERROR')
successful = len(results) - parse_errors - skipped - errors

print("\n" + "=" * 60)
print("INFERENCE COMPLETE")
print("=" * 60)

print(f"\nSaved {len(results)} predictions to {OUTPUT_PATH}")

print("\nSummary:")
print(f"  Total:         {len(results)}")
print(f"  Successful:    {successful}")
print(f"  Parse errors:  {parse_errors}")
print(f"  Skipped:       {skipped}")
print(f"  Errors:        {errors}")

print("\nPrediction distribution:")
for label in ["direct", "intermediate", "fully_evasive"]:
    count = sum(1 for r in results if r['prediction'] == label)
    pct = count / len(results) * 100 if results else 0
    print(f"  {label}: {count} ({pct:.1f}%)")

## Evaluation Metrics

In [None]:
# Filter to valid predictions only
VALID_LABELS = {"direct", "intermediate", "fully_evasive"}

y_pred = results_df["prediction"]
y_true = results_df["ground_truth"]

# Create mask for valid predictions
mask = y_pred.isin(VALID_LABELS) & y_true.isin(VALID_LABELS)

y_pred_filtered = y_pred[mask]
y_true_filtered = y_true[mask]

print(f"Valid predictions: {mask.sum()} / {len(mask)}")
print(f"Filtered out: {(~mask).sum()} (PARSE_ERROR, SKIPPED, ERROR, or missing ground truth)")

In [None]:
# Calculate metrics
if len(y_pred_filtered) > 0:
    # Accuracy
    accuracy = accuracy_score(y_true_filtered, y_pred_filtered)
    
    # Precision / Recall / F1 (Macro)
    precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support(
        y_true_filtered,
        y_pred_filtered,
        average="macro",
        zero_division=0
    )
    
    # Weighted F1
    f1_weighted = f1_score(
        y_true_filtered,
        y_pred_filtered,
        average="weighted",
        zero_division=0
    )
    
    print("=" * 60)
    print("EVALUATION METRICS (Eva-4B on QEvasion Test Set)")
    print("=" * 60)
    print(f"\nAccuracy        : {accuracy:.4f}")
    print(f"Precision (Mac) : {precision_macro:.4f}")
    print(f"Recall (Mac)    : {recall_macro:.4f}")
    print(f"F1 (Macro)      : {f1_macro:.4f}")
    print(f"F1 (Weighted)   : {f1_weighted:.4f}")
else:
    print("No valid predictions to evaluate!")

In [None]:
# Detailed classification report
if len(y_pred_filtered) > 0:
    print("\nClassification Report:")
    print("-" * 60)
    print(classification_report(
        y_true_filtered,
        y_pred_filtered,
        labels=list(VALID_LABELS),
        zero_division=0
    ))

In [None]:
# Confusion Matrix
if len(y_pred_filtered) > 0:
    labels = ["direct", "intermediate", "fully_evasive"]
    cm = confusion_matrix(y_true_filtered, y_pred_filtered, labels=labels)
    
    print("\nConfusion Matrix:")
    print("-" * 60)
    print(f"{'':20} {'Predicted':^45}")
    print(f"{'Actual':20} {labels[0]:^15} {labels[1]:^15} {labels[2]:^15}")
    print("-" * 60)
    for i, label in enumerate(labels):
        print(f"{label:20} {cm[i, 0]:^15} {cm[i, 1]:^15} {cm[i, 2]:^15}")
    print("-" * 60)

## Sample Outputs

In [None]:
# Show sample outputs
print("Sample predictions (first 5):")
print("-" * 60)
for i, r in enumerate(results[:5]):
    raw = r.get('raw_output', 'N/A')[:80]
    match = "CORRECT" if r['prediction'] == r['ground_truth'] else "WRONG"
    print(f"{i+1}. Pred: {r['prediction']:15} | GT: {r['ground_truth']:15} | {match}")
    print(f"   Raw: '{raw}...'")
    print()