In [73]:
class Predicate:
    def __init__(self, name, variables,negation = False):
        """
        Predicate class.
 
        Attributes
        --------------------
            name              --  Name of the table which the predicate corresponds to
            variables         --  Array of variable objects
            negation          --  Whether the predicate is negated or not
        """
        
        self.name = name
        self.variables = variables
        self.negation = negation
        
class Variable:
    def __init__(self, name, atom = False):
        """
        Predicate class.
 
        Attributes
        --------------------
            name              --  Name of the variable (e.g. "x1", "1")
            quantifier        --  "exist" or "forall"
            atom              --  Whether the variable is a ground atom
        """
        
        self.name = name
        self.quantifier = "exist"
        self.atom = atom

class Lift:
    def __init__(self, query, database):
        self.query = query
        self.database = database
        # predicate[][] query, [CNF1, CNF2,...]
        # dict[name: dict[(): int]] database
    
    def findInTable(self,tableName,variables): # str tableName; str_tuple variables
        """
        Find the probability of a variable tuple in a table, -1 if not found
 
        Parameters
        --------------------
            tablename         --  String, Name of the table
            variables         --  Tuple of name of ground variables(atom), (e.g. ("1","2"))
 
        Returns
        --------------------
            p       --  the probability of a variable tuple in a table, 
                        -1 if not found in table,
                        -400 if table does not exist
        """
        
        if tableName not in self.database:
            return -400
        table = self.database[tableName]
        if len(variables)==1:
            variables = variables[0]
        if variables not in table:
            return 0    ##return 0
        else:
            return table[variables]
    
    def find_separator(self,query):
        variables_at = dict()
        if len(query) != 1:
            print("not conjunction query")
            return -1
        predicates = [predicate for predicate in query[0]] 
        for predicate in predicates:
            for variable in predicate.variables:
                if variable.atom == True:
                    continue
                if variable.name not in variables_at.keys():
                    variables_at[variable.name] = []
                variables_at[variable.name].append(predicate.name)
        for key in variables_at:
            if len(variables_at[key]) == len(predicates):
                return key
        print("no separator")
        return -1
           
    def convert_to_uni(self, query, separator):
        if len(query) != 1:
            print("not conjunction query")
            return -1
        predicates = [predicate for predicate in query[0]] 
        sepa_values = []
        for predicate in predicates:
            for id_v, variable in enumerate(predicate.variables):
                if variable.name == separator:
                    table = self.database[predicate.name]
                    tuples = table.keys()
                    for tuple_ in tuples:
                        if not tuple_[id_v] in sepa_values:
                            sepa_values.append(tuple_[id_v])      
        variables = []
        for predicate in predicates:
            for variable in predicate.variables:
                if (not variable in variables) and (variable.atom == False):
                    variables.append(variable)               
        if all([variable.quantifier == "universal" for variable in variables]):
            result = 1
            for sepa_value in sepa_values:
                temp_query = [[]]
                for id_p, predicate in enumerate(predicates):
                    temp_query[0].append(predicate)
                    for id_v, variable in enumerate(predicate.variables):
                        if variable.name == separator:
                            temp_query[0][id_p].variables[id_v].name = sepa_value
                            temp_query[0][id_p].variables[id_v].atom = True
                result = result * self.infer(temp_query)
            return result
        elif all([variable.quantifier == "exist" for variable in variables]):
            result = 1
            temp_query = [[]]
            for id_p, predicate in enumerate(predicates):
                temp_query[0].append(predicate)
                for id_v, variable in enumerate(predicate.variables):
                    if variable.name == separator:
                        for sepa_value in sepa_values:
                            temp_query[0][id_p].variables[id_v].name = sepa_value
                            temp_query[0][id_p].variables[id_v].atom = True
                            temp_query[0][id_p].variables[id_v].quantifier = "universal"
                            result = result * (1 - self.infer(temp_query))
            return 1 - result                  
            
                    
    def Step0(self, query):
        """
        Step 0 of Lifted Inference Algorithm 
 
        Parameters
        --------------------
 
        Returns
        --------------------
            p           --  the probability of a variable tuple in a table, 
                        -1 if not applicable,
                        -400 if query is empty
        """
        if not query:
            return -400
        if len(query) == 1 and len(query[0]) == 1:
            predicate = query[0][0]
            if all([var.atom for var in predicate.variables]):
                parameter = tuple([var.name for var in predicate.variables])
                # tuple of variables in predicate
                p = self.findInTable(predicate.name,parameter)
                return p
        return -1
    
    def Step4(self, query):
        if not query:
            return -400
#         flag, q = separa_inde(query)
#         if flag:
#             return self.infer(q[0])*self.infer(q[1])
#         else:
#             return -1
        return -1
            
    def Step5(self, query):
        if not query:
            return -400
        separator = self.find_separator(query)
        if separator != -1:
            result = self.convert_to_uni(query, separator)
            return result
        else:
            return -1
            
    def infer(self, query):
        flag_0 = self.Step0(query)
        if flag_0 == -400:
            print("not query")
        elif flag_0 == -1:
            flag_4 = self.Step4(query)
            if flag_4 == -400:
                print("not query")
            elif flag_4 == -1:
                flag_5 = self.Step5(query)
                if flag_5 == -400:
                    print("not query")
                elif flag_5 == -1:
                    print("fail")
                    return -1
                else:
                    return flag_5
            else:
                return flag_4
        else:
            return flag_0

In [17]:
db = {"S":{ ("1","1"):0.6, ("2","1"):0.4, ("1","2"):0.8 } }
v1 = Variable("1", atom = True)
v2 = Variable("2", atom = True)
p = Predicate("S",[v1,v2])
lift = Lift([[p]],db)
lift.infer(lift.query)

0.8

In [74]:
db = {"S":{ ("1","1"):0.6, ("2","1"):0.5, ("1","2"):0.7 } , "H":{"1":0.6}, "R":{"1":0.8}, "Q":{ ("1","1"):0.6, ("1","2"):0.5}}
x1 = Variable("x", atom = False)
x2 = Variable("y", atom = False)
h = Predicate("H",[x1])
r = Predicate("R",[x1])
s = Predicate("Q",[x1,x2])
lift = Lift([[s]],db)
result = lift.infer(lift.query)
print(result)

('1', '2')
('1', '1')
0.8
