In [2]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os

In [None]:
# Load the .csv file into a Pandas dataframe
df = pd.read_csv("cases-deaths-tests.csv")
database_name = "test/"

#You can use the loc function to filter the data, by passing a boolean mask to it:
#db_main_column = df.loc[df["demographic_category"] == "Age Group"]

# You can also use query() function to filter the data in a more concise way
# Ex: db_main_column = df.query("demographic_category == 'Age Group'")
db_main_column = df.query("area == 'San Diego'")
#db_main_column[db_main_column["demographic_value"] == "0-17"]["total_cases"]
db_main_column_name = "area"
db_main_column_value = db_main_column[db_main_column_name]
db_main_column_unique = db_main_column[db_main_column_name].unique()

#data[columns].apply(pd.to_numeric, downcast='integer', errors='coerce')
# https://urbaninstitute.github.io/graphics-styleguide/

Defining the graph

In [None]:
def plot_graph(x_axis, y_axis, filename):
    # Loop through each unique value in demographic_value
    for unique_value in db_main_column_unique[:5]:
        # Plot the x-axis and y-axis values for the current demographic_value
        plt.plot(db_main_column[db_main_column[db_main_column_name] == unique_value][x_axis],
                 db_main_column[db_main_column[db_main_column_name] == unique_value][y_axis])

    # Customize the x-axis labels to show every 30th date and rotate them for readability
    plt.xticks(range(0, len(db_main_column[db_main_column_value == db_main_column_unique[0]][x_axis]), 30), 
           db_main_column[db_main_column_value == db_main_column_unique[0]][x_axis][::30], rotation=90)
    
    # Label the x-axis and y-axis
    plt.xlabel(x_axis)
    plt.ylabel(y_axis)

    # Title the graph
    plt.title(y_axis + " by Age Group")
    # Add a legend to the graph with the unique demographic_value values
    plt.legend(db_main_column_unique)

    # Create a directory to save the graphs if it doesn't already exist
    if not os.path.exists("graphs/" + database_name):
        os.mkdir("graphs/" + database_name)
    
    # Save the graph to the "graphs" directory with the specified filename
    plt.savefig("graphs/"+ filename, format="png", dpi=300, bbox_inches="tight")
    # Show the graph
    plt.show() 


# Loop through each column in the db_main_column DataFrame (excluding desired columns)
for value in db_main_column.columns[3:]:
    # Call the plot_graph function for the current column, using "report_date" as the x-axis and the current column as the y-axis
    plot_graph("date", value, database_name + value + "_by_db_main_column.png")

Scatter plot

In [None]:
def plot_graph(x_axis, y_axis, value, filename):
    # Create a list of unique values in the demographic_value column
    plt.scatter(db_main_column[db_main_column_value == value][x_axis], db_main_column[db_main_column_value == value][y_axis])
    
    # Label the x-axis and y-axis
    plt.xlabel(x_axis)
    plt.ylabel(y_axis)

    # Title the graph
    plt.title(y_axis+" by "+x_axis + " for " + value)

    # Create a directory to save the graphs if it doesn't already exist
    if not os.path.exists("graphs/" + database_name):
        os.mkdir("graphs/" + database_name)
    
    # Save the graph to the "graphs" directory with the specified filename
    plt.savefig("graphs/"+ filename +"_for_"+value+".png", format="png", dpi=300, bbox_inches="tight")
    plt.show()

y_axis = "deaths"
x_axis = "cases"

# Loop through each unique value in demographic_value
for value in db_main_column_unique:
    # Call the plot_graph function for the current demographic_value
    plot_graph(x_axis, y_axis, value, database_name + y_axis+"_by_"+x_axis)


Pair Plot

In [None]:
# Create a scatter plot matrix
sns.pairplot(db_main_column, hue=db_main_column_name, diag_kind = "hist", height=2.5)

# Create a directory to save the graphs if it doesn't already exist
if not os.path.exists("graphs/" + database_name):
    os.mkdir("graphs/" + database_name)

plt.savefig("graphs/"+ database_name + "scatter_plot_matrix_nohue.png", format="png", dpi=300, bbox_inches="tight")
plt.show()