Skip to content

Commit

Permalink
NUKE PLOT UTILS!!! (#320)
Browse files Browse the repository at this point in the history
  • Loading branch information
alex-l-kong committed Nov 4, 2020
1 parent f4390fb commit c48700b
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 277 deletions.
195 changes: 0 additions & 195 deletions ark/utils/plot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
from skimage.exposure import rescale_intensity


# plotting functions

def plot_overlay(predicted_contour, plotting_tif, alternate_contour=None, path=None):
"""Take in labeled contour data, along with optional mibi tif and second contour,
and overlay them for comparison"
Expand Down Expand Up @@ -109,196 +107,3 @@ def plot_overlay(predicted_contour, plotting_tif, alternate_contour=None, path=N
io.imsave(path, rescaled)
else:
io.imshow(rescaled)


def randomize_labels(label_map):
"""Takes in a labeled matrix and swaps the integers around
so that color gradient has better contrast
Args:
label_map (numpy.ndarray): labeled TIF with each object assigned a unique value
Returns:
numpy.ndarray:
2D array corresponding to a labeled TIF with permuted object labels"""

unique_vals = np.unique(label_map)[1:]
pos_1 = np.random.choice(unique_vals, size=len(unique_vals))
pos_2 = np.random.choice(unique_vals, size=len(unique_vals))

for i in range(len(pos_1)):
swap_1 = pos_1[i]
swap_2 = pos_2[i]
swap_1_mask = label_map == swap_1
swap_2_mask = label_map == swap_2
label_map[swap_1_mask] = swap_2
label_map[swap_2_mask] = swap_1

label_map = label_map.astype('int16')

return label_map


# TODO: make documentation more specific here
def outline_objects(L_matrix, list_of_lists):
"""Takes in an L matrix generated by skimage.label, along with a
list of lists, and returns a mask that has the
pixels for all cells from each list represented as integer values for easy plotting
Args:
L_matrix (numpy.ndarray):
a label map indicating the label of each cell
list_of_lists (list):
each element is a list of cells we wish to plot separately
Returns:
np.ndarray:
an binary mask indicating the regions of cells outlined
"""

L_plot = copy.deepcopy(L_matrix).astype(float)

for idx, val in enumerate(list_of_lists):
mask = np.isin(L_plot, val)

# use a negative value to not interfere with cell labels
L_plot[mask] = -(idx + 2)

L_plot[L_plot > 1] = 1
L_plot = np.absolute(L_plot)
L_plot = L_plot.astype('int16')
return L_plot


def plot_color_map(outline_matrix, names, plotting_colors=None, ground_truth=None, save_path=None):
"""Plot label map with cells of specified category colored the same
Displays plot in window
Args:
outline_matrix (numpy.ndarray):
output of outline_objects function which assigns same value to cells of same class
names (list):
list of names for each category to use for plotting
plotting_colors (list):
list of colors to use for plotting cell categories
ground_truth (numpy.ndarray):
optional argument to supply label map of true segmentation to be plotted alongside
save_path (str):
optional argument to save plot as TIF
"""

if plotting_colors is None:
plotting_colors = ['Black', 'Grey', 'Blue', 'Green',
'Pink', 'moccasin', 'tan', 'sienna', 'firebrick']

num_categories = np.max(outline_matrix)
plotting_colors = plotting_colors[:num_categories + 1]
cmap = mpl.colors.ListedColormap(plotting_colors)

if ground_truth is not None:
fig, ax = plt.subplots(nrows=1, ncols=2)
mat = ax[0].imshow(outline_matrix, cmap=cmap, vmin=np.min(outline_matrix) - .5,
vmax=np.max(outline_matrix) + .5)
swapped = randomize_labels(ground_truth)
ax[1].imshow(swapped)
else:
fig, ax = plt.subplots(nrows=1, ncols=1)
mat = ax.imshow(outline_matrix, cmap=cmap, vmin=np.min(outline_matrix) - .5,
vmax=np.max(outline_matrix) + .5)

# tell the colorbar to tick at integers
cbar = fig.colorbar(mat, ticks=np.arange(np.min(outline_matrix), np.max(outline_matrix) + 1))

cbar.ax.set_yticklabels(names)

fig.tight_layout()
if save_path is not None:
fig.savefig(save_path, dpi=200)


# TODO: make documentation more specific here
def plot_barchart_errors(pd_array, contour_errors, predicted_errors, save_path=None):
"""Plot different error types in a barchart, along with cell-size correlation in a scatter plot
Args:
pd_array (pandas.array):
pandas cell array representing error types for each class of cell
contour_errors (list):
list of contour error types to extract from array
predicted_errors (list):
list of predictive error types to extract from the array
save_path (str):
optional file path to save generated TIF
"""

# make sure all supplied categories are column names
if np.any(~np.isin(contour_errors + predicted_errors, pd_array.columns)):
raise ValueError("Invalid column name")

fig, ax = plt.subplots(2, 1, figsize=(10, 10))

ax[0].scatter(pd_array["contour_cell_size"], pd_array["predicted_cell_size"])
ax[0].set_xlabel("Contoured Cell")
ax[0].set_ylabel("Predicted Cell")

# compute percentage of different error types
errors = np.zeros(len(predicted_errors) + len(contour_errors))
for i in range(len(contour_errors)):
errors[i] = len(set(pd_array.loc[pd_array[contour_errors[i]], "contour_cell"]))

for i in range(len(predicted_errors)):
errors[i + len(contour_errors)] = len(set(pd_array.loc[pd_array[predicted_errors[i]],
"predicted_cell"]))

errors = errors / len(set(pd_array["predicted_cell"]))
position = range(len(errors))
ax[1].bar(position, errors)

ax[1].set_xticks(position)
ax[1].set_xticklabels(contour_errors + predicted_errors)
ax[1].set_title("Fraction of cells misclassified")

if save_path is not None:
fig.savefig(save_path, dpi=200)


def plot_mod_ap(mod_ap_list, thresholds, labels):
df = pd.DataFrame({'iou': thresholds})

for idx, label in enumerate(labels):
df[label] = mod_ap_list[idx]['scores']

fig, ax = plt.subplots()
for label in labels:
ax.plot('iou', label, data=df, linestyle='-', marker='o')

ax.set_xlabel('IOU Threshold')
ax.set_ylabel('mAP')
ax.legend()
fig.show()


def plot_error_types(errors, labels, error_plotting):
data_dict = pd.DataFrame(pd.Series(errors[0])).transpose()

for i in range(1, len(labels)):
data_dict = data_dict.append(errors[i], ignore_index=True)

data_dict['algos'] = labels

fig, axes = plt.subplots(len(error_plotting))
for i in range(len(error_plotting)):
barchart_helper(ax=axes[i], values=data_dict[error_plotting[i]], labels=labels,
title='{} Errors'.format(error_plotting[i]))

fig.show()
fig.tight_layout()


def barchart_helper(ax, values, labels, title):
positions = range(len(values))
ax.bar(positions, values)
ax.set_xticks(positions)
ax.set_xticklabels(labels)
ax.set_title(title)
82 changes: 0 additions & 82 deletions ark/utils/plot_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,85 +46,3 @@ def test_plot_overlay():
plot_utils.plot_overlay(predicted_contour=example_labels, plotting_tif=example_images,
alternate_contour=example_labels,
path=os.path.join(temp_dir, "example_plot3.tiff"))


def test_randomize_labels():
labels = _generate_segmentation_labels((1024, 1024))
randomized = plot_utils.randomize_labels(labels)

assert np.array_equal(np.unique(labels), np.unique(randomized))

# check that all pixels to belong to single cell in newly transformed image
unique_vals = np.unique(labels)
for val in unique_vals:
coords = labels == val
assert len(np.unique(randomized[coords])) == 1


def test_outline_objects():
labels = _generate_segmentation_labels((1024, 1024), num_cells=300)
unique_vals = np.unique(labels)[1:]

# generate a random subset of unique vals to be placed in each list
vals = np.random.choice(unique_vals, size=60, replace=False)
object_list = [vals[:20], vals[20:40], vals[40:]]

outlined = plot_utils.outline_objects(labels, object_list)

# check that cells in same object list were assigned the same label
mask1 = np.isin(labels, object_list[0])
assert len(np.unique(outlined[mask1])) == 1

mask2 = np.isin(labels, object_list[1])
assert len(np.unique(outlined[mask2])) == 1

mask3 = np.isin(labels, object_list[2])
assert len(np.unique(outlined[mask3])) == 1


def test_plot_mod_ap():
labels = ['alg1', 'alg2', 'alg3']
thresholds = np.arange(0.5, 1, 0.1)
mAP_array = [{'scores': [0.9, 0.8, 0.7, 0.4, 0.2]}, {'scores': [0.8, 0.7, 0.6, 0.3, 0.1]},
{'scores': [0.95, 0.85, 0.75, 0.45, 0.25]}]

plot_utils.plot_mod_ap(mAP_array, thresholds, labels)


def test_plot_error_types():
stats_dict = {
'n_pred': 200,
'n_true': 200,
'correct_detections': 140,
'missed_detections': 40,
'gained_detections': 30,
'merge': 20,
'split': 10,
'catastrophe': 20
}

stats_dict1 = {
'n_pred': 210,
'n_true': 210,
'correct_detections': 120,
'missed_detections': 30,
'gained_detections': 50,
'merge': 50,
'split': 30,
'catastrophe': 50
}

stats_dict2 = {
'n_pred': 10,
'n_true': 20,
'correct_detections': 10,
'missed_detections': 70,
'gained_detections': 50,
'merge': 5,
'split': 3,
'catastrophe': 5
}

plot_utils.plot_error_types([stats_dict, stats_dict1, stats_dict2], ['alg1', 'alg2', 'alg3'],
['missed_detections', 'gained_detections', 'merge', 'split',
'catastrophe'])

0 comments on commit c48700b

Please sign in to comment.