Imports

In [8]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import random
import seaborn as sns

from statsmodels.graphics.gofplots import qqplot
import scipy.stats as stats

import os
from os import listdir
from os.path import isfile, join
import glob
pd.set_option('display.max_columns', 50)

Function to plot everything contained in a given dataframe, except for the given columns

In [7]:
def plot_values_excl(plot_name, df, t, exclude_cols=[]):
    cols = [col for col in df.columns if not any(col.startswith(exclude) for exclude in exclude_cols)]
    num_plots = len(cols)
    num_cols = 3
    num_rows = num_plots // num_cols + (num_plots % num_cols > 0)
    fig, axs = plt.subplots(nrows=num_rows, ncols=num_cols, figsize=(15, 5*num_rows))
    axs = axs.flatten()

    for i, column in enumerate(cols):
        title = column
        color = "#" + ''.join([random.choice('0123456789ABCDEF') for j in range(6)])
        axs[i].plot(t, df[column], label=title, color=color)
        axs[i].set_xlabel('Timestep')
        axs[i].set_ylabel(column)
        axs[i].set_title(column.capitalize() )
        #axs[i].legend()
        axs[i].grid(True)
    plt.tight_layout()
    plt.savefig(f"{plot_name}.png", dpi=600)
    plt.show()

def plot_values_incl(plot_name, df, t, include_cols=[]):
    cols = [col for col in df.columns if any(col.startswith(include) for include in include_cols)]
    num_plots = len(cols)
    num_cols = 3
    num_rows = num_plots // num_cols + (num_plots % num_cols > 0)
    fig, axs = plt.subplots(nrows=num_rows, ncols=num_cols, figsize=(15, 5*num_rows))
    axs = axs.flatten()

    for i, column in enumerate(cols):
        title = column.capitalize()
        color = "#" + ''.join([random.choice('0123456789ABCDEF') for j in range(6)])
        axs[i].plot(t, df[column], label=title, color=color)
        axs[i].set_xlabel('Timestep')
        axs[i].set_ylabel(column)
        axs[i].set_title(column.capitalize())
        #axs[i].legend()
        axs[i].grid(True)
    plt.tight_layout()
    plt.savefig(f"{plot_name}.png", dpi=600)
    plt.show()


Function to plot histogram of given timesteps, column and dataframe

In [10]:
def plot_histogram(name,df, cols, times, nrows=2):
    ncols = int(np.ceil(len(times) / nrows))
    fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=(15, 5 * nrows))
    axs = axs.ravel() # Flatten the array for easy indexing
    for i, time in enumerate(times):
        sns.histplot(df[df['date'] == time][cols[0]], bins=25, ax=axs[i], kde=False)
        axs[i].set_xlabel(cols[0])
        axs[i].set_ylabel('Frequency')
        axs[i].set_title(f'Distribution of {cols[0]} at t={time}')
    # Remove unused subplots
    for j in range(i+1, nrows*ncols):
        fig.delaxes(axs[j])

    fig.suptitle(f" Distribution of {cols[0]} across time", fontsize=16)
    
    plt.tight_layout()

    plt.savefig(f"{name}.png", dpi=600)
    plt.show()


Function to output distribution properties and histogram plots for given dataframe, column, and times

In [13]:
def analyze_distribution(name,df, col_name, dates):
    results = []
    for date in dates:
        # Filter the DataFrame by date
        df_date = df[df['date'] == date]  # assuming the date column is named 'date'

        # Calculate statistics
        mean = df_date[col_name].mean()
        std = df_date[col_name].std()
        min_val = df_date[col_name].min()
        max_val = df_date[col_name].max()
        
        # Assuming the data follows a normal distribution, the bounds of the uniform
        # distribution that may have generated the data would be around the 3-sigma
        # range (as it contains about 99.7% of the data in a normal distribution).
        uniform_min = mean - 3 * std
        uniform_max = mean + 3 * std
        
        # Adjust the bounds to not exceed the actual min and max values
        uniform_min = max(uniform_min, min_val)
        uniform_max = min(uniform_max, max_val)

        # Perform Shapiro-Wilk test for normality
        #_, p_value = stats.shapiro(df_date[col_name])

        result = {
            'date': date,
            'mean': mean,
            'std': std,
            'min': min_val,
            'max': max_val,
            #'uniform_distribution_bounds': (uniform_min, uniform_max),
            #'normality_p_value': p_value
        }
        results.append(result)

    plot_histogram(name,df, [col_name], dates)

    return results


Function to plot multiple columns overlaid  in the same plot

In [11]:
import seaborn as sns

def plot_columns(n_rows=2, name="composite_plot", *args):
    # Calculate the number of columns based on the number of rows
    n_cols = np.ceil(len(args) / (2 * n_rows)).astype(int)

    fig, axes = plt.subplots(n_rows, n_cols, figsize=(10, 8))  # Create subplots

    axes = axes.flatten()  # Flatten the axes object to easily iterate over it

    for i in range(0, len(args), 2):
        df = args[i]
        col_name = args[i + 1]

        # Plot the column data with label
        sns.lineplot(data=df[col_name], ax=axes[i // 2], color='C' + str(i // 2))

        # Set y label as column name
        axes[i // 2].set_ylabel(col_name)

        # Set x label as time
        axes[i // 2].set_xlabel('Time')

        # Set plot title as column name
        axes[i // 2].set_title(col_name.capitalize())

        # Add grid lines
        axes[i // 2].grid(True)
        
        
        
        
    # If there are fewer plots than total subplots, remove the extras
    if len(args) // 2 < len(axes):
        for ax in axes[len(args) // 2:]:
            fig.delaxes(ax)

    plt.tight_layout()  # Adjust subplot parameters to give specified padding

    # Save the figure as a high-resolution PNG image
    plt.savefig(f'{name}.png', dpi=600)
    

    plt.show()


Function to check normality at t =0


In [7]:
def plot_qq_and_hist(df,col):
    plt.figure(figsize=(10,4))
    
    plt.subplot(1,2,1)
    #bin_size = int(np.ceil(np.sqrt(len(data))))
    plt.hist(df[col], bins = 'auto')
    plt.title('Histogram of {}'.format(col))
    plt.xlabel("Value")
    plt.ylabel("Frequency")
    
    plt.subplot(1,2,2)
    stats.probplot(df[col], dist="norm", plot=plt)
    plt.title(col + " Normal Distribution QQ plot")
    plt.show()

In [None]:
#Get all columns that should be normally distributed (i.e. not all zeros)
"""df_h_norm = df_h.loc[:,((df_h.sum(axis=0) != 0)&(df_h.mean(axis=0) != 1))]
df_h_norm.head()
for col_name in df_h_norm.columns:
     plot_qq_and_hist(df_h_norm,col_name)"""