In [None]:
import pickle
import random
import matplotlib.pyplot as plt
import numpy as np
import zipfile
import os
import random
import datetime
import h5py
import sklearn.metrics 
import cv2
import skimage.measure as measure
import skimage.filters as filters
import skimage.morphology as morphology
import skimage.exposure as exposure
import tensorflow as tf
import tensorflow.keras.layers as layers
import tensorflow.compat.v2 as tf
import seaborn as sns
import matplotlib.patches as patches
from functions import *

In [None]:
#os.environ["CUDA_VISIBLE_DEVICES"]="1" # NVIDIA GeForce RTX 3090
os.environ["CUDA_VISIBLE_DEVICES"]="3" # NVIDIA GeForce RTX 2080
 
print("Num GPUs Available: ", len(tf.config.experimental.list_physical_devices('GPU')))
gpus = tf.config.experimental.list_physical_devices('GPU')
 
print(gpus)

# Test set

In [None]:
with open('Data/TEST_Images_2.pkl', 'rb') as file:
    test_images = pickle.load(file)

with open('Data/TEST_Labels_2.pkl', 'rb') as file:
    test_labels = pickle.load(file)

with open('Data/TEST_Coordinates_2.pkl', 'rb') as file:
    test_coordinates = pickle.load(file)
    
with open('Data/TEST_Genera_2.pkl', 'rb') as file:
    test_genera = pickle.load(file)
    
name_model = "Final (M0_6(3x3)_2)"
filepath_dic = "Results/" + name_model
name_test_prediction = filepath_dic + "/test_prediction.pkl"

with open(name_test_prediction, 'rb') as file:
    test_predictions = pickle.load(file)

## Data exploration

In [None]:
print(test_images.dtype)
print(test_labels.dtype)

In [None]:
print("The test set contains", str(len(test_images)), "images.")
print("The test set contains", str(len(test_labels)), "labels.")
print("The length of the list with coordinates is:", str(len(test_coordinates)))
print("The length of the list with genera is:", str(len(test_genera)))

In [None]:
print("The images in the test set have a width of", str( test_images.shape[2]),"and a height of",str(test_images.shape[1]))
print("The labels in the test set have a width of", str( test_labels.shape[2]),"and a height of",str(test_labels.shape[1]))

In [None]:
print("The maximum pixel value of the images is:", str(np.amax(test_images)))
print("The minumum pixel value of the images is:", str(np.amin(test_images)))

print("The maximum pixel value of the labels is:", str(np.amax(test_labels)))
print("The minumum pixel value of the labels is:", str(np.amin(test_labels)))

### Plot

In [None]:
i = random.choice(range(len(test_images)))
plt.figure(figsize=(30,20))
plt.subplot(1,2,1)
plt.imshow(test_images[i],cmap="gray")
plt.axis("off")
plt.title("Annotated image", size=30)
plt.scatter(*zip(*test_coordinates[i]),s=20, c="orange")

plt.subplot(1,2,2)
plt.imshow(test_labels[i],cmap="nipy_spectral")
plt.axis("off")
plt.title("Ground thruth", size=30)

# Prediction

## Final model

In [None]:
name_model = "Final (M0_6(3x3)_2)"
filepath_dic = "Results/" + name_model
filepath_checkpoint_model = filepath_dic + "/checkpoint.model.keras"

## Prediction

In [None]:
model_best = tf.keras.models.load_model(filepath_checkpoint_model)
batch_size = 1
test_predictions = model_best.predict(
    test_images,
    batch_size=batch_size)

In [None]:
test_predictions.shape

In [None]:
name_test_prediction = filepath_dic + "/test_prediction.pkl"
test_predictions = test_predictions.reshape((126,2048,2688))
with open(name_test_prediction, 'wb') as file:
    pickle.dump(test_predictions, file)

In [None]:
i = random.choice(range(len(test_images)))
plt.figure(figsize=(50,25))
plt.subplot(1,3,1)
plt.imshow(test_images[i],cmap="gray")
plt.axis("off")
plt.title("Annotated image", size=50)
plt.scatter(*zip(*test_coordinates[i]),s=50, c="orange")
plt.subplot(1,3,2)
plt.imshow(test_labels[i],cmap="nipy_spectral")
plt.axis("off")
plt.title("Ground truth", size=50)
plt.subplot(1,3,3)
plt.imshow(test_predictions[i],cmap="nipy_spectral")
plt.axis("off")
plt.title("Prediction", size=50)
plt.scatter(*zip(*test_coordinates[i]),s=5, c="white")

In [None]:
i

# Post-processing

## Parameters for post-processing

In [None]:
blurring = True
blur_kernel_size = 11

kernel_shape = "ellipse"

threshold_technique = cv2.THRESH_BINARY+cv2.THRESH_OTSU
threshold = 1

morphological_operations = {"erosion":{"kernel_size":11, "iterations":1},
                            "closing":{"kernel_size":11, "iterations":1}}

order_morphological_operations =  ["closing","erosion"] 

## Post-processing

In [None]:
post_processed_predictions = []

for prediction in test_predictions:
    
    # normalize pixel values between 0 and 1
    normalized = normalization(prediction)
    normalized = (normalized*255).astype("uint8")
    
    # post-processing
    post_processed  = post_processing(
        im=normalized,
        blurring=blurring,
        blur_kernel_sz=blur_kernel_size,
        thresh_technique=threshold_technique,
        thresh=threshold,
        kernel_shape=kernel_shape,
        morph_ops=morphological_operations,
        order_morph_ops=order_morphological_operations)
    
    post_processed_predictions.append(post_processed)

### Figures

In [None]:
index = random.choice(range(len(test_images)))

In [None]:
plots = []

real_positions = test_coordinates[index]
    
prediction = test_predictions[index]
plots.append(prediction)

# normalisation
normalised = normalization(prediction)
normalised = (normalised*255).astype("uint8")
plots.append(normalised)
    
# blurring
blur_kernel = np.ones((11, 11), np.float32)/11**2
blurred = cv2.filter2D(src=normalised, ddepth=-1, kernel=blur_kernel)
plots.append(blurred)
 
# thresholding
binary = cv2.threshold(blurred,0,255,cv2.THRESH_BINARY+cv2.THRESH_OTSU)[1]
plots.append(binary)
        

# morpholoical operations
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(11, 11))
closing = cv2.morphologyEx(binary, cv2.MORPH_CLOSE,kernel, iterations=1)
erosion = cv2.erode(closing, kernel, iterations=1)
plots.append(closing)
plots.append(erosion)
        
# labbeling
labelled = morphology.label(erosion)
plots.append(labelled)

cmaps = ["nipy_spectral", "nipy_spectral", "nipy_spectral", "gray", "gray", "gray", "nipy_spectral"]
# figure
plt.figure(figsize=(100,80))
for i,plot in enumerate(plots):
    plt.subplot(1,len(plots),i+1)
    plt.imshow(plot, cmap=cmaps[i], interpolation='none')
    plt.scatter(*zip(*real_positions),s=10, c="white")
    plt.axis("off")

In [None]:
plt.imsave(os.path.join("Images thesis/Post-processing","Prediction "+str(index)+".png"), arr=prediction, cmap="nipy_spectral", format="png")
plt.imsave(os.path.join("Images thesis/Post-processing","Average blurring "+str(index)+".png"), arr=blurred, cmap="nipy_spectral", format="png")
plt.imsave(os.path.join("Images thesis/Post-processing","Thresholding "+str(index)+".png"), arr=binary, cmap="gray", format="png")
plt.imsave(os.path.join("Images thesis/Post-processing","Closing "+str(index)+".png"), arr=closing, cmap="gray", format="png")
plt.imsave(os.path.join("Images thesis/Post-processing","Erosion "+str(index)+".png"), arr=erosion, cmap="gray", format="png")
plt.imsave(os.path.join("Images thesis/Post-processing","Labelling "+str(index)+".png"), arr=labelled, cmap="nipy_spectral", format="png")

# Evaluation

In [None]:
genera_unique = ["Agapanthus", "Geranium", "Ilex", "Persicaria","Salvia", "Thalictrum"]

In [None]:
positions_actual = test_coordinates
positions_predicted = [centroid(post_processed_predictions[i]) for i in range(len(post_processed_predictions))]

In [None]:
textwidth = 455.24411

def set_size(width, fraction=1):
    """Set figure dimensions to avoid scaling in LaTeX.

    Parameters
    ----------
    width: float
            Document textwidth or columnwidth in pts
    fraction: float, optional
            Fraction of the width which you wish the figure to occupy

    Returns
    -------
    fig_dim: tuple
            Dimensions of figure in inches
    """
    # Width of figure (in pts)
    fig_width_pt = width * fraction

    # Convert from pt to inches
    inches_per_pt = 1 / 72.27

    # Golden ratio to set aesthetic figure height
    # https://disq.us/p/2940ij3
    golden_ratio = (5**.5 - 1) / 2

    # Figure width in inches
    fig_width_in = fig_width_pt * inches_per_pt
    # Figure height in inches
    fig_height_in = fig_width_in * golden_ratio

    fig_dim = (fig_width_in, fig_height_in)

    return fig_dim

### Figure

In [None]:
i = random.choice(range(len(test_images)))
plt.figure(figsize=(50,20))
plt.subplot(1,3,1)
plt.imshow(test_images[i],cmap="gray")
plt.axis("off")
plt.title("Annotated image", size=50)
plt.scatter(*zip(*positions_actual[i]),s=30, c="orange")
plt.subplot(1,3,2)
plt.imshow(test_predictions[i], cmap="nipy_spectral")
plt.axis("off")
plt.title("Prediction", size=50)
plt.subplot(1,3,3)
plt.imshow(post_processed_predictions[i], cmap="nipy_spectral")
plt.scatter(*zip(*positions_predicted[i]),s=50, c="white", marker="X")
plt.scatter(*zip(*positions_actual[i]),s=50, c="orange")
plt.axis("off")
plt.title("Post-processed prediction", size=50)

In [None]:
actual_positions = positions_actual[i]
predicted_positions = positions_predicted[i]
n_chromosomes_real = len(actual_positions)
n_chromosomes_predicted = len(predicted_positions)
print("Real chromsome number:", n_chromosomes_real)
print("Predicted chromsome number:", n_chromosomes_predicted)

## Chromosome number

In [None]:
dict_chromosome_number = {"Genus":[], "Actual chromosome number":[], "Predicted chromosome number":[], "True positives": []}
for i,post_processed_prediction in enumerate(post_processed_predictions):
    genus = test_genera[i]
    actual_positions = positions_actual[i]
    predicted_positions = positions_predicted[i]
    n_chromosomes_real = len(actual_positions)
    n_chromosomes_predicted = len(predicted_positions)

    
    distances_best_matches = find_best_matches(actual_positions, predicted_positions)
    TP = calculate_true_positives(distances_best_matches, genus)
    
    dict_chromosome_number["Genus"].append(genus)
    dict_chromosome_number["Actual chromosome number"].append(n_chromosomes_real)
    dict_chromosome_number["Predicted chromosome number"].append(n_chromosomes_predicted)
    dict_chromosome_number["True positives"].append(TP)

In [None]:
dict_chromosome_number["Actual chromosome number"][44]

### Scatterplot

In [None]:
fig, ax = plt.subplots(figsize=set_size(textwidth))
sns.scatterplot(data=dict_chromosome_number, 
                x="Actual chromosome number",
                y="Predicted chromosome number",
                hue="Genus",
                hue_order=genera_unique,
                ax=ax,
                s=25)
ax.tick_params(axis='both', which='major', labelsize=12)
ax.set_xlabel("Werkelijke chromosoomaantal", fontsize=14)
ax.set_ylabel("Voorspelde chromosoomaantal", fontsize=14)
n = np.linspace(0, max(max(dict_chromosome_number["Actual chromosome number"]), max(dict_chromosome_number["Predicted chromosome number"])), 1000)
ax.plot(n, n, 'k-')
plt.savefig("Images presentation/Scatterplot DLM.pdf", format="pdf", bbox_inches='tight')
plt.show()


### Overestimation/Underestimation

In [None]:
actual_numbers = list(dict_chromosome_number["Actual chromosome number"])
predicted_numbers = list(dict_chromosome_number["Predicted chromosome number"])
difference = []
difference_abs = []
for i in range(len(actual_numbers)):
    actual_number = actual_numbers[i]
    predicted_number = predicted_numbers[i]
    difference_abs.append(abs(actual_number-predicted_number))
    difference.append(actual_number-predicted_number)

In [None]:
print("Min difference:",np.min(difference_abs))
print("Max difference:",np.max(difference_abs))
print("Mean difference:",np.mean(difference_abs))
print("Median difference:",np.median(difference_abs))

print("Overestimation of the chromosome number:", len([i for i in difference if i < 0]))
print("Underestimation of the chromosome number:", len([i for i in difference if i > 0]))
print("Correct prediction:", len([i for i in difference if i == 0]))

### Mean absolute error

In [None]:
MAE = sklearn.metrics.mean_absolute_error(list(dict_chromosome_number["Actual chromosome number"]),list(dict_chromosome_number["Predicted chromosome number"]))
print("The mean absolute error is:", MAE)

In [None]:
difference.sort()
print(difference)

## Evalution metrics

### True positives

In [None]:
dict_chromosome_number["True positives"]    

In [None]:
dict_evaluation_metrics = {"recall":{"All":[], "Agapanthus":[], "Geranium":[], "Ilex":[], "Persicaria":[],"Salvia":[], "Thalictrum":[]},
                           "precision":{"All":[],"Agapanthus":[], "Geranium":[], "Ilex":[], "Persicaria":[],"Salvia":[], "Thalictrum":[]},
                           "F1":{"All":[],"Agapanthus":[], "Geranium":[], "Ilex":[], "Persicaria":[],"Salvia":[], "Thalictrum":[]}}
    
for i in range(len(test_genera)):
    genus = test_genera[i]
    actual_positions = positions_actual[i]
    predicted_positions = positions_predicted[i]
    
    rec = recall_evaluation(actual_positions, predicted_positions,genus)
    prec = precision_evaluation(actual_positions, predicted_positions,genus)
    f1 = F1_evaluation(actual_positions, predicted_positions,genus)
    
    dict_evaluation_metrics["recall"]["All"].append(rec)
    dict_evaluation_metrics["precision"]["All"].append(prec)
    dict_evaluation_metrics["F1"]["All"].append(f1)
    dict_evaluation_metrics["recall"][genus].append(rec)
    dict_evaluation_metrics["precision"][genus].append(prec)
    dict_evaluation_metrics["F1"][genus].append(f1)

In [None]:
rec_all = dict_evaluation_metrics["recall"]["All"]
prec_all = dict_evaluation_metrics["precision"]["All"]
f1_all = dict_evaluation_metrics["F1"]["All"]

### Table

For each genera:

In [None]:
print("{:15s} {:1s}{:^10s}{:^10s}{:^10s}{:^10s}{:^10s}{:^10s} ".format("GENUS","|", "RECALL","std", "PRECISION","std", "F1","std"))
print(75*"=")
for genus in genera_unique:
    rec = np.mean(dict_evaluation_metrics["recall"][genus])
    std_rec = np.std(dict_evaluation_metrics["recall"][genus])
    prec = np.mean(dict_evaluation_metrics["precision"][genus])
    std_prec = np.std(dict_evaluation_metrics["precision"][genus])
    f1 = np.mean(dict_evaluation_metrics["F1"][genus])
    std_f1 = np.std(dict_evaluation_metrics["F1"][genus])
    print("{:15s} {:1s}{:^10.2f}{:^10.2f}{:^10.2f}{:^10.2f}{:^10.2f}{:^10.2f} ".format(genus,"|", rec, std_rec, prec, std_prec, f1, std_f1))   
print(75*"-")
print("{:15s} {:1s} {:^10.2f}{:^10.2f}{:^10.2f}{:^10.2f}{:^10.2f}{:^10.2f} ".format("All","|", np.mean(rec_all), np.std(rec_all), np.mean(prec_all), np.std(prec_all), np.mean(f1_all), np.std(f1_all)))

For whole test set:

In [None]:
print("{:20s}{:1s}{:^10s}{:^10s}{:^10s}{:^10s}{:^10s}".format("","|", "MEAN","std", "MEDIAN","MIN","MAX"))
print(70*"=")
print("{:20s}{:1s}{:^10.2f}{:^10.2f}{:^10.2f}{:^10.2f}{:^10.2f}".format("precision","|",np.mean(prec_all),np.std(prec_all), np.median(prec_all), np.min(prec_all),np.max(prec_all)))
print("{:20s}{:1s}{:^10.2f}{:^10.2f}{:^10.2f}{:^10.2f}{:^10.2f}".format("recall","|",np.mean(rec_all), np.std(rec_all), np.median(rec_all),np.min(rec_all),np.max(rec_all)))
print("{:20s}{:1s}{:^10.2f}{:^10.2f}{:^10.2f}{:^10.2f}{:^10.2f}".format("F1-score","|",np.mean(f1_all),np.std(f1_all), np.median(f1_all),np.min(f1_all),np.max(f1_all)))

### Boxplots

In [None]:
fig, ax = plt.subplots(figsize=set_size(textwidth))
data = [rec_all, prec_all, f1_all]
labels = ["Recall", "Precision", "F1"]
bp = ax.boxplot(data, patch_artist=True,
                labels=labels,
                boxprops=dict(facecolor="steelblue", color="steelblue"),
                whiskerprops=dict(color="steelblue"),
                capprops=dict(color="steelblue"),
                medianprops=dict(color="orange", linewidth=1.5),
                showfliers=True,
                flierprops=dict(markeredgecolor="darkorange"))
for box in bp['boxes']:
    box.set_facecolor("lightsteelblue")  
ax.yaxis.set_ticks_position('none')
ax.xaxis.set_ticks_position('none')
ax.set_ylim([0, 1])
ax.tick_params(axis='both', which='major', labelsize=12)
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.spines["left"].set_visible(False)
ax.spines["bottom"].set_visible(False)
ax.grid(color='grey', axis='y', linestyle='-', linewidth=0.25, alpha=0.5)
#plt.savefig("Images presentation/Boxplot evaluation DLM.png", format="png", bbox_inches='tight')

## Good/Bad predictions

In [None]:
f1_good = [i for i, x in enumerate(f1_all) if x == 1]     
rec_good = [i for i, x in enumerate(rec_all) if x == 1]             
prec_good = [i for i, x in enumerate(prec_all) if x == 1]

f1_bad = [i for i, x in enumerate(f1_all) if x < 0.8]    
rec_bad = [i for i, x in enumerate(rec_all) if x < 0.8]               
prec_bad = [i for i, x in enumerate(prec_all) if x < 0.8]

In [None]:
#bad_predictions = f1_bad
Agapanthus = [18,92,110]
Salvia = [14,31,54,86]
Thalictrum = [24,41,48,66, 70, 90, 125]

In [None]:
for i in f1_bad:
    print("F1:", f1_all[i])
    print("recall:", rec_all[i])
    print("precision:", prec_all[i])
    print("")

In [None]:
f1_good

In [None]:
index = 99
image = test_images[index]
prediction = test_predictions[index]
post_processed_prediction = post_processed_predictions[index]
label = test_labels[index]
positions_actual_resized = [(positions_actual[index][i][0] / 128, positions_actual[index][i][1] / 128) for i in range(len(positions_actual[index]))]
positions_predicted_resized = [(positions_predicted[index][i][0] / 128, positions_predicted[index][i][1] / 128) for i in range(len(positions_predicted[index]))]

In [None]:
plt.figure(figsize=(21,16.45), facecolor="whitesmoke")
plt.scatter(*zip(*positions_actual_resized),s=200, c="orange")
plt.scatter(*zip(*positions_predicted_resized),s=200, c="steelblue", marker="X")
plt.xlim(0,21)
plt.ylim(16.45,0)
plt.legend(["True chromosome","Predicted chromosome"], fontsize=45, facecolor= "whitesmoke", edgecolor="whitesmoke")
plt.axis("off")
plt.savefig(os.path.join("Images presentation","Ground truth "+str(index)+".pdf"), format="pdf", bbox_inches='tight', pad_inches=0)
plt.figure(figsize=(21,16.45))
plt.imshow(post_processed_prediction , cmap="nipy_spectral")
plt.scatter(*zip(*positions_actual[index]),s=15, c="gray")
plt.scatter(*zip(*positions_predicted[index]),s=20, c="w", marker="X")
plt.axis("off")
plt.legend(["True chromosome","Predicted chromosome"], fontsize=25)
#plt.savefig(os.path.join("Images thesis/Bad predictions DLM","Fused blobs  Ground truth + Prediction "+str(index)+".pdf"), format="pdf", bbox_inches='tight', pad_inches=0)

In [None]:
plt.figure(figsize=(21,16.45))
plt.imshow(image , cmap="gray")
plt.scatter(*zip(*positions_actual[index]),s=100, c="orange")
plt.axis("off")
#plt.savefig(os.path.join("Images thesis/Bad predictions DLM","F1 Annotated image "+str(index)+".pdf"), format="pdf", bbox_inches='tight', pad_inches=0)

In [None]:
plt.imsave(os.path.join("Images presentation","Original image "+str(index)+".png"), arr=image, cmap="gray", format="png")
#plt.imsave(os.path.join("Images presentation","Prediction "+str(index)+".png"), arr=prediction, cmap="nipy_spectral", format="png")
#plt.imsave(os.path.join("Images presentation","Post-processed prediction "+str(index)+".png"), arr=post_processed_prediction, cmap="nipy_spectral", format="png")
#plt.imsave(os.path.join("Images presentation","Label "+str(index)+".png"), arr=label, cmap="nipy_spectral", format="png")

In [None]:
plt.imsave(os.path.join("Images thesis/Good predictions DLM","Original image "+str(index)+".png"), arr=image, cmap="gray", format="png")
plt.imsave(os.path.join("Images thesis/Good predictions DLM","Prediction "+str(index)+".png"), arr=prediction, cmap="nipy_spectral", format="png")
plt.imsave(os.path.join("Images thesis/Good predictions DLM","Post-processed prediction "+str(index)+".png"), arr=post_processed_prediction, cmap="nipy_spectral", format="png")
plt.imsave(os.path.join("Images thesis/Good predictions DLM","Label "+str(index)+".png"), arr=label, cmap="nipy_spectral", format="png")

In [None]:
plt.imsave(os.path.join("Images thesis/Bad predictions DLM","Blob shape Original image "+str(index)+".png"), arr=image, cmap="gray", format="png")
plt.imsave(os.path.join("Images thesis/Bad predictions DLM","Blob shape Prediction "+str(index)+".png"), arr=prediction, cmap="nipy_spectral", format="png")
plt.imsave(os.path.join("Images thesis/Bad predictions DLM","Fused blobs Post-processed prediction "+str(index)+".png"), arr=post_processed_prediction, cmap="nipy_spectral", format="png")
plt.imsave(os.path.join("Images thesis/Bad predictions DLM","Rec Label "+str(index)+".png"), arr=post_processed_prediction, cmap="nipy_spectral", format="png")

In [None]:
plt.figure(figsize=(21,16.45))
ax = plt.subplot()
ax.imshow(image[1070:1670,750:1350], cmap="gray")
rect = patches.Rectangle((230,88), 60, 40, linewidth=4, edgecolor="red", facecolor="none")
ax.add_patch(rect)
ax.axis("off")
plt.savefig(os.path.join("Images thesis/Bad predictions DLM","Removal blob Original image "+str(index)+".pdf"), format="pdf", bbox_inches='tight', pad_inches=0)

plt.figure(figsize=(21,16.45))
ax = plt.subplot()
ax.imshow(prediction[1070:1670,750:1350], cmap="nipy_spectral")
rect = patches.Rectangle((230,88), 60, 40, linewidth=4, edgecolor="red", facecolor="none")
ax.add_patch(rect)
ax.axis("off")
plt.savefig(os.path.join("Images thesis/Bad predictions DLM","Removal blob Prediction "+str(index)+".pdf"), format="pdf", bbox_inches='tight', pad_inches=0)

plt.figure(figsize=(21,16.45))
ax = plt.subplot()
ax.imshow(post_processed_prediction[1070:1670,750:1350], cmap="nipy_spectral")
rect = patches.Rectangle((230,88), 60, 40, linewidth=4, edgecolor="red", facecolor="none")
ax.add_patch(rect)
ax.axis("off")
plt.savefig(os.path.join("Images thesis/Bad predictions DLM","Removal blob Post-processed prediction "+str(index)+".pdf"), format="pdf", bbox_inches='tight', pad_inches=0)

In [None]:
for index in f1_good:
    image = test_images[index]
    prediction = test_predictions[index]
    post_processed_prediction = post_processed_predictions[index]
    
    plt.figure(figsize=(30,20))
    plt.subplot(1,3,1)
    plt.imshow(image , cmap="gray")
    plt.scatter(*zip(*positions_actual[index]), c="orange")
    plt.title("Annotated image "+str(index), size=30)
    plt.axis("off")
    plt.subplot(1,3,2)
    plt.imshow(prediction , cmap="nipy_spectral")
    plt.title("Prediction", size=30)
    plt.axis("off")
    plt.subplot(1,3,3)
    plt.imshow(post_processed_prediction , cmap="nipy_spectral")
    plt.scatter(*zip(*positions_actual[index]), c="gray")
    plt.scatter(*zip(*positions_predicted[index]), c="w", marker="X")
    plt.title("Post-processed prediction", size=30)
    plt.axis("off")
    plt.tight_layout()
    plt.legend(["True chromosome","Predicted chromosome"],fontsize=20)

In [None]:
prec_all[99]