diff --git a/utils.py b/utils.py index 56d7ff9..e127f15 100644 --- a/utils.py +++ b/utils.py @@ -3,6 +3,7 @@ def color_image(image, num_classes=20): import matplotlib as mpl + import matplotlib.cm norm = mpl.colors.Normalize(vmin=0., vmax=num_classes) mycm = mpl.cm.get_cmap('Set1') return mycm(norm(image))