In [7]:
from dataclasses import dataclass
import networkx as nx
import re
from itertools import combinations
from collections import defaultdict
from string import Template
import re
import psycopg2
import hypernetx as hnx


In [8]:
dc_violation_template=Template("SELECT t1.* FROM $table t1 WHERE EXISTS (SELECT t2.* FROM $table AS t2 WHERE ($dc_desc));")
ops = re.compile(r'IQ|EQ')   

@dataclass
class Complaint:
	"""
	complaint_type: 
		DC or LF
	complaint_instance: 
		if DC then should be cell attribute name
		if LF then should be sentence text
	"""

	complain_type:str 
	attr_name:str='foo' # DC only
	tid:int=-1 # DC only
        
class RulePruner:
    """
    Given a set of rules (DCs or LFs), and a usercomplaint
    prune the input rules based on the following principle:
        1.DCs: produce a graph where each edge represents
        a pair of attributes appearing in the DC, prune out
        DCs that are not connected to the attriute which is
        the attribute that the complaint comes from
        2.LFs: prune out LFs that gives the user complaint
        trivial label (ABSTAINs) 
    """

    def __init__(self, ):
        pass 

    def prune_and_return(self, complaint, rules):
        if(complaint.complain_type=='DC'):
            rule_graph = hnx.Hypergraph()
            # undirected graph
            rule_dict = {}
            res=[]
            nodes=set([])
            i=0
            for r in rules:
                rule_nodes=set(re.findall(r't[12]\.([-\w]+)', r))
            #         print(rule_nodes)
                rule_dict[r]=rule_nodes
            for k,v in rule_dict.items():
                if(complaint.attr_name in v):
                    res.append(k)
                    nodes=nodes.union(v)
            return len(res),res        
# 			rule_graph = nx.Graph()
# 			# undirected graph
# 			rule_dict = {}
# 			for r in rules:
# 				rule_nodes=list(set(re.findall(r't[12]\.(\w+)', r)))
# 				rule_edges = list(combinations(rule_nodes, 2))
# 				rule_graph.add_edges_from(rule_edges)
# 				rule_dict[r]=rule_nodes
# 			start_node=complaint.attr_name
# 			useful_nodes=[start_node] + [v for u, v in nx.bfs_edges(rule_graph, start_node)]
# 			set_useful_nodes=set(useful_nodes)
# 			print(f"set_useful_nodes: {set_useful_nodes}")
# 			useless_rules = [k for (k,v) in rule_dict.items() if not set(v).intersection(set_useful_nodes)]
# 			print(f'useless rules: {useless_rules}')
# 			return rule_graph, [k for (k,v) in rule_dict.items() if set(v).intersection(set_useful_nodes)]
# 		else:
# 			pass

class DataPruner:
	"""
	Given a set of rules, prune data(tuples/sentences) based on if 
	the data points have any effect on the rules
	"""
	def __init__(self,):
		pass 

	def dc_prune_and_return(self, db_conn, target_table, pruned_rules):
		drop_if_exist_q = f"drop table if exists {target_table}_pruned"
		drop_if_exist_intermediate_q = f"drop table if exists {target_table}_pruned_intermediate"
		create_q=f"create table {target_table}_pruned as select * from {target_table} limit 0"
		create_intermediate_q=f"create table {target_table}_pruned_intermediate as select * from {target_table} limit 0"
		cur = db_conn.cursor()
		cur.execute(drop_if_exist_q)
		cur.execute(drop_if_exist_intermediate_q)
		cur.execute(create_q)
		cur.execute(create_intermediate_q)
		for r in pruned_rules:
			r_q  = dc_violation_template.substitute(table=target_table, dc_desc=self.parse_rule_to_where_clause(r))
			cur.execute(f"INSERT INTO {target_table}_pruned_intermediate {r_q}")
			print(f"INSERT INTO {target_table}_pruned_intermediate {r_q}")
		cur.execute("SELECT COLUMN_NAME from information_schema.columns WHERE table_schema = 'public' AND table_name = 'adult'");
		col_names = ', '.join([f'"{x[0]}"' for x in cur.fetchall()])
		q_insert_distinct = f"""
		WITH distincts AS (SELECT COUNT(*) AS cnt, {col_names} from {target_table}_pruned_intermediate
		GROUP BY {col_names}) INSERT INTO {target_table}_pruned SELECT {col_names} FROM distincts
		"""
		cur.execute(q_insert_distinct)
		q_cnt_pruned=f"select count(*) from {target_table}_pruned"
		q_cnt_before_pruned=f"select count(*) from {target_table}"
		cur.execute(q_cnt_before_pruned)
		cnt_before = cur.fetchone()[0]
		cur.execute(q_cnt_pruned)
		cnt_after = cur.fetchone()[0]
		print(f"before pruning data: {target_table} has {cnt_before} rows")
		print(f"after pruning data: {target_table}_pruned has {cnt_after} rows")
	def lf_prune_and_return(self, ):
		pass

	def parse_rule_to_where_clause(self, rule):
		# support EQ and IQ only so far
		res = []
		for xl in rule.split('&'):
			if(ops.search(xl)):
				if(ops.search(xl).group()=='EQ'):
					sign='='
				else:
					sign='!='
				bracket_content = re.findall(r'\((.*)\)', xl)[0]
				res.append(sign.join(re.sub(r'(t[1|2]\.)([-\w]+)', r'\1"\2"', bracket_content).split(',')))
		return ' AND '.join(res)


In [9]:
test_rules_hospital="""t1&t2&IQ(t1.native-country,t2.native-country)&EQ(t1.marital-status,t2.marital-status)&EQ(t1.workclass,t2.workclass)
t1&t2&IQ(t1.race,t2.race)&EQ(t1.hours-per-week,t2.hours-per-week)&IQ(t1.income,t2.income)
t1&t2&IQ(t1.race,t2.race)&IQ(t1.native-country,t2.native-country)&IQ(t1.income,t2.income)
t1&t2&IQ(t1.sex,t2.sex)&EQ(t1.hours-per-week,t2.hours-per-week)&IQ(t1.income,t2.income)"""

test_rules = test_rules_hospital.split('\n')
# print(test_rules)
c=Complaint(complain_type='DC', attr_name='sex', tid=495)
rp=RulePruner()
res, res_rules = rp.prune_and_return(complaint=c,rules=test_rules)
print(f"before pruning: we have {len(test_rules)} rules ")
print(f"after pruning: we have {res} rules ")
res_rules

before pruning: we have 4 rules 
after pruning: we have 1 rules 


['t1&t2&IQ(t1.sex,t2.sex)&EQ(t1.hours-per-week,t2.hours-per-week)&IQ(t1.income,t2.income)']

In [10]:
conn = psycopg2.connect(dbname="holo", user="holocleanuser", password="abcd1234")
conn.autocommit=True

In [11]:
dp=DataPruner()
dp.dc_prune_and_return(db_conn=conn,target_table='adult',pruned_rules=res_rules)

INSERT INTO adult_pruned_intermediate SELECT t1.* FROM adult t1 WHERE EXISTS (SELECT t2.* FROM adult AS t2 WHERE (t1."sex"!=t2."sex" AND t1."hours-per-week"=t2."hours-per-week" AND t1."income"!=t2."income"));
before pruning data: adult has 500 rows
after pruning data: adult_pruned has 421 rows
