# This notebook demonstrates how to do inference with LogicLLaMA and parse the text FOL rule into a syntax tree that can be used elsewhere

# INIT

In [None]:
import torch
from functools import partial
from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer
from peft import PeftModel, prepare_model_for_kbit_training
from utils import TranslationDataPreparer
from generatev2 import llama_batch_generate
import json
import time
import re
from tqdm import tqdm

In [2]:
base_model=r'Z:\WORK\LogicLLaMA\Llama-2-7b-hf' # TODO: fill in with the path to the llama-7b model
prompt_template_path=r'Z:\WORK\LogicLLaMA\data\prompt_templates'
load_in_8bit = True
max_output_len = 256

In [3]:
tokenizer = LlamaTokenizer.from_pretrained(base_model)
tokenizer.padding_side = "left"# Allow batched inference
tokenizer.add_special_tokens({
    "eos_token": "</s>",
    "bos_token": "<s>",
    "unk_token": '<unk>',
    "pad_token": '<unk>',
})  

generation_config = GenerationConfig(
    temperature=0.1,
    top_p=0.75,
    top_k=40,
    num_beams=1
)

llama_model = LlamaForCausalLM.from_pretrained(
    base_model,
    load_in_8bit=load_in_8bit,
    torch_dtype=torch.float16,
    device_map='auto',
)
llama_model = prepare_model_for_kbit_training(llama_model)

The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.
Loading checkpoint shards: 100%|██████████| 2/2 [00:24<00:00, 12.15s/it]


In [None]:
def has_abcd_pattern(s: str) -> bool:
    """
    Returns True if `s` contains, in order, on separate lines:
      - a line starting with "A" 
      - then a line starting with "B"
      - then a line starting with "C"
      - then a line starting with "D"
    """
    # 
    # Explanation of the pattern:
    #  \nA[^\n]*     – a newline + “A” + anything up to the next newline
    #  \nB[^\n]*     – then newline + “B” + anything up to its newline
    #  \nC[^\n]*     – likewise for “C”
    #  \nD[^\n]*:    – then newline + “D” + anything, ending with a colon
    #
    pattern = r"\nA[^\n]*\nB[^\n]*\nC[^\n]*\nD[^\n]*"
    return bool(re.search(pattern, s))
def has_comma_and_pattern(s: str) -> bool:
    """
    Returns True if `s` contains the substring ", and" anywhere.
    """
    # simple regex for a comma + space + "and"
    pattern = r", and"
    return bool(re.search(pattern, s))
def split_question_options(s: str):
    # Capture groups:
    # 1: question (lazy up to the line before A)
    # 2: text after "A"
    # 3: text after "B"
    # 4: text after "C"
    # 5: text after "D" (colon is matched but not included)
    capture = (
        r"^(.*?)\r?\n"       # 1: question (anything up to first newline before A)
        r"A\s*([^\n]*)\r?\n"  # 2: A-line content
        r"B\s*([^\n]*)\r?\n"  # 3: B-line content
        r"C\s*([^\n]*)\r?\n"  # 4: C-line content
        r"D\s*([^\n]*)"      # 5: D-line content (colon out of capture)
    )
    m = re.search(capture, s, flags=re.DOTALL)
    if not m:
        raise ValueError("Failed to parse question/options despite matching the pattern")

    question = m.group(1).strip()
    opts = [m.group(i).strip() for i in range(2, 6)]
    return [question, opts[0], opts[1], opts[2], opts[3]]
def split_double_question(parts):
    return parts.split(", and")
def combine_double_question(parts):
    return "<q>".join(parts)
def combine_question_options(parts):
    """
    Given a list of exactly five strings:
      [question, optionA, optionB, optionC, optionD]
    returns a single string formatted as:

      question
      A optionA
      B optionB
      C optionC
      D optionD
    """
    q, a, b, c, d = parts
    return "\n".join([
        q.strip(),
        f"A {a.strip()}",
        f"B {b.strip()}",
        f"C {c.strip()}",
        f"D {d.strip()}:"
    ])


In [None]:
def retry_fill(fol_list, data_list, generate_fn):
    """
    Repeatedly call `generate_fn` on any positions where fol_list[i] is None,
    pulling the same NL inputs from data_list until no slots remain None.
    """
    none_idxs = [i for i, v in enumerate(fol_list) if v is None]
    while none_idxs:
        print(f"GOT NONE at positions: {none_idxs}")
        retry_input = [data_list[i] for i in none_idxs]
        _, retry_parts = generate_fn(input_str=retry_input)
        # retry_parts is a list of (inp_dict, fol_str)
        for orig_idx, (_, new_fol) in zip(none_idxs, retry_parts):
            fol_list[orig_idx] = new_fol
        none_idxs = [i for i, v in enumerate(fol_list) if v is None]
    return fol_list

# LogicLLaMA Translation

## INIT

In [None]:
peft_path='LogicLLaMA-7b-direct-translate-delta-v0.1'
model = PeftModel.from_pretrained(
    llama_model,
    peft_path,
    torch_dtype=torch.float16
)
model.to('cuda')

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): LlamaForCausalLM(
      (model): LlamaModel(
        (embed_tokens): Embedding(32000, 4096)
        (layers): ModuleList(
          (0-31): 32 x LlamaDecoderLayer(
            (self_attn): LlamaAttention(
              (q_proj): lora.Linear8bitLt(
                (base_layer): Linear8bitLt(in_features=4096, out_features=4096, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=4096, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=4096, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
              )
              (k_proj): lo

In [None]:
data_preparer = TranslationDataPreparer(
    prompt_template_path,
    tokenizer,
    False,
    256 # just a filler number
)

prepare_input = partial(
    data_preparer.prepare_input,
    **{"nl_key": "NL"},
    add_eos_token=False,
    eval_mode=True,
    return_tensors='pt'
)
batch_simple_generate = partial(
    llama_batch_generate,
    llama_model=model,
    data_preparer=data_preparer,
    max_new_tokens=max_output_len,
    generation_config=generation_config,
    prepare_input=prepare_input,
    return_tensors=False
)

Z:\WORK\LogicLLaMA\data\prompt_templates\continuous_correct_prompt_template.json
Z:\WORK\LogicLLaMA\data\prompt_templates\correct_prompt_template.json
Z:\WORK\LogicLLaMA\data\prompt_templates\paraphrase_prompt_template.json
Z:\WORK\LogicLLaMA\data\prompt_templates\translate_prompt_template.json


## Batch

In [None]:


# Set your starting index here
start_idx = 0 

# Load your data
with open(r'Z:\WORK\LogicLLaMA\demo.json', 'r', encoding='utf-8') as f:
    data = json.load(f)

output_path = r'Z:\WORK\LogicLLaMA\demo_v2.json'

# Only loop from the specified index
for idx in tqdm(range(start_idx, len(data)), desc="Processing samples"):
    sample = data[idx]

    premises = sample.get("premises-NL", [])
    raw_questions = sample.get("questions", [])

    # 1) Flatten questions, record where we expanded MCQ vs “, and”
    flat_qs = []
    mcq_positions = []      # list of (start_idx, option_count)
    comma_and_positions = []  # list of start_idx

    for q in raw_questions:
        if has_abcd_pattern(q):
            parts = split_question_options(q)
            mcq_positions.append((len(flat_qs), len(parts)))
            flat_qs.extend(parts)
        elif has_comma_and_pattern(q):
            parts = split_double_question(q)
            comma_and_positions.append(len(flat_qs))
            flat_qs.extend(parts)
        else:
            flat_qs.append(q)

    # 2) Build data_list in one go
    data_list = (
        [{"NL": p} for p in premises] +
        [{"NL": q} for q in flat_qs]
    )
    sep_idx = len(premises)

    # 3) Generate and retry‐fill
    full_str, resp_parts = batch_simple_generate(input_str=data_list)
    llm_fol = [fol for _, fol in resp_parts]
    llm_fol = retry_fill(llm_fol, data_list, batch_simple_generate)

    # 4) Slice out premises vs question‐FOL
    sample['LLM-FOL'] = llm_fol[:sep_idx]
    ques_fol = llm_fol[sep_idx:]

    # 5) Combine back in **reverse** order so earlier splices don't shift later ones
    for start, count in sorted(mcq_positions, reverse=True):
        slice_ = ques_fol[start:start+count]
        merged = combine_question_options(slice_)
        ques_fol[start:start+count] = [merged]

    for start in sorted(comma_and_positions, reverse=True):
        slice_ = ques_fol[start:start+2]
        merged = combine_double_question(slice_)
        ques_fol[start:start+2] = [merged]

    sample['question-FOL'] = ques_fol

    # Save progress after each sample
    with open(output_path, 'w', encoding='utf-8') as f:
        json.dump(data, f, ensure_ascii=False, indent=2)
