In [1]:
import torch
from deepproblog.dataset import DataLoader
from deepproblog.engines import ExactEngine
from deepproblog.model import Model
from deepproblog.network import Network
from deepproblog.train import train_model
from data_manager import loading_abstraction_extended, BinaryTargetDataset, BinaryTargetInterface, AnimalCategorizer
from neural import OptimizedMLP
from conf_matrix import get_confusion_matrix_and_errors
from sklearn.model_selection import train_test_split
from collections import Counter



In [2]:
TEST_SPLIT_RATIO = 0.2
RANDOM_SEED = 42
#NUM_ATTRIBUTES = 48
DATA_PATH = 'Animals_with_Attributes2/ResNet101/AwA2-features.txt'
LABEL_PATH = 'Animals_with_Attributes2/ResNet101/AwA2-labels.txt'
BINARY_TARGET_PATH = 'Animals_with_Attributes2/predicate-matrix-binary.txt'

QUERY_NAMES = ['is_terrestrial', 'is_aquatic', 'is_ungulate', 'is_carnivore']
INDICES = [[32,42],[10,11,27,43],[18,22,32,35,42],[33,34,37]]
QUERY_TO_INDICES = {
    'is_terrestrial': INDICES[0],
    'is_aquatic': INDICES[1],
    'is_ungulate' : INDICES[2],
    'is_carnivore' : INDICES[3]
}
KB_FILENAMES = ['knowledge_bases/kb_terr_test.pl', 'knowledge_bases/kb_aqua_test.pl',
                'knowledge_bases/kb_ungulates.pl', 'knowledge_bases/kb_carnivore.pl']
QUERY_TO_KB = {
    'is_terrestrial': KB_FILENAMES[0],
    'is_aquatic': KB_FILENAMES[1],
    'is_ungulate' : KB_FILENAMES[2],
    'is_carnivore' : KB_FILENAMES[3]
}

In [3]:
#indices = [0, 1, 2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 19, 20, 26, 28, 29, 30, 32, 33, 34, 35, 36, 37, 39, 42, 44, 45, 51, 52, 54, 55, 58, 59, 64, 66, 68, 74, 75, 76, 78, 80, 84]
indices_to_extract = sorted([item for sublist in INDICES for item in sublist])
data, label_indices, binary_vectors = loading_abstraction_extended(DATA_PATH, LABEL_PATH, BINARY_TARGET_PATH,
                                                                   attribute_indices_to_extract=indices_to_extract)
print(len(label_indices))

37322


In [4]:
non_terrestrial = [2,8,16,17,23,29,35,46,49]
terrestrial_labels = [1 if label not in non_terrestrial else 0 for label in label_indices]

aquatic = [2,3,8,13,17,23,35,44,46,49]
hairless_aquatic = [2,8,13,17,23,46,49]
furry_aquatic = [3,35,44]
aquatic_labels = [2 if label in furry_aquatic else 1 if label in hairless_aquatic else 0 for label in label_indices]

unhorned_ungulates = [6,22,37,41]
horned_ungulates = [0,15,20,27,30,36,39,48]
ungulate_labels = [2 if label in horned_ungulates else 1 if label in unhorned_ungulates else 0 for label in label_indices]

carnivores = [1,2,7,9,12,14,21,23,29,31,33,34,35,40,42,44,45]
carnivore_labels = [1 if label in carnivores else 0 for label in label_indices]

QUERY_TO_LABELS = {
    'is_terrestrial': terrestrial_labels,
    'is_aquatic': aquatic_labels,
    'is_ungulate': ungulate_labels,
    'is_carnivore': carnivore_labels
}

In [5]:
original_sample_indices = list(range(len(label_indices)))
train_indices, test_indices = train_test_split(
    original_sample_indices,
    test_size=TEST_SPLIT_RATIO,
    random_state=RANDOM_SEED,
    stratify=[label_indices[i] for i in original_sample_indices]
)
print(f"Training samples: {len(train_indices)}, Testing samples: {len(test_indices)}")

Training samples: 29857, Testing samples: 7465


In [6]:
test_distribution = [label_indices[i] for i in test_indices]
print(f'Test set distribution: {Counter(test_distribution)}')

Test set distribution: Counter({6: 329, 22: 284, 39: 269, 48: 268, 26: 240, 30: 240, 37: 234, 28: 218, 0: 209, 18: 208, 7: 207, 45: 206, 42: 204, 23: 198, 49: 189, 36: 179, 12: 175, 38: 175, 19: 174, 44: 174, 1: 170, 25: 156, 35: 152, 5: 149, 20: 146, 24: 146, 14: 144, 41: 143, 17: 142, 15: 141, 27: 139, 13: 137, 21: 133, 40: 126, 31: 118, 32: 113, 4: 110, 47: 102, 9: 100, 29: 76, 33: 62, 16: 58, 2: 58, 34: 54, 46: 43, 3: 38, 10: 37, 43: 37, 8: 35, 11: 20})


In [9]:
for query in QUERY_NAMES[2:3]:
    labels = QUERY_TO_LABELS[query]
    dataset = BinaryTargetDataset(data, labels, binary_vectors) # here change labels
    # We pass full dataset to the interface
    interface = BinaryTargetInterface(dataset)

    # Create datasets using the full dataset but with specific sample indices
    # HERE CHANGE FUNCTION
    train_set = AnimalCategorizer(
        dataset_name=dataset,  # Pass the full dataset
        function_name=query,
        seed=42,
        sample_indices=train_indices  # Pass only the sample indices you want
    )
    test_set = AnimalCategorizer(
        dataset_name=dataset,
        function_name=query,
        seed=42,
        sample_indices=test_indices
    )

    networks = {}
    for i in QUERY_TO_INDICES[query]:
        network = OptimizedMLP()
        network.load_state_dict(torch.load(f'pretrained_mlps/net{i}.pth', map_location=torch.device('cpu')))
        net = Network(network, f"net{i}", batching=True)
        net.optimizer = torch.optim.AdamW(network.parameters(), lr=1e-4)
        networks[f"net{i}"] = net
    model = Model('knowledge_bases/kb_ungulates_updated.pl', list(networks.values()))
    model.set_engine(ExactEngine(model), cache=True)
    # Add tensor sources for accessing the images
    model.add_tensor_source("dataset", interface)
    
    _, errors_indices = get_confusion_matrix_and_errors(model, test_set, verbose=1)
    miss_labels = [label_indices[x] for x in errors_indices]
    print(Counter(miss_labels))

Caching ACs


Consider using tensor.detach() first. (Triggered internally at /pytorch/torch/csrc/autograd/generated/python_variable_methods.cpp:836.)
  return -self.eps <= float(a) <= self.eps


         	 	    	Actual	   
         	 	   0	     2	  1
Predicted	0	4863	    55	 40
         	2	   9	  1500	 16
         	1	  12	    36	934
Accuracy:  0.9774949765572672
Accuracy 0.9774949765572672
Counter({15: 25, 48: 25, 41: 23, 6: 20, 36: 13, 27: 12, 22: 10, 39: 8, 1: 5, 13: 5, 0: 5, 18: 4, 37: 3, 30: 2, 19: 1, 7: 1, 40: 1, 45: 1, 42: 1, 31: 1, 20: 1, 28: 1})


In [None]:
for query in QUERY_NAMES[1:2]:
    labels = QUERY_TO_LABELS[query]
    dataset = BinaryTargetDataset(data, labels, binary_vectors) # here change labels
    # We pass full dataset to the interface
    interface = BinaryTargetInterface(dataset)

    # Create datasets using the full dataset but with specific sample indices
    # HERE CHANGE FUNCTION
    train_set = AnimalCategorizer(
        dataset_name=dataset,  # Pass the full dataset
        function_name=query,
        seed=42,
        sample_indices=train_indices  # Pass only the sample indices you want
    )
    test_set = AnimalCategorizer(
        dataset_name=dataset,
        function_name=query,
        seed=42,
        sample_indices=test_indices
    )

    networks = {}
    for i in QUERY_TO_INDICES[query]:
        network = OptimizedMLP()
        net = Network(network, f"net{i}", batching=True)
        net.optimizer = torch.optim.AdamW(network.parameters(), lr=1e-4)
        networks[f"net{i}"] = net
    model = Model(QUERY_TO_KB[query], list(networks.values()))
    model.set_engine(ExactEngine(model), cache=True)
    # Add tensor sources for accessing the images
    model.add_tensor_source("dataset", interface)
    # Create a data loader for training
    loader = DataLoader(train_set, 24, shuffle=True)
    # Train the model
    train = train_model(model, loader, 10, log_iter=50, profile=0)
    # Save the trained model
    model.save_state(f"snapshot/scratch_model_{query}.pth")
    
    _, errors_indices = get_confusion_matrix_and_errors(model, test_set, verbose=1)
    miss_labels = [label_indices[x] for x in errors_indices]
    print(Counter(miss_labels))

Caching ACs
Training  for 10 epoch(s)
Epoch 1
Iteration:  50 	s:8.9783 	Average Loss:  0.7301214514904859
Iteration:  100 	s:7.5383 	Average Loss:  0.32127998276013386
Iteration:  150 	s:10.8749 	Average Loss:  0.2619075731005978
Iteration:  200 	s:26.8869 	Average Loss:  0.17706948293047844
Iteration:  250 	s:29.4375 	Average Loss:  0.14438578516799452
Iteration:  300 	s:29.9381 	Average Loss:  0.14478316725816967
Iteration:  350 	s:20.1766 	Average Loss:  0.1379188561191654
Iteration:  400 	s:24.4318 	Average Loss:  0.12316168004940892
Iteration:  450 	s:23.5154 	Average Loss:  0.12148968469966576
Iteration:  500 	s:14.3378 	Average Loss:  0.13075506510375307
Iteration:  550 	s:4.7991 	Average Loss:  0.177791486276069
Iteration:  600 	s:6.6518 	Average Loss:  0.12242058071691077
Iteration:  650 	s:6.5300 	Average Loss:  0.14307049435502905
Iteration:  700 	s:6.5191 	Average Loss:  0.11831600891825018
Iteration:  750 	s:6.8952 	Average Loss:  0.10730638166843164
Iteration:  800 	s:22.

In [7]:
for query in QUERY_NAMES[3:]:
    labels = QUERY_TO_LABELS[query]
    dataset = BinaryTargetDataset(data, labels, binary_vectors) # here change labels
    # We pass full dataset to the interface
    interface = BinaryTargetInterface(dataset)

    # Create datasets using the full dataset but with specific sample indices
    # HERE CHANGE FUNCTION
    train_set = AnimalCategorizer(
        dataset_name=dataset,  # Pass the full dataset
        function_name=query,
        seed=42,
        sample_indices=train_indices  # Pass only the sample indices you want
    )
    test_set = AnimalCategorizer(
        dataset_name=dataset,
        function_name=query,
        seed=42,
        sample_indices=test_indices
    )

    networks = {}
    for i in QUERY_TO_INDICES[query]:
        network = OptimizedMLP()
        net = Network(network, f"net{i}", batching=True)
        net.optimizer = torch.optim.AdamW(network.parameters(), lr=1e-4)
        networks[f"net{i}"] = net
    model = Model(QUERY_TO_KB[query], list(networks.values()))
    model.load_state('snapshot/scratch_model_is_carnivore.pth')
    model.set_engine(ExactEngine(model), cache=True)
    # Add tensor sources for accessing the images
    model.add_tensor_source("dataset", interface)
    
    _, errors_indices = get_confusion_matrix_and_errors(model, test_set, verbose=1)
    miss_labels = [label_indices[x] for x in errors_indices]
    print(Counter(miss_labels))

Caching ACs
         	 	Actual	    
         	 	     0	   1
Predicted	0	  5018	 154
         	1	    90	2203
Accuracy:  0.9673141326188881
Accuracy 0.9673141326188881
Counter({33: 32, 45: 19, 29: 13, 32: 13, 1: 12, 43: 11, 9: 11, 35: 11, 2: 11, 23: 11, 40: 9, 7: 9, 3: 8, 5: 7, 34: 7, 15: 7, 6: 5, 46: 5, 47: 4, 25: 3, 41: 3, 36: 3, 39: 3, 48: 2, 31: 2, 44: 2, 42: 2, 10: 2, 16: 2, 4: 2, 14: 2, 22: 2, 24: 2, 17: 1, 13: 1, 11: 1, 0: 1, 21: 1, 49: 1, 27: 1})
