In [26]:
import torch
import torch.nn as nn
import ltn
import numpy as np
import matplotlib.pyplot as plt

### Código para Gerar o Dataset

Este código cria a classe CLEVR_Generator e gera os vetores exatamente no formato [x, y, r, g, b, s].

In [42]:
def get_clevr_data_expanded(n_samples=100):
    """
    Gera um dataset CLEVR sintético estendido para regras complexas.
    Retorna tensor de formato (n_samples, 11).
    
    ESTRUTURA DO VETOR [11 features]:
    -------------------------------------------------
    [0, 1]    : Posição x, y (0.0 a 1.0)
    [2, 3, 4] : Cores One-Hot (Vermelho, Verde, Azul)
    [5, 6, 7, 8, 9] : Formas One-Hot (Círculo, Quadrado, Cilindro, Cone, Triângulo)
    [10]      : Tamanho (0.0 = Pequeno, 1.0 = Grande)
    -------------------------------------------------
    """
    data = []
    labels = []
    
    # Mapeamento para labels legíveis
    shapes_names = ["Circle", "Square", "Cylinder", "Cone", "Triangle"]
    colors_names = ["Red", "Green", "Blue"]
    sizes_names  = ["Small", "Large"]
    
    for _ in range(n_samples):
        # 1. Posição (x, y)
        x = np.random.rand()
        y = np.random.rand()
        
        # 2. Cor (One-hot 3 cores)
        color_idx = np.random.randint(0, 3)
        color_vec = [0.0] * 3
        color_vec[color_idx] = 1.0
        
        # 3. Forma (One-hot 5 formas)
        shape_idx = np.random.randint(0, 5)
        shape_vec = [0.0] * 5
        shape_vec[shape_idx] = 1.0
        
        # 4. Tamanho (Binário com ruído leve para realismo)
        # Se < 0.5 é Pequeno, se > 0.5 é Grande
        is_large = np.random.rand() > 0.5
        size_val = 1.0 if is_large else 0.0
        
        # Construção do Vetor
        # [x, y] + [r, g, b] + [s1...s5] + [size]
        vector = [x, y] + color_vec + shape_vec + [size_val]
        data.append(vector)
        
        # Label para debug
        desc = f"{sizes_names[int(size_val)]} {colors_names[color_idx]} {shapes_names[shape_idx]} at ({x:.2f}, {y:.2f})"
        labels.append(desc)
        
    tensor_data = torch.tensor(data, dtype=torch.float32)
    return tensor_data, labels

# --- USO ---
# Gerar dados
data, texts = get_clevr_data_expanded(150)
objects = ltn.Variable("objects", data)

print(f"Dataset gerado. Shape: {data.shape}")
print(f"Exemplo: {texts[0]}")
print(f"Vetor: {data[0]}")

Dataset gerado. Shape: torch.Size([150, 11])
Exemplo: Small Red Triangle at (0.13, 0.29)
Vetor: tensor([0.1318, 0.2926, 1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        1.0000, 0.0000])


In [29]:
# --- Exemplo de Uso para os Alunos ---

# Gerar 50 objetos
objects_data = get_clevr_data(50)

# Criar a variável LTN (que representa "todos os objetos")
objects = ltn.Variable("objects", objects_data)

print(f"Dataset gerado com sucesso! Formato: {objects_data.shape}")
print("Exemplo de um objeto (x, y, r, g, b, s):")
print(objects_data[0])

Dataset gerado com sucesso! Formato: torch.Size([50, 6])
Exemplo de um objeto (x, y, r, g, b, s):
tensor([0.0489, 0.1118, 0.0000, 0.0000, 1.0000, 1.0000])


In [30]:
# Configure LTNtorch to use the appropriate device
ltn.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {ltn.device}")

Using device: cpu


In [31]:
# 1. DEFINE LOGICAL CONNECTIVES AND QUANTIFIERS (REQUIRED IN LTNTORCH)
Not = ltn.Connective(ltn.fuzzy_ops.NotStandard())
And = ltn.Connective(ltn.fuzzy_ops.AndProd())
Or = ltn.Connective(ltn.fuzzy_ops.OrProbSum())
Implies = ltn.Connective(ltn.fuzzy_ops.ImpliesReichenbach())
Forall = ltn.Quantifier(ltn.fuzzy_ops.AggregPMeanError(p=2), quantifier="f")
Exists = ltn.Quantifier(ltn.fuzzy_ops.AggregPMean(p=2), quantifier="e")
sat_agg = ltn.fuzzy_ops.SatAgg()

In [4]:
# 2. CREATE LTN OBJECTS (Constants and Variables)
# Constants represent specific objects/entities
# Variables represent quantifiable sets of objects

# Example constants (specific objects)
red_ball = ltn.Constant(torch.tensor([1.0, 0.0, 0.0]), trainable=False)  # RGB color vector
blue_cube = ltn.Constant(torch.tensor([0.0, 0.0, 1.0]), trainable=False)  # RGB color vector

# Example variables (sets of objects)
balls = ltn.Variable("balls", torch.tensor([
    [1.0, 0.0, 0.0],  # red ball
    [0.0, 1.0, 0.0],  # green ball  
    [0.5, 0.0, 0.5],  # purple ball
]))

cubes = ltn.Variable("cubes", torch.tensor([
    [0.0, 0.0, 1.0],  # blue cube
    [1.0, 1.0, 0.0],  # yellow cube
    [0.5, 0.5, 0.5],  # gray cube
]))


In [5]:
# 3. DEFINE PREDICATES (Logical properties/relations)
class ColorPredicate(nn.Module):
    def __init__(self):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(3, 10),
            nn.ReLU(),
            nn.Linear(10, 1),
            nn.Sigmoid()  # Ensure output in [0,1] range
        )
    
    def forward(self, x):
        # Handle LTNObject inputs
        if hasattr(x, 'value'):
            x = x.value
        return self.network(x).squeeze()

In [6]:
# Create predicates
is_red = ltn.Predicate(ColorPredicate())
is_blue = ltn.Predicate(ColorPredicate())
is_round = ltn.Predicate(ColorPredicate())  # Simplified example

In [7]:
# 4. DEFINE LOGICAL AXIOMS (Knowledge base)
def axioms():
    # Basic facts about specific objects
    facts = [
        is_red(red_ball),           # The red ball is red
        is_blue(blue_cube),         # The blue cube is blue
        is_round(red_ball),         # The red ball is round
        Not(is_round(blue_cube))    # The blue cube is not round
    ]
    
    # General rules about categories
    rules = [
        # All balls are round
        Forall(balls, is_round(balls)),
        
        # If something is red, then it's likely to be a ball (simplified rule)
        Forall(ltn.Variable("x", torch.randn(10, 3, device=ltn.device)), 
               Implies(is_red(ltn.Variable("x", torch.randn(10, 3, device=ltn.device))), 
                      is_round(ltn.Variable("x", torch.randn(10, 3, device=ltn.device)))))
    ]
    
    # Combine all axioms using And
    all_axioms = facts + rules
    return sat_agg(*all_axioms)


In [8]:
# 5. SATISFIABILITY CHECKING (Reasoning)
def check_satisfiability():
    # Get the satisfaction level of our knowledge base
    sat_level = axioms()
    print(f"Knowledge base satisfaction level: {sat_level.item():.4f}")
    
    return sat_level

In [32]:
# 6. TRAINING THE PREDICATES (Learning from data)
def train_predicates(epochs=100):
    # Collect all parameters from predicates
    parameters = list(is_red.model.parameters()) + \
                 list(is_blue.model.parameters()) + \
                 list(is_round.model.parameters())
    
    optimizer = torch.optim.Adam(parameters, lr=0.01)
    
    print("\n=== TRAINING PREDICATES ===")
    for epoch in range(epochs):
        optimizer.zero_grad()
        
        # Get satisfaction loss (1 - satisfaction)
        sat_loss = 1 - axioms()
        
        # Backpropagate
        sat_loss.backward()
        optimizer.step()
        
        if epoch % 20 == 0:
            print(f"Epoch {epoch}: Satisfaction = {1-sat_loss.item():.4f}")


In [33]:
# 7. QUERYING THE KNOWLEDGE BASE (Inference)
def query_knowledge_base():
    print("\n=== QUERYING KNOWLEDGE BASE ===")
    
    # Query 1: Is the red ball red?
    query1 = is_red(red_ball)
    print(f"Q1: Is the red ball red? A: {query1.value.item():.4f}")
    
    # Query 2: Is the blue cube round?
    query2 = is_round(blue_cube)
    print(f"Q2: Is the blue cube round? A: {query2.value.item():.4f}")
    
    # Query 3: Are all balls round? (Universal quantification)
    all_balls_round = Forall(balls, is_round(balls))
    print(f"Q3: Are all balls round? A: {all_balls_round.value.item():.4f}")


In [34]:
# 8. RUN THE COMPLETE EXAMPLE
print("=== LTN LOGICAL REASONING EXAMPLE ===")

# Check initial satisfiability
initial_sat = check_satisfiability()

# Train the predicates to better satisfy our axioms
train_predicates(epochs=100)

# Check final satisfiability after training
final_sat = check_satisfiability()

# Query the trained knowledge base
query_knowledge_base()

print("\n=== KEY INSIGHTS ABOUT LTNObjects ===")
print("1. LTNObjects wrap tensors with logical meaning")
print("2. Constants represent specific entities")
print("3. Variables represent quantifiable sets")
print("4. Predicates map LTNObjects to truth values [0,1]")
print("5. Logical connectives (And, Or, Not, Implies) work on truth values")
print("6. Quantifiers (Forall, Exists) aggregate over variables")
print("7. In LTNTorch, you must explicitly define all logical connectives")
print("8. The SatAgg function combines multiple axioms into a single satisfaction score")

=== LTN LOGICAL REASONING EXAMPLE ===
Knowledge base satisfaction level: 0.9519

=== TRAINING PREDICATES ===
Epoch 0: Satisfaction = 0.9372
Epoch 20: Satisfaction = 0.9752
Epoch 40: Satisfaction = 0.9151
Epoch 60: Satisfaction = 0.9356
Epoch 80: Satisfaction = 0.9858
Knowledge base satisfaction level: 0.9835

=== QUERYING KNOWLEDGE BASE ===
Q1: Is the red ball red? A: 0.9735
Q2: Is the blue cube round? A: 0.0300
Q3: Are all balls round? A: 0.9960

=== KEY INSIGHTS ABOUT LTNObjects ===
1. LTNObjects wrap tensors with logical meaning
2. Constants represent specific entities
3. Variables represent quantifiable sets
4. Predicates map LTNObjects to truth values [0,1]
5. Logical connectives (And, Or, Not, Implies) work on truth values
6. Quantifiers (Forall, Exists) aggregate over variables
7. In LTNTorch, you must explicitly define all logical connectives
8. The SatAgg function combines multiple axioms into a single satisfaction score


## Setup for Equality Experiments

In [35]:
print("\n" + "="*50)
print("EXTENSION: DIFFERENT FORMS OF EQUALITY IN LTN")
print("="*50)

# Define logical connectives for equality experiments
Not_eq = ltn.Connective(ltn.fuzzy_ops.NotStandard())
And_eq = ltn.Connective(ltn.fuzzy_ops.AndProd())
Or_eq = ltn.Connective(ltn.fuzzy_ops.OrProbSum())
Implies_eq = ltn.Connective(ltn.fuzzy_ops.ImpliesReichenbach())
Forall_eq = ltn.Quantifier(ltn.fuzzy_ops.AggregPMeanError(p=2), quantifier="f")
Exists_eq = ltn.Quantifier(ltn.fuzzy_ops.AggregPMean(p=2), quantifier="e")
sat_agg_eq = ltn.fuzzy_ops.SatAgg()

# Create sample data for equality experiments
# Two objects that should be considered equal (same color)
red_object1 = ltn.Constant(torch.tensor([1.0, 0.0, 0.0], device=ltn.device), trainable=True)
red_object2 = ltn.Constant(torch.tensor([0.95, 0.05, 0.0], device=ltn.device), trainable=True)

# Two objects that should be considered different
blue_object = ltn.Constant(torch.tensor([0.0, 0.0, 1.0], device=ltn.device), trainable=True)

# Variables for quantification
objects = ltn.Variable("objects", torch.tensor([
    [1.0, 0.0, 0.0],    # red
    [0.95, 0.05, 0.0],  # similar red  
    [0.0, 0.0, 1.0],    # blue
    [0.0, 1.0, 0.0],    # green
], device=ltn.device))

print("Equality experiment setup complete!")


EXTENSION: DIFFERENT FORMS OF EQUALITY IN LTN
Equality experiment setup complete!


## 1 LTN Diagonal Equality Implementation

In [36]:
print("\n" + "-"*30)
print("1. LTN DIAGONAL EQUALITY")
print("-"*30)

class DiagonalEquality(nn.Module):
    """Implements diagonal equality using LTN's diag function"""
    def forward(self, x, y):
        # Handle LTNObject inputs
        x_val = x.value if hasattr(x, 'value') else x
        y_val = y.value if hasattr(y, 'value') else y
        
        # Create diagonal pairs and compute similarity
        # For simplicity, we use cosine similarity on the diagonal
        cos_sim = torch.nn.functional.cosine_similarity(x_val, y_val, dim=-1)
        return 0.5 * (cos_sim + 1.0)  # Normalize to [0,1]

# Create diagonal equality predicate
Equal_Diag = ltn.Predicate(DiagonalEquality().to(ltn.device))

def axioms_withEquality_Diag():
    """Knowledge base using diagonal equality"""
    # Basic equality axioms
    axioms = [
        # Reflexivity: every object should be equal to itself
        Forall_eq(objects, Equal_Diag(objects, objects)),
        
        # Symmetry example: if object1 equals object2, then object2 equals object1
        Equal_Diag(red_object1, red_object2),
        Equal_Diag(red_object2, red_object1),
        
        # Transitivity example (simplified)
        Implies_eq(
            And_eq(Equal_Diag(red_object1, red_object2), Equal_Diag(red_object2, red_object1)),
            Equal_Diag(red_object1, red_object1)
        ),
        
        # Different objects should not be equal
        Not_eq(Equal_Diag(red_object1, blue_object))
    ]
    
    return sat_agg_eq(*axioms)

# Test diagonal equality
print("Testing Diagonal Equality...")
sat_diag = axioms_withEquality_Diag()
print(f"Diagonal Equality KB Satisfaction: {sat_diag.item():.4f}")

# Query specific equalities
query1 = Equal_Diag(red_object1, red_object2)
query2 = Equal_Diag(red_object1, blue_object)
print(f"Q: Are the two red objects equal? A: {query1.value.item():.4f}")
print(f"Q: Is red object equal to blue object? A: {query2.value.item():.4f}")


------------------------------
1. LTN DIAGONAL EQUALITY
------------------------------
Testing Diagonal Equality...
Diagonal Equality KB Satisfaction: 0.7764
Q: Are the two red objects equal? A: 0.9993
Q: Is red object equal to blue object? A: 0.5000


## 2 - Cosine Equality Implementation

In [37]:

class CosineEquality(nn.Module):
    """Implements cosine similarity-based equality with learnable temperature parameter"""
    def __init__(self):
        super().__init__()
        # Learnable temperature parameter to control sharpness of similarity
        self.temperature = nn.Parameter(torch.tensor(1.0, device=ltn.device))
        
    def forward(self, x, y):
        # Handle LTNObject inputs
        x_val = x.value if hasattr(x, 'value') else x
        y_val = y.value if hasattr(y, 'value') else y
        
        # Compute cosine similarity
        cos_sim = torch.nn.functional.cosine_similarity(x_val, y_val, dim=-1)
        
        # Apply temperature scaling and normalize to [0,1] range
        scaled_sim = torch.sigmoid(self.temperature * (cos_sim + 1.0) / 2.0)
        return scaled_sim

# Create cosine equality predicate
Equal_Cos = ltn.Predicate(CosineEquality().to(ltn.device))

def axioms_withEquality_Cos():
    """Knowledge base using cosine equality"""
    # Get learnable parameters (now includes temperature)
    parameters = list(Equal_Cos.model.parameters())
    optimizer = torch.optim.Adam(parameters, lr=0.1)  # Higher LR for temperature
    
    print("Training Cosine Equality...")
    for epoch in range(200):
        optimizer.zero_grad()
        
        # Same axioms as diagonal equality but with cosine predicate
        axioms = [
            Forall_eq(objects, Equal_Cos(objects, objects)),  # Reflexivity
            Equal_Cos(red_object1, red_object2),               # Similar reds should be equal
            Equal_Cos(red_object2, red_object1),               # Symmetry
            Not_eq(Equal_Cos(red_object1, blue_object))        # Different colors not equal
        ]
        
        sat = sat_agg_eq(*axioms)
        loss = 1 - sat
        loss.backward()
        optimizer.step()
        
        if epoch % 50 == 0:
            print(f"Epoch {epoch}: Cosine Equality Satisfaction = {sat.item():.4f}, Temp = {Equal_Cos.model.temperature.item():.4f}")
    
    # Return final satisfaction
    final_axioms = [
        Forall_eq(objects, Equal_Cos(objects, objects)),
        Equal_Cos(red_object1, red_object2),
        Not_eq(Equal_Cos(red_object1, blue_object))
    ]
    return sat_agg_eq(*final_axioms)



In [38]:
print("\n" + "-"*30)
print("2. COSINE EQUALITY (FIXED)")
print("-"*30)

# Test cosine equality
sat_cos = axioms_withEquality_Cos()
print(f"Final Cosine Equality KB Satisfaction: {sat_cos.item():.4f}")

# Query specific equalities
query1 = Equal_Cos(red_object1, red_object2)
query2 = Equal_Cos(red_object1, blue_object)
print(f"Q: Are the two red objects equal (cosine)? A: {query1.value.item():.4f}")
print(f"Q: Is red object equal to blue object (cosine)? A: {query2.value.item():.4f}")


------------------------------
2. COSINE EQUALITY (FIXED)
------------------------------
Training Cosine Equality...
Epoch 0: Cosine Equality Satisfaction = 0.6112, Temp = 1.1000
Epoch 50: Cosine Equality Satisfaction = 0.6255, Temp = 1.5738
Epoch 100: Cosine Equality Satisfaction = 0.6255, Temp = 1.5680
Epoch 150: Cosine Equality Satisfaction = 0.6255, Temp = 1.5687
Final Cosine Equality KB Satisfaction: 0.5793
Q: Are the two red objects equal (cosine)? A: 0.8275
Q: Is red object equal to blue object (cosine)? A: 0.6866


## 3 - Euclidean Equality Implementation


In [39]:
print("\n" + "-"*30)
print("3. EUCLIDEAN EQUALITY (FIXED)")
print("-"*30)

class EuclideanEquality(nn.Module):
    """Implements Euclidean distance-based equality with learnable gamma parameter"""
    def __init__(self):
        super().__init__()
        # Learnable gamma parameter for strictness
        self.gamma = nn.Parameter(torch.tensor(0.5, device=ltn.device))
        
    def forward(self, x, y):
        # Handle LTNObject inputs
        x_val = x.value if hasattr(x, 'value') else x
        y_val = y.value if hasattr(y, 'value') else y
        
        # Compute Euclidean distance
        distance = torch.sum(torch.square(x_val - y_val), dim=-1)
        
        # Apply exponential decay with learnable gamma
        similarity = torch.exp(-self.gamma * distance)
        
        return similarity

# Create Euclidean equality predicate
Equal_Eucl = ltn.Predicate(EuclideanEquality().to(ltn.device))

def axioms_withEquality_Eucl():
    """Knowledge base using Euclidean equality"""
    parameters = list(Equal_Eucl.model.parameters())
    optimizer = torch.optim.Adam(parameters, lr=0.1)  # Higher LR for gamma
    
    print("Training Euclidean Equality...")
    for epoch in range(150):
        optimizer.zero_grad()
        
        axioms = [
            Forall_eq(objects, Equal_Eucl(objects, objects)),  # Reflexivity
            Equal_Eucl(red_object1, red_object2),               # Similar objects
            Equal_Eucl(red_object2, red_object1),               # Symmetry
            Not_eq(Equal_Eucl(red_object1, blue_object))        # Different objects
        ]
        
        sat = sat_agg_eq(*axioms)
        loss = 1 - sat
        loss.backward()
        optimizer.step()
        
        if epoch % 30 == 0:
            print(f"Epoch {epoch}: Euclidean Equality Satisfaction = {sat.item():.4f}, Gamma = {Equal_Eucl.model.gamma.item():.4f}")
    
    # Final evaluation
    final_axioms = [
        Forall_eq(objects, Equal_Eucl(objects, objects)),
        Equal_Eucl(red_object1, red_object2),
        Not_eq(Equal_Eucl(red_object1, blue_object))
    ]
    return sat_agg_eq(*final_axioms)

# Test Euclidean equality
sat_eucl = axioms_withEquality_Eucl()
print(f"Final Euclidean Equality KB Satisfaction: {sat_eucl.item():.4f}")

# Query specific equalities
query1 = Equal_Eucl(red_object1, red_object2)
query2 = Equal_Eucl(red_object1, blue_object)
print(f"Q: Are the two red objects equal (Euclidean)? A: {query1.value.item():.4f}")
print(f"Q: Is red object equal to blue object (Euclidean)? A: {query2.value.item():.4f}")


------------------------------
3. EUCLIDEAN EQUALITY (FIXED)
------------------------------
Training Euclidean Equality...
Epoch 0: Euclidean Equality Satisfaction = 0.8160, Gamma = 0.6000
Epoch 30: Euclidean Equality Satisfaction = 0.9903, Gamma = 2.2782
Epoch 60: Euclidean Equality Satisfaction = 0.9905, Gamma = 2.4738
Epoch 90: Euclidean Equality Satisfaction = 0.9905, Gamma = 2.4658
Epoch 120: Euclidean Equality Satisfaction = 0.9905, Gamma = 2.4479
Final Euclidean Equality KB Satisfaction: 0.9917
Q: Are the two red objects equal (Euclidean)? A: 0.9879
Q: Is red object equal to blue object (Euclidean)? A: 0.0076


## 4 - Learnable Manifold Equality Implementation

In [40]:
print("\n" + "-"*30)
print("4. LEARNABLE MANIFOLD EQUALITY (FIXED)")
print("-"*30)

class LearnableEquality(nn.Module):
    """Implements learnable equality using neural network"""
    def __init__(self, input_dim=3):
        super(LearnableEquality, self).__init__()
        # Input dim is 2 * feature_dim because of concatenation
        self.net = nn.Sequential(
            nn.Linear(input_dim * 2, 64),
            nn.ELU(),
            nn.Linear(64, 32),
            nn.ELU(),
            nn.Linear(32, 1),
            nn.Sigmoid()  # Critical: output must be in [0,1]
        )
    
    def forward(self, x, y):
        # Handle LTNObject inputs
        x_val = x.value if hasattr(x, 'value') else x
        y_val = y.value if hasattr(y, 'value') else y
        
        # Concatenate along the feature dimension
        cat_inputs = torch.cat([x_val, y_val], dim=-1)
        return self.net(cat_inputs).squeeze()

# Create learnable equality predicate
Equal_Learned = ltn.Predicate(LearnableEquality(input_dim=3).to(ltn.device))

def axioms_withEquality_Learned():
    """Knowledge base using learnable manifold equality"""
    parameters = list(Equal_Learned.model.parameters())
    optimizer = torch.optim.Adam(parameters, lr=0.01)
    
    print("Training Learnable Equality Network...")
    for epoch in range(300):
        optimizer.zero_grad()
        
        # Complex equality axioms that require learning
        axioms = [
            # Reflexivity: every object equals itself
            Forall_eq(objects, Equal_Learned(objects, objects)),
            
            # Similar objects should be equal (red objects)
            Equal_Learned(red_object1, red_object2),
            Equal_Learned(red_object2, red_object1),
            
            # Different objects should not be equal
            Not_eq(Equal_Learned(red_object1, blue_object)),
            Not_eq(Equal_Learned(blue_object, red_object1)),
            
            # Transitivity example (if A=B and B=C then A=C)
            Implies_eq(
                And_eq(Equal_Learned(red_object1, red_object2), Equal_Learned(red_object2, red_object1)),
                Equal_Learned(red_object1, red_object1)
            )
        ]
        
        sat = sat_agg_eq(*axioms)
        loss = 1 - sat
        loss.backward()
        optimizer.step()
        
        if epoch % 50 == 0:
            print(f"Epoch {epoch}: Learnable Equality Satisfaction = {sat.item():.4f}")
    
    # Final evaluation with more complex queries
    final_axioms = [
        Forall_eq(objects, Equal_Learned(objects, objects)),
        Equal_Learned(red_object1, red_object2),
        Not_eq(Equal_Learned(red_object1, blue_object))
    ]
    return sat_agg_eq(*final_axioms)

# Test learnable equality
sat_learned = axioms_withEquality_Learned()
print(f"Final Learnable Equality KB Satisfaction: {sat_learned.item():.4f}")

# Query specific equalities
query1 = Equal_Learned(red_object1, red_object2)
query2 = Equal_Learned(red_object1, blue_object)
query3 = Equal_Learned(red_object1, red_object1)  # Reflexivity check

print(f"Q: Are the two red objects equal (Learned)? A: {query1.value.item():.4f}")
print(f"Q: Is red object equal to blue object (Learned)? A: {query2.value.item():.4f}")
print(f"Q: Is red object equal to itself (Learned)? A: {query3.value.item():.4f}")


------------------------------
4. LEARNABLE MANIFOLD EQUALITY (FIXED)
------------------------------
Training Learnable Equality Network...
Epoch 0: Learnable Equality Satisfaction = 0.5278
Epoch 50: Learnable Equality Satisfaction = 0.7959
Epoch 100: Learnable Equality Satisfaction = 0.7959
Epoch 150: Learnable Equality Satisfaction = 0.7959
Epoch 200: Learnable Equality Satisfaction = 0.7959
Epoch 250: Learnable Equality Satisfaction = 0.7959
Final Learnable Equality KB Satisfaction: 0.7113
Q: Are the two red objects equal (Learned)? A: 1.0000
Q: Is red object equal to blue object (Learned)? A: 0.0002
Q: Is red object equal to itself (Learned)? A: 1.0000


In [22]:
print(objects_data[20])

tensor([0.6269, 0.7298, 1.0000, 0.0000, 0.0000, 1.0000])
