Skip to content

Commit

Permalink
feat: Add optional input parameter to change data points shape by col…
Browse files Browse the repository at this point in the history
…or column
  • Loading branch information
iamMoid committed Jan 19, 2022
1 parent cbd3b0a commit d83f1a2
Showing 1 changed file with 27 additions and 9 deletions.
36 changes: 27 additions & 9 deletions src/magmaviz/scatterplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from pandas.api.types import is_numeric_dtype
import re

def scatterplot(df, x, y, c="", t="", xtitle="", ytitle="", ctitle="", xzero=False, yzero=False):
def scatterplot(df, x, y, c="", t="", xtitle="", ytitle="", ctitle="", xzero=False, yzero=False, shapes=True):
"""Plot a scatterplot on the dataframe with the magma color scheme.
Parameters
Expand Down Expand Up @@ -36,6 +36,9 @@ def scatterplot(df, x, y, c="", t="", xtitle="", ytitle="", ctitle="", xzero=Fal
yzero : boolean
Scale the y-axis to start from 0 by specifying True
Default value is set to False
shapes : boolean
Assigns the color column to shapes attribute of the plot if True
Default value is set to False
Returns
-------
Expand Down Expand Up @@ -88,6 +91,10 @@ def scatterplot(df, x, y, c="", t="", xtitle="", ytitle="", ctitle="", xzero=Fal
if not isinstance(yzero, bool):
raise TypeError("Invalid value passed to 'ctitle' variable: Assign boolean True to begin y axis from zero.")

# check if shapes is a boolean
if not isinstance(shapes, bool):
raise TypeError("Invalid value passed to 'shapes' variable: Assign boolean True to show different shapes for each color category.")

# check if column name assigned to x-axis is present in the dataframe
assert x in list(
df.columns
Expand Down Expand Up @@ -137,13 +144,24 @@ def scatterplot(df, x, y, c="", t="", xtitle="", ytitle="", ctitle="", xzero=Fal
alt.Y(y, title=ytitle.capitalize(), scale=alt.Scale(zero=yzero))
)
else:
plot = alt.Chart(
data=df, title=t
).mark_point(
).encode(
alt.X(x, title=xtitle.capitalize(), scale=alt.Scale(zero=xzero)),
alt.Y(y, title=ytitle.capitalize(), scale=alt.Scale(zero=yzero)),
alt.Color(c, title=ctitle.capitalize(), scale=alt.Scale(scheme="magma"))
)
if shapes is False:
plot = alt.Chart(
data=df, title=t
).mark_point(
).encode(
alt.X(x, title=xtitle.capitalize(), scale=alt.Scale(zero=xzero)),
alt.Y(y, title=ytitle.capitalize(), scale=alt.Scale(zero=yzero)),
alt.Color(c, title=ctitle.capitalize(), scale=alt.Scale(scheme="magma"))
)
else:
plot = alt.Chart(
data=df, title=t
).mark_point(
).encode(
alt.X(x, title=xtitle.capitalize(), scale=alt.Scale(zero=xzero)),
alt.Y(y, title=ytitle.capitalize(), scale=alt.Scale(zero=yzero)),
alt.Color(c, title=ctitle.capitalize(), scale=alt.Scale(scheme="magma")),
shape=c
)

return plot

0 comments on commit d83f1a2

Please sign in to comment.