In [3]:
import seaborn as sns
import matplotlib.pyplot as plt
import os
import numpy as np

In [4]:
def outlier_vars(data, show_plot=False, save_img=os.getcwd()+'/outliers.png'):
    
    """
    This functions checks for columns with outlers using the IQR method
    
    It accespts as argmuent a dataset. 
    show_plot can be set to True to output pairplots of outlier columns    
    """
    
    outliers = [] 
    Q1 = data.quantile(0.25)
    Q3 = data.quantile(0.75)
    IQR = Q3 - Q1
    num_data = data.select_dtypes(include='number')
    result = dict ((((num_data < (Q1 - 1.5 * IQR)) | (num_data > (Q3 + 1.5 * IQR)))==True).any())
    for k,v in result.items():
        if v == True:  
            outliers.append(k)
    if show_plot:
        pair_plot = sns.pairplot(data[outliers]);
        print(f'{result},\n\n Visualization of outlier columns')
        plt.savefig(fname=save_img, format='png')
        return pair_plot
    else:
        return data[outliers]

In [5]:
def outlier_treatment(data, col_list, type='median_replacement'):
    
    """
    This treat outliers using any ofthses 3 methods as specified by user
    
        1. median replacement
        
        2. quantile flooring
        
        3. trimming 
        
        4. log transformations
    
    The methods are some of the commont statistical methods in treating outler
    columns
    
    By default treatment type is set to median replacement

    """
    
    if type == "median_replacement":
        
        for col in col_list:
            median = data[col].quantile(0.50)
            q1 = data[col].quantile(0.25)
            q3 = data[col].quantile(0.75)
            iqr = q3 - q1
            high = q3 + 1.5 * iqr
            low = q1 - 1.5 * iqr
            print(q3 + 1.5 * iqr)
            data[col]=np.where(data[col] > high, median, data[col])
            data[col]=np.where(data[col] < low, median, data[col])
        
    
    if type == "quantile_flooring":
        
        for col in col_list:
            q_10 = data[col].quantile(0.10)
            q_90 = data[col].quantile(0.90)
            data[col] =  data[col] = np.where(data[col] < q_10, q_10 , data[col])
            data[col] =  data[col] = np.where(data[col] > q_90, q_90 , data[col])
            
    if type == "trimming":
        
        for col in col_list:
            q1 = data[col].quantile(0.25)
            q3 = data[col].quantile(0.75)
            iqr = q3 - q1
            high = (q3 + 1.5) * iqr 
            low = (q1 - 1.5) * iqr
            index = data[(data[col] >= high)|(data[col] <= low)].index
            print(col,'\n', index)
            data[col] = data[col].drop(index)
            
    if type == "log_transformations":
        for col in outlier_cols:
            data[col] = data[col].map(lambda i: np.log(i) if i > 0 else 0)
        

    return data

In [6]:
def plot_univariate (data, x=None, y=None, color='r',save=False,
                title='New Chart', chart_type='hist', xlabel='', ylabel='',
                    save_to=os.getcwd(), log_normalise=False):
    
    
    """
    Make a univariate plot of any of these selcted types:
    
    1. bar - barchart
    
    2. hist - Histogram
    
    3. pie - Piechart
    
    4. count - Countplot
    
    
    """
    
    plt.subplots(figsize=(10,7))
    plt.title(title, fontsize=18)
    plt.xlabel(xlabel, fontsize=15)
    plt.ylabel(ylabel, fontsize=15)
    plt.xticks(fontsize=12)
    plt.yticks(fontsize=12)
    
    
    if chart_type == 'hist':
        if log_normalise:
            data = np.log(data)
        plot = sns.distplot(a=data, color=color)
        if save:
            plt.savefig(fname=save_to+f'/{title}.png', format='png')
        
    return plot

In [None]:
def plot_bivariate(data, x=None, y=None, hue=None, 
                  color='r',save=False,
                title='New Chart', chart_type='hist',
                   xlabel='', ylabel='',
                    save_to=os.getcwd(), img_name = " ", 
                   palette={'use':False, "size":1}, log_normalise=False,
                  kind_joint_plot = 'scatter', kind_pair_plot="scatter", figsize=(10,7)):
    
    """
    Make a bivariate plot of any of the selcted types:
    
    1. bar - barchart
    
    2. scatter  - scatter plot
    
    3. cat  - catplot
    
    4. count - countplot
    
    5 joint - jointplot 
    
    6  pair - pairplot
    
    7  corr - corr_plot
    
    When calling joint_plot:
        
        kind_joint_plot is default to `scatter`
        other types include "reg", "reside", "kde", "hex"
        
    When calling pair_plot:
        
        kind_pair_plot is default to `scatter`
        other types include 'reg'
    """
    def plt_tweaks():
        plt.subplots(figsize= figsize)
        plt.title(title, fontsize=18)
        plt.xlabel(xlabel, fontsize=15)
        plt.ylabel(ylabel, fontsize=15)
        plt.xticks(fontsize=12)
        plt.yticks(fontsize=12)
    
    
    # define helper functions
    
    def use_palette():
        palettes = []
#        palette_to_use=[]
        if palette['use'] == True:
            palette_to_use = [palettes[i] for i in range(palette['size'])]
            
            return palette_to_use

    def log_norm():
        if log_normalise and y != None:
            y = np.log(y)
        elif log_normalise and y == None:
            data = np.log(data)
            
    def save_image():
        if save:
            if img_name != " ":
                plt.savefig(fname=save_to+"/"+img_name+'.png', format='png')
            else:
                plt.savefig(fname=save_to+f'/{title}.png', format='png')
                
        
    # make plots
    
    if chart_type == "joint":
        log_norm()
        plot = sns.jointplot(x=x, y=y, data=data,
                            height=6, ratio=5, space=0.2, kind=kind_joint_plot)
        
        save_image()
        
    if chart_type == "pair":
       # try:
        log_norm()
        if palette['use'] == True:
            palette_to_use = use_palette()
            plot = sns.pairplot(data, palette=palette_to_use, 
                            kind= kind_pair_plot,height=3, aspect=1, hue=hue)
        else:
             plot = sns.pairplot(data, 
                            kind= kind_pair_plot,height=2.5, aspect=1, hue=hue, )
        save_image()
        
    if chart_type  == "corr":
        plt_tweaks()
        corr_data = data.corr()
        corr_plot = sns.heatmap(corr_data,annot=True, fmt='.2g', center=0) 