In [40]:
# !pip install -r requirements.txt

In [41]:
import sys

sys.path.insert(0, '/Users/jq23948/Documents/GFLOWNET-ILP')

import numpy as np
from src.logic_structures import get_initial_state, theory_to_string
from src.logic_engine import LogicEngine, Example
from src.reward import RewardCalculator
from src.graph_encoder_enhanced import EnhancedGraphConstructor, EnhancedStateEncoder
from src.gflownet_models import HierarchicalGFlowNet
from src.training import GFlowNetTrainer
from src.exploration import get_combined_strategy
from src.visualization import TrainingVisualizer
from matplotlib import pyplot as plt



In [42]:
# Problem setup
background_facts = [
    Example('parent', ('alice', 'bob')),
    Example('parent', ('bob', 'charlie')),
    Example('parent', ('eve', 'frank')),
    Example('parent', ('frank', 'grace')),
    Example('parent', ('diana', 'henry')),
    Example('parent', ('henry', 'irene')),
    Example('parent', ('grace', 'jack'))
]


positive_examples = [
    Example('grandparent', ('alice', 'charlie')),
    Example('grandparent', ('eve', 'grace')),
    Example('grandparent', ('diana', 'irene')),
    Example('grandparent', ('frank', 'jack'))
]

negative_examples = [
    Example('grandparent', ('alice', 'alice')),
    Example('grandparent', ('bob', 'bob')),
    Example('grandparent', ('alice', 'eve')),
    Example('grandparent', ('bob', 'frank')),
    Example('grandparent', ('eve', 'frank')),
]

predicate_vocab = ['parent']
predicate_arities = {'parent': 2}

In [43]:
# Save configuration
config = {
    'problem': 'grandparent',
    'predicate_vocab': predicate_vocab,
    'predicate_arities': predicate_arities,

    'logic_engine_max_depth': 10,
    'num_episodes': 10000,
    'embedding_dim': 32,
    'hidden_dim': 64,
    'num_layers_encoder': 2,
    'learning_rate': 1e-4,
    'max_body_length': 4,

    'use_sophisticated_backward': True,

    'use_f1': True,
    'weight_precision': 0.5,
    'weight_recall': 0.5,
    'weight_simplicity': 0.05,
    'disconnected_var_penalty': 0.2,
    'self_loop_penalty': 0.3,
    'free_var_penalty': 1.0,

    'use_detailed_balance': True,

    'use_replay_buffer': True,
    'replay_probability': 0.5,
    'replay_buffer_capacity': 50,
    'buffer_reward_threshold': 0.5,

    'reward_weighted_loss': False,
    'reward_scale_alpha': 10.0,

    'num_background_facts': len(background_facts),
    'num_positive_examples': len(positive_examples),
    'num_negative_examples': len(negative_examples),
}


In [44]:
print("="*80)
print("METHOD DEMONSTRATION")
print("="*80)
print("\nGoal: Learn grandparent(X, Y) rule from examples")
print(f"\nBackground Knowledge ({len(background_facts)} facts):")
for fact in background_facts:
    print(f"  {fact.predicate_name}({', '.join(fact.args)})")

print(f"\nPositive Examples ({len(positive_examples)}):")
for ex in positive_examples:
    print(f"  {ex.predicate_name}({', '.join(ex.args)})")

print(f"\nNegative Examples ({len(negative_examples)}):")
for ex in negative_examples:
    print(f"  {ex.predicate_name}({', '.join(ex.args)})")

METHOD DEMONSTRATION

Goal: Learn grandparent(X, Y) rule from examples

Background Knowledge (7 facts):
  parent(alice, bob)
  parent(bob, charlie)
  parent(eve, frank)
  parent(frank, grace)
  parent(diana, henry)
  parent(henry, irene)
  parent(grace, jack)

Positive Examples (4):
  grandparent(alice, charlie)
  grandparent(eve, grace)
  grandparent(diana, irene)
  grandparent(frank, jack)

Negative Examples (5):
  grandparent(alice, alice)
  grandparent(bob, bob)
  grandparent(alice, eve)
  grandparent(bob, frank)
  grandparent(eve, frank)


In [45]:
logic_engine = LogicEngine(max_depth=config['logic_engine_max_depth'], background_facts=background_facts)
reward_calc = RewardCalculator(
    logic_engine,
    weight_precision=config['weight_precision'],      # Penalize false positives (covering negatives)
    weight_recall=config["weight_recall"],          # Penalize false negatives (missing positives)
    weight_simplicity=config['weight_simplicity'],      # Small penalty for longer rules
    disconnected_var_penalty=config['disconnected_var_penalty'],
    self_loop_penalty= config['self_loop_penalty'],        # Moderate penalty for self-loops
    free_var_penalty=config['free_var_penalty'],
    use_f1=config['use_f1']                 # Use F1-score for balanced precision-recall
)
graph_constructor = EnhancedGraphConstructor(config['predicate_vocab'])
state_encoder = EnhancedStateEncoder(
    predicate_vocab_size=len(config['predicate_vocab']),
    embedding_dim=config['embedding_dim'],
    num_layers=config['num_layers_encoder']
)
gflownet = HierarchicalGFlowNet(
    embedding_dim=config['embedding_dim'],
    num_predicates=len(config['predicate_vocab']),
    hidden_dim=config['hidden_dim'],
    use_sophisticated_backward=config['use_sophisticated_backward'],
    predicate_vocab=config['predicate_vocab']
)


# exploration = get_combined_strategy("aggressive")


trainer = GFlowNetTrainer(
    state_encoder=state_encoder,
    gflownet=gflownet,
    graph_constructor=graph_constructor,
    reward_calculator=reward_calc,
    predicate_vocab=config['predicate_vocab'],
    predicate_arities=config['predicate_arities'],
    learning_rate=config['learning_rate'],
    exploration_strategy=None,  # No exploration strategy for demo
    use_detailed_balance=config['use_detailed_balance'],
    use_replay_buffer=config['use_replay_buffer'],
    replay_buffer_capacity=config['replay_buffer_capacity'],
    reward_weighted_loss=config['reward_weighted_loss'],
    replay_probability=config['replay_probability'],
    max_body_length=config['max_body_length'],
    buffer_reward_threshold=config['buffer_reward_threshold'],
    reward_scale_alpha=config['reward_scale_alpha']
    
)

# Initialize visualizer
visualizer = TrainingVisualizer(
    experiment_name=config['problem'],
    output_dir="results"
)


visualizer.save_config(config)

Saving results to: results/run_20251021_134418
✓ Saved configuration to results/run_20251021_134418/config.json


# Contrastive Pre-Training

**Problem:** The graph encoder initially produces very similar embeddings for all rules, even those with different semantics. This makes it hard for the GFlowNet to learn which states lead to good rewards.

**Solution:** Pre-train the encoder using contrastive learning BEFORE GFlowNet training:
- **Positive pairs**: Same rule with renamed variables → should have SIMILAR embeddings
- **Negative pairs**: Different variable connections → should have DIFFERENT embeddings

This teaches the encoder to distinguish structural differences in rules.

In [46]:
from contrastive_pretraining import ContrastivePreTrainer, generate_base_rules
from sklearn.metrics.pairwise import cosine_similarity
from src.logic_structures import Rule, Atom, Variable

print("=" * 80)
print("CONTRASTIVE PRE-TRAINING")
print("=" * 80)

# Helper function to create test rules
def create_test_rule(head_pred, head_args, body_atoms_list):
    head_vars = [Variable(id=vid) for vid in head_args]
    head = Atom(predicate_name=head_pred, args=tuple(head_vars))
    body_atoms = []
    for pred_name, var_ids in body_atoms_list:
        vars = [Variable(id=vid) for vid in var_ids]
        body_atoms.append(Atom(predicate_name=pred_name, args=tuple(vars)))
    rule = Rule(head=head, body=tuple(body_atoms))
    return [rule]

def get_test_embedding(theory):
    graph_data = graph_constructor.theory_to_graph(theory)
    state_embedding, _ = state_encoder(graph_data)
    return state_embedding.squeeze(0).detach().numpy()

# Test BEFORE pre-training
print("\nStep 1: Testing encoder BEFORE pre-training")
print("-" * 80)

rule_chain = create_test_rule('grandparent', [0, 1], [('parent', (0, 2)), ('parent', (2, 1))])
rule_convergent = create_test_rule('grandparent', [0, 1], [('parent', (0, 2)), ('parent', (1, 2))])

emb_chain_before = get_test_embedding(rule_chain)
emb_conv_before = get_test_embedding(rule_convergent)
sim_before = cosine_similarity([emb_chain_before], [emb_conv_before])[0, 0]

print(f"Rule 1 (chain):       grandparent(X0, X1) :- parent(X0, X2), parent(X2, X1)")
print(f"Rule 2 (convergent):  grandparent(X0, X1) :- parent(X0, X2), parent(X1, X2)")
print(f"\nSimilarity: {sim_before:.6f}")
print(f"Status: {'❌ TOO SIMILAR (need pre-training)' if sim_before > 0.95 else '✅ Already good'}")

# Generate base rules for pre-training
print("\n\nStep 2: Generating base rules for pre-training")
print("-" * 80)
base_rules = generate_base_rules(predicate_vocab, predicate_arities, num_rules=100)
print(f"Generated {len(base_rules)} diverse base rules")

# Pre-train the encoder
print("\n\nStep 3: Running contrastive pre-training")
print("-" * 80)

pretrainer = ContrastivePreTrainer(
    state_encoder=state_encoder,
    graph_constructor=graph_constructor,
    predicate_vocab=predicate_vocab,
    predicate_arities=predicate_arities
)

# Run pre-training (200 epochs takes ~2-3 minutes)
print("Pre-training for 200 epochs (this may take a few minutes)...")
losses = pretrainer.pretrain(base_rules, num_epochs=200, verbose=True)

# Test AFTER pre-training
print("\n\nStep 4: Testing encoder AFTER pre-training")
print("-" * 80)

emb_chain_after = get_test_embedding(rule_chain)
emb_conv_after = get_test_embedding(rule_convergent)
sim_after = cosine_similarity([emb_chain_after], [emb_conv_after])[0, 0]

print(f"Rule 1 (chain):       grandparent(X0, X1) :- parent(X0, X2), parent(X2, X1)")
print(f"Rule 2 (convergent):  grandparent(X0, X1) :- parent(X0, X2), parent(X1, X2)")
print(f"\nSimilarity: {sim_after:.6f}")
print(f"Status: {'✅ IMPROVED! Can now distinguish semantics' if sim_after < 0.90 else '⚠️ Still too similar'}")

# Visualize improvement
print("\n\nStep 5: Visualizing improvement")
print("-" * 80)

improvement = sim_before - sim_after
print(f"\nImprovement: {improvement:.6f} reduction in similarity")
print(f"  Before: {sim_before:.6f}")
print(f"  After:  {sim_after:.6f}")

# Plot results
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Loss curve
axes[0].plot(losses, linewidth=2, color='blue')
axes[0].set_xlabel('Epoch', fontsize=12)
axes[0].set_ylabel('Contrastive Loss', fontsize=12)
axes[0].set_title('Pre-training Loss Curve', fontsize=14, fontweight='bold')
axes[0].grid(alpha=0.3)

# Before/After comparison
categories = ['Before\nPre-training', 'After\nPre-training']
similarities = [sim_before, sim_after]
colors = ['red' if s > 0.90 else 'green' for s in similarities]

bars = axes[1].bar(categories, similarities, color=colors, alpha=0.7, edgecolor='black', linewidth=2)
axes[1].axhline(y=0.90, color='orange', linestyle='--', linewidth=2, label='Target (<0.90)')
axes[1].set_ylabel('Cosine Similarity', fontsize=12)
axes[1].set_title('Embedding Similarity:\nChain vs Convergent Pattern', fontsize=14, fontweight='bold')
axes[1].set_ylim([0, 1.05])
axes[1].legend(fontsize=10)
axes[1].grid(alpha=0.3, axis='y')

# Add value labels on bars
for bar, sim in zip(bars, similarities):
    height = bar.get_height()
    axes[1].text(bar.get_x() + bar.get_width()/2., height + 0.02,
                f'{sim:.4f}', ha='center', va='bottom', fontsize=11, fontweight='bold')

plt.tight_layout()
plt.savefig(f'{visualizer.run_dir}/contrastive_pretraining_results.png', dpi=300, bbox_inches='tight')
print(f"\n✓ Visualization saved to: {visualizer.run_dir}/contrastive_pretraining_results.png")
plt.show()

# Summary
print("\n" + "=" * 80)
print("PRE-TRAINING SUMMARY")
print("=" * 80)

if sim_after < 0.90:
    print("\n🎉 SUCCESS! Pre-training was effective.")
    print(f"   The encoder can now distinguish semantic differences.")
    print(f"   Similarity reduced from {sim_before:.4f} to {sim_after:.4f}")
    print("\n✓ The pre-trained encoder will now be used in GFlowNet training.")
else:
    print("\n⚠️  Pre-training helped but similarity is still high.")
    print(f"   Similarity: {sim_before:.4f} → {sim_after:.4f}")
    print("\n   Consider:")
    print("   - More pre-training epochs (500-1000)")
    print("   - More diverse base rules")
    print("   - Using improved architecture (see improved_graph_encoder.py)")

print("=" * 80)

CONTRASTIVE PRE-TRAINING

Step 1: Testing encoder BEFORE pre-training
--------------------------------------------------------------------------------
Rule 1 (chain):       grandparent(X0, X1) :- parent(X0, X2), parent(X2, X1)
Rule 2 (convergent):  grandparent(X0, X1) :- parent(X0, X2), parent(X1, X2)

Similarity: 0.999996
Status: ❌ TOO SIMILAR (need pre-training)


Step 2: Generating base rules for pre-training
--------------------------------------------------------------------------------
Generated 101 diverse base rules


Step 3: Running contrastive pre-training
--------------------------------------------------------------------------------
Pre-training for 200 epochs (this may take a few minutes)...
CONTRASTIVE PRE-TRAINING

Training for 200 epochs with 101 base rules


TypeError: can only concatenate list (not "tuple") to list

In [39]:
# Training
num_episodes = config['num_episodes']
initial_state = get_initial_state('grandparent', 2)

print(f"\n" + "="*80)
print(f"TRAINING ({num_episodes} episodes)")
print("="*80)

rewards = []
discovered_rules = {}  # Rule string -> (reward, episode, scores)
recent_rules = []  # Track last 50 rules for analysis

for episode in range(num_episodes):
    metrics = trainer.train_step(initial_state, positive_examples, negative_examples)

    print(episode, metrics)
    if metrics:
        rewards.append(metrics['reward'])

        # Record metrics with visualizer
        visualizer.record_episode(episode, metrics)

        # Sample trajectories periodically to see what rules are being found
        if episode % 10 == 0:
            trajectory, reward = trainer.generate_trajectory(
                initial_state, positive_examples, negative_examples
            )
            theory = trajectory[-1].next_state if trajectory else initial_state
            rule_str = theory_to_string(theory)

            scores = reward_calc.get_detailed_scores(theory, positive_examples, negative_examples)

            # Record with visualizer
            visualizer.record_rule(rule_str, reward, episode, scores)

            # Add detailed metrics to visualizer
            visualizer.record_episode(episode, {
                **metrics,
                'precision': scores['precision'],
                'recall': scores['recall'],
                'f1_score': scores['f1_score'],
                'accuracy': scores['accuracy']
            })

            discovered_rules[rule_str] = (reward, episode, scores)

            recent_rules.append((rule_str, reward, episode, scores))
            if len(recent_rules) > 100:
                recent_rules.pop(0)

        if episode % 100 == 0 and recent_rules:
            mean_reward = np.mean(rewards[-100:])
            print(f"\n--- Episode {episode:4d}: Mean Reward (last 100 episodes) = {mean_reward:.4f} ---")
            latest_rule, latest_reward, _, _ = recent_rules[-1]
            print(f"Episode {episode:4d}: reward={metrics['reward']:.4f}, length={metrics['trajectory_length']}")
            print(f"  Latest sampled rule: {latest_rule}")

    




TRAINING (10000 episodes)
0 {'loss': 18.95529556274414, 'reward': 1e-06, 'trajectory_length': 10, 'log_Z': 0.0, 'replay_used': False}

--- Episode    0: Mean Reward (last 100 episodes) = 0.0000 ---
Episode    0: reward=0.0000, length=10
  Latest sampled rule: grandparent(X0, X0) :- parent(X0, X0).
1 {'loss': 18.83078384399414, 'reward': 1e-06, 'trajectory_length': 10, 'log_Z': 0.0, 'replay_used': False}
2 {'loss': 20.999984741210938, 'reward': 1e-06, 'trajectory_length': 10, 'log_Z': 0.0, 'replay_used': False}
3 {'loss': 41.6654167175293, 'reward': 1e-06, 'trajectory_length': 5, 'log_Z': 0.0, 'replay_used': False}
4 {'loss': 20.983837127685547, 'reward': 1e-06, 'trajectory_length': 10, 'log_Z': 0.0, 'replay_used': False}
5 {'loss': 19.685344696044922, 'reward': 1e-06, 'trajectory_length': 9, 'log_Z': 0.0, 'replay_used': False}
6 {'loss': 35.33517074584961, 'reward': 1e-06, 'trajectory_length': 5, 'log_Z': 0.0, 'replay_used': False}
7 {'loss': 21.055641174316406, 'reward': 1e-06, 'traj

KeyboardInterrupt: 

In [57]:
# Analysis
print("\n" + "="*80)
print("TRAINING RESULTS")
print("="*80)

if rewards:
    final_avg_reward = np.mean(rewards[-100:]) if len(rewards) > 100 else np.mean(rewards)
    max_reward = np.max(rewards)
    high_reward_count = sum(1 for r in rewards if r > 0.8)

    print(f"\nFinal avg reward (last 100): {final_avg_reward:.4f}")
    print(f"Max reward: {max_reward:.4f}")
    print(f"High-reward episodes (>0.8): {high_reward_count}")
else:
    print("No training data was generated.")

print(f"Unique rules discovered: {len(discovered_rules)}")


# Show discovered rules sorted by reward
print("\n" + "="*80)
print("TOP DISCOVERED RULES")
print("="*80)

sorted_rules = sorted(discovered_rules.items(), key=lambda x: x[1][0], reverse=True)

print("\nShowing top 10 rules by reward:\n")

for i, (rule_str, (reward, episode, scores)) in enumerate(sorted_rules[:10], 1):
    pos_total = scores['TP'] + scores['FN']
    neg_total = scores['FP'] + scores['TN']
    print(f"{i}. [Reward: {scores['reward']:.4f}] {rule_str}")
    print(f"   Discovered at Episode: {episode}")
    print(f"   Confusion Matrix: TP={scores['TP']}, FN={scores['FN']}, FP={scores['FP']}, TN={scores['TN']}")
    print(f"   Coverage: {scores['TP']}/{pos_total} positives, "
          f"{scores['FP']}/{neg_total} negatives")
    print(f"   Metrics: Precision={scores['precision']:.4f}, Recall={scores['recall']:.4f}, F1={scores['f1_score']:.4f}")
    print(f"   Penalties: Disconnected={scores['num_disconnected_vars']} (-{scores['disconnected_penalty']:.2f}), "
          f"Self-loops={scores['num_self_loops']} (-{scores['self_loop_penalty']:.2f}), "
          f"Free-vars={scores['num_free_vars']} (-{scores['free_var_penalty']:.2f})")
    print()





TRAINING RESULTS

Final avg reward (last 100): 0.0000
Max reward: 1.0167
High-reward episodes (>0.8): 4
Unique rules discovered: 384

TOP DISCOVERED RULES

Showing top 10 rules by reward:

1. [Reward: 0.8989] grandparent(X2, X1) :- parent(X2, X4), parent(X4, X5), parent(X6, X4), parent(X8, X1).
   Discovered at Episode: 1340
   Confusion Matrix: TP=4, FN=0, FP=1, TN=4
   Coverage: 4/4 positives, 1/5 negatives
   Metrics: Precision=0.8000, Recall=1.0000, F1=0.8889
   Penalties: Disconnected=0 (-0.00), Self-loops=0 (-0.00), Free-vars=0 (-0.00)

2. [Reward: 0.6100] grandparent(X6, X1) :- parent(X2, X1), parent(X4, X2), parent(X6, X7), parent(X8, X9).
   Discovered at Episode: 410
   Confusion Matrix: TP=4, FN=0, FP=0, TN=5
   Coverage: 4/4 positives, 0/5 negatives
   Metrics: Precision=1.0000, Recall=1.0000, F1=1.0000
   Penalties: Disconnected=2 (-0.40), Self-loops=0 (-0.00), Free-vars=0 (-0.00)

3. [Reward: 0.2957] grandparent(X5, X1) :- parent(X5, X3), parent(X4, X5), parent(X6, X5), 

In [58]:
# Analyze replay buffer
print("="*80)
print("REPLAY BUFFER ANALYSIS")
print("="*80)

if trainer.replay_buffer and len(trainer.replay_buffer.buffer) > 0:
    print(f"\nReplay buffer size: {len(trainer.replay_buffer.buffer)}")

    replay_rules = []
    for trajectory, reward in trainer.replay_buffer.buffer:
        theory = trajectory[-1].next_state
        rule_str = theory_to_string(theory)
        scores = reward_calc.get_detailed_scores(theory, positive_examples, negative_examples)
        replay_rules.append((rule_str, reward, scores))

    replay_rules.sort(key=lambda x: x[1], reverse=True)
    print(f"\nTop 10 rules in replay buffer:\n")

    for i, (rule_str, reward, scores) in enumerate(replay_rules[:10], 1):
        pos_total = scores['TP'] + scores['FN']
        neg_total = scores['FP'] + scores['TN']
        print(f"{i}. [Reward: {reward:.4f}] {rule_str}")
        print(f"   Coverage: {scores['TP']}/{pos_total} positives, "
              f"{scores['FP']}/{neg_total} negatives")
        print(f"   Issues: {scores['num_disconnected_vars']} disconnected, "
              f"{scores['num_self_loops']} self-loops, "
              f"{scores['num_free_vars']} free-vars")
        print()

    # Quality statistics
    num_perfect = sum(1 for _, _, s in replay_rules if s['recall'] == 1.0 and s['FP'] == 0)
    num_disconnected = sum(1 for _, _, s in replay_rules if s['num_disconnected_vars'] > 0)
    num_self_loops = sum(1 for _, _, s in replay_rules if s['num_self_loops'] > 0)
    buffer_size = len(replay_rules)

    print("="*80)
    print("REPLAY BUFFER QUALITY STATISTICS")
    print("="*80)
    print(f"\nPerfect rules (100% recall, 0 false positives): {num_perfect}/{buffer_size} ({100*num_perfect/buffer_size:.1f}%)")
    print(f"Rules with disconnected variables: {num_disconnected}/{buffer_size} ({100*num_disconnected/buffer_size:.1f}%)")
    print(f"Rules with self-loops: {num_self_loops}/{buffer_size} ({100*num_self_loops/buffer_size:.1f}%)")
else:
    print("Replay buffer is empty.")



REPLAY BUFFER ANALYSIS

Replay buffer size: 8

Top 10 rules in replay buffer:

1. [Reward: 1.0167] grandparent(X0, X1) :- parent(X2, X1), parent(X0, X2).
   Coverage: 4/4 positives, 0/5 negatives
   Issues: 0 disconnected, 0 self-loops, 0 free-vars

2. [Reward: 1.0167] grandparent(X0, X5) :- parent(X0, X3), parent(X3, X5).
   Coverage: 4/4 positives, 0/5 negatives
   Issues: 0 disconnected, 0 self-loops, 0 free-vars

3. [Reward: 1.0100] grandparent(X6, X5) :- parent(X2, X5), parent(X2, X5), parent(X6, X7), parent(X8, X2).
   Coverage: 4/4 positives, 0/5 negatives
   Issues: 0 disconnected, 0 self-loops, 0 free-vars

4. [Reward: 0.8989] grandparent(X4, X1) :- parent(X2, X3), parent(X4, X2), parent(X6, X7), parent(X6, X1).
   Coverage: 4/4 positives, 1/5 negatives
   Issues: 0 disconnected, 0 self-loops, 0 free-vars

5. [Reward: 0.7439] grandparent(X0, X5) :- parent(X0, X3), parent(X4, X5).
   Coverage: 4/4 positives, 3/5 negatives
   Issues: 0 disconnected, 0 self-loops, 0 free-vars

6.

In [59]:
# Generate all visualizations
visualizer.finalize()

print("\n" + "="*80)
print("TRAINING COMPLETE")
print("="*80)
print(f"Results saved to: {visualizer.run_dir}")
print("\nGenerated files:")
print("  - training_curves.png       : Reward and loss over time")
print("  - metrics_over_time.png     : Precision, recall, F1-score")
print("  - confusion_matrices.png    : Top rules' confusion matrices")
print("  - trajectory_lengths.png    : Trajectory length distribution")
print("  - best_rules.txt            : Top 20 discovered rules")
print("  - summary_dashboard.png     : Comprehensive overview")
print("  - config.json               : Training configuration")
print("="*80)


GENERATING VISUALIZATIONS
✓ Saved training curves to results/run_20251021_110111/training_curves.png
✓ Saved metrics plot to results/run_20251021_110111/metrics_over_time.png
✓ Saved confusion matrices to results/run_20251021_110111/confusion_matrices.png
✓ Saved trajectory lengths to results/run_20251021_110111/trajectory_lengths.png
✓ Saved best rules to results/run_20251021_110111/best_rules.txt
✓ Saved summary dashboard to results/run_20251021_110111/summary_dashboard.png

✓ All visualizations saved to: results/run_20251021_110111

TRAINING COMPLETE
Results saved to: results/run_20251021_110111

Generated files:
  - training_curves.png       : Reward and loss over time
  - metrics_over_time.png     : Precision, recall, F1-score
  - confusion_matrices.png    : Top rules' confusion matrices
  - trajectory_lengths.png    : Trajectory length distribution
  - best_rules.txt            : Top 20 discovered rules
  - summary_dashboard.png     : Comprehensive overview
  - config.json      

# Embedding Analysis: Testing Semantic Equivalence Detection

Now that we have a **trained** graph encoder, let's test whether it can distinguish between semantically different rules while treating semantically equivalent rules similarly.

In [None]:
import torch
from sklearn.metrics.pairwise import cosine_similarity
import matplotlib.pyplot as plt
import seaborn as sns
from src.logic_structures import Rule, Atom, Variable

def create_test_rule(head_pred, head_args, body_atoms_list):
    """Create a rule for testing embeddings."""
    head_vars = [Variable(id=vid) for vid in head_args]
    head = Atom(predicate_name=head_pred, args=tuple(head_vars))
    
    body_atoms = []
    for pred_name, var_ids in body_atoms_list:
        vars = [Variable(id=vid) for vid in var_ids]
        body_atoms.append(Atom(predicate_name=pred_name, args=tuple(vars)))
    
    rule = Rule(head=head, body=tuple(body_atoms))
    return [rule]  # Return as Theory (list of rules)

def get_embedding(theory, graph_constructor, state_encoder):
    """Extract embedding for a theory using the trained encoder."""
    graph_data = graph_constructor.theory_to_graph(theory)
    state_embedding, _ = state_encoder(graph_data)
    return state_embedding.squeeze(0).detach().numpy()

print("=" * 80)
print("GRAPH EMBEDDING ANALYSIS (TRAINED ENCODER)")
print("=" * 80)

# Test 1: Variable Renaming (Should be similar)
print("\nTest 1: Variable Renaming (Semantic Equivalence)")
print("-" * 80)
rule1 = create_test_rule('grandparent', [0, 1], [('parent', (0, 2)), ('parent', (2, 1))])
rule2 = create_test_rule('grandparent', [10, 11], [('parent', (10, 12)), ('parent', (12, 11))])

emb1 = get_embedding(rule1, graph_constructor, state_encoder)
emb2 = get_embedding(rule2, graph_constructor, state_encoder)
sim_rename = cosine_similarity([emb1], [emb2])[0, 0]

print(f"Rule 1: grandparent(X0, X1) :- parent(X0, X2), parent(X2, X1)")
print(f"Rule 2: grandparent(X10, X11) :- parent(X10, X12), parent(X12, X11)")
print(f"Cosine Similarity: {sim_rename:.6f}")
print(f"Result: {'✓ PASS' if sim_rename > 0.95 else '✗ FAIL'} (Expected: >0.95)")

# Test 2: Different Semantics (Should be different)
print("\n\nTest 2: Different Semantics (Similar Syntax)")
print("-" * 80)
rule3 = create_test_rule('grandparent', [0, 1], [('parent', (0, 2)), ('parent', (2, 1))])  # Correct
rule4 = create_test_rule('grandparent', [0, 1], [('parent', (0, 2)), ('parent', (1, 2))])  # Wrong (sibling-like)

emb3 = get_embedding(rule3, graph_constructor, state_encoder)
emb4 = get_embedding(rule4, graph_constructor, state_encoder)
sim_different = cosine_similarity([emb3], [emb4])[0, 0]

print(f"Rule 3 (correct): grandparent(X0, X1) :- parent(X0, X2), parent(X2, X1)")
print(f"Rule 4 (wrong):   grandparent(X0, X1) :- parent(X0, X2), parent(X1, X2)")
print(f"Cosine Similarity: {sim_different:.6f}")
print(f"Result: {'✓ PASS' if sim_different < 0.90 else '✗ FAIL'} (Expected: <0.90)")

# Test 3: Predicate Order (Should be similar)
print("\n\nTest 3: Predicate Order Swap (Semantic Equivalence)")
print("-" * 80)
rule5 = create_test_rule('rule', [0, 1], [('parent', (0, 2)), ('parent', (2, 1))])
rule6 = create_test_rule('rule', [0, 1], [('parent', (2, 1)), ('parent', (0, 2))])

emb5 = get_embedding(rule5, graph_constructor, state_encoder)
emb6 = get_embedding(rule6, graph_constructor, state_encoder)
sim_order = cosine_similarity([emb5], [emb6])[0, 0]

print(f"Rule 5: rule(X0, X1) :- parent(X0, X2), parent(X2, X1)")
print(f"Rule 6: rule(X0, X1) :- parent(X2, X1), parent(X0, X2)")
print(f"Cosine Similarity: {sim_order:.6f}")
print(f"Result: {'✓ PASS' if sim_order > 0.90 else '✗ FAIL'} (Expected: >0.90)")

# Test 4: Different Rule Lengths
print("\n\nTest 4: Different Rule Lengths")
print("-" * 80)
rule7 = create_test_rule('rule', [0, 1], [('parent', (0, 1))])  # Short
rule8 = create_test_rule('rule', [0, 1], [('parent', (0, 2)), ('parent', (2, 3)), ('parent', (3, 1))])  # Long

emb7 = get_embedding(rule7, graph_constructor, state_encoder)
emb8 = get_embedding(rule8, graph_constructor, state_encoder)
sim_length = cosine_similarity([emb7], [emb8])[0, 0]

print(f"Rule 7 (short): rule(X0, X1) :- parent(X0, X1)")
print(f"Rule 8 (long):  rule(X0, X1) :- parent(X0, X2), parent(X2, X3), parent(X3, X1)")
print(f"Cosine Similarity: {sim_length:.6f}")
print(f"Result: Informational (different complexity)")

In [None]:
# Create comprehensive similarity matrix visualization
import numpy as np

print("\n\n" + "=" * 80)
print("COMPREHENSIVE SIMILARITY MATRIX")
print("=" * 80)

# Collect all embeddings
all_embeddings = np.array([emb1, emb2, emb3, emb4, emb5, emb6, emb7, emb8])
similarity_matrix = cosine_similarity(all_embeddings)

rule_labels = [
    "R1: GP(X,Y):-P(X,Z),P(Z,Y)",
    "R2: GP(A,B):-P(A,C),P(C,B) [renamed]",
    "R3: GP(X,Y):-P(X,Z),P(Z,Y) [correct]",
    "R4: GP(X,Y):-P(X,Z),P(Y,Z) [wrong]",
    "R5: R(X,Y):-P(X,Z),P(Z,Y)",
    "R6: R(X,Y):-P(Z,Y),P(X,Z) [swapped]",
    "R7: R(X,Y):-P(X,Y) [short]",
    "R8: R(X,Y):-P(X,Z),P(Z,W),P(W,Y) [long]"
]

# Create heatmap
plt.figure(figsize=(12, 10))
sns.heatmap(
    similarity_matrix,
    annot=True,
    fmt='.3f',
    cmap='RdYlGn',
    vmin=0,
    vmax=1,
    xticklabels=rule_labels,
    yticklabels=rule_labels,
    cbar_kws={'label': 'Cosine Similarity'}
)
plt.title('Graph Embedding Similarity Matrix (TRAINED Encoder)\n(Green = More Similar, Red = Less Similar)', 
          fontsize=14, pad=20)
plt.xticks(rotation=45, ha='right', fontsize=9)
plt.yticks(rotation=0, fontsize=9)
plt.tight_layout()
plt.savefig(f'{visualizer.run_dir}/embedding_similarity_matrix.png', dpi=300, bbox_inches='tight')
plt.show()

print(f"\n✓ Similarity matrix saved to: {visualizer.run_dir}/embedding_similarity_matrix.png")

# Summary
print("\n" + "=" * 80)
print("SUMMARY")
print("=" * 80)

test_results = [
    ("Variable Renaming (R1 vs R2)", sim_rename, sim_rename > 0.95, "Should be similar"),
    ("Different Semantics (R3 vs R4)", sim_different, sim_different < 0.90, "Should be different"),
    ("Predicate Order (R5 vs R6)", sim_order, sim_order > 0.90, "Should be similar"),
]

print("\nTest Results:")
for test_name, sim_value, passed, expectation in test_results:
    status = "✓ PASS" if passed else "✗ FAIL"
    print(f"  {status} - {test_name}: {sim_value:.4f} ({expectation})")

all_passed = all(result[2] for result in test_results)

print("\n" + "=" * 80)
if all_passed:
    print("✓ OVERALL: All critical tests passed!")
    print("The TRAINED encoder successfully captures semantic differences.")
else:
    print("✗ OVERALL: Some tests failed")
    print("The encoder may still have issues distinguishing semantic differences.")
print("=" * 80)

print("\n\nKey Insights:")
print(f"  - Same rule with renamed variables: {sim_rename:.4f} similarity")
print(f"  - Different semantics (chain vs convergent): {sim_different:.4f} similarity")
print(f"  - Same rule with swapped order: {sim_order:.4f} similarity")
print(f"  - Short vs long rules: {sim_length:.4f} similarity")
print("\nNote: Training the encoder through GFlowNet helps it learn to distinguish")
print("      structurally different rules based on their reward feedback.")

# Embedding Trajectory Visualization

Visualize how embeddings evolve step-by-step during trajectory generation. This shows:
- Whether embeddings diverge meaningfully as rules are constructed
- Whether high-reward trajectories follow different paths than low-reward ones
- How much each action changes the embedding representation

In [None]:
from visualize_embedding_trajectories import EmbeddingTrajectoryVisualizer

print("=" * 80)
print("EMBEDDING TRAJECTORY VISUALIZATION")
print("=" * 80)
print("\nThis visualization shows how embeddings evolve as rules are constructed.")
print("Each trajectory is a path through embedding space.\n")

# Create visualizer
emb_viz = EmbeddingTrajectoryVisualizer(
    trainer=trainer,
    graph_constructor=graph_constructor,
    state_encoder=state_encoder
)

# Collect trajectories
print("Sampling trajectories to visualize embedding evolution...")
trajectories_data = emb_viz.collect_trajectory_embeddings(
    initial_state=initial_state,
    positives=positive_examples,
    negatives=negative_examples,
    num_trajectories=10,  # Sample 10 trajectories
    max_steps=5          # Up to 5 steps each
)

# Generate all visualizations
print("\nGenerating visualizations...")
figs = emb_viz.visualize_all(
    trajectories_data,
    output_dir=visualizer.run_dir,
    prefix='embedding_trajectory'
)

# Show first two plots
plt.show()

In [None]:
# Analyze trajectory patterns
print("\n" + "=" * 80)
print("TRAJECTORY ANALYSIS")
print("=" * 80)

# Compute statistics
high_reward_trajs = [t for t in trajectories_data if t['reward'] > 0.5]
low_reward_trajs = [t for t in trajectories_data if t['reward'] <= 0.5]

print(f"\nSampled {len(trajectories_data)} trajectories:")
print(f"  - High reward (>0.5): {len(high_reward_trajs)}")
print(f"  - Low reward (≤0.5):  {len(low_reward_trajs)}")

# Average trajectory length
avg_length = np.mean([t['length'] for t in trajectories_data])
print(f"\nAverage trajectory length: {avg_length:.2f} steps")

# Compute average distance traveled in embedding space
distances_traveled = []
for traj in trajectories_data:
    embeddings = traj['embeddings']
    total_distance = 0.0
    for i in range(1, len(embeddings)):
        total_distance += np.linalg.norm(embeddings[i] - embeddings[i-1])
    distances_traveled.append(total_distance)

avg_distance = np.mean(distances_traveled)
print(f"Average distance traveled in embedding space: {avg_distance:.4f}")

# Compare high vs low reward trajectories
if high_reward_trajs and low_reward_trajs:
    high_reward_distances = []
    low_reward_distances = []
    
    for traj in high_reward_trajs:
        embeddings = traj['embeddings']
        dist = sum(np.linalg.norm(embeddings[i] - embeddings[i-1]) 
                   for i in range(1, len(embeddings)))
        high_reward_distances.append(dist)
    
    for traj in low_reward_trajs:
        embeddings = traj['embeddings']
        dist = sum(np.linalg.norm(embeddings[i] - embeddings[i-1]) 
                   for i in range(1, len(embeddings)))
        low_reward_distances.append(dist)
    
    print(f"\nEmbedding space traveled:")
    print(f"  - High-reward trajectories: {np.mean(high_reward_distances):.4f}")
    print(f"  - Low-reward trajectories:  {np.mean(low_reward_distances):.4f}")
    
    if np.mean(high_reward_distances) > np.mean(low_reward_distances):
        print("\n  → High-reward trajectories explore more of the embedding space")
    else:
        print("\n  → Low-reward trajectories explore more of the embedding space")

# Action distribution
all_actions = []
for traj in trajectories_data:
    all_actions.extend([a for a in traj['actions'] if a != 'FINAL'])

from collections import Counter
action_counts = Counter(all_actions)

print(f"\nAction distribution across all trajectories:")
for action, count in action_counts.most_common():
    pct = 100 * count / len(all_actions)
    print(f"  - {action}: {count} ({pct:.1f}%)")

print("\n" + "=" * 80)
print("INTERPRETATION GUIDE")
print("=" * 80)
print("\nWhat to look for in the visualizations:")
print("\n1. PCA/t-SNE plots:")
print("   ✓ GOOD: Trajectories fan out and explore different regions")
print("   ✗ BAD:  All trajectories cluster together (embeddings too similar)")

print("\n2. Distance evolution plots:")
print("   ✓ GOOD: High-reward trajectories follow different paths than low-reward")
print("   ✗ BAD:  All trajectories follow similar paths regardless of reward")

print("\n3. Similarity heatmap:")
print("   ✓ GOOD: Block structure visible (different trajectories are different)")
print("   ✗ BAD:  Uniformly high similarity (all states look the same)")

print("\n4. Action-colored plot:")
print("   ✓ GOOD: Different action types cluster in different regions")
print("   ✗ BAD:  All action types overlap completely")

print("\n" + "=" * 80)

# Policy Graph Visualization

Visualize the trained GFlowNet policy as a directed graph where:
- **Nodes** = States (rules under construction)
- **Edges** = Actions (ADD_ATOM, UNIFY_VARIABLES, TERMINATE)
- **Edge width** = Action probability
- **Colors**: Blue (ADD_ATOM), Green (UNIFY_VARIABLES), Red (TERMINATE)

In [None]:
from visualize_gflownet_graph import GFlowNetGraphVisualizer

print("=" * 80)
print("POLICY GRAPH VISUALIZATION")
print("=" * 80)

# Create visualizer
policy_viz = GFlowNetGraphVisualizer(
    trainer=trainer,
    predicate_vocab=predicate_vocab,
    predicate_arities=predicate_arities,
    max_body_length=config['max_body_length']
)

# Explore the policy graph starting from initial state
print("\nExploring policy graph from initial state...")
print("Parameters:")
print(f"  - Max depth: 3")
print(f"  - Min probability threshold: 0.05")
print(f"  - Max branches per state: 3")

policy_viz.explore_from_state(
    initial_state=initial_state,
    max_depth=3,
    min_prob=0.05,
    max_branches=3
)

print(f"\n✓ Explored {len(policy_viz.graph.nodes)} states")
print(f"✓ Found {len(policy_viz.graph.edges)} actions")

# Visualize the graph
output_path = f'{visualizer.run_dir}/policy_graph.png'
policy_viz.visualize(output_path=output_path, figsize=(20, 14))

# Print top probability paths
policy_viz.print_paths(max_paths=5)

In [None]:
# Analyze policy behavior at different states
print("\n" + "=" * 80)
print("POLICY BEHAVIOR ANALYSIS")
print("=" * 80)

# Analyze initial state
print("\n1. Initial State Analysis")
print("-" * 80)
action_probs, _, _, _ = policy_viz.get_action_probabilities(initial_state)
print(f"Initial state: {theory_to_string(initial_state)}")
print(f"\nAction probabilities:")
print(f"  - ADD_ATOM:        {action_probs[0]:.4f}")
print(f"  - UNIFY_VARIABLES: {action_probs[1]:.4f}")
print(f"  - TERMINATE:       {action_probs[2]:.4f}")

if action_probs[0] > 0.5:
    print("\n✓ Policy prefers adding atoms (exploration)")
    atom_probs = policy_viz.get_atom_probabilities(initial_state)
    top_pred_idx = np.argmax(atom_probs)
    top_pred = predicate_vocab[top_pred_idx]
    print(f"  Most likely predicate to add: {top_pred} (prob: {atom_probs[top_pred_idx]:.4f})")

# Sample a trajectory and analyze intermediate states
print("\n\n2. Trajectory Analysis")
print("-" * 80)
print("Sampling a trajectory to see policy behavior...")

trajectory, reward = trainer.generate_trajectory(
    initial_state, positive_examples, negative_examples, max_steps=5
)

print(f"\nTrajectory length: {len(trajectory)} steps")
print(f"Final reward: {reward:.4f}\n")

for i, step in enumerate(trajectory):
    state_str = theory_to_string(step.state)
    action_type = step.action_type
    
    print(f"Step {i+1}:")
    print(f"  State: {state_str if state_str else '[empty]'}")
    print(f"  Action taken: {action_type}")
    print(f"  Log probability: {step.log_pf:.4f}")
    
    # Get action distribution at this state
    action_probs, _, _, _ = policy_viz.get_action_probabilities(step.state)
    print(f"  Action distribution: ADD={action_probs[0]:.3f}, UNIFY={action_probs[1]:.3f}, TERM={action_probs[2]:.3f}")
    print()

final_state_str = theory_to_string(trajectory[-1].next_state) if trajectory else "N/A"
print(f"Final state: {final_state_str}")
print(f"Final reward: {reward:.4f}")

# Statistics on action preferences
print("\n\n3. Overall Policy Statistics")
print("-" * 80)

action_types = [step.action_type for step in trajectory]
add_count = action_types.count('ADD_ATOM')
unify_count = action_types.count('UNIFY_VARIABLES')
term_count = action_types.count('TERMINATE')

print(f"Action distribution in sampled trajectory:")
print(f"  - ADD_ATOM:        {add_count}/{len(trajectory)} ({100*add_count/len(trajectory):.1f}%)")
print(f"  - UNIFY_VARIABLES: {unify_count}/{len(trajectory)} ({100*unify_count/len(trajectory):.1f}%)")
print(f"  - TERMINATE:       {term_count}/{len(trajectory)} ({100*term_count/len(trajectory):.1f}%)")

print("\n" + "=" * 80)

# Loss vs Reward Mismatch Analysis

Diagnose the "zero flow problem" where the loss function is minimized but the policy doesn't find good rules. This happens when the model learns to assign uniformly low probabilities everywhere.

In [None]:
from analyze_loss_reward_mismatch import LossRewardAnalyzer

# Create analyzer
analyzer = LossRewardAnalyzer(trainer)

# Prepare states for analysis
analysis_states = {
    'initial': initial_state,
    'positives': positive_examples,
    'negatives': negative_examples
}

print("=" * 80)
print("COMPREHENSIVE LOSS/REWARD DIAGNOSIS")
print("=" * 80)
print("\nThis analysis checks for common GFlowNet training problems:")
print("  1. Zero Flow Problem - Model assigns low probabilities everywhere")
print("  2. Low Reward Problem - Policy doesn't find good rules")
print("  3. Vanishing Gradients - Learning has stalled")
print("\nRunning diagnosis...")

# Run comprehensive diagnosis
diagnosis = analyzer.diagnose_zero_flow_problem(
    states=analysis_states,
    num_samples=50
)

# Save figures
flow_fig = diagnosis['flow_fig']
reward_fig = diagnosis['reward_fig']

flow_fig.savefig(f'{visualizer.run_dir}/flow_analysis.png', dpi=300, bbox_inches='tight')
reward_fig.savefig(f'{visualizer.run_dir}/reward_analysis.png', dpi=300, bbox_inches='tight')

print(f"\n✓ Flow analysis saved to: {visualizer.run_dir}/flow_analysis.png")
print(f"✓ Reward analysis saved to: {visualizer.run_dir}/reward_analysis.png")

plt.show()

In [None]:
# Detailed interpretation of results
print("\n" + "=" * 80)
print("DETAILED INTERPRETATION")
print("=" * 80)

# Check diagnosis results
has_zero_flow = diagnosis['zero_flow_problem']
has_low_reward = diagnosis['low_reward_problem']

flow_stats = diagnosis['flow_stats']
reward_stats = diagnosis['reward_stats']

print("\n1. Flow Analysis Results:")
print(f"   - Learned log Z: {flow_stats['log_Z']:.4f}")
print(f"   - Mean log P_F: {flow_stats['mean_log_pf']:.4f}")
print(f"   - Std log P_F:  {flow_stats['std_log_pf']:.4f}")

if has_zero_flow:
    print("   ⚠️  Zero flow problem detected!")
    print("      The model is assigning very low probabilities to all actions.")
else:
    print("   ✓ Flow values look reasonable")

print("\n2. Reward Analysis Results:")
print(f"   - Mean reward: {np.mean(reward_stats['rewards']):.4f}")
print(f"   - Max reward:  {np.max(reward_stats['rewards']):.4f}")
print(f"   - % Zero rewards: {100 * np.sum(reward_stats['rewards'] < 1e-6) / len(reward_stats['rewards']):.1f}%")
print(f"   - % High rewards (>0.8): {100 * np.sum(reward_stats['rewards'] > 0.8) / len(reward_stats['rewards']):.1f}%")

if has_low_reward:
    print("   ⚠️  Low reward problem detected!")
    print("      The policy is not finding good rules.")
else:
    print("   ✓ Policy is finding decent rules")

print("\n3. Connection to Training Results:")
print(f"   - Final average reward (from training): {final_avg_reward:.4f}")
print(f"   - Max reward achieved: {max_reward:.4f}")
print(f"   - High-reward episodes (>0.8): {high_reward_count}/{num_episodes}")

# Determine root cause
print("\n" + "=" * 80)
print("ROOT CAUSE ANALYSIS")
print("=" * 80)

if has_zero_flow and has_low_reward:
    print("\n🔴 CRITICAL: Both zero flow AND low reward problems detected")
    print("\nLikely causes:")
    print("  1. Encoder produces similar embeddings for all states (see embedding analysis)")
    print("  2. Reward signal is too sparse or weak")
    print("  3. Learning rate may be too high, causing instability")
    print("  4. Model has converged to a local minimum (uniform low probabilities)")
    
elif has_low_reward and not has_zero_flow:
    print("\n🟡 Low rewards but flow values are reasonable")
    print("\nLikely causes:")
    print("  1. Policy needs more exploration")
    print("  2. Encoder cannot distinguish good from bad partial rules")
    print("  3. Max body length may be too restrictive")
    
elif has_zero_flow and not has_low_reward:
    print("\n🟡 Zero flow detected but rewards are okay")
    print("\nLikely causes:")
    print("  1. Model is under-confident (could benefit from temperature tuning)")
    print("  2. Log probabilities are scaled differently than expected")
    
else:
    print("\n🟢 Training appears healthy!")
    print("\nThe model shows:")
    print("  ✓ Reasonable flow values")
    print("  ✓ Good reward distribution")
    print("  Continue training or experiment with harder problems.")

print("\n" + "=" * 80)

# Final Summary: Complete Pipeline with Advanced Diagnostics

This notebook demonstrates a complete GFlowNet-ILP pipeline with state-of-the-art diagnostics:

## Pipeline Stages:
1. **✓ Contrastive Pre-Training** - Teaches encoder to distinguish semantic differences
2. **✓ GFlowNet Training** - Learns to generate high-quality logic rules
3. **✓ Embedding Analysis** - Verifies encoder can distinguish semantics (post-training)
4. **✓ Embedding Trajectory Visualization** ← NEW! - Shows how embeddings evolve step-by-step
5. **✓ Policy Visualization** - Shows what the GFlowNet has learned
6. **✓ Loss/Reward Diagnosis** - Identifies any training problems

## Key Innovations:

### 1. Contrastive Pre-Training
Solves the critical problem where untrained encoders produce nearly identical embeddings for all rules, making GFlowNet learning impossible.

### 2. Embedding Trajectory Visualization
**NEW!** Shows how embeddings evolve as rules are constructed:
- **6 different visualizations** (PCA 2D/3D, t-SNE, similarity heatmap, distance evolution, action clustering)
- Reveals whether high-reward and low-reward trajectories take different paths through embedding space
- Diagnoses whether the encoder is learning meaningful representations during construction

## Generated Visualizations:
The notebook produces **18 visualization files** organized into 4 categories:
- Pre-training (1 file)
- Training (7 files)
- Embedding Trajectories (6 files) ← NEW!
- Diagnostics (4 files)

Check the results directory for all generated visualizations!

In [None]:
import os

# List all generated files
print("=" * 80)
print("GENERATED FILES SUMMARY")
print("=" * 80)
print(f"\nAll results saved to: {visualizer.run_dir}\n")

print("Pre-training Visualizations:")
print("  ✓ contrastive_pretraining_results.png - Before/after pre-training comparison")

print("\nTraining Visualizations:")
print("  ✓ training_curves.png       - Reward and loss over time")
print("  ✓ metrics_over_time.png     - Precision, recall, F1-score")
print("  ✓ confusion_matrices.png    - Top rules' confusion matrices")
print("  ✓ trajectory_lengths.png    - Trajectory length distribution")
print("  ✓ summary_dashboard.png     - Comprehensive overview")
print("  ✓ best_rules.txt            - Top 20 discovered rules")
print("  ✓ config.json               - Training configuration")

print("\nEmbedding Trajectory Visualizations:")
print("  ✓ embedding_trajectory_pca_2d.png       - 2D PCA trajectory paths")
print("  ✓ embedding_trajectory_pca_3d.png       - 3D PCA trajectory paths")
print("  ✓ embedding_trajectory_tsne.png         - t-SNE trajectory paths")
print("  ✓ embedding_trajectory_similarity.png   - Step-by-step similarity heatmap")
print("  ✓ embedding_trajectory_distance_evolution.png - Distance evolution plots")
print("  ✓ embedding_trajectory_by_action.png    - Embeddings colored by action type")

print("\nDiagnostic Visualizations:")
print("  ✓ embedding_similarity_matrix.png - Encoder semantic equivalence test")
print("  ✓ policy_graph.png                - GFlowNet policy as state-action graph")
print("  ✓ flow_analysis.png               - Flow value distribution")
print("  ✓ reward_analysis.png             - Reward distribution analysis")

print("\n" + "=" * 80)
print("ANALYSIS COMPLETE!")
print("=" * 80)

# Count files
if os.path.exists(visualizer.run_dir):
    files = [f for f in os.listdir(visualizer.run_dir) if not f.startswith('.')]
    print(f"\nTotal files generated: {len(files)}")
    print(f"Results directory: {visualizer.run_dir}")
else:
    print("\nResults directory not found!")