Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SummaryPlot: Enhancement & Code Optimization #50

Open
Tanvi-Jain01 opened this issue Jul 9, 2023 · 0 comments
Open

SummaryPlot: Enhancement & Code Optimization #50

Tanvi-Jain01 opened this issue Jul 9, 2023 · 0 comments

Comments

@Tanvi-Jain01
Copy link

Tanvi-Jain01 commented Jul 9, 2023

@nipunbatra , @patel-zeel

Description:

I went through the whole source code of SummaryPlot where I feel that the code can be optimized for better performance.

Code:

import xarray as xr
import numpy as np
import pandas as pd
import geopandas as gpd
import plotly.express as px
import matplotlib.pyplot as plt

np.random.seed(42)  

start_date = pd.to_datetime('2022-01-01')
end_date = pd.to_datetime('2022-12-31')

dates = pd.date_range(start_date, end_date)

pm25_values = np.random.rand(365)  # Generate 365 random values
o3_values = np.random.rand(365) 
nox_values = np.random.rand(365)
co_values = np.random.rand(365)
pm10_values = np.random.rand(365)

"pm10", "pm25", "sox", "co", "o3", "nox", "pb", "nh3"
df = pd.DataFrame({
    'date': dates,
    'pm25': pm25_values,
    'o3':o3_values,
    'nox': nox_values,
    'co': co_values,
     'pm10': pm10_values
})

df['date'] = df['date'].dt.strftime('%Y-%m-%d')  # Convert date format to 'YYYY-MM-DD'

print(df)

from vayu.summaryPlot import summaryPlot

print(df.columns)
summaryPlot(df)

Error:

KeyError                                  Traceback (most recent call last)
File ~\anaconda3\lib\site-packages\pandas\core\indexes\base.py:3802, in Index.get_loc(self, key, method, tolerance)
   3801 try:
-> 3802     return self._engine.get_loc(casted_key)
   3803 except KeyError as err:

File ~\anaconda3\lib\site-packages\pandas\_libs\index.pyx:138, in pandas._libs.index.IndexEngine.get_loc()

File ~\anaconda3\lib\site-packages\pandas\_libs\index.pyx:165, in pandas._libs.index.IndexEngine.get_loc()

File pandas\_libs\hashtable_class_helper.pxi:5745, in pandas._libs.hashtable.PyObjectHashTable.get_item()

File pandas\_libs\hashtable_class_helper.pxi:5753, in pandas._libs.hashtable.PyObjectHashTable.get_item()

KeyError: 'so2'

The above exception was the direct cause of the following exception:

KeyError                                  Traceback (most recent call last)
Cell In[4], line 15
     12 #benzene.dropna(inplace=True)
     13 benzene.fillna(0, inplace=True)
---> 15 summaryPlot(benzene)

File ~\anaconda3\lib\site-packages\vayu\summaryPlot.py:126, in summaryPlot(df)
    124 plt.subplot(9, 3, sub)
    125 sub = sub + 1
--> 126 a = df_all[dataPoints[i]].plot.line(color="gold")
    127 a.axes.get_xaxis().set_visible(False)
    128 a.yaxis.set_label_position("left")

File ~\anaconda3\lib\site-packages\pandas\core\frame.py:3807, in DataFrame.__getitem__(self, key)
   3805 if self.columns.nlevels > 1:
   3806     return self._getitem_multilevel(key)
-> 3807 indexer = self.columns.get_loc(key)
   3808 if is_integer(indexer):
   3809     indexer = [indexer]

File ~\anaconda3\lib\site-packages\pandas\core\indexes\base.py:3804, in Index.get_loc(self, key, method, tolerance)
   3802     return self._engine.get_loc(casted_key)
   3803 except KeyError as err:
-> 3804     raise KeyError(key) from err
   3805 except TypeError:
   3806     # If we have a listlike key, _check_indexing_error will raise
   3807     #  InvalidIndexError. Otherwise we fall through and re-raise
   3808     #  the TypeError.
   3809     self._check_indexing_error(key)

KeyError: 'so2'

OUTPUT:

summaryplot error

Explaintion:

Here, the code expects so2 attribute to be present in the the dataframe which is not everytime possible, there can be possibilities that so2 or someother attribute is not present in the dataframe, even if its not present the code should run smoothly without any error.

ISSUE-1: Making the function general

Source Code:

df_all = df

This line assigns the value of the variable df to a new variable df_all. It creates a new reference to the same DataFrame object.

dataPoints = ["pm25", "co", "so2", "pm10", "o3", "no2", "nox", "wd", "ws"]

This line creates a list called dataPoints containing strings representing different data points or columns in the DataFrame. Each string corresponds to a specific column name in the DataFrame.

for column in df_all.columns:

This line initiates a loop that iterates over each column name in the df_all DataFrame. The loop assigns each column name to the variable column in each iteration.

a = df_all[dataPoints[i]].plot.line(color="gold")

This line selects a specific column from the DataFrame df_all using the index i from the dataPoints list.
and as we are making a loop from datapoints it expects all the pollutants which are specified in datapoints should be present in dataframe, which is giving error if any of the pollutant is not present in user's dataframe.

hence this line should be written as follows:

Solution:

 
a = df_all[df_all.columns[i]].plot.line(color="gold") 

The above line will take only the columns which are present in the user's dataframe and will plot that only without giving the error.

ISSUE-2: Code Opimization

vayu/vayu/summaryPlot.py

Lines 22 to 39 in ef99aef

pm10_s = 0
pm10_m = 0
pm10_h = 0
pm25_s = 0
pm25_m = 0
pm25_h = 0
so2_s = 0
so2_m = 0
so2_h = 0
co_s = 0
co_m = 0
co_h = 0
o3_s = 0
o3_m = 0
o3_h = 0
no2_s = 0
no2_m = 0
no2_h = 0

Scalability: If we need to work with additional pollutants or categories in the future, we would need to manually add new variables to the code. This can be cumbersome and error-prone, especially as the number of pollutants and categories increases.

Solution:

Using a list can help decrease the amount of code and repetitiveness.

 pollutants = ["pm10", "pm25", "sox", "co", "o3", "nox", "pb", "nh3"]
    categories = ["s", "m", "h"]
.....
.....

ISSUE-3: Adding other pollutants that contributes for generating AQI

https://www.ncbi.nlm.nih.gov/pmc/articles/PMC8137507/

The above site states that:
Details of air quality index along with range of concentrations of criteria pollutant PM2.5, PM10, NOX, SOX, O3, CO, Pb and NH3.
which means the above polutants contributes for AQI.
Hence added Pb, NH3 into the code which wasn't present before and removed wd and ws which doesn't contribute for AQI.

pollutants = ["pm10", "pm25", "sox", "co", "o3", "nox", "pb", "nh3"]
    categories = ["s", "m", "h"]
    
    ....
........
 elif pollutant == "pb":
                    if data <= 1:
                        counts[pollutant]["s"] += 1
                    elif data <= 2:
                        counts[pollutant]["m"] += 1
                    else:
                        counts[pollutant]["h"] += 1
                elif pollutant == "nh3":
                    if data <= 400:
                        counts[pollutant]["s"] += 1
                    elif data <= 800:
                        counts[pollutant]["m"] += 1
                    else:
                        counts[pollutant]["h"] += 1

Now, this is how the whole code looks like:

Improved Code:

def summaryPlot(df):
    import pandas as pd
    import matplotlib.pyplot as plt

    # Initialize variables
    pollutants = ["pm10", "pm25", "sox", "co", "o3", "nox", "pb", "nh3"]
    categories = ["s", "m", "h"]

    counts = {pollutant: {category: 0 for category in categories} for pollutant in pollutants}

    
    df.index = pd.to_datetime(df.date)
    df = df.drop("date", axis=1)
    df_all = df.resample("1D")
    df_all = df.copy()
    df_all = df_all.fillna(method="ffill")
    #print(df_all.columns)

    # Calculate counts for each pollutant category
    for pollutant in pollutants:
        if pollutant in df_all.columns:
            column_data = df_all[pollutant]
            #print(df_all)
            for _, data in column_data.iteritems():
                if pollutant in ["pm10", "pm25"]:
                    if data < 100:
                        counts[pollutant]["s"] += 1
                    elif data < 250:
                        counts[pollutant]["m"] += 1
                    else:
                        counts[pollutant]["h"] += 1
                elif pollutant == "co":
                    if data < 2:
                        counts[pollutant]["s"] += 1
                    elif data < 10:
                        counts[pollutant]["m"] += 1
                    else:
                        counts[pollutant]["h"] += 1
                elif pollutant == "sox":
                    if data <= 80:
                        counts[pollutant]["s"] += 1
                    elif data <= 380:
                        counts[pollutant]["m"] += 1
                    else:
                        counts[pollutant]["h"] += 1
                elif pollutant == "o3":
                    if data < 100:
                        counts[pollutant]["s"] += 1
                    elif data < 168:
                        counts[pollutant]["m"] += 1
                    else:
                        counts[pollutant]["h"] += 1
                elif pollutant == "nox":
                    if data < 80:
                        counts[pollutant]["s"] += 1
                    elif data < 180:
                        counts[pollutant]["m"] += 1
                    else:
                        counts[pollutant]["h"] += 1
                elif pollutant == "pb":
                    if data <= 1:
                        counts[pollutant]["s"] += 1
                    elif data <= 2:
                        counts[pollutant]["m"] += 1
                    else:
                        counts[pollutant]["h"] += 1
                elif pollutant == "nh3":
                    if data <= 400:
                        counts[pollutant]["s"] += 1
                    elif data <= 800:
                        counts[pollutant]["m"] += 1
                    else:
                        counts[pollutant]["h"] += 1
         
                

    # Plot line, histogram, and pie charts for each pollutant
    fig, axes = plt.subplots(len(df_all.columns), 3, figsize=(25,25))

    for i, pollutant in enumerate(df_all.columns):
        ax_line = axes[i, 0]
        ax_hist = axes[i, 1]
        ax_pie = axes[i, 2]

        df_all[pollutant].plot.line(ax=ax_line, color="gold")
        ax_line.axes.get_xaxis().set_visible(False)
        ax_line.yaxis.set_label_position("left")
        ax_line.set_ylabel(pollutant, fontsize=30, bbox=dict(facecolor="whitesmoke"))

        ax_hist.hist(df_all[pollutant], bins=50, color="green")

        labels = ["Safe", "Moderate", "High"]
        sizes = [counts[pollutant][category] for category in categories]
        explode = [0, 0, 1]

        ax_pie.pie(sizes, explode=explode, labels=labels, autopct="%1.1f%%", shadow=False, startangle=90)
        ax_pie.axis("equal")

        ax_pie.set_xlabel("Statistics")
      
        print(f"{pollutant}\nmin = {df_all[pollutant].min():.2f}\nmax = {df_all[pollutant].max():.2f}\nmissing = {df_all[pollutant].isna().sum()}\nmean = {df_all[pollutant].mean():.2f}\nmedian = {df_all[pollutant].median():.2f}\n95th percentile = {df_all[pollutant].quantile(0.95):.2f}\n")

    plt.savefig("summaryPlot.png", dpi=300, format="png")
    plt.show()
    print("your plots has also been saved")
    plt.close()

summaryPlot(df)

NOTE: I'm also adding plt.savefig("summaryPlot.png", dpi=300, format="png") to save the figure.

OUTPUT:

summaryplot

summaryPlot

@Tanvi-Jain01 Tanvi-Jain01 changed the title SummaryPlot: Code Optimization SummaryPlot: Enhancement & Code Optimization Jul 9, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant