# Slightly more in-depth example_usage of some functions

Imports to be ran before the below code

In [2]:
# Imports to run before running individual codeblocks below
import os
import pandas as pd
import matplotlib.pyplot as plt

### Wrangle

More advanced examples of the Wrangle module

In [None]:
# Import wrangle module functions
from vistool.wrangle import clean_data, filter_data


# Example 1: Cleaning Data & Handling of NaN Values
data = pd.read_csv('data/Monthly_AE_Attendances_Nov_2024.csv')

print("Original Data Top 20 Rows:")
print(data.head(20))

# Ask the user how they want to handle NaN values
print("\nHow would you like to handle NaN values?")
print("1. Remove rows or columns with NaN values.")
print("2. Fill NaN values with the column mean.")
print("3. Fill NaN values with the row mean.")
print("4. Do nothing (keep NaN values).")

choice = input("Enter your choice (1/2/3/4): ")

# Choice 1: Remove rows or columns with NaN values
if choice == "1":
    # Ask whether the user wants to remove rows or columns
    remove_choice = input(
        "Would you like to remove rows or columns with NaN values? (Enter 'rows' or 'columns'): "
    ).strip().lower()

    if remove_choice == "rows":
        # Ask if the user wants to remove all rows or specific rows by index
        rows_choice = input(
            "Would you like to remove all rows or a specific row? (Enter 'all' or 'specific'): "
        ).strip().lower()

        if rows_choice == "all":
            # Remove all rows with NaN values
            cleaned_data = data.dropna(axis=0)
            print("\nCleaned Data (all rows with NaN values removed):")
            print(cleaned_data.head(20))
        
        elif rows_choice == "specific":
            # Ask the user to specify the row indices to remove
            try:
                rows_to_remove = input(
                    "Enter row indices to remove (comma-separated): "
                ).split(",")
                rows_to_remove = [int(row.strip()) for row in rows_to_remove]
                cleaned_data = data.drop(rows_to_remove)
                print("\nCleaned Data (specified rows removed):")
                print(cleaned_data.head(20))
            except ValueError:
                print("Invalid input. Please enter valid row indices (e.g., 0, 1, 2).")
        else:
            print("Invalid choice. No rows removed.")

    elif remove_choice == "columns":
        # Ask if the user wants to remove all columns or specific columns
        columns_choice = input(
            "Would you like to remove all columns or a specific column? (Enter 'all' or 'specific'): "
        ).strip().lower()

        if columns_choice == "all":
            # Remove all columns with NaN values
            cleaned_data = data.dropna(axis=1)
            print("\nCleaned Data (all columns with NaN values removed):")
            print(cleaned_data.head(20))
        
        elif columns_choice == "specific":
            # Ask the user to specify the columns to remove
            columns_to_remove = input(
                "Enter column names to remove (comma-separated): "
            ).split(",")
            columns_to_remove = [col.strip() for col in columns_to_remove]
            cleaned_data = data.drop(columns=columns_to_remove)
            print("\nCleaned Data (specified columns removed):")
            print(cleaned_data.head(20))
        else:
            print("Invalid choice. No columns removed.")
       
    else:
        print("Invalid choice. No rows or columns removed.")

# Choice 2: Fill NaN values with the column mean
elif choice == "2":
    specific_columns = input(
        "Enter column names to fill NaN with mean (comma-separated, or leave empty for all numeric columns): "
    ).split(",")
    specific_columns = [col.strip() for col in specific_columns if col.strip()]

    if specific_columns:
        # Fill NaN only in specified columns
        for col in specific_columns:
            if col in data.columns:
                data[col] = data[col].fillna(data[col].mean())
                print(f"Filled NaN values in column '{col}' with its mean.")
            else:
                print(f"Warning: Column '{col}' does not exist in the dataset.")
    else:
        # Fill NaN in all numeric columns
        data = data.fillna(data.mean(numeric_only=True))
        print("Filled NaN values in all numeric columns with their mean.")

    print("\nCleaned Data (NaN values filled with column mean):")
    # Format each column with floating-point numbers to two decimal places
    data = data.round(2)
    print(data)

# Choice 3: Fill NaN values with the row mean
elif choice == "3":
    # Automatically apply to rows, no need for further user input
    cleaned_data = clean_data(data, fill_with="mean", apply_to="rows")
    print("\nCleaned Data (NaN values filled with row mean):")
    print(cleaned_data.round(2))

# Choice 4: Do nothing (keep NaN values)
elif choice == "4":
    print("\nNo cleaning applied. Data remains unchanged:")
    print(data.head(20))

else:
    print("\nInvalid choice. No cleaning applied.")
    

More advanced examples of Filtering Data

In [None]:
# Example 2: Filtering the Data

# Convert the first 4 columns to strings and rest to ints (or numeric types)
for i, col in enumerate(data.columns):
    if i < 4:  # First 4 columns
        data[col] = data[col].astype(str)
        #print(f"Column '{col}' converted to string.")
    else:  # Remaining columns -> Convert to numeric
        data[col] = pd.to_numeric(data[col], errors='coerce')  
        #print(f"Column '{col}' converted to numeric.")

# Prompt user for a condition
condition = input(
    "Enter condition to filter data (e.g., 'other_emergency_admissions > 100'): "
)

# Apply the filter
try:
    filtered_data = filter_data(data, condition)
    print("\nFiltered Data (20):")
    print(filtered_data.head(20))
except Exception as e:
    print(f"An error occurred while filtering data: {e}")

### Visualize

More advanced examples of the Visualize module

In [None]:
# Import functions

from vistool.visualize import (
    plot_histogram,
    plot_scatter,
    plot_correlation_matrix,
    plot_line,
    plot_overlay
)

# Example Data
data = pd.read_csv('data/Monthly_AE_Attendances_Nov_2024.csv')

data.columns = [col.lower() for col in data.columns]

# For demonstration purposes, due to the dataset having nearly 200 rows
data = data.head(20)

# Example 1: Plotting different graph types
# Display available columns
print("Available Columns for plotting:", list(data.columns))
graph_type = input(
    "Choose graph type (histogram/scatter/correlation/line): "
).strip().lower()

# Choice 1: Histogram
if graph_type == "histogram":
    column = input("Enter column name for the histogram: ").strip().lower()
    if column not in data.columns:
        print(f"Error: {column} is not a valid column.")
    else:
        # Plot the histogram
        plot_histogram(data, column)
        
        # Ask user if they want to save the plot
        save_option = input(
            "Would you like to save the plot? (yes/no): "
        ).strip().lower()
        if save_option == "yes":
            save_path = input(
                "Enter file path to save the plot (e.g., 'histogram.png'): "
            ).strip() # Then save the plot without re-plotting
            plot_histogram(data, column, save_path)  
            print(f"Plot saved to {save_path}")
        else:
            print("Plot not saved.")

# Choice 2: Scatter
elif graph_type == "scatter":
    x_column = input("Enter X-axis column: ").strip().lower()
    y_column = input("Enter Y-axis column: ").strip().lower()
    
    # Check if the columns exist
    if x_column not in data.columns and y_column not in data.columns:
        print(f"Error: Both '{x_column}' and '{y_column}' are not valid columns.")
    elif x_column not in data.columns:
        print(f"Error: '{x_column}' is not a valid X-axis column.")
    elif y_column not in data.columns:
        print(f"Error: '{y_column}' is not a valid Y-axis column.")
    else:
        # Plot the scatter plot
        plot_scatter(data, x_column, y_column)
        
        # Ask user if they want to save the plot
        save_option = input(
            "Would you like to save the plot? (yes/no): "
        ).strip().lower()
        if save_option == "yes":
            save_path = input(
                "Enter file path to save the plot (e.g., 'scatter_plot.png'): "
            ).strip() # Then save the plot without re-plotting
            plot_scatter(data, x_column, y_column, save_path)  
            print(f"Plot saved to {save_path}")
        else:
            print("Plot not saved.")

# Choice 3: Correlation
elif graph_type == "correlation":
    # Plot the correlation matrix
    plot_correlation_matrix(data)
        
    # Ask user if they want to save the plot
    save_option = input(
        "Would you like to save the plot? (yes/no): "
    ).strip().lower()
    if save_option == "yes":
        save_path = input(
            "Enter file path to save the plot (e.g., 'correlation_matrix.png'): "
        ).strip() # Then save the plot without re-plotting
        plot_correlation_matrix(data, save_path)  
        print(f"Plot saved to {save_path}")
    else:
        print("Plot not saved.")

# Choice 4: Line
elif graph_type == "line":
    x_column = input("Enter X-axis column: ").strip().lower()
    y_column = input("Enter Y-axis column: ").strip().lower()
    
     # Check if the columns exist
    if x_column not in data.columns and y_column not in data.columns:
        print(f"Error: Both '{x_column}' and '{y_column}' are not valid columns.")
    elif x_column not in data.columns:
        print(f"Error: '{x_column}' is not a valid X-axis column.")
    elif y_column not in data.columns:
        print(f"Error: '{y_column}' is not a valid Y-axis column.")
    else:
        plot_line(data, x_column, y_column)
        
        # Ask user if they want to save the plot
        save_option = input(
            "Would you like to save the plot? (yes/no): "
        ).strip().lower()
        if save_option == "yes":
            save_path = input(
                "Enter file path to save the plot (e.g., 'line_chart.png'): "
            ).strip() # Then save the plot without re-plotting
            plot_line(data, x_column, y_column, save_path)  
            print(f"Plot saved to {save_path}")
        else:
            print("Plot not saved.")

else:
    print("Invalid graph type. Please choose histogram, scatter, correlation, or line.")
  


More advanced example of Overlay Plot

In [None]:
# Example 2: Graph Overlay    
print("Available Columns for Graph Overlaying:", list(data.columns))

# Prompt user for overlay details
num_columns = int(input("How many columns would you like to overlay? "))

columns = []
plot_types = []

for i in range(num_columns):
    col = input(f"Enter column {i + 1}: ").strip()
    if col not in data.columns:
        print(f"Error: {col} is not a valid column.")
        continue
    columns.append(col)
    
    plot_type = input(f"Choose plot type for {col} (line/bar): ").strip().lower()
    if plot_type not in ["line", "bar"]:
        print("Invalid plot type. Defaulting to 'line'.")
        plot_type = "line"
    plot_types.append(plot_type)

# Plot overlay
plot_overlay(data, columns, plot_types)

# Ask user if they want to save the plot
save_option = input("Would you like to save the plot? (yes/no): ").strip().lower()
if save_option == "yes":
    save_path = input(
        "Enter file path to save the plot (e.g., 'overlay_plot.png'): "
    ).strip()
    plot_overlay(data, columns, plot_types, save_path=save_path)
else:
    print("Plot not saved.")