Skip to content

Commit

Permalink
Upgrade matplotlib and for plot-file-logger set agg as backend
Browse files Browse the repository at this point in the history
  • Loading branch information
dzimmerer committed Oct 23, 2020
1 parent 01ecc68 commit f5df496
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 7 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
colorlover>=0.2.1
Flask>=0.12.2
graphviz>=0.8.4
matplotlib>=2.2.2
matplotlib>=3.3.2
numpy>=1.14.5
seaborn>=0.8.1
scipy>=0.19.1
Expand Down
9 changes: 7 additions & 2 deletions trixi/logger/file/numpyplotfilelogger.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,22 @@ class NumpyPlotFileLogger(NumpySeabornPlotLogger):
"""

def __init__(self, img_dir, plot_dir, **kwargs):
def __init__(self, img_dir, plot_dir, switch_backend=True, **kwargs):
"""
Initializes a numpy plot file logger to plot images and plots into an image and plot directory
Args:
img_dir: The directory to store images in
plot_dir: The directory to store plots in
switch_backend: If true switchtes backend to agg
"""
super(NumpyPlotFileLogger, self).__init__(**kwargs)
self.img_dir = img_dir
self.plot_dir = plot_dir
if switch_backend:
import matplotlib.pyplot as plt

plt.switch_backend("Agg")

@convert_params
def show_image(self, image, name, file_format=".png", *args, **kwargs):
Expand Down Expand Up @@ -66,7 +71,7 @@ def show_value(self, value, name, counter=None, tag=None, file_format=".png", *a
else:
outname = os.path.join(self.plot_dir, tag) + file_format
os.makedirs(os.path.dirname(outname), exist_ok=True)
threaded(savefig_and_close)(figure, outname)
savefig_and_close(figure, outname)

@convert_params
def show_barplot(self, array, name, file_format=".png", *args, **kwargs):
Expand Down
10 changes: 6 additions & 4 deletions trixi/logger/plt/numpyseabornplotlogger.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,9 @@ def show_boxplot(self, array, name, show=True, *args, **kwargs):

handles, _ = ax.get_legend_handles_labels()
try:
legend = kwargs['opts']['legend']
legend = kwargs["opts"]["legend"]
ax.legend(handles, legend)
except KeyError: # if no legend is defined
except KeyError: # if no legend is defined
pass
if show:
plt.show(block=False)
Expand Down Expand Up @@ -202,8 +202,10 @@ def show_scatterplot(self, array, name=None, show=True, *args, **kwargs):
"""

if not isinstance(array, np.ndarray):
raise TypeError("Array must be numpy arrays (this class is called NUMPY seaborn logger, and seaborn"
" can only handle numpy arrays -.- .__. )")
raise TypeError(
"Array must be numpy arrays (this class is called NUMPY seaborn logger, and seaborn"
" can only handle numpy arrays -.- .__. )"
)
if len(array.shape) != 2:
raise ValueError("Array must be 2D for scatterplot")
if array.shape[1] != 2:
Expand Down

0 comments on commit f5df496

Please sign in to comment.