Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

scatterPlot Enhancement #43

Open
Tanvi-Jain01 opened this issue Jun 21, 2023 · 1 comment
Open

scatterPlot Enhancement #43

Tanvi-Jain01 opened this issue Jun 21, 2023 · 1 comment
Labels
priority1 Priority Level 1

Comments

@Tanvi-Jain01
Copy link

Tanvi-Jain01 commented Jun 21, 2023

scatterPlot: Not generalized to handle any pollutant

So, I was just trying to plot bivariate scatterplot of pm25 and o3 but it was asking for pm10 and nox which is not present in my dataframe, I think the source code is specific to certain pollutants.
It should be optimized by making the function more generalized and allowing it to handle any attribute or pollutant.

Code:

import numpy as np
import pandas as pd
np.random.seed(42)  

start_date = pd.to_datetime('2022-01-01')
end_date = pd.to_datetime('2022-12-31')

dates = pd.date_range(start_date, end_date)

pm25_values = np.random.rand(365)  # Generate 365 random values
o3_values = np.random.rand(365) 
ws_values = np.random.rand(365)
wd_values = np.random.rand(365)

df = pd.DataFrame({
    'date': dates,
    'pm25': pm25_values,
    'o3':o3_values,
    'ws': ws_values,
    'wd': wd_values
})

df['date'] = df['date'].dt.strftime('%Y-%m-%d')  # Convert date format to 'YYYY-MM-DD'
print(df)
from vayu.scatterPlot import scatterPlot

pollutants=['o3','pm25']

for pollutant in pollutants:
    scatterPlot(df, 'o3', pollutant)

Error:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[6], line 6
      3 pollutants=['o3','pm25']
      5 for pollutant in pollutants:
----> 6     scatterPlot(df, 'o3', pollutant)

File ~\anaconda3\lib\site-packages\vayu\scatterPlot.py:22, in scatterPlot(df, x, y, **kwargs)
     19 import matplotlib.cm as cm
     20 from math import pi
---> 22 pm10 = df.pm10
     23 o3 = df.o3
     24 ws = df.ws

File ~\anaconda3\lib\site-packages\pandas\core\generic.py:5902, in NDFrame.__getattr__(self, name)
   5895 if (
   5896     name not in self._internal_names_set
   5897     and name not in self._metadata
   5898     and name not in self._accessors
   5899     and self._info_axis._can_hold_identifiers_and_holds_name(name)
   5900 ):
   5901     return self[name]
-> 5902 return object.__getattribute__(self, name)

AttributeError: 'DataFrame' object has no attribute 'pm10'

Source Code:

vayu/vayu/scatterPlot.py

Lines 22 to 27 in ef99aef

pm10 = df.pm10
o3 = df.o3
ws = df.ws
wd = df.wd
nox = df.nox
no2 = df.no2

vayu/vayu/scatterPlot.py

Lines 42 to 57 in ef99aef

if x == "nox":
x = nox
elif x == "no2":
x = no2
elif x == "o3":
x = o3
elif x == "pm10":
x = pm10
if y == "nox":
y = nox
elif y == "no2":
y = no2
elif y == "o3":
y = o3
elif y == "pm10":
y = pm10

Reason:

The code provided above is currently limited and specific to certain pollutants. To make it more generalized, it should be modified to work with any pollutant present in the dataframe. Additionally, the code should be flexible enough to handle various pollutants, such as pm25 without being explicitly specified in the source code.

Solution:

As a solution we can directly use:

sns.jointplot(x=df[x].values, y=df[y].values, kind="hex") 

In this generalized form, the plot will be created smoothly by extracting the values of the first attribute from the dataframe using df[x].values and the values of the second attribute using df[y].values. These values will be utilized to generate the plot without any specific restrictions or requirements.

Example:

def scatterPlot(df, x, y, **kwargs):
    import seaborn as sns
    import matplotlib.pyplot as plt
    from math import pi

   # df1 = pd.DataFrame({"speed": ws, "direction": wd})
    df["speed"+str(x)] = df['ws'] * np.sin(df['wd'] * pi / 180.0)
    df["speed"+str(y)] = df['ws'] * np.cos(df['wd'] * pi / 180.0)
    print(df)
    fig, ax = plt.subplots(figsize=(8, 8), dpi=80)
    x0, x1 = ax.get_xlim()
    y0, y1 = ax.get_ylim()
    #ax.set_aspect("equal")
    _ = df.plot(kind="scatter", x="speed"+str(x), y="speed"+str(y), alpha=0.35, ax=ax)
    plt.show()
    
    sns.jointplot(x=df[x].values, y=df[y].values, kind="hex")
    #print(x,y)
    plt.xlabel(x)
    plt.ylabel(y)
    plt.show()
    

pollutants=['o3','pm25']

for pollutant in pollutants:
    scatterPlot(df1, 'o3', pollutant)

Output:

scatterplot simple
@patel-zeel
Copy link
Member

Looks good to me. Go for a PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
priority1 Priority Level 1
Projects
None yet
Development

No branches or pull requests

2 participants