In [217]:
import torch
import numpy as np
from sklearn.metrics import accuracy_score, f1_score
import pandas as pd
import json
from ontolearn.knowledge_base import KnowledgeBase
from model import *
import helper as h

In [295]:
path = './testing_data/family/Data.json'
with open(path, 'r') as f:
    data_family = json.load(f)
data_family 

[['Brother ⊔ (Parent ⊓ (∀ hasSibling.Father))',
  {'positive examples': ['family#F5F61',
    'family#F7M110',
    'family#F9M166',
    'family#F6M95',
    'family#F6F93',
    'family#F2F10',
    'family#F6F72',
    'family#F2F22',
    'family#F6M100',
    'family#F10M184',
    'family#F10M196',
    'family#F5M60',
    'family#F7M104',
    'family#F3M45',
    'family#F10M183',
    'family#F10M171',
    'family#F7M131',
    'family#F2M34',
    'family#F6F70',
    'family#F3M50',
    'family#F9M139',
    'family#F9F154',
    'family#F6M88',
    'family#F10M187',
    'family#F6M73',
    'family#F2F24',
    'family#F2M21',
    'family#F8F137',
    'family#F1M4',
    'family#F9M162',
    'family#F10M199',
    'family#F5F65',
    'family#F5M66',
    'family#F9F140',
    'family#F9M151',
    'family#F9M146',
    'family#F7M128',
    'family#F3M44',
    'family#F3M47',
    'family#F2M18',
    'family#F10M182',
    'family#F6F74',
    'family#F7M122',
    'family#F9F141',
    'family#F7M117',
  

In [219]:
!ls ../generated_data/

animals  family  lymphography  nctrer  suramin


In [220]:
emb = pd.read_csv('../generated_data/family/family_emb.csv', index_col = 0)
emb.head()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,30,31,32,33,34,35,36,37,38,39
family,-0.510082,-0.145016,-0.009557,-0.078788,-0.019206,-0.010803,0.349008,0.209991,-0.088038,0.116868,...,-0.439352,0.280357,0.469358,0.083962,-0.110305,-0.296295,-0.071612,0.038979,0.453106,-0.148915
Brother,-0.631511,-0.240898,-0.2507,-0.362371,-0.040183,0.015462,0.141325,0.325979,-0.335817,0.45222,...,-0.307253,0.310754,0.059858,0.328623,-0.105266,-0.234409,-0.476253,0.088189,0.438635,-0.342936
Child,-0.542442,-0.296409,0.082527,-0.274644,-0.006012,0.078398,0.180037,0.444667,0.043876,0.313792,...,0.021881,0.181143,0.132985,0.013914,-0.077513,-0.27381,-0.150218,0.106706,0.416187,-0.546767
Daughter,-0.507977,-0.480087,0.355015,-0.252879,-0.054749,0.083345,0.268233,0.03425,-0.265779,0.377001,...,-0.434662,0.158338,0.053465,-0.056371,-0.070795,-0.387119,0.017065,0.1143,0.385666,-0.424619
family#F10F172,-0.346503,-0.025178,0.056965,-0.381369,0.04077,0.59951,-0.136524,0.235721,-0.040183,0.545229,...,-0.272744,0.309809,0.310318,-0.102748,-0.298615,-0.497718,-0.021591,0.238793,0.062246,0.198001


In [221]:
# Negation input_size, hidden_size, output_size, batch_size
neg_path = './trained_models/family/family_model_Negation.pth'
conj_path = './trained_models/family/family_model_Conjunction.pth'
disj_path = './trained_models/family/family_model_Disjunction.pth'
exist_path = './trained_models/family/family_model_Existential.pth'
uni_path = './trained_models/family/family_model_Universal.pth'

In [222]:
# !ls trained_models/family
# !ls trained_models/family/

In [223]:
neg = Negation(40, 40, 202, 16)
neg.load_state_dict(torch.load(neg_path))
neg.eval()

conj = Conjunction(40, 40, 202, 16)
conj.load_state_dict(torch.load(conj_path))
conj.eval()

disj = Disjunction(40, 40, 202, 16)
disj.load_state_dict(torch.load(disj_path))
disj.eval()

exist = Existential(40, 40, 202, 16)
exist.load_state_dict(torch.load(exist_path))
exist.eval()

uni = Universal(40, 40, 202, 16)
uni.load_state_dict(torch.load(uni_path))
uni.eval()

Universal(
  (fc1): Linear(in_features=80, out_features=40, bias=True)
  (fc2): Linear(in_features=40, out_features=40, bias=True)
  (fc3): Linear(in_features=40, out_features=40, bias=True)
  (head): Linear(in_features=40, out_features=202, bias=True)
  (activation): ReLU()
)

In [224]:
models = {
    'neg': neg,
    'conj': conj,
    'disj': disj,
    'exist': exist,
    'uni': uni
}

In [225]:
print(models['neg'])

Negation(
  (fc1): Linear(in_features=40, out_features=40, bias=True)
  (fc2): Linear(in_features=40, out_features=40, bias=True)
  (fc3): Linear(in_features=40, out_features=40, bias=True)
  (head): Linear(in_features=40, out_features=202, bias=True)
  (activation): ReLU()
)


In [226]:
def convert_to_labels(kb_path, expr):
    # datapoint = []
    KB = KnowledgeBase(path=kb_path) # update file path Done :)
    L = sorted([ind.str.split("/")[-1] for ind in KB.individuals()])
    print(len(L))
    datapoints = [] 
    name_to_ids = {name: idx for idx, name in enumerate(L)} 
    

    y_C = np.zeros(len(L))
    pos = expr[1]['positive examples']
    Ind_pos = list(map(name_to_ids.get, pos))
    y_C[Ind_pos] = 1
    datapoint = (expr[0], list(y_C))
    return datapoint

In [227]:
kb_path = '../NCESData/family/family.owl'
expr = data_family[0]

In [228]:
input = convert_to_labels(kb_path, expr)
input[0]

202


'Brother ⊔ (Parent ⊓ (∀ hasSibling.Father))'

In [229]:
hasSibling = emb.loc['hasSibling'].values
father = emb.loc['Father'].values
parent = emb.loc['Parent'].values
brother = emb.loc['Brother'].values

In [230]:
encoding = models['uni'].encode(torch.FloatTensor(hasSibling).unsqueeze(0), torch.FloatTensor(father).unsqueeze(0))

In [231]:
encoding

tensor([[2.5860, 0.0000, 1.6543, 6.7100, 3.9401, 0.0000, 4.1829, 5.3426, 6.1417,
         4.2385, 2.5556, 3.7432, 3.8188, 2.8912, 6.0619, 0.0000, 0.0000, 6.4436,
         0.0000, 2.9545, 2.1965, 5.9600, 0.0000, 0.0000, 3.3867, 3.1742, 0.0000,
         4.6995, 1.6832, 0.0000, 2.8607, 0.0000, 3.2073, 5.6859, 4.6913, 3.1838,
         1.3843, 6.6142, 4.4438, 4.5054]], grad_fn=<ReluBackward0>)

In [232]:
inter_p = models['conj'].encode(torch.FloatTensor(parent).unsqueeze(0), encoding)

In [233]:
list(inter_p.detach().numpy()[0])

[83.412704,
 -89.638275,
 -54.758286,
 38.977333,
 -51.5317,
 81.157265,
 59.06265,
 -38.610428,
 124.44534,
 57.69741,
 -24.797762,
 67.90391,
 -20.528208,
 53.779522,
 83.910065,
 68.03709,
 -56.935196,
 67.95903,
 84.29726,
 -78.37951,
 96.45507,
 -44.90122,
 65.84144,
 -50.694546,
 -78.67968,
 -92.47435,
 95.54845,
 8.059974,
 58.439354,
 48.626846,
 99.924576,
 82.127174,
 -81.5899,
 72.488716,
 79.00285,
 -86.302475,
 51.04128,
 74.47499,
 -67.17391,
 53.448994]

In [234]:
encoding.shape

torch.Size([1, 40])

In [235]:
y_pred = models['disj'](torch.FloatTensor(brother).unsqueeze(0), inter_p)

In [236]:
y_pred 

tensor([[0., 0., 1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 1., 0., 1., 1., 0., 1.,
         0., 0., 0., 0., 1., 1., 0., 1., 1., 0., 0., 1., 1., 0., 0., 1., 1., 1.,
         0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 0., 0., 1., 1., 1., 1., 1., 1.,
         0., 1., 0., 1., 0., 0., 1., 1., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0.,
         0., 0., 0., 1., 1., 1., 0., 0., 1., 1., 0., 0., 1., 0., 1., 1., 0., 0.,
         0., 0., 1., 0., 0., 0., 0., 0., 1., 1., 1., 0., 0., 0., 1., 1., 1., 0.,
         1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 1., 0., 0., 1., 1., 0., 0., 1.,
         0., 0., 1., 0., 0., 0., 1., 0., 0., 1., 1., 0., 1., 0., 1., 1., 1., 0.,
         1., 1., 1., 0., 0., 0., 0., 1., 1., 1., 1., 1., 0., 0., 1., 1., 0., 0.,
         1., 0., 0., 0., 0., 0., 1., 1., 0., 1., 1., 1., 1., 1., 1., 0., 1., 0.,
         0., 0., 1., 0., 1., 0., 0., 0., 0., 1., 0., 1., 1., 1., 0., 1., 1., 1.,
         0., 1., 1., 0.]], grad_fn=<SigmoidBackward0>)

In [237]:
pred = (y_pred > 0.5).numpy().astype(int)


In [238]:
label = input[1]

In [239]:
accuracy_score(pred.squeeze(), label)

0.25742574257425743

In [240]:
f1_score(pred.squeeze(), label)

0.31818181818181823

In [241]:
my_dict = {'0': 2}
'0' in my_dict

True

In [262]:
def neg(exp):
        exp_b = exp.replace('¬', '')
        b_emb = emb.loc[exp_b].values
        return models['neg'].encode(torch.FloatTensor(b_emb).unsqueeze(0))
    
def process(ele, eval):
        ele = ele.strip().replace('(', '').replace(')', '')
        if ele in eval:
            exp_emb = eval[ele]
        else:
            exp_emb = neg(ele) if '¬' in ele else torch.FloatTensor(emb.loc[ele].values).unsqueeze(0)
        return exp_emb

def exp_three(exp, eval, last=False):
    if ' ' not in exp:
        if '¬' in exp:
            # print('neg', exp)
            return neg(exp)
        else:
            return torch.FloatTensor(emb.loc[exp].values).unsqueeze(0)
        
    elif (h.quant(exp) or h.const(exp)) and h.top_bot(exp):
        if (' ⊔ ' in exp):
            parts = exp.split(' ⊔ ')
            parts = [process(part, eval) for part in parts]
            enc = models['disj'](parts[0], parts[1]) if last else models['disj'].encode(parts[0], parts[1]) 
            return enc
                
        elif (' ⊓ ' in exp):
            parts = exp.split(' ⊓ ')
            parts = [process(part, eval) for part in parts]
            enc = models['conj'](parts[0], parts[1]) if last else models['disj'].encode(parts[0], parts[1])
            return enc
            
        else:
            if (exp.startswith('∀')):
                parts = exp.split(' ')[1].split('.')
                parts = [process(part, eval) for part in parts]
                enc = models['uni'](parts[0], torch.FloatTensor(parts[1])) if last else models['uni'].encode(parts[0], parts[1])
                return enc
                
            if (exp.startswith('∃')):
                parts = exp.split(' ')[1].split('.')
                parts = [process(part, eval) for part in parts]
                enc = models['exist'](parts[0], torch.FloatTensor(parts[1])) if last else models['exist'].encode(parts[0], parts[1])
                return enc
            else:
                print('Failed', exp)
    else:
        print('comp', exp)

In [288]:
def evaluate_parentheses(expression):
    # exp_list = list(expression)
    count_exp = 0
    print('expression', expression)
    computed = {}
    i = 0
    while '(' in expression:
        open_idx = None
        close_idx = None

        for idx, item in enumerate(expression):
            if item == '(':
                open_idx = idx
            if item == ')':
                close_idx = idx
                break
            
        # print('evaluating', expression[open_idx + 1: close_idx])
        inner_results = exp_three(expression[open_idx + 1: close_idx], computed)  


        expression = list(expression)
        computed[str(i)] = inner_results
        expression[open_idx: close_idx + 1] = str(i)
        expression = "".join(expression)
        
        i+= 1

    # print('inner_results', inner_results)
    print('expression', expression)
    if (not (h.quant(expression) or h.const(expression))) and h.top_bot(expression):
        print('complex', expression)
        parts = expression.split(" ⊓ ")
        first = " ⊓ ".join(parts[:2])
        len_first = len(first)
        inner_results = exp_three(first, computed)
        expression = list(expression)
        computed[str(i)] = inner_results
        expression[:len_first] = str(i)
        expression = "".join(expression)
        print(expression)
    
    final = exp_three(expression, computed, True)
    # print('computed', computed)
    return final

# parts = input_string.split(" ⊓ ")
# result = " ⊓ ".join(parts[:2])

# print(result) Grandmother ⊓ 1 ⊓ 3

In [289]:
# Example usage:
# expression = '(Person ⊓ (Male ⊔ (∃ hasChild.(¬Grandparent)))) ⊔ (∃ hasParent.Female)'
expression = convert_to_labels(kb_path, data_family[2])[0]
label = convert_to_labels(kb_path, data_family[2])[1]
y_pred = evaluate_parentheses(expression)

# print(y_pred)
pred = (y_pred > 0.5).numpy().astype(int)
pred
accuracy_score(pred.squeeze(), label)
f1_score(pred.squeeze(), label)

expression Grandmother ⊓ (∃ hasChild.(¬Grandfather)) ⊓ (∀ hasChild.(¬Brother))
expression Grandmother ⊓ 1 ⊓ 3
complex Grandmother ⊓ 1 ⊓ 3
4 ⊓ 3


0.0

In [294]:
test_acc = []
test_f1 = []

for pair in data_family:
    
    converted = convert_to_labels(kb_path, pair)
    expression = converted[0]
    label = converted[1]
    # print(expression)
    if not h.top_bot(expression):
        print("skipping", expression)
        continue
    y_pred = evaluate_parentheses(expression)
    pred = (y_pred > 0.5).numpy().astype(int)
    # pred
    test_acc.append(accuracy_score(pred.squeeze(), label))
    test_f1.append(f1_score(pred.squeeze(), label))
print("\n ############ Done ############### \n")

expression Brother ⊔ (Parent ⊓ (∀ hasSibling.Father))
expression Brother ⊔ 1
expression Sister ⊔ (∃ hasParent.(Brother ⊓ (∃ hasParent.PersonWithASibling)))
expression Sister ⊔ 2
expression Grandmother ⊓ (∃ hasChild.(¬Grandfather)) ⊓ (∀ hasChild.(¬Brother))
expression Grandmother ⊓ 1 ⊓ 3
complex Grandmother ⊓ 1 ⊓ 3
4 ⊓ 3
expression Person ⊓ (∀ hasSibling.(Grandchild ⊓ (∃ hasSibling.(¬Father))))
expression Person ⊓ 3
expression Son ⊔ (∀ married.Grandmother)
expression Son ⊔ 0
expression Grandson ⊔ (∀ hasParent.(¬Grandfather))
expression Grandson ⊔ 1
skipping Person ⊓ (Brother ⊔ (∀ hasChild.(¬Mother))) ⊓ (∀ hasSibling.⊤)
expression PersonWithASibling ⊓ (∀ hasSibling.Mother)
expression PersonWithASibling ⊓ 0
expression Grandparent ⊔ (∃ married.(Male ⊓ (∀ hasParent.(¬Daughter))))
expression Grandparent ⊔ 3
expression Brother ⊔ (Person ⊓ (Granddaughter ⊔ (∀ married.Father)))
expression Brother ⊔ 2
expression Grandparent ⊓ (∀ hasChild.Father)
expression Grandparent ⊓ 0
expression Person ⊓ (∀ 

In [293]:
test_f1

[0.8221343873517787,
 0.2452830188679245,
 0.0,
 0.2318840579710145,
 0.5443037974683543,
 0.13592233009708737,
 0.2702702702702703,
 0.8151658767772512,
 0.6210045662100456,
 0.07407407407407407,
 0.1827956989247312,
 0.11235955056179775,
 0.6099290780141844,
 0.6666666666666667,
 0.25471698113207547,
 0.23776223776223773,
 0.0625,
 0.5577689243027888,
 0.5770750988142292,
 0.3781094527363184,
 0.0,
 0.5020242914979758,
 0.47804878048780497,
 0.0,
 0.39784946236559143,
 0.5125628140703518,
 0.23255813953488372,
 0.07246376811594203,
 0.0,
 0.6938775510204082,
 0.041237113402061855,
 0.6160337552742616,
 0.21138211382113825,
 0.47120418848167533,
 0.7237354085603113,
 0.7575757575757575,
 0.06666666666666667,
 0.28571428571428575,
 0.0,
 0.7443609022556391,
 0.7719298245614035,
 0.7557003257328989,
 0.12598425196850396,
 0.5896414342629482,
 0.39999999999999997,
 0.0,
 0.6926070038910506,
 0.2395209580838323,
 0.6456692913385828,
 0.25,
 0.16666666666666669,
 0.5,
 0.6567164179104478,
