In [3]:
from huggingface_hub import login
login(token="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") 

In [None]:
# --- Notebook Block 1: Setup & configuration ---

from pathlib import Path
import json
import re
from typing import List, Tuple, Optional

from transformers import AutoTokenizer

# >> Change if needed <<
INPUT_JSON_PATH = Path("virtualHome_raw.json")   # your input file
OUTPUT_JSON_PATH = Path("virtualHome_gemma-3-4b-pt.json")

# Make the tokenizer easy to change
TOKENIZER_NAME = "google/gemma-3-4b-pt"  # e.g., "google/gemma-3-4b" or any other HF tokenizer id

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME)

tokenizer_config.json:   0%|          | 0.00/1.16M [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.69M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/33.4M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/35.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/662 [00:00<?, ?B/s]

In [5]:
# --- Notebook Block 2: Utilities: matching & span mapping ---

def find_plan_start(text: str) -> int:
    """
    Returns the character index immediately after the first occurrence of 'Plan:' line.
    If not found, returns 0 (i.e., analyze whole text).
    """
    m = re.search(r"\bPlan:\s*\n", text)
    return m.end() if m else 0

def sequential_find(text: str, needles: List[str], start_at: int) -> List[Tuple[int, int]]:
    """
    Sequentially find each needle in order inside text starting at start_at.
    Returns list of (char_start, char_end) for each needle.
    If any is missing, returns an empty list.
    """
    spans = []
    pos = start_at
    for needle in needles:
        i = text.find(needle, pos)
        if i < 0:
            return []
        spans.append((i, i + len(needle)))
        pos = i + len(needle)
    return spans

def char_to_token_span(text: str, match_span: Tuple[int, int], offsets: List[Tuple[int, int]]) -> Tuple[int, int]:
    """
    Map a character span [start_char, end_char) to token indices [start_tok, end_tok),
    such that tokens[start_tok:end_tok] cover exactly that char span.
    We pick:
      - start_tok: first token whose offset_start <= match_start < offset_end
      - end_tok: first token index AFTER the last token overlapping the span
    Returns (start_tok, end_tok). If not found, raises ValueError.
    """
    m_start, m_end = match_span
    start_tok = None
    end_tok_exclusive = None

    # Find start token
    for ti, (cs, ce) in enumerate(offsets):
        if cs <= m_start < ce:
            start_tok = ti
            break
    if start_tok is None:
        # Edge case: match may start exactly at a boundary where a token offset has ce == cs
        for ti, (cs, ce) in enumerate(offsets):
            if cs == m_start and cs == ce:
                start_tok = ti
                break
    if start_tok is None:
        raise ValueError("Could not map start char to a token index.")

    # Find end token (exclusive): the first token whose offset_start >= m_end
    # If none, it's the length of offsets (span ends at last token)
    for ti, (cs, ce) in enumerate(offsets[start_tok:], start=start_tok):
        if cs >= m_end:
            end_tok_exclusive = ti
            break
    if end_tok_exclusive is None:
        end_tok_exclusive = len(offsets)

    return start_tok, end_tok_exclusive

def encode_with_offsets(text: str):
    """
    Tokenize and return (input_ids, offsets) where offsets is a list of (start_char, end_char).
    """
    enc = tokenizer(
        text,
        return_offsets_mapping=True,
        add_special_tokens=False
    )
    offsets = enc["offset_mapping"]
    return enc["input_ids"], offsets

In [6]:
# --- Notebook Block 3: Per-sample processing ---

def process_sample(sample: dict) -> dict:
    """
    Recomputes start_token_idx and end_token_idx for this sample using the configured tokenizer.
    Searches only the portion of text after 'Plan:\n'.
    Returns a NEW sample dict (does not mutate the input).
    """
    text = sample.get("text", "")
    targets: List[str] = sample.get("tar_eq", [])  # assuming field name is tar_eq
    if not isinstance(targets, list):
        targets = []

    # Locate the region after 'Plan:\n'
    plan_start_char = find_plan_start(text)

    # Sequentially match targets
    match_spans = sequential_find(text, targets, plan_start_char)

    # If any target not found, leave indices empty but keep text & tar_eq
    if not match_spans:
        return {
            **sample,
            "start_token_idx": [],
            "end_token_idx": []
        }

    # Tokenize entire text (we'll map with global char offsets)
    _, offsets = encode_with_offsets(text)

    start_tok_list: List[int] = []
    end_tok_list: List[int] = []

    for span in match_spans:
        try:
            st, et = char_to_token_span(text, span, offsets)
            # The user’s example uses exclusive end? Their example had element-wise pairs;
            # We'll keep indices as token-level [start, end) (exclusive end), which is standard.
            start_tok_list.append(st)
            end_tok_list.append(et)
        except ValueError:
            # If mapping fails, bail out for this sample with empty indices
            return {
                **sample,
                "start_token_idx": [],
                "end_token_idx": []
            }

    # Return updated sample with *new* indices; text and tar_eq unchanged
    return {
        **sample,
        "start_token_idx": start_tok_list,
        "end_token_idx": end_tok_list
    }

In [7]:
# --- Notebook Block 4: Batch processing & save ---

def load_dataset(path: Path) -> List[dict]:
    with path.open("r", encoding="utf-8") as f:
        data = json.load(f)
    if not isinstance(data, list):
        raise ValueError("Input JSON must be a list of samples.")
    return data

def save_dataset(path: Path, data: List[dict]) -> None:
    with path.open("w", encoding="utf-8") as f:
        json.dump(data, f, ensure_ascii=False, indent=4)

def recompute_indices(input_path: Path, output_path: Path) -> None:
    samples = load_dataset(input_path)
    out = []
    missing = 0
    for s in samples:
        new_s = process_sample(s)
        if not new_s.get("start_token_idx") or not new_s.get("end_token_idx"):
            missing += 1
        out.append(new_s)

    save_dataset(output_path, out)
    print(f"Saved {len(out)} samples to: {output_path}")
    if missing:
        print(f"Note: {missing} sample(s) had some targets not found or mapping issues; indices left empty.")

# Run
recompute_indices(INPUT_JSON_PATH, OUTPUT_JSON_PATH)

Saved 247 samples to: virtualHome_gemma-3-4b-pt.json


In [8]:
# --- Notebook Block 5: Sanity check (optional) ---

data = json.load(open(INPUT_JSON_PATH, "r", encoding="utf-8"))
first = process_sample(data[0])
print("tar_eq:", first["tar_eq"])
print("start_token_idx:", first["start_token_idx"])
print("end_token_idx:", first["end_token_idx"])

# Show token strings for each span for visual verification
text = first["text"]
_, offsets = encode_with_offsets(text)
tokens = tokenizer.convert_ids_to_tokens(tokenizer(text, add_special_tokens=False)["input_ids"])

def tokens_for_span(st, et):
    return "".join(text[s:e] for (s, e) in offsets[st:et])

for t, st, et in zip(first["tar_eq"], first["start_token_idx"], first["end_token_idx"]):
    print(f"{t:>12} -> tokens[{st}:{et}] == '{tokens_for_span(st, et)}'")

tar_eq: ['[WALK]', '<home_office>', '[WALK]', '<chair>', '[FIND]', '<chair>', '[SIT]', '<chair>', '[FIND]', '<phone>', '[GRAB]', '<phone>', '[END]']
start_token_idx: [117, 121, 127, 487, 491, 494, 498, 502, 506, 697, 701, 705, 709]
end_token_idx: [121, 126, 131, 490, 494, 497, 502, 505, 509, 700, 705, 708, 712]
      [WALK] -> tokens[117:121] == '[WALK]'
<home_office> -> tokens[121:126] == ' <home_office>'
      [WALK] -> tokens[127:131] == '[WALK]'
     <chair> -> tokens[487:490] == ' <chair>'
      [FIND] -> tokens[491:494] == '[FIND]'
     <chair> -> tokens[494:497] == ' <chair>'
       [SIT] -> tokens[498:502] == '[SIT]'
     <chair> -> tokens[502:505] == ' <chair>'
      [FIND] -> tokens[506:509] == '[FIND]'
     <phone> -> tokens[697:700] == ' <phone>'
      [GRAB] -> tokens[701:705] == '[GRAB]'
     <phone> -> tokens[705:708] == ' <phone>'
       [END] -> tokens[709:712] == '[END]'
