In [None]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt
from collections import Counter

# PSO Related

class Node():
    def __init__(self, threshold=None, left=None, right=None, info_gain=None, value=None,
                particle=None,min_val=None,max_val=None,samples=None):
        self.threshold = threshold
        self.left = left
        self.right = right
        self.info_gain = info_gain
        self.min_val = min_val
        self.max_val = max_val
        self.value = value
        self.particle = particle
        self.samples = samples
        
def gini_index(y):
    y = y.reshape(-1)
    total_samples = len(y)
    classes = set(y)
    gini = 0.0
    
    for c in classes:
        proportion = sum([(1 if label == c else 0) for label in y]) / total_samples
        gini += proportion * (1 - proportion)
    
    return gini
        
def entropy(y):
    unique, counts = np.unique(y, return_counts=True)
    probabilities = counts / counts.sum()
    return -np.sum(probabilities * np.log2(probabilities))

def information_gain(particle, X, y):
    res = multiply_weight(X, particle[0:X.shape[1]])
    threshold = particle[-1]
    X_left, y_left, X_right, y_right = split(X, y, res, threshold)

    parent_entropy = entropy(y)
    left_entropy = entropy(y_left)
    right_entropy = entropy(y_right)
    n = len(y)
    n_left, n_right = len(y_left), len(y_right)

    child_entropy = (n_left / n) * left_entropy + (n_right / n) * right_entropy
    return parent_entropy - child_entropy

def information_gain_gini(parent_y, left_y, right_y):
    parent_gini = gini_index(parent_y)
    left_weight = len(left_y) / len(parent_y)
    right_weight = len(right_y) / len(parent_y)
    left_gini = gini_index(left_y)
    right_gini = gini_index(right_y)
    return parent_gini - (left_weight * left_gini + right_weight * right_gini)

def objective_function(particle):
    return -information_gain(particle, X, y)

def information_gain_gini_given_particle(X, y, a_particle):
    weights = a_particle[:-1]
    threshold = a_particle[-1]
    products = np.dot(X, weights)
    products = (products - np.min(products)) / (np.max(products) - np.min(products))
    X_left = X[np.where(products < threshold)]
    X_right = X[np.where(products > threshold)]
    y_left = y[np.where(products < threshold)]
    y_right = y[np.where(products > threshold)]
    return information_gain_gini(y, y_left, y_right)

class Particle:
    def __init__(self, dim):
        self.position = np.random.rand(dim)
        self.velocity = np.random.rand(dim)
        self.best_position = self.position.copy()
        self.best_fitness = float('-inf')

def apply_PSO(num_particles, num_epochs, X, y, inertia_weight=0.25, cognitive_weight=0.5, social_weight=0.5, mask=None):
    if not isinstance(mask, np.ndarray):                
        mask = np.ones(X.shape[1])

    dim = X.shape[1] + 1
    swarm = [Particle(dim) for _ in range(num_particles)]

    global_best_position = Particle(dim).position
    global_best_fitness = float('-inf')

    for epoch in range(num_epochs):
        for particle in swarm:
            fitness = information_gain_gini_given_particle(X, y, particle.position)
            if fitness > global_best_fitness:
                global_best_position = particle.position.copy()
                global_best_fitness = fitness
            if fitness > particle.best_fitness:
                particle.best_position = particle.position.copy()
                particle.best_fitness = fitness

        for particle in swarm:
            r1 = np.random.rand(dim)
            r2 = np.random.rand(dim)

            particle.velocity = (inertia_weight * particle.velocity +
                                 cognitive_weight * r1 * (particle.best_position - particle.position) +
                                 social_weight * r2 * (global_best_position - particle.position))
            particle.position += particle.velocity
            particle.position[:-1] = particle.position[:-1] * mask

    return global_best_position, global_best_fitness

def split(X, y, positions):
    weights = positions[:-1]
    threshold = positions[-1]
    products = np.dot(X, weights)
    
    min_products = np.min(products)
    max_products = np.max(products)
    products = (products - min_products) / (max_products - min_products)
    
    X_left = X[np.where(products < threshold)]
    X_right = X[np.where(products >= threshold)]
    y_left = y[np.where(products < threshold)]
    y_right = y[np.where(products >= threshold)]
    
    return X_left, y_left, X_right, y_right, min_products, max_products

def calculate_leaf_value(Y):
    Y = list(Y)
    return max(Y, key=Y.count)

def build_tree(X, y, curr_depth=0, max_depth=5, max_num_particles=10, max_num_epochs=10, min_split_size=3, mask=None):
    num_particles = max(max_num_particles * (curr_depth + 1), 10)
    num_epochs = max(max_num_epochs * (curr_depth + 1), 10)
    
    cntr = Counter(y)
    nm = len(list(cntr.keys()))
    
    if curr_depth < max_depth and X.shape[0] > min_split_size and nm != 1:  
        best_position, best_fitness = apply_PSO(num_particles, num_epochs, X, y, mask=mask)
        X_left, y_left, X_right, y_right, min_products, max_products = split(X, y, best_position)
        print("depth",curr_depth,X.shape,"next",curr_depth+1,X_left.shape,X_right.shape)
        print("\n",Counter(list(y.reshape(-1))),"\n",Counter(list(y_left.reshape(-1))),"\n",Counter(list(y_right.reshape(-1))))
        
        if X_left.shape[0] == 0 or X_right.shape[0] == 0:
            leaf_value = calculate_leaf_value(y)
            return Node(value=leaf_value)

        left_subtree = build_tree(X_left, y_left, curr_depth=curr_depth+1, max_depth=max_depth,
                                  max_num_particles=max_num_particles, max_num_epochs=max_num_epochs,
                                  min_split_size=min_split_size, mask=mask)
        right_subtree = build_tree(X_right, y_right, curr_depth=curr_depth+1, max_depth=max_depth,
                                   max_num_particles=max_num_particles, max_num_epochs=max_num_epochs,
                                   min_split_size=min_split_size, mask=mask)
        return Node(threshold=best_position[-1], left=left_subtree, right=right_subtree,
                    info_gain=best_fitness, particle=best_position, min_val=min_products, max_val=max_products,
                    samples=X.shape[0])
    
    leaf_value = calculate_leaf_value(y)
    return Node(value=leaf_value)

def make_prediction(x, tree):
    if tree.value is not None: return tree.value

    particle = tree.particle
    weights = particle[:-1]
    res = np.dot(x, weights)

    min_val = tree.min_val
    max_val = tree.max_val
    res = (res - min_val) / (max_val - min_val)
    if res < tree.threshold:
        return make_prediction(x, tree.left)
    else:
        return make_prediction(x, tree.right)

def predict(X, root):
    predictions = [make_prediction(x, root) for x in X]
    return predictions

def score(predicted_labels, y):
    correct = sum(1 for i in range(len(y)) if predicted_labels[i] == y[i])
    return correct / len(y)

def traverse_get_weights_samples(root):
    if not root:
        return []
    stack, weights, populations = [(root, 0)], [], []
    
    while stack:
        node, level = stack.pop()
        if node:
            if isinstance(node.particle, np.ndarray):                
                weights.append(node.particle[:-1])
                populations.append(node.samples)
            stack.append((node.right, level + 1))
            stack.append((node.left, level + 1))
    return weights, populations

def calculate_weighted_average_particles(weights, populations, prune_rate=1.2):
    total = 0
    vals = np.zeros(weights[0].shape[0])
    for i in range(len(weights)):
        vals += np.array(weights[i]) * populations[i]
        total += populations[i]
    vals = vals / total
    vals = (vals - np.min(vals)) / (np.max(vals) - np.min(vals))
    thresh = prune_rate * np.std(vals)
    vals[vals < thresh] = 0
    vals[vals >= thresh] = 1
    return vals

def calculate_feat_importance(weights, populations):
    total = 0
    vals = np.zeros(weights[0].shape[0])
    for i in range(len(weights)):
        vals += np.array(weights[i]) * populations[i]
        total += populations[i]
    vals = vals / total
    vals = (vals - np.min(vals)) / (np.max(vals) - np.min(vals))    
    return vals    


In [None]:

# Parameters
depth = 6
max_num_epochs = 100
max_num_particles = 100
min_split_size = 5
prune_rate = 1.0

# Load data
data = pd.read_csv('/Users/dr.ashhadulislam/projects/postDoc_HBKU/new_ideas/Explainability/nip2022_interpretable_dts/interpretable-dts/data/audit_data/audit_risk_formatted.csv')
# data = pd.read_csv('alger_forest_fire.csv')
print("data shape", data.shape)
print(data.columns)

# Assuming the last column is the label
X = data.iloc[:, :-1].values
y = data.iloc[:, -1].values
X=np.nan_to_num(X)
# Split data into train and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

print(X_train.shape,X_test.shape,y_train.shape,y_test.shape)


In [None]:
y_train

In [None]:
# Build and evaluate the initial tree
tree = build_tree(X_train, y_train, max_depth=depth, max_num_epochs=max_num_epochs * 2, max_num_particles=max_num_particles * 2, min_split_size=min_split_size)
y_pred = predict(X_test, tree)
pre_acc = accuracy_score(y_test, y_pred)

print(f"Pre-Pruning Accuracy: {pre_acc:.2f}")



In [None]:
# Pruning
weights, populations = traverse_get_weights_samples(tree)
mask = calculate_weighted_average_particles(weights, populations, prune_rate=prune_rate)
compression = len(np.where(mask == 0)[0]) / mask.size

if compression >= 1:
    print(f"Prune rate {prune_rate} is too high. Consider selecting a lower prune rate.")
else:
    tree_pruned = build_tree(X_train, y_train, max_depth=depth, max_num_epochs=max_num_epochs, max_num_particles=max_num_particles, min_split_size=min_split_size, mask=mask)
    weights_pruned, populations = traverse_get_weights_samples(tree_pruned)
    feat_importance = calculate_feat_importance(weights_pruned, populations)
    feat_importance = feat_importance * mask

    # Mapping weights to column names
    importance_df = pd.DataFrame({
        'Feature': data.columns[:-1],
        'Importance': feat_importance
    })

    print(importance_df.head())

    # Sort by importance
    importance_df = importance_df.sort_values(by='Importance', ascending=False)

    # Plotting the feature importance
    print("Feature Importance After Pruning")
    fig, ax = plt.subplots(figsize=(10, 6))
    sns.barplot(x='Importance', y='Feature', data=importance_df, ax=ax)
    ax.set_title("Feature Importance")
    plt.show()

    y_pred_pruned = predict(X_test, tree_pruned)
    post_acc = accuracy_score(y_test, y_pred_pruned)

    print(f"Post-Pruning Accuracy: {post_acc:.2f}")

    # Precision, Recall, F1 Score
    precision = precision_score(y_test, y_pred_pruned, average="weighted")
    recall = recall_score(y_test, y_pred_pruned, average="weighted")
    f1 = f1_score(y_test, y_pred_pruned, average="weighted")

    print(f"Precision: {precision:.2f}")
    print(f"Recall: {recall:.2f}")
    print(f"F1 Score: {f1:.2f}")

    # Confusion Matrix
    cm = confusion_matrix(y_test, y_pred_pruned)
    print("Confusion Matrix")
    fig, ax = plt.subplots()
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=ax)
    plt.show()
