Skip to content

Commit

Permalink
Merge pull request #189 from DHI/fix_density_bins_bug
Browse files Browse the repository at this point in the history
fix small density-bug and added test
  • Loading branch information
daniel-caichac-DHI committed May 11, 2023
2 parents 605b4f8 + 06215d2 commit 3138331
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 37 deletions.
103 changes: 66 additions & 37 deletions fmskill/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
# register_option("plot.scatter.table.show", False, validator=settings.is_bool)
register_option("plot.scatter.legend.fontsize", 12, validator=settings.is_positive)


def scatter(
x,
y,
Expand Down Expand Up @@ -133,7 +134,7 @@ def scatter(
user default units to override default units, eg 'metre', by default None
kwargs
"""
if show_hist == None and show_density == None:
if show_hist is None and show_density is None:
# Default: points density
show_density = True

Expand All @@ -152,14 +153,14 @@ def scatter(

x_sample = x
y_sample = y
sample_warning=False
sample_warning = False
if show_points is None:
# If nothing given, and more than 50k points, 50k sample will be shown
if len(x) < 5e4:
show_points = True
else:
show_points = 50000
sample_warning=True
sample_warning = True
if type(show_points) == float:
if show_points < 0 or show_points > 1:
raise ValueError(" `show_points` fraction must be in [0,1]")
Expand All @@ -170,31 +171,30 @@ def scatter(
)
x_sample = x[ran_index]
y_sample = y[ran_index]
if len(x_sample)<len(x):
sample_warning=True
if len(x_sample) < len(x):
sample_warning = True
# if show_points is an int
elif type(show_points) == int:
np.random.seed(20)
ran_index = np.random.choice(range(len(x)), show_points, replace=False)
x_sample = x[ran_index]
y_sample = y[ran_index]
if len(x_sample)<len(x):
sample_warning=True
if len(x_sample) < len(x):
sample_warning = True
elif type(show_points) == bool:
pass
else:
raise TypeError(" `show_points` must be either bool, int or float")
if sample_warning:
warnings.warn(
message=f'Showing only {len(x_sample)} points in plot. If all scatter points wanted in plot, use `show_points=True`',
stacklevel=2)
message=f"Showing only {len(x_sample)} points in plot. If all scatter points wanted in plot, use `show_points=True`",
stacklevel=2,
)
xmin, xmax = x.min(), x.max()
ymin, ymax = y.min(), y.max()
xymin = min([xmin, ymin])
xymax = max([xmax, ymax])



if quantiles is None:
if len(x) >= 3000:
quantiles = 1000
Expand Down Expand Up @@ -224,10 +224,10 @@ def scatter(
# Remove previous piece of code when nbins and bin_size are deprecated.

if xlim is None:
xlim = [xymin - binsize, xymax+ binsize]
xlim = [xymin - binsize, xymax + binsize]

if ylim is None:
ylim = [xymin - binsize, xymax+ binsize]
ylim = [xymin - binsize, xymax + binsize]

if type(quantiles) == int:
xq = np.quantile(x, q=np.linspace(0, 1, num=quantiles))
Expand All @@ -236,11 +236,11 @@ def scatter(
# if not an int nor None, it must be a squence of floats
xq = np.quantile(x, q=quantiles)
yq = np.quantile(y, q=quantiles)
x_trend= np.array([xlim[0],xlim[1]])
x_trend = np.array([xlim[0], xlim[1]])

if show_hist:
# if histogram is wanted (explicit non-default flag) then density is off
if show_density == True:
if show_density is True:
raise TypeError(
"if `show_hist=True` then `show_density` must be either `False` or `None`"
)
Expand All @@ -251,18 +251,17 @@ def scatter(
"if `show_density=True` then bins must be either float or int"
)
# if point density is wanted, then 2D histogram is not shown
if show_hist == True:
if show_hist is True:
raise TypeError(
"if `show_density=True` then `show_hist` must be either `False` or `None`"
)
# calculate density data
z = _scatter_density(x_sample, y_sample, binsize=binsize)
z = __scatter_density(x_sample, y_sample, binsize=binsize)
idx = z.argsort()
# Sort data by colormaps
x_sample, y_sample, z = x_sample[idx], y_sample[idx], z[idx]
x_sample, y_sample, z = x_sample[idx], y_sample[idx], z[idx]
# scale Z by sample size
z = z * len(x) / len(x_sample)

z = z * len(x) / len(x_sample)

# linear fit
slope, intercept = _linear_regression(obs=x, model=y, reg_method=reg_method)
Expand All @@ -274,8 +273,8 @@ def scatter(
reglabel = f"Fit: y={slope:.2f}x{sign}{intercept:.2f}"

if backend == "matplotlib":
_,ax=plt.subplots(figsize=figsize)
#plt.figure(figsize=figsize)
_, ax = plt.subplots(figsize=figsize)
# plt.figure(figsize=figsize)
plt.plot(
[xlim[0], xlim[1]],
[xlim[0], xlim[1]],
Expand Down Expand Up @@ -328,7 +327,7 @@ def scatter(
plt.xlim([xlim[0], xlim[1]])
plt.ylim([ylim[0], ylim[1]])
plt.minorticks_on()
plt.grid(which="both", axis="both", linewidth="0.2", color="k",alpha=0.6)
plt.grid(which="both", axis="both", linewidth="0.2", color="k", alpha=0.6)
max_cbar = None
if show_hist or (show_density and show_points):
cbar = plt.colorbar(fraction=0.046, pad=0.04)
Expand All @@ -338,7 +337,7 @@ def scatter(

plt.title(title)
# Add skill table
if skill_df != None:
if skill_df is not None:
_plot_summary_table(skill_df, units, max_cbar=max_cbar)
return ax

Expand Down Expand Up @@ -515,8 +514,8 @@ def taylor_diagram(
fig.suptitle(title, size="x-large")


def _scatter_density(x, y, binsize: float = 0.1, method: str = "linear"):
"""Interpolates scatter data on a 2D histogram (gridded) based on data density.
def __hist2d(x, y, binsize):
"""Calculates 2D histogram (gridded) of data.
Parameters
----------
Expand All @@ -525,33 +524,63 @@ def _scatter_density(x, y, binsize: float = 0.1, method: str = "linear"):
y: np.array
Y values e.g observation values, must be same length as x
binsize: float, optional
2D grid resolution, by default = 0.1
method: str, optional
Scipy griddata interpolation method, by default 'linear'
2D histogram (bin) resolution, by default = 0.1
Returns
----------
Z_grid: np.array
Array with the colors based on histogram density
histodata: np.array
2D-histogram data
cxy: np.array
Center points of the histogram bins
exy: np.array
Edges of the histogram bins
"""

# Make linear-grid for interpolation
minxy = min(min(x), min(y))-binsize/2
maxxy = max(max(x), max(y))+binsize/2
minxy = min(min(x), min(y)) - binsize
maxxy = max(max(x), max(y)) + binsize
# Center points of the bins
cxy = np.arange(minxy, maxxy, binsize)
# Edges of the bins
exy = np.arange(minxy - binsize * 0.5, maxxy + binsize * 0.5, binsize)
if exy[-1] <= cxy[-1]:
# sometimes, given the bin size, the edges array ended before (left side) of the bins-center array
# in such case, and extra half-bin is added at the end
exy = np.arange(minxy - binsize * 0.5, maxxy + binsize, binsize)

# Calculate 2D histogram
histodata, exh, eyh = np.histogram2d(x, y, [exy, exy])
histodata, _, _ = np.histogram2d(x, y, [exy, exy])

# Histogram values
hist = []
for j in range(len(cxy)):
for i in range(len(cxy)):
hist.append(histodata[i, j])

return hist, cxy


def __scatter_density(x, y, binsize: float = 0.1, method: str = "linear"):
"""Interpolates scatter data on a 2D histogram (gridded) based on data density.
Parameters
----------
x: np.array
X values e.g model values, must be same length as y
y: np.array
Y values e.g observation values, must be same length as x
binsize: float, optional
2D histogram (bin) resolution, by default = 0.1
method: str, optional
Scipy griddata interpolation method, by default 'linear'
Returns
----------
Z_grid: np.array
Array with the colors based on histogram density
"""

hist, cxy = __hist2d(x, y, binsize)

# Grid-data
xg, yg = np.meshgrid(cxy, cxy)
xg = xg.ravel()
Expand Down Expand Up @@ -591,7 +620,7 @@ def _plot_summary_table(skill_df, units, max_cbar):

text_ = "\n".join(lines)

if max_cbar == None:
if max_cbar is None:
x = 0.93
elif max_cbar < 1e3:
x = 0.99
Expand Down
2 changes: 2 additions & 0 deletions tests/test_multivariable_compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,8 @@ def test_mv_mm_scatter(cc):
cc.scatter(
model="SW_1", variable="Wind_speed", observation="F16_wind", skill_table=True
)
cc.scatter(model="SW_1", variable="Wind_speed", show_density=True,bins=19)
cc.scatter(model="SW_1", variable="Wind_speed", show_density=True,bins=21)
assert True
plt.close("all")

Expand Down

0 comments on commit 3138331

Please sign in to comment.