# ACT-X Dataset Consistency Check

이 노트북은 ACT-X 원본 JSON과 변환된 `*_llada.json`을 각각 로드한 뒤, 
새로 설계한 FIM 레이블 조립 규칙이 두 데이터셋에서 동일한 결과를 만드는지 검증합니다.

In [None]:
from pathlib import Path
import json
from copy import deepcopy
from typing import Dict, List, Tuple
from transformers import AutoTokenizer

BASE_DIR = Path('/home/20223206/ACT-X')
TOKENIZER_DIR = Path('/home/20223206/LLaDA-V/train/llada_v_prepare/files')
QUESTION_TEMPLATE = "What activity is happening in this image?"
ANSWER_BLOCK_SIZE = 20

print('Loading tokenizer...')
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_DIR, trust_remote_code=True)
RESERVED_TOKEN = '<|reserved_token_1|>'
RESERVED_TOKEN_ID = tokenizer.convert_tokens_to_ids(RESERVED_TOKEN)
print('Tokenizer:', tokenizer.__class__.__name__)
print('Reserved token id:', RESERVED_TOKEN_ID)

In [None]:
def load_raw_actx(path: Path) -> List[Dict]:
    payload = json.load(path.open())
    if isinstance(payload, list):
        return payload
    if isinstance(payload, dict):
        return [dict(id=k, **v) for k, v in payload.items()]
    raise ValueError(f'Unsupported payload type: {type(payload)}')


def expand_actx_records(raw_records: List[Dict], image_prefix: str = 'images') -> List[Dict]:
    expanded = []
    for entry in raw_records:
        base_id = entry.get('image_id') or entry.get('id')
        answer = (entry.get('answers') or entry.get('answer') or '').strip()
        explanations = entry.get('explanation') or ['']
        if not isinstance(explanations, list):
            explanations = [explanations]
        cleaned = [exp.strip() for exp in explanations if exp and exp.strip()]
        if not cleaned:
            cleaned = ['']

        for idx, exp in enumerate(cleaned):
            conversations = [
                {'from': 'human', 'value': f"<image>\n{QUESTION_TEMPLATE}"},
                {'from': 'gpt', 'value': ''},
            ]
            expanded.append(
                {
                    'id': base_id if idx == 0 else f"{base_id}_{idx}",
                    'image': str(Path(image_prefix) / entry.get('image_name', f"{base_id}.jpg")),
                    'answer': answer,
                    'explanation': exp,
                    'explanation_index': idx,
                    'conversations': conversations,
                }
            )
    return expanded


def build_fim_label(answer: str, explanation: str) -> Tuple[List[int], str]:
    answer = (answer or '').strip()
    prefix_ids = tokenizer('The answer is ', add_special_tokens=False).input_ids
    answer_ids = tokenizer(answer, add_special_tokens=False).input_ids if answer else []

    block_ids = [RESERVED_TOKEN_ID] * ANSWER_BLOCK_SIZE
    copy_len = min(len(prefix_ids), ANSWER_BLOCK_SIZE)
    block_ids[:copy_len] = prefix_ids[:copy_len]

    remaining = ANSWER_BLOCK_SIZE - copy_len
    if remaining > 0 and answer_ids:
        ans_copy_len = min(len(answer_ids), remaining)
        block_ids[copy_len:copy_len + ans_copy_len] = answer_ids[:ans_copy_len]

    explanation = (explanation or '').strip()
    explanation_ids: List[int] = []
    if explanation:
        if explanation[-1] not in '.!?':
            explanation = explanation + '.'
        because_ids = tokenizer(' because', add_special_tokens=False).input_ids
        expl_ids = tokenizer(explanation, add_special_tokens=False).input_ids
        explanation_ids = because_ids + expl_ids

    label_ids = block_ids + explanation_ids
    label_text = tokenizer.decode(label_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)
    return label_ids, label_text


def attach_labels(records: List[Dict]) -> List[Dict]:
    labeled = []
    for rec in records:
        _, label_text = build_fim_label(rec.get('answer'), rec.get('explanation'))
        new_rec = deepcopy(rec)
        new_rec['conversations'] = deepcopy(rec['conversations'])
        new_rec['conversations'][1]['value'] = label_text
        labeled.append(new_rec)
    return labeled


def compare_datasets(baseline: List[Dict], target: List[Dict]) -> Dict:
    def index(records):
        return {
            (rec['id'], rec.get('explanation_index', 0)): rec
            for rec in records
        }

    base_idx = index(baseline)
    target_idx = index(target)

    missing = set(base_idx) - set(target_idx)
    extra = set(target_idx) - set(base_idx)

    mismatched = []
    for key in sorted(set(base_idx).intersection(target_idx)):
        if base_idx[key]['conversations'][1]['value'] != target_idx[key]['conversations'][1]['value']:
            mismatched.append((key, base_idx[key]['conversations'][1]['value'], target_idx[key]['conversations'][1]['value']))

    return {
        'baseline_count': len(base_idx),
        'target_count': len(target_idx),
        'missing': missing,
        'extra': extra,
        'mismatched': mismatched,
    }


In [None]:
datasets = {
    'train': {
        'raw': BASE_DIR / 'actX_train.json',
        'converted': BASE_DIR / 'actX_train_llada.json',
    },
    'test': {
        'raw': BASE_DIR / 'actX_test.json',
        'converted': BASE_DIR / 'actX_test_llada.json',
    },
}

comparison_summary = {}
for split, paths in datasets.items():
    raw_records = load_raw_actx(paths['raw'])
    expanded_raw = expand_actx_records(raw_records)
    expanded_converted = json.load(paths['converted'].open())

    labeled_raw = attach_labels(expanded_raw)
    labeled_converted = attach_labels(expanded_converted)

    cmp = compare_datasets(labeled_raw, labeled_converted)
    comparison_summary[split] = cmp

    print(f"[{split}] baseline={cmp['baseline_count']} target={cmp['target_count']} mismatched={len(cmp['mismatched'])}")
    if cmp['missing']:
        print('  Missing keys:', list(cmp['missing'])[:3])
    if cmp['extra']:
        print('  Extra keys:', list(cmp['extra'])[:3])
    if cmp['mismatched']:
        key, base_val, target_val = cmp['mismatched'][0]
        print('  Example mismatch:', key)

comparison_summary

In [None]:
# Inspect a couple of assembled labels
sample_ids = ['020934932', '008364309_1', '026558760']
for split, paths in datasets.items():
    records = json.load(paths['converted'].open())
    indexed = {
        (rec['id'], rec.get('explanation_index', 0)): rec
        for rec in records
    }
    print(f"=== {split.upper()} ===")
    for sid in sample_ids:
        matches = [rec for key, rec in indexed.items() if key[0] == sid]
        if not matches:
            continue
        labeled_matches = attach_labels(matches)
        for rec in labeled_matches:
            idx = rec.get('explanation_index', 0)
            print(f"{sid} (exp #{idx}) -> {rec['conversations'][1]['value']}")
    print()
