In [1]:
from huggingface_hub import login
login(token="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") 

In [2]:
# --- Block 1: Setup & configuration ---

from pathlib import Path
import json
import re
from typing import List, Tuple, Optional

from transformers import AutoTokenizer

# Paths
INPUT_JSON_PATH = Path("virtualHome_raw.json")
OUTPUT_JSON_PATH = Path("virtualHome_gemma-3-4b-pt.json")

# Tokenizer (easy to change)
TOKENIZER_NAME = "google/gemma-3-4b-pt"
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME)


In [3]:
# --- Block 2: Tokenization & mapping utilities ---

def encode_with_offsets(text: str):
    enc = tokenizer(text, return_offsets_mapping=True, add_special_tokens=False)
    return enc["input_ids"], enc["offset_mapping"]

def char_to_token_span(match_span: Tuple[int, int], offsets: List[Tuple[int, int]]) -> Tuple[int, int]:
    m_start, m_end = match_span
    start_tok = None
    for ti, (cs, ce) in enumerate(offsets):
        if cs <= m_start < ce:
            start_tok = ti
            break
    if start_tok is None:
        # Very rare boundary case; fallback to first token starting at m_start
        for ti, (cs, ce) in enumerate(offsets):
            if cs == m_start:
                start_tok = ti
                break
    if start_tok is None:
        raise ValueError("Start char not mapped to any token.")

    end_tok_excl = None
    for ti in range(start_tok, len(offsets)):
        cs, ce = offsets[ti]
        if cs >= m_end:
            end_tok_excl = ti
            break
    if end_tok_excl is None:
        end_tok_excl = len(offsets)
    return start_tok, end_tok_excl

def find_plan_start(text: str) -> int:
    m = re.search(r"\bPlan:\s*\n?", text)
    return m.end() if m else 0

def sequential_find(text: str, needles: List[str], start_at: int) -> List[Tuple[int, int]]:
    spans, pos = [], start_at
    for needle in needles:
        i = text.find(needle, pos)
        if i < 0:
            return []
        spans.append((i, i + len(needle)))
        pos = i + len(needle)
    return spans


In [4]:
# --- Block 3: Task parsing & exact Plan matching ---

TASK_HDR_RE = re.compile(r"(Task\s*\d+:)", flags=re.IGNORECASE)

def split_tasks(text: str) -> List[Tuple[str, int, int]]:
    """
    Returns list of (task_text, start_char, end_char) within the original text.
    """
    matches = list(TASK_HDR_RE.finditer(text))
    if not matches:
        return []
    tasks = []
    for i, m in enumerate(matches):
        start = m.start()
        end = matches[i+1].start() if i+1 < len(matches) else len(text)
        tasks.append((text[start:end], start, end))
    return tasks

def extract_plan_block(task_text: str) -> str:
    """
    Returns the substring from 'Plan:' to the end of the task_text.
    (Empirically, Plan is last; if not, this still suffices.)
    """
    m = re.search(r"\bPlan:\s*\n?", task_text)
    if not m:
        return ""
    return task_text[m.end():]

def parse_plan_tags(plan_text: str) -> List[str]:
    """
    Extracts tags in order: tokens like [ACTION] or <object>.
    """
    return re.findall(r"(\[[^\]]+\]|<[^>]+>)", plan_text)

def choose_exact_task(text: str, tar_eq: List[str]) -> Optional[Tuple[str, int]]:
    """
    Returns (kept_task_text, plan_start_offset_within_kept_task) if a task whose Plan
    tags equal tar_eq exactly is found; else None.
    """
    for task_text, _, _ in split_tasks(text):
        plan = extract_plan_block(task_text)
        tags = parse_plan_tags(plan)
        if tags == tar_eq:
            # Plan start offset within the kept task text
            plan_start = find_plan_start(task_text)
            return task_text, plan_start
    return None


In [5]:
# --- Block 4: Per-sample processing ---

def process_sample(sample: dict) -> dict:
    text: str = sample.get("text", "")
    tar_eq: List[str] = list(sample.get("tar_eq", []))
    if not isinstance(tar_eq, list):
        tar_eq = []

    # Try to pick the one and only Task whose Plan exactly equals tar_eq
    choice = choose_exact_task(text, tar_eq)

    if choice is None:
        # No strict match found
        kept_text = text
        strict = [False] * len(tar_eq)
        start_token_idx, end_token_idx = [], []
        # We still return with text unchanged and empty indices
        return {
            **sample,
            "text": kept_text,
            "start_token_idx": start_token_idx,
            "end_token_idx": end_token_idx,
            "strict_match": strict,
        }

    kept_task_text, plan_start = choice
    strict = [True] * len(tar_eq)

    # Tokenize kept task text and map tar_eq spans starting AFTER "Plan:"
    _, offsets = encode_with_offsets(kept_task_text)

    # Compute char spans of tar_eq within the kept task text
    spans = sequential_find(kept_task_text, tar_eq, plan_start)
    if not spans:
        # Shouldn't happen if tags were parsed from this task, but guard anyway
        return {
            **sample,
            "text": kept_task_text,
            "start_token_idx": [],
            "end_token_idx": [],
            "strict_match": [False] * len(tar_eq),
        }

    start_token_idx, end_token_idx = [], []
    try:
        for span in spans:
            st, et = char_to_token_span(span, offsets)
            start_token_idx.append(st)
            end_token_idx.append(et)
    except ValueError:
        # Fallback on mapping issues
        start_token_idx, end_token_idx = [], []
        strict = [False] * len(tar_eq)

    return {
        **sample,
        "text": kept_task_text,
        "start_token_idx": start_token_idx,
        "end_token_idx": end_token_idx,
        "strict_match": strict,
    }


In [6]:
# --- Block 5: Batch processing & save ---

def load_dataset(path: Path) -> List[dict]:
    with path.open("r", encoding="utf-8") as f:
        data = json.load(f)
    if not isinstance(data, list):
        raise ValueError("Input JSON must be a list of samples.")
    return data

def save_dataset(path: Path, data: List[dict]) -> None:
    with path.open("w", encoding="utf-8") as f:
        json.dump(data, f, ensure_ascii=False, indent=4)

def recompute_indices(input_path: Path, output_path: Path) -> None:
    samples = load_dataset(input_path)
    out = []
    no_match = 0
    for s in samples:
        new_s = process_sample(s)
        if not all(new_s.get("strict_match", [])):
            no_match += 1
        out.append(new_s)

    save_dataset(output_path, out)
    print(f"Saved {len(out)} samples to: {output_path}")
    if no_match:
        print(f"Note: {no_match} sample(s) had no exact Task/Plan match or mapping issues (strict_match not all True).")

# Run
recompute_indices(INPUT_JSON_PATH, OUTPUT_JSON_PATH)


Saved 247 samples to: virtualHome_gemma-3-4b-pt.json


In [7]:
# --- Block 6: Sanity check (optional) ---

data = load_dataset(INPUT_JSON_PATH)
if data:
    first_out = process_sample(data[0])
    print("strict_match:", first_out.get("strict_match"))
    print("start_token_idx:", first_out.get("start_token_idx"))
    print("end_token_idx:", first_out.get("end_token_idx"))
    print("\nKept text preview:\n", first_out.get("text")[:400])


strict_match: [True, True, True, True, True, True, True, True, True, True, True, True, True]
start_token_idx: [129, 133, 139, 143, 147, 150, 154, 158, 162, 165, 169, 173, 177]
end_token_idx: [133, 138, 143, 146, 150, 153, 158, 161, 165, 168, 173, 176, 180]

Kept text preview:
 Task 4:
I am in ['dining_room']. The objects I can manipulate are ['check', 'chair', 'mouse', 'cupboard', 'bedroom', 'home_office', 'food_food', 'freezer', 'keyboard', 'faucet', 'mail', 'television', 'novel', 'light', 'couch', 'desk', 'phone', 'computer', 'table', 'dining_room', 'sink', 'bathroom', 'bed'].
Goal:
Pick up phone
Hint:
When the phone rings, I pick up the call and give response to the 
