In [2]:
import os
os.chdir("/data/npl/ICEK/News/SymbolicResoning/LogicLLaMA")

In [3]:
import torch
from functools import partial
from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer, BitsAndBytesConfig
from peft import PeftModel, prepare_model_for_kbit_training
import json
import time

In [4]:
base_model='/data/npl/ViInfographicCaps/Contest/demo_contest/xai/Llama-2-7b-chat-hf' # TODO: fill in with the path to the llama-7b model
prompt_template_path='/data/npl/ICEK/News/SymbolicResoning/LogicLLaMA/data/prompt_templates'
load_in_8bit = True
max_output_len = 128

In [5]:
import huggingface_hub
huggingface_hub.login(token="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")

In [6]:
# Load tokenizer
tokenizer = LlamaTokenizer.from_pretrained(base_model)
tokenizer.add_special_tokens({
    "eos_token": "</s>",
    "bos_token": "<s>",
    "unk_token": "<unk>",
    "pad_token": "<unk>",
})
tokenizer.padding_side = "left"

# ✅ Config tối ưu cho A100
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",              # nf4: tốt hơn fp4
    bnb_4bit_use_double_quant=True,         # giúp nén tốt hơn
    bnb_4bit_compute_dtype=torch.bfloat16   # A100 hỗ trợ tốt
)

# ✅ Load model dùng A100
llama_model = LlamaForCausalLM.from_pretrained(
    base_model,
    quantization_config=bnb_config,
    torch_dtype=torch.bfloat16,           # ưu tiên bfloat16 thay vì float16
    device_map="auto",                    # tự chia lên multi-GPU nếu có
    low_cpu_mem_usage=True,
    trust_remote_code=True,               # nếu dùng repo custom
)

# ✅ Chuẩn bị model để huấn luyện 4bit (nếu cần PEFT)
llama_model = prepare_model_for_kbit_training(llama_model)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [7]:
peft_path='yuan-yang/LogicLLaMA-7b-direct-translate-delta-v0'

model = PeftModel.from_pretrained(
    llama_model,
    peft_path,
    torch_dtype=torch.float16
)
model



PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): LlamaForCausalLM(
      (model): LlamaModel(
        (embed_tokens): Embedding(32000, 4096)
        (layers): ModuleList(
          (0-31): 32 x LlamaDecoderLayer(
            (self_attn): LlamaAttention(
              (q_proj): lora.Linear4bit(
                (base_layer): Linear4bit(in_features=4096, out_features=4096, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=4096, out_features=16, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=16, out_features=4096, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
              )
              (k_proj): lora

In [8]:
merged_model = model.merge_and_unload()
merged_model.save_pretrained("clean_llm")
tokenizer.save_pretrained("clean_llm")



('clean_llm/tokenizer_config.json',
 'clean_llm/special_tokens_map.json',
 'clean_llm/tokenizer.model',
 'clean_llm/added_tokens.json')

In [9]:
from transformers import LlamaForCausalLM, LlamaTokenizer

merged_model = LlamaForCausalLM.from_pretrained(
    "clean_llm",
    torch_dtype=torch.float16,
    device_map="auto"
)

merged_tokenizer = LlamaTokenizer.from_pretrained("clean_llm")

In [10]:
import json
import re

def read_json(path):
    with open(path, 'r', encoding='utf-8') as f:
        data = json.load(f)
    return data

dataset = read_json("/data/npl/ICEK/News/SymbolicResoning/data/hard_samples_v3.json")

In [11]:
dataset[0].get("LLM-FOL", [])

['∀x (Student(x) ∧ (Year1(x) ∨ Year2(x)) ∧ RemainingCreditsGreaterThan(x, 10) → CanWithdrawUpTo3Courses(x)) ∧ (Year3Plus(x) ∧ RemainingCreditsGreaterThan(x, 8) → CanWithdrawUpTo3Courses(x)) ∧ DeductsCredits(x, 0.5)\n\nExplanation:\nThe NL statement is translated into FOL as follows:\n\n∀x (Student(x) ∧ (Year1(x) ∨ Year2(x)) ∧ RemainingCreditsGreaterThan(x, 10) → CanWithdrawUpTo3Courses(x))\n\nThis rule states that for any student x, if they are in Year 1 or Year 2 and have remaining credits greater than 10, they can withdraw up to 3 courses.\n\nSimilarly, for any student x in Year 3 or later, and have remaining cred',
 '∀x (Course(x) → (HasCredits(x, 3) ∨ HasCredits(x, 4) ∨ HasCredits(x, 5))) ∧ ∀y (WithdrawnCourse(y) → ContributesZeroCreditsToGPA(y))',
 '∀x (Student(x) → (RegisteredCredits(x, 12) ∧ (RegisteredCredits(x, 13) ∨ RegisteredCredits(x, 14)) ∧ (RegisteredCredits(x, 15) ∨ RegisteredCredits(x, 16)) ∧ ¬Withdrawals(x))) ∧ ∀y (Withdrawal(y) → ReducesSemesterCredits(y)))',
 '∀x (Ac

In [12]:
def extract_PredicatesIndividuals(sample: dict) -> list:
    premises_nl = sample.get("premises-NL", [])
    questions_nl = sample.get("questions", [])
    premises_fol = sample.get("LLM-FOL", [])
    questions_fol = sample.get("question-FOL", [])

    all_nl = premises_nl + questions_nl
    # Find all predicate names and individuals (constants)
    all_statements = premises_fol + questions_fol

    predicates_entities = []
    
    for stmt in all_statements:
        # Find predicate names
        temp = []
        pred_matches = re.findall(r'([a-zA-Z_]+)\(([^)]+)\)', stmt)
        for pred_name, args in pred_matches:
            temp.append(f"{pred_name}({args})")
        predicates_entities.append(temp)

    return all_nl, list(predicates_entities)

In [13]:
for sample in dataset:
    all_nl, predicates_entities = extract_PredicatesIndividuals(sample)
    print(f"Natural Language: {all_nl}")
    print(f"Predicates Entities: {predicates_entities}")
    break

Natural Language: ['A student can withdraw up to 3 courses per academic year if remaining credits ≥ 10 (Year 1, 2) or ≥ 8 (Year 3+); each withdrawal deducts 0.5 credits from total accumulated credits.', 'Courses have 3, 4, or 5 credits; withdrawn courses contribute 0 credits to semester GPA.', 'Students must register for 12–18 credits per semester; withdrawals reduce semester credits.', 'No regulation limits total withdrawals, but max 3 per year.', 'A student with < 8 accumulated credits cannot withdraw courses.', 'A student (Year 2) has 45 credits, withdrew 2 courses in Year 1 (penalty 1 credit), including C1 (4 credits, withdrawn), C2 (3 credits, withdrawn).', 'In Year 2, semester S1, the student registered for 15 credits, withdrew C3 (5 credits), C4 (3 credits), attempted C5 (4 credits, passed).', 'What is the student’s total accumulated credits after Year 2, semester S1?', 'How many credits were withdrawn in Year 2, semester S1, and can the student withdraw another course in semest

In [None]:
example = """<s>[INST]
### Task: Define the meaning of each FOL predicate individually based on the information in the natural language (NL) statement.

You are given:
- One Natural Language (NL) statement describing a situation, domain, or context.
- A list of First-Order Logic (FOL) predicates extracted from that context.

Please follow these instructions carefully:
1. Interpret the NL statement: Understand the general context and concepts described.
2. Define each FOL predicate: For each predicate in the given list:
  - Explain its meaning explicitly in natural language.
  - Focus only on the specific predicate, using the context provided by the NL statement.
  - Do not invent additional information not mentioned or implied by the NL statement.
3. Use the required output format: For each predicate, output its definition in the following structure:
Predicate ::: Natural Language Description

### Input:
- NL Statement: {input_statement}
- FOL Predicates: {input_predicates}
[/INST]
Output: </s>"""

In [15]:
from transformers import pipeline
pipe = pipeline(task="text-generation", model=merged_model, tokenizer=tokenizer, max_length=2048)
all_nl, predicates_entities = extract_PredicatesIndividuals(dataset[0])
result = pipe(example)
print(result[0]['generated_text'])

Device set to use cuda:0
Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.


<s>[INST]
### Task: For each statement, define the meaning of its associated FOL predicates.

Each block contains:
- A natural language statement.
- A list of predicates derived from that statement.

Write each predicate definition in the format:
Predicate ::: Natural Language Description

### Input:

1. Statement: A student can withdraw up to 3 courses per academic year if remaining credits ≥ 10 (Year 1, 2) or ≥ 8 (Year 3+); each withdrawal deducts 0.5 credits from total accumulated credits.
   Predicates: 
   - Student(x)
   - RemainingCreditsGreaterThan(x, 10)
   - Courses(x)
   - Plus(x)
   - RemainingCreditsGreaterThan(x, 8)
   - DeductsCredits(x, 0.5)

2. Statement: Courses have 3, 4, or 5 credits; withdrawn courses contribute 0 credits to semester GPA.
   Predicates:
   - Course(x)
   - HasCredits(x, 3)
   - HasCredits(x, 4)
   - HasCredits(x, 5)
   - WithdrawnCourse(y)
   - ContributesZeroCreditsToGPA(y)


[/INST]
Output: </s>
1. Statement: A student can withdraw up to 3 course

In [17]:
def extract_predicate_definitions_grouped(llm_output: str) -> list:
    lines = llm_output.splitlines()
    predicate_definitions = []
    current_block = []

    for line in lines:
        line = line.strip()

        # Bắt đầu block mới khi gặp dòng bắt đầu bằng số (ví dụ "1. Statement: ...")
        if re.match(r'^\d+\.\s*Statement:', line):
            # Nếu đã có block trước đó, thêm vào danh sách
            if current_block:
                combined = ''.join(current_block)
                predicate_definitions.append(combined)
                current_block = []
        
        # Dòng định nghĩa predicate
        elif line.startswith("*") and "::: " in line:
            current_block.append(line)

    # Đừng quên block cuối
    if current_block:
        combined = ''.join(current_block)
        predicate_definitions.append(combined)

    return predicate_definitions
llm_output = """
1. Statement: ...

Predicate Definitions:

* Student(x) ::: x is a student.
* RemainingCreditsGreaterThan(x, 10) ::: x has remaining credits greater than 10.
* Courses(x) ::: x is a student.

2. Statement: ...

Predicate Definitions:

* Course(x) ::: x is a course.
* HasCredits(x, 3) ::: x is a course and n is a positive integer.
"""

output = extract_predicate_definitions_grouped(llm_output)
print(output)


['* Student(x) ::: x is a student.* RemainingCreditsGreaterThan(x, 10) ::: x has remaining credits greater than 10.* Courses(x) ::: x is a student.', '* Course(x) ::: x is a course.* HasCredits(x, 3) ::: x is a course and n is a positive integer.']


In [None]:
import re
import json
from transformers import pipeline
from tqdm import tqdm  # ✅ Thêm tqdm vào

# Khởi tạo pipeline từ mô hình bạn đã load sẵn
pipe = pipeline(task="text-generation",
                 model=merged_model,
                 tokenizer=tokenizer,
                 temperature=0.0,
                 do_sample=False,
                #  return_full_text=False,
                   max_length=2048)

# --- HÀM TIỆN ÍCH ---

def extract_predicates(sample: dict):
    premises_nl = sample.get("premises-NL", [])
    questions_nl = sample.get("question", [])
    LLM_fol = sample.get("LLM-FOL", [])
    questions_fol = sample.get("question-FOL", [])

    all_nl = premises_nl + questions_nl
    all_statements = LLM_fol + questions_fol

    predicates_per_statement = []

    for stmt in all_statements:
        temp = []
        pred_matches = re.findall(r'([a-zA-Z_]+)\(([^)]+)\)', stmt)
        for pred_name, args in pred_matches:
            temp.append(f"{pred_name}({args})")
        predicates_per_statement.append(temp)

    return all_nl, predicates_per_statement

def build_prompt(nl_statements, predicate_lists):
    blocks = []
    for i, (nl, preds) in enumerate(zip(nl_statements, predicate_lists), 1):
        if not preds: continue
        pred_block = '\n'.join([f"- {p}" for p in preds])
        blocks.append(f"{i}. Statement: {nl}\n   Predicates:\n{pred_block}")
    
    prompt = "<s>[INST]\n### Task: For each statement, define the meaning of its associated FOL predicates.\n\nEach block contains:\n- A natural language statement.\n- A list of predicates derived from that statement.\n\nWrite each predicate definition in the format:\nPredicate ::: Natural Language Description\n\n### Input:\n\n" + "\n\n".join(blocks) + "\n\n[/INST]\nOutput: </s>"
    return prompt

def extract_predicate_definitions_grouped(llm_output: str) -> list:
    lines = llm_output.splitlines()
    predicate_definitions = []
    current_block = []

    for line in lines:  
        line = line.strip()

        if re.match(r'^\d+\.\s*Statement:', line):
            if current_block:
                predicate_definitions.append(''.join(current_block))
                current_block = []
        elif line.startswith("*") and "::: " in line:
            current_block.append(line)

    if current_block:
        predicate_definitions.append(''.join(current_block))

    return predicate_definitions

# --- MAIN PROCESS ---

for sample in tqdm(dataset, desc="🚀 Generating predicate definitions"):
    nl_statements, predicate_lists = extract_predicates(sample)
    
    if not any(predicate_lists):
        sample["predicate-definition"] = []
        continue

    prompt = build_prompt(nl_statements, predicate_lists)
    llm_output = pipe(prompt)[0]["generated_text"]
    predicate_defs = extract_predicate_definitions_grouped(llm_output)

    sample["predicate-definition"] = predicate_defs
    print(predicate_defs,'\n')

# Ghi ra file mới
with open("/data/npl/ICEK/News/SymbolicResoning/data/updated_hard_samples_v3.json", "w", encoding="utf-8") as f:
    json.dump(dataset, f, indent=2, ensure_ascii=False)

print("✅ Xử lý hoàn tất! File đã được lưu vào dataset_with_predicate_definitions.json")

Device set to use cuda:0
🚀 Generating predicate definitions:   6%|▋         | 1/16 [02:59<44:48, 179.24s/it]

[] 



🚀 Generating predicate definitions:  12%|█▎        | 2/16 [05:24<37:08, 159.18s/it]

[] 



🚀 Generating predicate definitions:  19%|█▉        | 3/16 [08:21<36:17, 167.47s/it]

[] 



🚀 Generating predicate definitions:  25%|██▌       | 4/16 [11:16<34:05, 170.44s/it]

[] 



🚀 Generating predicate definitions:  31%|███▏      | 5/16 [13:36<29:14, 159.48s/it]

[] 



🚀 Generating predicate definitions:  38%|███▊      | 6/16 [16:32<27:28, 164.85s/it]

['* ChoosingMajor(x) ::: Natural Language Description: "x is a student who has chosen a major."* Credits(y) ::: Natural Language Description: "y is a number of credits."* GreaterThanOrEqualTo(y, 30) ::: Natural Language Description: "y is greater than or equal to 30."* GPA(z) ::: Natural Language Description: "z is a number representing a student\'s GPA."* GreaterThanOrEqualTo(z, 2.5) ::: Natural Language Description: "z is greater than or equal to 2.5."* Violations(w) ::: Natural Language Description: "w is a number of violations."* LessThan(w, 2) ::: Natural Language Description: "w is less than 2."* SecondYearStatus(x) ::: Natural Language Description: "x is in their second year of study."', '* FailedCourse(x) ::: Natural Language Description: "x is a course that the student failed."* DeductsCredits(x, 1) ::: Natural Language Description: "x deducts 1 credit."* Violation(y) ::: Natural Language Description: "y is a violation."* ReducesGPA(y, 0.1) ::: Natural Language Description: "y

🚀 Generating predicate definitions:  44%|████▍     | 7/16 [18:31<22:31, 150.17s/it]

[] 



🚀 Generating predicate definitions:  50%|█████     | 8/16 [20:07<17:43, 132.89s/it]

[] 



🚀 Generating predicate definitions:  56%|█████▋    | 9/16 [21:55<14:35, 125.08s/it]

[] 



🚀 Generating predicate definitions:  62%|██████▎   | 10/16 [24:08<12:45, 127.59s/it]You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset


[] 



🚀 Generating predicate definitions:  69%|██████▉   | 11/16 [27:09<11:58, 143.73s/it]

[] 



🚀 Generating predicate definitions:  75%|███████▌  | 12/16 [28:45<08:36, 129.22s/it]

[] 



🚀 Generating predicate definitions:  81%|████████▏ | 13/16 [31:42<07:11, 143.83s/it]

[] 



🚀 Generating predicate definitions:  88%|████████▊ | 14/16 [33:31<04:26, 133.30s/it]

[] 



🚀 Generating predicate definitions:  94%|█████████▍| 15/16 [35:27<02:07, 127.95s/it]

[] 



🚀 Generating predicate definitions: 100%|██████████| 16/16 [38:39<00:00, 144.98s/it]

[] 

✅ Xử lý hoàn tất! File đã được lưu vào dataset_with_predicate_definitions.json





In [17]:
sample = {
    "premises-NL": [
      "If a Python code is well-tested, then the project is optimized.",
      "If a Python code does not follow PEP 8 standards, then it is not well-tested.",
      "All Python projects are easy to maintain.",
      "All Python code is well-tested.",
      "If a Python code follows PEP 8 standards, then it is easy to maintain.",
      "If a Python code is well-tested, then it follows PEP 8 standards.",
      "If a Python project is well-structured, then it is optimized.",
      "If a Python project is easy to maintain, then it is well-tested.",
      "If a Python project is optimized, then it has clean and readable code.",
      "All Python projects are well-structured.",
      "All Python projects have clean and readable code.",
      "There exists at least one Python project that follows best practices.",
      "There exists at least one Python project that is optimized.",
      "If a Python project is not well-structured, then it does not follow PEP 8 standards."
    ],
    "premises-FOL": [
      "∀x (WT(x) → O(x))",
      "∀x (¬PEP8(x) → ¬WT(x))",
      "∀x (EM(x))",
      "∀x (WT(x))",
      "∀x (PEP8(x) → EM(x))",
      "∀x (WT(x) → PEP8(x))",
      "∀x (WS(x) → O(x))",
      "∀x (EM(x) → WT(x))",
      "∀x (O(x) -> CR(x))",
      "∀x (WS(x))",
      "∀x (CR(x))",
      "∃x (BP(x))",
      "∃x (O(x))",
      "∀x (¬WS(x) → ¬PEP8(x))"
    ],
    "questions": [
      "Which conclusion follows with the fewest premises?\nA. If a Python project is not optimized, then it is not well-tested\nB. If all Python projects are optimized, then all Python projects are well-structured\nC. If a Python project is well-tested, then it must be clean and readable\nD. If a Python project is not optimized, then it does not follow PEP 8 standards",
      "Does it follow that if all Python projects are well-structured, then all Python projects are optimized, according to the premises?"
    ],
    "answers": [
      "A",
      "Yes"
    ],
    "idx": [
      [
        1
      ],
      [
        7,
        10
      ]
    ],
    "explanation": [
      "Premise 1 states that if a Python project is well-tested, it is optimized. By logical contraposition, if a project is not optimized, it is not well-tested, supporting option A with the fewest premises. Option B is false because optimization does not imply well-structured projects. Option C follows from premises 4, 1, and 9 but requires more steps. Option D follows from premises 1 and 6 but is less direct than A.",
      "Premise 10 confirms all Python projects are well-structured. Premise 7 states that well-structured projects are optimized, implying all projects are optimized, so the statement that well-structured projects imply optimized projects holds."
    ],
    "LLM-FOL": [
      "∀x (PythonCode(x) ∧ WellTested(x) → OptimizedProject(x))",
      "∀x (PythonCode(x) ∧ ¬PEP8Standards(x) → ¬WellTested(x))",
      "∀x (PythonProject(x) → EasyToMaintain(x))",
      "∀x (PythonCode(x) → WellTested(x))",
      "∀x (PythonCode(x) ∧ FollowsPEP8Standards(x) → EasyToMaintain(x))",
      "∀x (WellTestedPythonCode(x) → FollowsPEP8Standards(x))",
      "∀x (PythonProject(x) ∧ WellStructured(x) → Optimized(x))",
      "∀x (PythonProject(x) ∧ EasyToMaintain(x) → WellTested(x))",
      "∀x (PythonProject(x) ∧ Optimized(x) → (CleanCode(x) ∧ ReadableCode(x)))",
      "∀x (PythonProject(x) → WellStructured(x)) ∧ ¬∃y (Person(y) ∧ Perfect(y))",
      "∀x (PythonProject(x) → (CleanCode(x) ∧ ReadableCode(x)))",
      "∃x (PythonProject(x) ∧ FollowsBestPractices(x))",
      "∃x (PythonProject(x) ∧ Optimized(x))",
      "∀x (PythonProject(x) ∧ ¬WellStructured(x) → ¬FollowsPEP8Standards(x))"
    ],
    "question-FOL": [
      "∀x (PythonProject(x) ∧ ¬Optimized(x) → ¬WellTested(x))",
      "∀x (PythonProject(x) ∧ WellStructured(x) → Optimized(x))"
    ]
  }

In [18]:
nl_statements, predicate_lists = extract_predicates(sample)

prompt = build_prompt(nl_statements, predicate_lists)
llm_output = pipe(prompt)[0]["generated_text"]
predicate_defs = extract_predicate_definitions_grouped(llm_output)

# sample["predicate-definition"] = predicate_defs
print(predicate_defs,'\n')

[] 



In [18]:
llm_output

'<s>[INST]\n### Task: For each statement, define the meaning of its associated FOL predicates.\n\nEach block contains:\n- A natural language statement.\n- A list of predicates derived from that statement.\n\nWrite each predicate definition in the format:\nPredicate ::: Natural Language Description\n\n### Input:\n\n1. Statement: A student can withdraw up to 3 courses per academic year if remaining credits ≥ 10 (Year 1, 2) or ≥ 8 (Year 3+); each withdrawal deducts 0.5 credits from total accumulated credits.\n   Predicates:\n- Student(x)\n- RemainingCreditsGreaterThan(x, 10)\n- Courses(x)\n- Plus(x)\n- RemainingCreditsGreaterThan(x, 8)\n- Courses(x)\n- DeductsCredits(x, 0.5)\n- Student(x)\n- RemainingCreditsGreaterThan(x, 10)\n- Courses(x)\n\n2. Statement: Courses have 3, 4, or 5 credits; withdrawn courses contribute 0 credits to semester GPA.\n   Predicates:\n- Course(x)\n- HasCredits(x, 3)\n- HasCredits(x, 4)\n- HasCredits(x, 5)\n- WithdrawnCourse(y)\n- ContributesZeroCreditsToGPA(y)\n\