In [1]:
import os
import pandas as pd
from datasketch import HyperLogLog

# Points
For now i am assuming that primary key is only a single attribute

In [2]:
class IND:
    """
    A class representing Inclusion Dependencies between attributes.
    
    Parameters
    ----------
    dependent : Attribute
        The dependent attribute in the inclusion dependency
    reference : Attribute
        The reference attribute in the inclusion dependency
    """
    
    def __init__(self, dependent, reference):
        self.dependent = dependent
        self.reference = reference

In [3]:
class Attribute:
    """
    A class representing a database table attribute with metadata for primary key detection.
    
    Parameters
    ----------
    table_name : str
        Name of the database table
    attribute_name : str
        Name of the attribute/column
    values : list
        List of values in the attribute
        
    Attributes
    ----------
    fullName : str
        Fully qualified name (table_name.attribute_name)
    uniquness : float
        Estimated uniqueness score using HyperLogLog
    cardinality : int
        Cardinality score (default 1)
    value_length : float
        Score based on value lengths
    position : float 
        Score based on column position 
    suffix : int
        Score based on common primary key suffixes
    pkScore : float
        Combined score for primary key likelihood
    
    Methods
    -------
    estUniqueness()
        Estimates uniqueness of values using HyperLogLog
    check_suffix(suffix_list)
        Checks if attribute name contains common primary key suffixes
    """

    def __init__(self, table_name, attribute_name, values):
        self.table_name = table_name
        self.attribute_name = attribute_name
        self.values = values
        self.fullName=f"{self.table_name}.{self.attribute_name}"

        self.uniquness = self.estUniqueness()
        self.cardinality=1
        self.value_length = 1/max(1, max([len(x) for x in values]) - 8) #8 is a hyper-parameter which penalties primary key candidates which has a value with length > 8
        self.position = 0
        self.suffix = self.check_suffix()

        self.pkScore = 0
        self.pkScore += self.uniquness
        self.pkScore += self.cardinality
        self.pkScore += self.value_length
        self.pkScore += self.position
        self.pkScore += self.suffix

    def estUniqueness(self):
        """
        Estimate uniqueness of attribute values using HyperLogLog algorithm.
        
        Returns
        -------
        float
            Ratio of unique values to total values
        """

        hll = HyperLogLog()
        total = 0
        
        for value in self.values:
            hll.update(str(value).encode('utf8'))
            total +=1
        
        return hll.count() / total
    
    def check_suffix(self, suffix_list=["key", 'id', 'nr', 'no']):
        """
        Check if attribute name contains common primary key suffixes.
        
        Parameters
        ----------
        suffix_list : list of str, optional
            List of common primary key suffix strings to check
            
        Returns
        -------
        int
            1 if suffix found, 0 otherwise
        """
        for suffix in suffix_list:
            if suffix in self.attribute_name:
                return 1
            else:
                return 0


In [None]:
def load_csv_files(directory_path):
    """
    Load CSV files from a directory and create Attribute objects.
    
    Parameters
    ----------
    directory_path : str
        Path to directory containing CSV files
        
    Returns
    -------
    dict
        Dictionary mapping "table.column" strings to Attribute objects
        
    Notes
    -----
    Prints progress information during loading
    """
    attributes = {}

    csv_files = [f for f in os.listdir(directory_path)]

    print(f"Found {len(csv_files)} \n CSV files: {csv_files}")    

    for filename in csv_files:
        file_path = os.path.join(directory_path, filename)
        table_name = os.path.splitext(filename)[0]

        df = pd.read_csv(file_path)
        print(f"Processing {filename}: {df.shape[0]} rows, {df.shape[1]} columns")

        for i, column in enumerate(df.columns):
            non_null_values = df[column].astype(str).dropna().tolist()
            if non_null_values:
                attr = Attribute(table_name, column, non_null_values)
                attr.position = 1/(i+1)
                attributes[f"{table_name}.{column}"] = attr
                print(f"Added attribute: {attr.table_name}.{attr.attribute_name} Total Values: {len(attr.values)}")
                
    return attributes
    
attributes = load_csv_files("/home/haseeb/Desktop/EKAI/ERD_automation/Dataset/train/northwind-db")            
            

Found 11 
 CSV files: ['employee_territories.csv', 'products.csv', 'orders.csv', 'customers.csv', 'territories.csv', 'orders_details.csv', 'suppliers.csv', 'employees.csv', 'categories.csv', 'shippers.csv', 'regions.csv']
Processing employee_territories.csv: 49 rows, 2 columns
Added attribute: employee_territories.employeeid Total Values: 49
1.0
Added attribute: employee_territories.territoryid Total Values: 49
0.5
Processing products.csv: 77 rows, 10 columns
Added attribute: products.productid Total Values: 77
1.0
Added attribute: products.productname Total Values: 77
0.5
Added attribute: products.supplierid Total Values: 77
0.3333333333333333
Added attribute: products.categoryid Total Values: 77
0.25
Added attribute: products.quantityperunit Total Values: 77
0.2
Added attribute: products.unitprice Total Values: 77
0.16666666666666666
Added attribute: products.unitsinstock Total Values: 77
0.14285714285714285
Added attribute: products.unitsonorder Total Values: 77
0.125
Added attribut

In [None]:
def extractPrimaryKeys(attributes):
    """
    Extract primary keys from attributes based on their pkScore.

    Parameters
    ----------
    attributes : dict
        Dictionary mapping "table.column" strings to Attribute objects

    Returns
    -------
    dict
        Dictionary mapping table names to tuples of (column name, pkScore)
    """
    pk_table = {}  # {table name: (column name, score)}

    for key, value in attributes.items():
        table_name = key.split(".")[0]
        current_pk = pk_table.get(table_name)
        if not current_pk or current_pk[1] < value.pkScore:
            pk_table[table_name] = (value.fullName, value.pkScore)
    return pk_table

pk_table = extractPrimaryKeys(attributes=attributes)

In [6]:
def read_IND(file_path):
    """
    Read inclusion dependencies from a file.
    
    Parameters
    ----------
    file_path : str
        Path to file containing inclusion dependencies
        
    Returns
    -------
    list of IND
        List of inclusion dependency objects
        
    Notes
    -----
    File format should be one dependency per line as: dependent=reference
    """
    inds = []
    with open(file_path, "r") as f:
        for line in f:
            vars = line.strip().split("=")
            inds.append(IND(attributes[vars[0]], attributes[vars[1]]))
    return inds
inds = read_IND("/home/haseeb/Desktop/EKAI/ERD_automation/codes/inclusionDependencyWithSpider/spider_results/northwind.txt")
print(f"{len(inds)=}")

len(inds)=111


# FK candidate prerfiltering

In [7]:
def prefiltering(inds):
    """
    Filter inclusion dependencies based on primary key and null value criteria.
    
    Parameters
    ----------
    inds : list of IND
        List of inclusion dependency objects to filter
        
    Returns
    -------
    list of IND
        Filtered list of inclusion dependencies that meet criteria:
        - Reference attribute is a primary key
        - Neither dependent nor reference is all null values
        
    Notes
    -----
    Uses global pk_table dictionary for primary key lookup
    """
    pruned_inds = []
    for ind in inds:
        #Checking if reference variable is a primary key
        is_pk = False
        for table_name, pk in pk_table.items():
            if pk[0].split(".")[1] == ind.reference.attribute_name:
                is_pk=True
                break

        #Checking if either all of the dependent or reference attribute is null
        dependent_all_null = True
        reference_all_null = True
        for value in ind.reference.values:
            if value != "nan":
                reference_all_null = False
        for value in ind.dependent.values:
            if value !="nan":
                dependent_all_null = False
        
        if is_pk and (not reference_all_null) and (not dependent_all_null):
            pruned_inds.append(ind)

    return pruned_inds

inds = prefiltering(inds=inds)
print(len(inds))


85
