In [1]:
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

In [2]:
def plot_dataset(X, y):
    colors = np.array(['#1f77b4', '#ff7f0e'])
    
    plt.plot(X[y==0, 0], X[y==0, 1], c=colors[0])
    plt.plot(X[y==1, 0], X[y==1, 1], c=colors[1])
    
    plt.axis('equal')
    plt.grid(True)
    plt.xlim([-0.2, 1.2])
    plt.ylim([-0.3, 1.3])
    plt.xlabel('$feature1$')
    plt.ylabel('$feature2$')

def plot_decision_boundary(X, y, model, output_directory, generator_name,
                           plot_name, plot_id, show_plot=False):
    y = y.reshape((len(y), ))
    x_min = np.min(X[:, :], axis=0)
    x_max = np.max(X[:, :], axis=0)

    x0, x1 = np.meshgrid(
            np.linspace(x_min[0] * 1.5, x_max[0] * 1.5, 500).reshape(-1, 1),
            np.linspace(x_min[1] * 1.5, x_max[1] * 1.5, 500).reshape(-1, 1),
        )
    
    x_new = np.c_[x0.ravel(), x1.ravel()]

    y_new = np.argmax(model.predict_proba(x_new), axis=1)

    plot_dataset(X, y)

    zz = y_new.reshape(x0.shape)
    plt.contour(x0, x1, zz, levels=np.array([0.5]), colors='k')
    
    plt.suptitle(f'Recourse generated by {generator_name.upper()}')
    plt.savefig(f"{output_directory}/{generator_name}_{plot_name}_{f'{plot_id:02}'}.png", bbox_inches='tight')
    
    if show_plot:
        plt.show()
        
    plt.close()

In [10]:
def plot_distribution(data, model, output_directory, generator_name, plot_name, plot_id, show_plot=False):
    data = data.to_numpy()
    x_min = np.min(data[:, :], axis=0) - 1
    x_max = np.max(data[:, :], axis=0) + 1
    
    x0, x1 = np.meshgrid(np.arange(x_min[0], x_max[0], 0.01),
                         np.arange(x_min[1], x_max[1], 0.01))
    
    x_new = np.c_[x0.flatten().reshape((-1, 1)),
                  x1.flatten().reshape((-1, 1))]
    
    y_new = model.predict_proba(x_new)[:, 1]
    
    z = y_new.reshape(x0.shape)
    
    y = data[:, 2]
    y = y.reshape((len(y), ))
    
    plt.figure(dpi=150)
    plt.axis('equal')
    plt.grid(True)
    plt.xlim([-0.25, 1.25])
    plt.ylim([-0.25, 1.25])
    plt.xlabel('$feature1$')
    plt.ylabel('$feature2$')
    
    plt.contourf(x0, x1, z, cmap='viridis', alpha=0.8)
    
    plt.scatter(data[y == 0, 0], data[y == 0, 1], s=60,
                cmap='Paired', linewidth=1, edgecolor='black')
    plt.scatter(data[y == 1, 0], data[y == 1, 1], s=60,
                cmap='Paired', linewidth=1, edgecolor='black')
    
    plt.show()