In [1]:
import sys
import os
import random
from typing import Dict, Tuple, List, Set, Union, Type, Literal
from itertools import product
from dataclasses import dataclass
from collections import Counter
import re
from pathlib import Path
from collections import defaultdict

import pandas as pd
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchmetrics import Accuracy
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader
from torchinfo import summary

from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

import matplotlib.pyplot as plt

from tqdm.auto import tqdm
from pathlib import Path

# --- Importing Formula Class ---
# Go two levels up: from ICTCS_notebooks → theorem_prover_core → project root
project_root = os.path.abspath(os.path.join(os.getcwd(), "..", ".."))
if project_root not in sys.path:
    sys.path.insert(0, project_root)

from logic_utils import Normalizer, Metavariable, Normalizer, CustomTokenizer, assign_embedding_indices, print_tree_with_embeddings
from data_setup import (generate_normalized_dataset, add_new_tautologies_to_dataset, parse_dimacs_files, prepare_formula_dataset, 
                        FormulaDataset, FormulaTreeNode, prepare_balanced_tree_dataloaders, prepare_tree_dataloaders)
from train_utils import (set_seeds, compute_vocab_size, train, save_results, save_model, train_tree, eval_model, analyze_model_errors,
                         evaluate_confusion_matrix, count_parameters)
from models import AsymmetricFocalLoss, TreeLSTMClassifierV1, TreeLSTMClassifierV2

from theorem_prover_core.formula import (Formula, Letter, Falsity, Conjunction, Disjunction, Implication,
                                         Negation, BinaryConnectiveFormula, UnaryConnectiveFormula, bottom)
from theorem_prover_core.sequent import Sequent 
from theorem_prover_core.proofgraph import ProofGraph
from theorem_prover_core.mcts import MCTS
from models import TreeLSTMClassifierV2, TreeLSTMClassifierV3
from logic_utils import FormulaTreeNode, CustomTokenizer, assign_embedding_indices, parse_formula_string

---
#### **Some Tests on Specific Sequents for MCTS**

In [2]:
# --- MCTS Test - Trivial Sequent ---

# Axiom sequent: A |- A
A = Letter(0)
axiom_sequent = Sequent(premises=(A,), conclusion=A)

# Dummy policy (it's non needed a real poliy for an axiom)
def dummy_policy(sequent):
    return [], 1.0  # No moves, maximum value

# Build proof graph and MCTS
pg = ProofGraph(axiom_sequent)
mcts = MCTS(policy_value_fn=dummy_policy, proof_graph=pg, n_playout=10)

# Set MCTS root
mcts.set_root(axiom_sequent)

# Check if MCTS recognises the squent as an axiom
assert mcts._proof_complete, "Root should be promptly marked as proven"

# Confirmation output
print("Test passed: the sequent axiom was recognized correctly.")

Test passed: the sequent axiom was recognized correctly.


In [3]:
# --- Test OR in Premises - Provable sequent ---

# --- Propositional Letter ---
A = Letter(1)

# --- Define the formula and target sequent ---
# Create the formula A |/ bottom
# Define the goal sequent: A \/ bottom |- A
disj = Disjunction(A, bottom)
goal = Sequent(premises=(disj,), conclusion=A)

# --- Construct the proof graph ---
pg = ProofGraph(goal)
root = pg.root

# --- Apply a rule: specifically OR-left to the only premise ---
# We look for the move whose position (pos) is 0, i.e., targeting the first premise
# Apply the rule to generate child sequents
# Add the children to the proof graph under the current root
move = next(m for m in root.sequent.moves() if m.pos == 0)
children = root.sequent.rule(move)
child_nodes = pg.add_children(root, move, children)

# --- Check each generated child sequent: ----
for node in child_nodes:
    if node.sequent.is_axiom():
        pg.set_proved(node)

# --- Print results for inspection ---
print("Root sequent:", root.sequent)
print("Root num_moves:", root.num_moves.n)
print("Child sequents and costs:")
for c in child_nodes:
    print("  ", c.sequent, ", cost:", c.num_moves.n)

Root sequent: A1 ∨ ⊥ ⊢ A1
Root num_moves: 0
Child sequents and costs:
   A1 ⊢ A1 , cost: 0
   ⊥ ⊢ A1 , cost: 0


In [4]:
# --- Test AND in Premises - Provable Sequent ---

# --- Propositional Letters ---
A = Letter(1)
B = Letter(2)

# --- Define the formula and target sequent ---
# We want to prove: A /\ B |- A
# That is, from the conjunction A /\ B, derive A
conj = Conjunction(A, B)
sequent = Sequent(premises=(), conclusion=conj)
target = Sequent(premises=(conj,), conclusion=A)

# --- Construct the proof graph ---
pg = ProofGraph(target)
root = pg.root

# --- Apply a rule: specifically AND-left ---
# Get all possible inference moves for the root sequent
# Select the move that applies to the conjunctive formula A /\ B
# Apply the rule to generate child sequents (premises of the inference)
moves = root.sequent.moves()
move = next(m for m in moves if root.sequent.premises[m.pos] == conj)
children = root.sequent.rule(move)

# Add the resulting child sequents as children of the root in the proof graph
child_nodes = pg.add_children(root, move, children)

# --- Mark as proved if any child is an axiom ---
# A node is considered proved if it matches the axiom schema 
for node in child_nodes:
    if node.sequent.is_axiom():
        pg.set_proved(node)

# --- Print results for inspection ---
print("Root sequent:", root.sequent)
print("Root num_moves:", root.num_moves.n)
print("Child sequents and costs:")
for c in child_nodes:
    print("  ", c.sequent, ", cost:", c.num_moves.n)

Root sequent: A1 ∧ A2 ⊢ A1
Root num_moves: 0
Child sequents and costs:
   A1, A2 ⊢ A1 , cost: 0


---
#### **Load the Tree-LSTM Classifier to guide the MCTS**

In [5]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [6]:
# --- Fit the Tokenizer on dataset ---

# Load the Synthetic Dataset
datapath = "datasets/extended_dataset_with_tautologies.csv"
dataset = pd.read_csv(datapath) 

# Build formulas from strings
formulas = [parse_formula_string(f) for f in dataset["formula"]]

# Fit the Tokenizer
tokenizer = CustomTokenizer()
tokenizer.fit(formulas)

In [7]:
len(dataset)

13000

In [8]:
# --- Load Tree-LSTM model--- 

# Hyperparametrs 
VOCAB_SIZE = compute_vocab_size(tokenizer)
print(f"Vocabulary size (including padding token): {VOCAB_SIZE}")

EMBEDDING_DIM = 32
HIDDEN_SIZE = 128
FC_SIZE = 32

# Seed for reproducibilty
set_seeds()

# Tree-LSTM Model
model = TreeLSTMClassifierV3(vocab_size=VOCAB_SIZE,        
                             embedding_dim=EMBEDDING_DIM,
                             hidden_size=HIDDEN_SIZE,
                             fc_size=FC_SIZE)

# Load in the saved state_dict()
model.load_state_dict(torch.load(f="models/DROP_New_Tree_lstm_with_dropout.pth"))  

# Send model to GPU
model = model.to(device)
model

Vocabulary size (including padding token): 108


TreeLSTMClassifierV3(
  (encoder): TreeLSTMEncoder(
    (embedding): Embedding(108, 32, padding_idx=0)
    (cell): BinaryTreeLSTMCell(
      (W_iou): Linear(in_features=32, out_features=384, bias=True)
      (U_iou): Linear(in_features=256, out_features=384, bias=True)
      (W_f): Linear(in_features=32, out_features=256, bias=True)
      (U_f): Linear(in_features=256, out_features=256, bias=True)
    )
  )
  (fc1): Linear(in_features=128, out_features=32, bias=True)
  (relu): ReLU()
  (dropout): Dropout(p=0.4, inplace=False)
  (fc2): Linear(in_features=32, out_features=1, bias=True)
)

In [9]:
# --- Formula Test with Tokenizer ---

# Test Formula
A = Letter(0)
formula = ~(A & (~ A))

# Build the Formula Tree and Assign Embedding Indices 
tree = FormulaTreeNode(formula)
assign_embedding_indices(tree, tokenizer)

# Model Prdiction
model.eval()
with torch.inference_mode():
    logit = model(tree)
    pred = torch.sigmoid(logit).item()

print(f"\nEstimate probability (tautology) = {pred:.4f} for the formula: {formula}")


Estimate probability (tautology) = 0.9316 for the formula: ¬(A0 ∧ ¬A0)


In [26]:
# --- Tree-LSTM Model Policy ---

from functools import reduce

def build_implication(premises, conclusion):
    """
    Builds a single implication formula from a list of premises and a conclusion.

    This function ensures the formula is in the form:
        (A1 ∧ A2 ∧ ... ∧ An) → B

    This format matches the structure expected by the Tree-LSTM model.

    Args:
        premises (list of Formula): A list of premise formulas.
        conclusion (Formula): The conclusion formula.

    Returns:
        Formula: The implication formula combining premises and conclusion.
    """
    if not premises:
        return conclusion
    elif len(premises) == 1:
        return Implication(premises[0], conclusion)
    else:
        # Use reduce to combine all the premises into one big conjunction:
        conj = reduce(lambda a, b: Conjunction(a, b), premises)
        return Implication(conj, conclusion)


def tree_model_policy_value(sequent: Sequent) -> Tuple[List[Tuple[Sequent.Move, float]], float]:
    """
    Predicts the policy and value for a given sequent using a Tree-LSTM model.

    The value is the estimated probability that the full sequent is a tautology.
    The policy assigns a prior probability to each legal inference move,
    based on the model’s prediction on the subgoal formulas after applying the move.

    Args:
        premises (list of Formula): A list of premise formulas.
        conclusion (Formula): The conclusion formula.

    Returns:
        - List of (move, probability) tuples: prior over moves.
        - Float: value estimate for the full sequent.
    """
    # --- Value ---
    full_formula = build_implication(sequent.premises, sequent.conclusion)
    formula_tree = FormulaTreeNode(full_formula)
    assign_embedding_indices(formula_tree, tokenizer)
    with torch.inference_mode():
        logit = model(formula_tree)
        value = torch.sigmoid(logit).item()

    # --- Policy ---
    moves = sequent.moves()
    move_scores = []

    for move in moves:
        try:
            subgoals = sequent.rule(move)
            if subgoals is None:
                continue
        except Exception:
            continue

        subgoal_probs = []
        for sg in subgoals:
            try:
                sg_formula = build_implication(sg.premises, sg.conclusion)
                sg_tree = FormulaTreeNode(sg_formula)
                assign_embedding_indices(sg_tree, tokenizer)
                with torch.inference_mode():
                    sg_logit = model(sg_tree)
                    sg_prob = torch.sigmoid(sg_logit).item()
                    subgoal_probs.append(sg_prob)
            except Exception as e:
                print(f"[WARN] Subgoal prediction failed for move {move}: {e}")
                continue  # Skip subgoal
                
        # --- Logic Aggregation (following aggrgate_by_connective) ---
        if subgoal_probs:
            f = (
                sequent.conclusion if move.pos == -1
                else sequent.premises[move.pos]
            )

            if isinstance(f, Conjunction):
                score = min(subgoal_probs)
            elif isinstance(f, Disjunction) and move.pos == -1:  
                score = max(subgoal_probs)
            elif isinstance(f, Disjunction) and move.pos >= 0:  
                score = min(subgoal_probs)
            elif isinstance(f, (Implication, Negation)):
                score = subgoal_probs[0]  
            else:
                print(f"[SKIP] Move {move} ignored — unsupported formula type: {type(f).__name__}")
                continue  # Skip the move
                
            move_scores.append((move, score))

    # --- Normalization ---
    if move_scores:
        scores = torch.tensor([s for _, s in move_scores])
        probs = torch.softmax(scores, dim=0).tolist()
        move_priors = list(zip([m for m, _ in move_scores], probs))
        
        # DEBUG:
        assert all(isinstance(mp, tuple) and len(mp) == 2 for mp in move_priors), \
        f"[ERROR] Malformed move_priors: {move_priors}"
    else:
        # There are no moves with valid score, no moves applicable
        return [], value

    return move_priors, value

---
#### **Tests for Model Policy and MCTS Agent**

In [21]:
# --- Tree Model Policy Test ---

A = Letter(0)
form = ~(A & (~A))
seq = Sequent(premises=(), conclusion=form)  # ⊢ A /\ ~ A (taut)

priors, value = tree_model_policy_value(seq)
print("Policy:", priors)
print("Value:", value)

[DEBUG] Sequent:  ⊢ ¬(A0 ∧ ¬A0) → 1 moves
Policy: [(Sequent.Move(pos=-1, param='right'), 1.0)]
Value: 0.9315673112869263


In [12]:
# --- MCTS guided by Tree-LSTM model Test ---

# Propositional letters
A = Letter(1)
B = Letter(2)

# Tautology 1: A -> A
f1 = A.implies(A)

# Tautology 2: (A /\ B) -> A
f2 = (A & B).implies(A)

# Non-Tautology: A -> (A /\ B)
f3 = A.implies(A & B)

# --- Choosen formula ---
chosen = f2  # change with: f1, f2, f3

sequent = Sequent(premises=(), conclusion=chosen)
print(f"\nSequent to prove: {sequent}")


Sequent to prove:  ⊢ A1 ∧ A2 → A1


In [13]:
# --- Build the proof graph ---
proof_graph = ProofGraph(sequent)

# --- Initialize MCTS ---
mcts = MCTS(
    policy_value_fn=tree_model_policy_value,
    proof_graph=proof_graph,
    c_puct=5,                 
    n_playout=1600,            
    verbose=True               
)
mcts.set_root(sequent)

# ---Run MCTS ---
#moves, 
probs = mcts.get_move_probs(sequent)

[MCTS start] root sequent ≡  ⊢ A1 ∧ A2 → A1
[expand] Node:  ⊢ A1 ∧ A2 → A1, Moves: [(Sequent.Move(pos=-1, param='right'), 1.0)]
[MCTS start] root sequent ≡  ⊢ A1 ∧ A2 → A1
[expand] Node: A1 ∧ A2 ⊢ A1, Moves: [(Sequent.Move(pos=0, param='left'), 1.0)]
[MCTS start] root sequent ≡  ⊢ A1 ∧ A2 → A1
Complete proof found during MCTS simulation.


In [14]:
# --- Print the Proof Tree --- 
mcts._root.print_tree()

-  ⊢ A1 ∧ A2 → A1 [proved] (Q=0.17, N=3)
  |_ Move: Sequent.Move(pos=-1, param='right')
    - A1 ∧ A2 ⊢ A1 [proved] (Q=0.50, N=2)
      |_ Move: Sequent.Move(pos=0, param='left')
        - A1, A2 ⊢ A1 [proved] (Q=1.00, N=1)


In [15]:
# --- 3 Playouts are enough to demonstrate the formula ---

A = Letter(0)
B = Letter(1)
formula = (A & B).implies(A)
sequent = Sequent(premises=(), conclusion=formula)

pg = ProofGraph(sequent)

mcts = MCTS(
    policy_value_fn=tree_model_policy_value,  
    proof_graph=pg,
    n_playout=3,
    verbose=True
)

mcts.set_root(sequent)
mcts.get_move_probs(sequent)
proved = mcts._proof_complete
# _, proved = mcts.playout(sequent)

print(f"\nProved: {proved}")
print("\n=== MCTS Tree ===")
mcts._root.print_tree()

[MCTS start] root sequent ≡  ⊢ A0 ∧ A1 → A0
[expand] Node:  ⊢ A0 ∧ A1 → A0, Moves: [(Sequent.Move(pos=-1, param='right'), 1.0)]
[MCTS start] root sequent ≡  ⊢ A0 ∧ A1 → A0
[expand] Node: A0 ∧ A1 ⊢ A0, Moves: [(Sequent.Move(pos=0, param='left'), 1.0)]
[MCTS start] root sequent ≡  ⊢ A0 ∧ A1 → A0
Complete proof found during MCTS simulation.

Proved: True

=== MCTS Tree ===
-  ⊢ A0 ∧ A1 → A0 [proved] (Q=0.29, N=3)
  |_ Move: Sequent.Move(pos=-1, param='right')
    - A0 ∧ A1 ⊢ A0 [proved] (Q=0.58, N=2)
      |_ Move: Sequent.Move(pos=0, param='left')
        - A0, A1 ⊢ A0 [proved] (Q=1.00, N=1)


In [16]:
A = Letter(0)
B = Letter(1)
C = Letter(2)

s1 = Sequent(premises=(A, B), conclusion=A)  # (A /\ B) -> A
s2 = Sequent(premises=(A,), conclusion=B)    #  A -A B
s3 = Sequent(premises=(), conclusion=A)      # |- A 

def sequent_to_formula(sequent):
    if not sequent.premises:
        return sequent.conclusion
    elif len(sequent.premises) == 1:
        return Implication(sequent.premises[0], sequent.conclusion)
    else:
        conj = Conjunction(sequent.premises[0], sequent.premises[1])
        for p in sequent.premises[2:]:
            conj = Conjunction(conj, p)
        return Implication(conj, sequent.conclusion)

for s in [s1, s2, s3]:
    f = sequent_to_formula(s)
    tree = FormulaTreeNode(f)
    assign_embedding_indices(tree, tokenizer)
    with torch.inference_mode():
        value = torch.sigmoid(model(tree)).item()
    print(f"Seq: {s}, Form: {f},  Prob {value}")

Seq: A0, A1 ⊢ A0, Form: A0 ∧ A1 → A0,  Prob 0.15390053391456604
Seq: A0 ⊢ A1, Form: A0 → A1,  Prob 0.0008999303099699318
Seq:  ⊢ A0, Form: A0,  Prob 0.0006949595408514142


---
#### **Load Test Set**

In [17]:
# --- Hyperparameters ---
DATA_PATH = "datasets/extended_dataset_with_tautologies.csv"
BATCH_SIZE = 1
TEST_SIZE = 0.2
SEED = 42

# --- Load Synthetic Dataset and Tokenizer ---
df = pd.read_csv(DATA_PATH)

# --- Split data for the test ---
_, test_loader, tokenizer = prepare_tree_dataloaders(df, test_size=TEST_SIZE, batch_size=BATCH_SIZE, seed=SEED)

#### **Evaluate MCTS Agent**

In [22]:
def evaluate_mcts(test_loader, policy_value_fn, device, n_playout):
    y_true = []
    y_pred = []

    for roots, labels in tqdm(test_loader):
        for root, label in zip(roots, labels):
            formula = root.formula
            sequent = Sequent(premises=(), conclusion=formula)

            pg = ProofGraph(sequent)
            mcts = MCTS(policy_value_fn=policy_value_fn,
                        proof_graph=pg,
                        n_playout=n_playout,
                        verbose=False)

            mcts.set_root(sequent)

            while not mcts._proof_complete:
                move_probs = mcts.get_move_probs(mcts._root.sequent)

                if not move_probs: #or not all(isinstance(mp, tuple) and len(mp) == 2 for mp in move_probs):
                    # print(f"[DEBUG] Invalid move_probs: {move_probs}")
                    break  # No valid moves

                # AlphaGo-style: pick move with highest visit-based probability
                best_move = max(move_probs, key=lambda p: p[1])[0]

                # Update tree with chosen move (advance root)
                mcts.update_with_move(best_move)

            prediction = 1 if mcts._proof_complete else 0
            y_pred.append(prediction)
            y_true.append(int(label.item()))

    accuracy = sum(p == t for p, t in zip(y_pred, y_true)) / len(y_true)
    return accuracy

In [33]:
def evaluate_mcts(test_loader, policy_value_fn, device, n_playout, max_steps=100):
    y_true = []
    y_pred = []

    for roots, labels in tqdm(test_loader):
        for root, label in zip(roots, labels):
            formula = root.formula
            sequent = Sequent(premises=(), conclusion=formula)

            # Initialize proofgraph and MCTS agent
            pg = ProofGraph(sequent)
            mcts = MCTS(
                policy_value_fn=policy_value_fn,
                proof_graph=pg,
                n_playout=n_playout,
                verbose=False
            )

            mcts.set_root(sequent)

            steps = 0
            while not mcts._proof_complete and steps < max_steps:
                move_probs = mcts.get_move_probs(mcts._root.sequent)

                if not move_probs or not all(isinstance(mp, tuple) and len(mp) == 2 for mp in move_probs):
                    # print(f"[STOP] No valid moves for sequent: {mcts._root.sequent}")
                    break

                best_move = max(move_probs, key=lambda p: p[1])[0]
                mcts.update_with_move(best_move)
                steps += 1

            #if steps >= max_steps:
            #    print(f"[WARN] Max steps exceeded for: {sequent}")

            prediction = 1 if mcts._proof_complete else 0
            y_pred.append(prediction)
            y_true.append(int(label.item()))

    accuracy = sum(p == t for p, t in zip(y_pred, y_true)) / len(y_true)
    return accuracy

In [None]:
acc = evaluate_mcts(
    test_loader=test_loader,
    policy_value_fn=tree_model_policy_value,
    device=device,
    n_playout=10 
)

print(f"MCTS Agent Accuracy: {acc:.4f}")

  0%|          | 0/2600 [00:00<?, ?it/s]