# Checkpoint 1 — Demo Notebook

Interactive walkthrough of the full pipeline:
1. Inspect sampled dataset
2. Visualise triplets
3. Load embeddings
4. Train baseline (or show pre-trained results)
5. Analyse prompt–failure correlations

In [None]:
import sys, json
from pathlib import Path
import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt

PROJECT_ROOT = Path.cwd()
sys.path.insert(0, str(PROJECT_ROOT / 'scripts'))

from utils.io import load_jsonl
from utils.schema import ADHERENCE_LABELS, ERROR_TYPES, LabelRecord
from utils.prompt_features import extract_prompt_features, PROMPT_FEATURE_NAMES

## 1. Inspect sampled dataset

In [None]:
DATA_DIR = PROJECT_ROOT / 'data' / 'sample'

meta = load_jsonl(DATA_DIR / 'metadata.jsonl')
print(f'Total samples: {len(meta)}')
if meta:
    print('\nFirst record:')
    print(json.dumps(meta[0], indent=2))

## 2. Visualise a few triplets

In [None]:
def show_triplet(rec, data_dir):
    orig = Image.open(data_dir / rec['orig_path'])
    edit = Image.open(data_dir / rec['edited_path'])
    fig, axes = plt.subplots(1, 2, figsize=(10, 5))
    axes[0].imshow(orig); axes[0].set_title('Original'); axes[0].axis('off')
    axes[1].imshow(edit); axes[1].set_title('Edited');   axes[1].axis('off')
    fig.suptitle(f'Prompt: {rec["prompt"]}', fontsize=11, y=0.02)
    plt.tight_layout()
    plt.show()

if meta:
    for rec in meta[:3]:
        show_triplet(rec, DATA_DIR)

## 3. Load embeddings

In [None]:
emb_path = DATA_DIR / 'embeddings.npz'
if emb_path.exists():
    npz = np.load(str(emb_path), allow_pickle=True)
    print('Embedding arrays:', list(npz.keys()))
    print('emb_orig shape:', npz['emb_orig'].shape)
    print('emb_edit shape:', npz['emb_edit'].shape)
    print('emb_text shape:', npz['emb_text'].shape)
else:
    print('No embeddings found. Run: python scripts/extract_embeddings.py --data data/sample')

## 4. Labels + baseline results

In [None]:
LABELS_PATH = PROJECT_ROOT / 'data' / 'annotations' / 'labels.jsonl'
labels = load_jsonl(LABELS_PATH)
print(f'Labels collected: {len(labels)}')

if labels:
    df_labels = pd.DataFrame(labels)
    print('\nAdherence distribution:')
    print(df_labels['adherence'].value_counts())

metrics_path = PROJECT_ROOT / 'runs' / 'baseline' / 'metrics.json'
if metrics_path.exists():
    metrics = json.loads(metrics_path.read_text())
    print('\n--- Baseline metrics ---')
    print(json.dumps(metrics, indent=2))

cm_path = PROJECT_ROOT / 'runs' / 'baseline' / 'confusion_matrix.png'
if cm_path.exists():
    img = Image.open(cm_path)
    plt.figure(figsize=(5, 4))
    plt.imshow(img); plt.axis('off'); plt.title('Confusion Matrix')
    plt.show()

## 5. Prompt–failure analysis

In [None]:
analysis_dir = PROJECT_ROOT / 'runs' / 'baseline' / 'analysis'

corr_path = analysis_dir / 'correlations.csv'
if corr_path.exists():
    corr = pd.read_csv(corr_path, index_col=0)
    print('Correlation matrix (prompt features × failure indicators):')
    display(corr.style.background_gradient(cmap='RdBu_r', vmin=-1, vmax=1))

for png_name in ['heatmap_correlations.png', 'error_type_frequency.png',
                  'wordcount_by_adherence.png', 'adherence_distribution.png']:
    p = analysis_dir / png_name
    if p.exists():
        img = Image.open(p)
        plt.figure(figsize=(10, 5))
        plt.imshow(img); plt.axis('off'); plt.title(png_name)
        plt.show()

## 6. Cross-lingual readiness check

In [None]:
from utils.text_encoder import load_translations

trans = load_translations(DATA_DIR / 'translations.csv')
if trans:
    print(f'Translations loaded for {len(trans)} samples:')
    for sid, langs in list(trans.items())[:5]:
        print(f'  {sid}: {langs}')
else:
    print('No translations.csv found (optional for checkpoint 1).')