# Explore segmentation results (all classes)

The goal of this notebook is to show side-by-side the normalized images and the result 
of the segmentation in order to verify that nothing went wrong at this stage of the 
pipeline.

In Jupyter notebooks, you can run the cells by pressing `Shift+Enter`. All cells need
to run in the correct order. You can also press `Run All` in the `Cell` menu.

Places where you need to fill in code are indicated with the following comments: 
``` python
#####################################
####### Parameters to change ########
#####################################
data = ...
```

In [None]:
# Import libraries we need
from pathlib import Path
import numpy as np
from matplotlib import pyplot as plt
import tifffile

In [None]:
#####################################
####### Parameters to change ########
#####################################
# Paths
root = Path("test_pipeline") # here replace with the root of the analysis folder
raw_folder = root / "raw-normalized" # replace with your image folder name
seg_folder = root / "result" # replace with the segmentation result folder

# Labels (as they appear in Labkit)
labels = ["background", "tissue", "collagen", "cell"]

# Save plots
save_plots = True # True if you want to save the plots, False if you don't
save_path = root / "plots" # replace with the folder where you want to save the plots

Now, we are going to show all images side by side!

In [None]:
# list tiffs in raw folder
raw_files = sorted(raw_folder.glob("*.tif"))

# number of labels
n_labels = len(labels)

# loop over the files
for f in raw_files:
    raw = tifffile.imread(f)

    sum_name = f.stem + "_segmentation.tif"
    seg_path = seg_folder / sum_name
    seg = tifffile.imread(seg_path)

    label_max = np.max(seg)
    if label_max != n_labels - 1:
        print(f"{sum_name} has max value {label_max}, expected {n_labels-1}")

    # plot raw image and mask overlays
    plt.figure(figsize=(20, 10))

    for i in range(n_labels):
        plt.subplot(1, n_labels, i + 1)
        plt.imshow(raw)
        plt.imshow(seg == i, alpha=0.2)
        plt.title(f.stem + " (" + labels[i] + ")")
    
    if save_plots:
        if not save_path.exists():
            save_path.mkdir()

        plt.savefig(save_path / f"{f.stem}_qc_classes.png")
  
