In [None]:
import re
import ast
import torch
import json
from collections import OrderedDict
from pathlib import Path
from transformers import AutoModelForCausalLM
from omegaconf import DictConfig
from typing import List, Dict, Any

from utils.data import get_tokenizer  

###############################################################################
#EVALUATE ONE EXAMPLE
###############################################################################

###############################################################################
# 1) Hardcode the data string
###############################################################################
HARD_CODED_STRING = (
    "Init_state: [ 0 : 6 , 1 : 5 , 2 : 3 , 3 : 1 , 4 : 0 ] Stack: [  { 0 : 2 , 1 : 6 , 2 : 3 , 3 : 4 , 4 : 5 } ]"
)


###############################################################################
# 2) Define utility functions for parsing / re-building the string
###############################################################################
def parse_data_string(data_str: str):
    """
    Parse a string of the form:
      Init_state: [ 0 : 6 , 1 : 5 , ... ] Stack: [  { 0 : 2 } , { 1 : 3 , 2 : 5 } ]
    into:
      init_state = {0: 6, 1: 5, ...}
      stack = [ {0: 2}, {1: 3, 2: 5}, ...]
    """
    # --- Extract the init_state portion
    init_pat = r"Init_state:\s*\[(.*?)\]"
    init_match = re.search(init_pat, data_str)
    if not init_match:
        raise ValueError("Could not find 'Init_state: [...]' in string.")
    init_content = init_match.group(1).strip()

    # --- Extract the stack portion
    stack_pat = r"Stack:\s*\[(.*?)\]$"
    stack_match = re.search(stack_pat, data_str)
    if not stack_match:
        raise ValueError("Could not find 'Stack: [...]' in string.")
    stack_content = stack_match.group(1).strip()

    # --- Parse init_state into a dict
    init_state = {}
    for pair_str in init_content.split(","):
        pair_str = pair_str.strip()
        if not pair_str:
            continue
        k_str, v_str = pair_str.split(":")
        k = int(k_str.strip())
        v = int(v_str.strip())
        init_state[k] = v

    # --- Parse the stack (list of dicts)
    raw_dicts = re.findall(r"\{(.*?)\}", stack_content)
    stack_list = []
    for rd in raw_dicts:
        d = {}
        pairs = rd.split(",")
        for p in pairs:
            p = p.strip()
            if not p:
                continue
            k_str, v_str = p.split(":")
            d[int(k_str.strip())] = int(v_str.strip())
        stack_list.append(d)

    return init_state, stack_list


def data_to_string(init_state: Dict[int,int], stack_list: List[Dict[int,int]]) -> str:
    """
    Convert init_state + stack_list back to the same string format:
      "Init_state: [ 0 : 6 , ... ] Stack: [  { 0 : 2 , ...} , { ... } ]"
    """
    # Build init_state string
    init_entries = [f"{k} : {v}" for k, v in init_state.items()]
    init_str = " , ".join(init_entries)

    # Build stack string
    stack_strs = []
    for d in stack_list:
        pairs = [f"{k} : {v}" for k, v in d.items()]
        dict_str = "{ " + " , ".join(pairs) + " }"
        stack_strs.append(dict_str)
    total_stack_str = " , ".join(stack_strs)

    return f"Init_state: [ {init_str} ] Stack: [  {total_stack_str} ]"


###############################################################################
# 3) Cleanup + Command Application
###############################################################################
def cleanup_stack(stack: List[Dict[int,int]]):
    """Remove any empty dictionaries from the beginning of the stack."""
    while stack and not stack[0]:
        stack.pop(0)

def apply_command(op: str, init_state: Dict[int,int], stack: List[Dict[int,int]]):
    """
    Apply a command to init_state & stack.
    This includes:
     - Movement commands: "X D" or "X U"
     - TF, RL, N {...} stubs (you can adapt your own rules here).
    """

    # Movement commands: "X D" or "X U"
    if op.endswith(" D") or op.endswith(" U"):
        parts = op.split()
        if len(parts) != 2:
            print(f"Error: Malformed movement command '{op}'")
            return
        automat_index, direction = parts
        automat_index = int(automat_index)
        if direction == "D":
            if init_state[automat_index] > 0:
                init_state[automat_index] -= 1
            else:
                print(f"Warning: Automat {automat_index} is already at 0.")
        elif direction == "U":
            if init_state[automat_index] < 6:
                init_state[automat_index] += 1
            else:
                print(f"Warning: Automat {automat_index} is already at max 6.")
        return

    # TF (Take First)
    if op == "TF":
        cleanup_stack(stack)
        if not stack or not stack[0]:
            print("Warning: TF on empty stack or empty first dict.")
            return
        first_dict = stack[0]
        # pop the first key-value
        key = next(iter(first_dict))
        value = first_dict.pop(key)
        from collections import OrderedDict
        new_dict = OrderedDict([(key, value)])
        # Insert the new OrderedDict at the beginning
        stack.insert(0, new_dict)
        return

    # RL (Remove if matches init_state)
    if op == "RL":
        cleanup_stack(stack)
        if not stack:
            print("Warning: RL on empty stack.")
            return
        first_dict = stack[0]
        valid = True
        for k, v in first_dict.items():
            if init_state.get(k) != v:
                valid = False
                break
        if valid:
            stack.pop(0)
        else:
            print("Warning: RL mismatch between stack[0] and init_state.")
        return

    # N {...} (Add new dict at the beginning of the stack)
    if op.startswith("N "):
        dict_str = op[2:].strip()  # everything after "N "
        from collections import OrderedDict
        try:
            new_pair = ast.literal_eval(dict_str)
            if not isinstance(new_pair, dict) or not (1 <= len(new_pair) <= 2):
                print("Warning: N operation requires 1-2 key-value pairs.")
                return
            cleanup_stack(stack)
            # Insert a new OrderedDict at the beginning
            stack.insert(0, OrderedDict(new_pair.items()))
        except Exception as e:
            print(f"Warning: error parsing 'N' dictionary: {e}")
        return

    # If we get here, we have an unknown operation
    print(f"Warning: Unknown operation '{op}'.")


###############################################################################
# 4) Load Model & Tokenizer (uses get_tokenizer from your snippet)
###############################################################################
def load_model_and_tokenizer(config) -> (AutoModelForCausalLM, Any):
    """
    Exactly as you specified:
        def load_model_and_tokenizer(config):
            print(f"Loading model from {config.inference.notebook_modelpath}...")
            tokenizer = get_tokenizer(config.tok_data)
            model_dir = Path(config.inference.notebook_modelpath)
            model = AutoModelForCausalLM.from_pretrained(
                model_dir,
                torch_dtype=torch.bfloat16,
                local_files_only=True,
            )
            device = "cuda" if torch.cuda.is_available() else "cpu"
            model.to(device)
            model.eval()
            return model, tokenizer
    """
    print(f"Loading model from {config.inference.notebook_modelpath}...")

    # Load tokenizer using the same function from your original code
    tokenizer = get_tokenizer(config.tok_data)

    # Load model
    model_dir = Path(config.inference.notebook_modelpath)
    model = AutoModelForCausalLM.from_pretrained(
        model_dir,
        torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
        local_files_only=True,
    )

    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.to(device)
    model.eval()

    return model, tokenizer


###############################################################################
# 5) Generate a command from the current data
###############################################################################
def generate_command(model, tokenizer, prompt: str, max_length=128) -> str:
    device = "cuda" if torch.cuda.is_available() else "cpu"
    inputs = tokenizer(prompt, return_tensors="pt").to(device)

    # Remove token_type_ids (if present) before generating
    if "token_type_ids" in inputs:
        del inputs["token_type_ids"]

    prompt_len = inputs["input_ids"].shape[1]

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_length=prompt_len + max_length,
            pad_token_id=tokenizer.eos_token_id,
            num_beams=1,  # greedy
            do_sample=False
        )

    # Decode
    full_text = tokenizer.decode(outputs[0], skip_special_tokens=False)
    generated = full_text[len(prompt):].strip()
    return generated


###############################################################################
# 6) Inference Loop
###############################################################################
def iterative_inference_loop(model, tokenizer, config, max_iters):
    """
    - Parse the HARD_CODED_STRING
    - Repeatedly prompt the model with the current data
    - Apply the command
    - Stop if init_state == original_stack[0] or we hit max_iters
    """

    # 6.1 Parse the original data
    original_init_state, original_stack = parse_data_string(HARD_CODED_STRING)
    if not original_stack:
        print("Error: Original stack is empty. Cannot continue.")
        return

    # The first dictionary in the original stack is our "target" for termination
    target_dict = original_stack[0].copy()

    # 6.2 Create current copies
    current_init_state = original_init_state.copy()
    current_stack = [d.copy() for d in original_stack]

    for i in range(max_iters):
        # Convert current data to string
        current_data_str = data_to_string(current_init_state, current_stack)

        # Build prompt. 
        # If your tokenizer has a bos_token, prepend it. If not, just use the raw.
        bos_token = getattr(tokenizer, "bos_token", "")
        prompt = f"{bos_token} {current_data_str} Command:"

        # 6.3 Generate the command
        command_pred = generate_command(model, tokenizer, prompt)
        # e.g. could be "0 D [EOS]" or "1 U"
        command_pred = command_pred.replace("[EOS]", "").strip()

        print(f"\nIteration={i}, current_data='{current_data_str}'")
        print(f"Model predicted command: '{command_pred}'")

        # 6.4 Apply the command
        apply_command(command_pred, current_init_state, current_stack)

        # 6.5 Check termination condition
        if current_init_state == target_dict:
            print("Termination: 'init_state' matches the first dictionary of the original stack!")
            break
    else:
        print(f"Reached {max_iters} iterations without matching target dict.")


In [None]:
from omegaconf import OmegaConf

def main(config_path: str = "config/config_base.yaml"):
    """
    Loads config from 'config/config_base.yaml',
    then loads model/tokenizer, and performs the iterative inference loop.
    """
    # 1) Load config
    cfg = OmegaConf.load(config_path)

    # 2) Load model + tokenizer (defined in Cell #1)
    model, tokenizer = load_model_and_tokenizer(cfg)

    # 3) Run iterative inference (defined in Cell #1)
    iterative_inference_loop(model, tokenizer, cfg, max_iters=200)

# Actually run it
main()


In [None]:
import re
import ast
import torch
import json
from collections import OrderedDict
from pathlib import Path
from transformers import AutoModelForCausalLM
from typing import List, Dict, Any

from utils.data import get_tokenizer  

###############################################################################
#EVALUATE FROM JSON
###############################################################################

###############################################################################
# 1) Define utility functions for parsing / re-building the string
###############################################################################
def parse_data_string(data_str: str):
    """
    Parse a string of the form:
      Init_state: [ 0 : 6 , 1 : 5 , ... ] Stack: [  { 0 : 2 } , { 1 : 3 , 2 : 5 } ]
    into:
      init_state = {0: 6, 1: 5, ...}
      stack = [ {0: 2}, {1: 3, 2: 5}, ...]
    """
    # --- Extract the init_state portion
    init_pat = r"Init_state:\s*\[(.*?)\]"
    init_match = re.search(init_pat, data_str)
    if not init_match:
        raise ValueError("Could not find 'Init_state: [...]' in string.")
    init_content = init_match.group(1).strip()

    # --- Extract the stack portion
    stack_pat = r"Stack:\s*\[(.*?)\]$"
    stack_match = re.search(stack_pat, data_str)
    if not stack_match:
        raise ValueError("Could not find 'Stack: [...]' in string.")
    stack_content = stack_match.group(1).strip()

    # --- Parse init_state into a dict
    init_state = {}
    for pair_str in init_content.split(","):
        pair_str = pair_str.strip()
        if not pair_str:
            continue
        k_str, v_str = pair_str.split(":")
        k = int(k_str.strip())
        v = int(v_str.strip())
        init_state[k] = v

    # --- Parse the stack (list of dicts)
    raw_dicts = re.findall(r"\{(.*?)\}", stack_content)
    stack_list = []
    for rd in raw_dicts:
        d = {}
        pairs = rd.split(",")
        for p in pairs:
            p = p.strip()
            if not p:
                continue
            k_str, v_str = p.split(":")
            d[int(k_str.strip())] = int(v_str.strip())
        stack_list.append(d)

    return init_state, stack_list


def data_to_string(init_state: Dict[int,int], stack_list: List[Dict[int,int]]) -> str:
    """
    Convert init_state + stack_list back to the same string format:
      "Init_state: [ 0 : 6 , ... ] Stack: [  { 0 : 2 , ...} , { ... } ]"
    """
    # Build init_state string
    init_entries = [f"{k} : {v}" for k, v in init_state.items()]
    init_str = " , ".join(init_entries)

    # Build stack string
    stack_strs = []
    for d in stack_list:
        pairs = [f"{k} : {v}" for k, v in d.items()]
        dict_str = "{ " + " , ".join(pairs) + " }"
        stack_strs.append(dict_str)
    total_stack_str = " , ".join(stack_strs)

    return f"Init_state: [ {init_str} ] Stack: [  {total_stack_str} ]"


###############################################################################
# 2) Cleanup + Command Application
###############################################################################
def cleanup_stack(stack: List[Dict[int,int]]):
    """Remove any empty dictionaries from the beginning of the stack."""
    while stack and not stack[0]:
        stack.pop(0)

def apply_command(op: str, init_state: Dict[int,int], stack: List[Dict[int,int]]):
    """
    Apply a command to init_state & stack.
    This includes:
     - Movement commands: "X D" or "X U"
     - TF, RL, N {...} stubs (you can adapt your own rules here).
    """

    # Movement commands: "X D" or "X U"
    if op.endswith(" D") or op.endswith(" U"):
        parts = op.split()
        if len(parts) != 2:
            print(f"Error: Malformed movement command '{op}'")
            return
        automat_index, direction = parts
        automat_index = int(automat_index)
        if direction == "D":
            if init_state[automat_index] > 0:
                init_state[automat_index] -= 1
            else:
                print(f"Warning: Automat {automat_index} is already at 0.")
        elif direction == "U":
            if init_state[automat_index] < 6:
                init_state[automat_index] += 1
            else:
                print(f"Warning: Automat {automat_index} is already at max 6.")
        return

    # TF (Take First)
    if op == "TF":
        cleanup_stack(stack)
        if not stack or not stack[0]:
            print("Warning: TF on empty stack or empty first dict.")
            return
        first_dict = stack[0]
        # pop the first key-value
        key = next(iter(first_dict))
        value = first_dict.pop(key)
        from collections import OrderedDict
        new_dict = OrderedDict([(key, value)])
        # Insert the new OrderedDict at the beginning
        stack.insert(0, new_dict)
        return

    # RL (Remove if matches init_state)
    if op == "RL":
        cleanup_stack(stack)
        if not stack:
            print("Warning: RL on empty stack.")
            return
        first_dict = stack[0]
        valid = True
        for k, v in first_dict.items():
            if init_state.get(k) != v:
                valid = False
                break
        if valid:
            stack.pop(0)
        else:
            print("Warning: RL mismatch between stack[0] and init_state.")
        return

    # N {...} (Add new dict at the beginning of the stack)
    if op.startswith("N "):
        dict_str = op[2:].strip()  # everything after "N "
        from collections import OrderedDict
        try:
            new_pair = ast.literal_eval(dict_str)
            if not isinstance(new_pair, dict) or not (1 <= len(new_pair) <= 2):
                print("Warning: N operation requires 1-2 key-value pairs.")
                return
            cleanup_stack(stack)
            # Insert a new OrderedDict at the beginning
            stack.insert(0, OrderedDict(new_pair.items()))
        except Exception as e:
            print(f"Warning: error parsing 'N' dictionary: {e}")
        return

    # If we get here, we have an unknown operation
    print(f"Warning: Unknown operation '{op}'.")


###############################################################################
# 3) Load Model & Tokenizer (uses get_tokenizer from your snippet)
###############################################################################
def load_model_and_tokenizer(config) -> (AutoModelForCausalLM, Any):
    """
    Exactly as you specified:
        def load_model_and_tokenizer(config):
            print(f"Loading model from {config.inference.notebook_modelpath}...")
            tokenizer = get_tokenizer(config.tok_data)
            model_dir = Path(config.inference.notebook_modelpath)
            model = AutoModelForCausalLM.from_pretrained(
                model_dir,
                torch_dtype=torch.bfloat16,
                local_files_only=True,
            )
            device = "cuda" if torch.cuda.is_available() else "cpu"
            model.to(device)
            model.eval()
            return model, tokenizer
    """
    print(f"Loading model from {config.inference.notebook_modelpath}...")

    # Load tokenizer using the same function from your original code
    tokenizer = get_tokenizer(config.tok_data)

    # Load model
    model_dir = Path(config.inference.notebook_modelpath)
    model = AutoModelForCausalLM.from_pretrained(
        model_dir,
        torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
        local_files_only=True,
    )

    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.to(device)
    model.eval()

    return model, tokenizer


###############################################################################
# 4) Generate a command from the current data
###############################################################################
def generate_command(model, tokenizer, prompt: str, max_length=128) -> str:
    device = "cuda" if torch.cuda.is_available() else "cpu"
    inputs = tokenizer(prompt, return_tensors="pt").to(device)

    # Remove token_type_ids (if present) before generating
    if "token_type_ids" in inputs:
        del inputs["token_type_ids"]

    prompt_len = inputs["input_ids"].shape[1]

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_length=prompt_len + max_length,
            pad_token_id=tokenizer.eos_token_id,
            num_beams=1,  # greedy
            do_sample=False
        )

    # Decode
    full_text = tokenizer.decode(outputs[0], skip_special_tokens=False)
    generated = full_text[len(prompt):].strip()
    return generated


###############################################################################
# 5) Inference Loop evaluation
###############################################################################
def evaluate_example(data_str: str,
                     model: AutoModelForCausalLM,
                     tokenizer: Any,
                     max_iters: int = 200) -> bool:
    """
    Returns True if the loop terminates by matching init_state to the target_dict
    within max_iters, False otherwise.
    """
    try:
        original_init_state, original_stack = parse_data_string(data_str)
    except Exception as e:
        print(f"[ERROR] parsing example: {e}")
        return False

    if not original_stack:
        print("[ERROR] example has empty original stack")
        return False

    target_dict = original_stack[0].copy()
    current_init_state = original_init_state.copy()
    current_stack = [d.copy() for d in original_stack]

    for i in range(max_iters):
        current_data_str = data_to_string(current_init_state, current_stack)
        bos = getattr(tokenizer, "bos_token", "")
        prompt = f"{bos} {current_data_str} Command:"
        command_pred = generate_command(model, tokenizer, prompt)
        command_pred = command_pred.replace("[EOS]", "").strip()
        apply_command(command_pred, current_init_state, current_stack)
        if current_init_state == target_dict:
            return True

    return False


In [None]:
from omegaconf import OmegaConf
import os    

def main(config_path: str = "config/config_base.yaml"):
    # --- Load config, model, tokenizer as before ---
    cfg = OmegaConf.load(config_path)
    model, tokenizer = load_model_and_tokenizer(cfg)

    # --- Load JSON of examples ---
    # *** ADDED: specify your JSON file here ***
    json_file = "data/test_only_first.json"
    with open(json_file, "r", encoding="utf-8") as jf:
        examples = json.load(jf)

    passed, failed = 0, 0

    # --- Iterate through each example and evaluate ---
    for idx, ex in enumerate(examples):
        text = ex.get("text", "")
        result = evaluate_example(text, model, tokenizer, max_iters=130)
        if result:
            passed += 1
        else:
            failed += 1
        print(f"Example {idx+1}/{len(examples)}: {'PASSED' if result else 'FAILED'}")

    # --- Summary ---
    print(f"\nTotal examples: {len(examples)}")
    print(f"✔ Passed: {passed}")
    print(f"✘ Failed: {failed}")

if __name__ == "__main__":
    main()