In [None]:
!pip install batchgenerators -q

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/61.7 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.7/61.7 kB[0m [31m1.6 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m96.4/96.4 kB[0m [31m5.0 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for batchgenerators (setup.py) ... [?25l[?25hdone


In [None]:
import torch
import torch.nn as nn
import os
import matplotlib
#from batchgenerators.utilities.file_and_folder_operations import join
matplotlib.use('agg')
import seaborn as sns
import matplotlib.pyplot as plt
from collections import OrderedDict
import numpy as np
import cv2

# Curvas de aprendizaje

In [None]:
def plot_progress_png(output_path, epoca, train_losses, val_losses, mean_dice, ema_dice):
        # we infer the epoch form our internal logging
        epoch = epoca
        sns.set(font_scale=2.5)
        fig, ax = plt.subplots(figsize=(30, 18))
        # regular progress.png as we are used to from previous nnU-Net versions
        ax2 = ax.twinx()
        x_values = list(range(epoch + 1))
        ax.plot(x_values, train_losses[:epoch + 1], color='b', ls='-', label="loss_tr", linewidth=4)
        ax.plot(x_values, val_losses[:epoch + 1], color='r', ls='-', label="loss_val", linewidth=4)
        ax2.plot(x_values, mean_dice[:epoch + 1], color='g', ls='dotted', label="pseudo dice",
                 linewidth=3)
        ax2.plot(x_values, ema_dice[:epoch + 1], color='g', ls='-', label="pseudo dice (mov. avg.)",
                 linewidth=4)
        ax.set_xlabel("epoch")
        ax.set_ylabel("loss")
        ax2.set_ylabel("pseudo dice")
        ax.legend(loc=(0, 1))
        ax2.legend(loc=(0.2, 1))

        plt.tight_layout()

        fig.savefig(join(output_path))
        plt.close()

In [None]:
checkpoint_uformer = torch.load('/content/drive/MyDrive/TFG/Uformer/V2/logs/CholecSeg8k/Uformer_T_/models/model_final.pth', map_location=torch.device('cpu'))
checkpoint_reprec = torch.load('/content/drive/MyDrive/TFG/Parte2/FineTune/logs/CholecSeg8k/models/model_latest.pth', map_location=torch.device('cpu'))

In [None]:
plot_progress_png('/content/drive/MyDrive/TFG/Resultados/learning_curves/uformer_learning-curve.png', checkpoint_uformer['epoch'], checkpoint_uformer['train_losses'], checkpoint_uformer['val_losses'], checkpoint_uformer['mean_dice'], checkpoint_uformer['ema_dice'])
plot_progress_png('/content/drive/MyDrive/TFG/Resultados/learning_curves/reprec_learning-curve.png', checkpoint_reprec['epoch'], checkpoint_reprec['train_losses'], checkpoint_reprec['val_losses'], checkpoint_reprec['mean_dice'], checkpoint_reprec['ema_dice'])

# Métricas

## Funciones Previas

In [None]:
from abc import ABC, abstractmethod
from typing import Tuple, Union, List
import numpy as np


class BaseReaderWriter(ABC):
    @staticmethod
    def _check_all_same(input_list):
        if len(input_list) == 1:
            return True
        else:
            # compare all entries to the first
            return np.allclose(input_list[0], input_list[1:])

    @staticmethod
    def _check_all_same_array(input_list):
        # compare all entries to the first
        for i in input_list[1:]:
            if i.shape != input_list[0].shape or not np.allclose(i, input_list[0]):
                return False
        return True

    @abstractmethod
    def read_images(self, image_fnames: Union[List[str], Tuple[str, ...]]) -> Tuple[np.ndarray, dict]:
        """
        Reads a sequence of images and returns a 4d (!) np.ndarray along with a dictionary. The 4d array must have the
        modalities (or color channels, or however you would like to call them) in its first axis, followed by the
        spatial dimensions (so shape must be c,x,y,z where c is the number of modalities (can be 1)).
        Use the dictionary to store necessary meta information that is lost when converting to numpy arrays, for
        example the Spacing, Orientation and Direction of the image. This dictionary will be handed over to write_seg
        for exporting the predicted segmentations, so make sure you have everything you need in there!

        IMPORTANT: dict MUST have a 'spacing' key with a tuple/list of length 3 with the voxel spacing of the np.ndarray.
        Example: my_dict = {'spacing': (3, 0.5, 0.5), ...}. This is needed for planning and
        preprocessing. The ordering of the numbers must correspond to the axis ordering in the returned numpy array. So
        if the array has shape c,x,y,z and the spacing is (a,b,c) then a must be the spacing of x, b the spacing of y
        and c the spacing of z.

        In the case of 2D images, the returned array should have shape (c, 1, x, y) and the spacing should be
        (999, sp_x, sp_y). Make sure 999 is larger than sp_x and sp_y! Example: shape=(3, 1, 224, 224),
        spacing=(999, 1, 1)

        For images that don't have a spacing, set the spacing to 1 (2d exception with 999 for the first axis still applies!)

        :param image_fnames:
        :return:
            1) a np.ndarray of shape (c, x, y, z) where c is the number of image channels (can be 1) and x, y, z are
            the spatial dimensions (set x=1 for 2D! Example: (3, 1, 224, 224) for RGB image).
            2) a dictionary with metadata. This can be anything. BUT it HAS to include a {'spacing': (a, b, c)} where a
            is the spacing of x, b of y and c of z! If an image doesn't have spacing, just set this to 1. For 2D, set
            a=999 (largest spacing value! Make it larger than b and c)

        """
        pass

    @abstractmethod
    def read_seg(self, seg_fname: str) -> Tuple[np.ndarray, dict]:
        """
        Same requirements as BaseReaderWriter.read_image. Returned segmentations must have shape 1,x,y,z. Multiple
        segmentations are not (yet?) allowed

        If images and segmentations can be read the same way you can just `return self.read_image((image_fname,))`
        :param seg_fname:
        :return:
            1) a np.ndarray of shape (1, x, y, z) where x, y, z are
            the spatial dimensions (set x=1 for 2D! Example: (1, 1, 224, 224) for 2D segmentation).
            2) a dictionary with metadata. This can be anything. BUT it HAS to include a {'spacing': (a, b, c)} where a
            is the spacing of x, b of y and c of z! If an image doesn't have spacing, just set this to 1. For 2D, set
            a=999 (largest spacing value! Make it larger than b and c)
        """
        pass

    @abstractmethod
    def write_seg(self, seg: np.ndarray, output_fname: str, properties: dict) -> None:
        """
        Export the predicted segmentation to the desired file format. The given seg array will have the same shape and
        orientation as the corresponding image data, so you don't need to do any resampling or whatever. Just save :-)

        properties is the same dictionary you created during read_images/read_seg so you can use the information here
        to restore metadata

        IMPORTANT: Segmentations are always 3D! If your input images were 2d then the segmentation will have shape
        1,x,y. You need to catch that and export accordingly (for 2d images you need to convert the 3d segmentation
        to 2d via seg = seg[0])!

        :param seg: A segmentation (np.ndarray, integer) of shape (x, y, z). For 2D segmentations this will be (1, y, z)!
        :param output_fname:
        :param properties: the dictionary that you created in read_images (the ones this segmentation is based on).
        Use this to restore metadata
        :return:
        """
        pass

In [None]:
from typing import Tuple, Union, List
import numpy as np
from skimage import io


class NaturalImage2DIO(BaseReaderWriter):
    """
    ONLY SUPPORTS 2D IMAGES!!!
    """

    # there are surely more we could add here. Everything that can be read by skimage.io should be supported
    supported_file_endings = [
        '.png',
        # '.jpg',
        # '.jpeg', # jpg not supported because we cannot allow lossy compression! segmentation maps!
        '.bmp',
        '.tif'
    ]

    def read_images(self, image_fnames: Union[List[str], Tuple[str, ...]]) -> Tuple[np.ndarray, dict]:
        images = []
        for f in image_fnames:
            npy_img = io.imread(f)
            if npy_img.ndim == 3:
                # rgb image, last dimension should be the color channel and the size of that channel should be 3
                # (or 4 if we have alpha)
                assert npy_img.shape[-1] == 3 or npy_img.shape[-1] == 4, "If image has three dimensions then the last " \
                                                                         "dimension must have shape 3 or 4 " \
                                                                         f"(RGB or RGBA). Image shape here is {npy_img.shape}"
                # move RGB(A) to front, add additional dim so that we have shape (c, 1, X, Y), where c is either 3 or 4
                images.append(npy_img.transpose((2, 0, 1))[:, None])
            elif npy_img.ndim == 2:
                # grayscale image
                images.append(npy_img[None, None])

        if not self._check_all_same([i.shape for i in images]):
            print('ERROR! Not all input images have the same shape!')
            print('Shapes:')
            print([i.shape for i in images])
            print('Image files:')
            print(image_fnames)
            raise RuntimeError()
        return np.vstack(images, dtype=np.float32, casting='unsafe'), {'spacing': (999, 1, 1)}

    def read_seg(self, seg_fname: str) -> Tuple[np.ndarray, dict]:
        return self.read_images((seg_fname, ))

    def write_seg(self, seg: np.ndarray, output_fname: str, properties: dict) -> None:
        io.imsave(output_fname, seg[0].astype(np.uint8, copy=False), check_contrast=False)

In [None]:
from collections.abc import Iterable

import numpy as np
import torch


def recursive_fix_for_json_export(my_dict: dict):
    # json is ... a very nice thing to have
    # 'cannot serialize object of type bool_/int64/float64'. Apart from that of course...
    keys = list(my_dict.keys())  # cannot iterate over keys() if we change keys....
    for k in keys:
        if isinstance(k, (np.int64, np.int32, np.int8, np.uint8)):
            tmp = my_dict[k]
            del my_dict[k]
            my_dict[int(k)] = tmp
            del tmp
            k = int(k)

        if isinstance(my_dict[k], dict):
            recursive_fix_for_json_export(my_dict[k])
        elif isinstance(my_dict[k], np.ndarray):
            assert my_dict[k].ndim == 1, 'only 1d arrays are supported'
            my_dict[k] = fix_types_iterable(my_dict[k], output_type=list)
        elif isinstance(my_dict[k], (np.bool_,)):
            my_dict[k] = bool(my_dict[k])
        elif isinstance(my_dict[k], (np.int64, np.int32, np.int8, np.uint8)):
            my_dict[k] = int(my_dict[k])
        elif isinstance(my_dict[k], (np.float32, np.float64, np.float16)):
            my_dict[k] = float(my_dict[k])
        elif isinstance(my_dict[k], list):
            my_dict[k] = fix_types_iterable(my_dict[k], output_type=type(my_dict[k]))
        elif isinstance(my_dict[k], tuple):
            my_dict[k] = fix_types_iterable(my_dict[k], output_type=tuple)
        elif isinstance(my_dict[k], torch.device):
            my_dict[k] = str(my_dict[k])
        else:
            pass  # pray it can be serialized


In [None]:
import multiprocessing
import os
from copy import deepcopy
from multiprocessing import Pool
from typing import Tuple, List, Union, Optional

import numpy as np
from batchgenerators.utilities.file_and_folder_operations import subfiles, join, save_json, load_json, \
    isfile

def label_or_region_to_key(label_or_region: Union[int, Tuple[int]]):
    return str(label_or_region)


def key_to_label_or_region(key: str):
    try:
        return int(key)
    except ValueError:
        key = key.replace('(', '')
        key = key.replace(')', '')
        split = key.split(',')
        return tuple([int(i) for i in split if len(i) > 0])


def save_summary_json(results: dict, output_file: str):
    """
    json does not support tuples as keys (why does it have to be so shitty) so we need to convert that shit
    ourselves
    """
    results_converted = deepcopy(results)
    # convert keys in mean metrics
    results_converted['mean'] = {label_or_region_to_key(k): results['mean'][k] for k in results['mean'].keys()}
    # convert metric_per_case
    for i in range(len(results_converted["metric_per_case"])):
        results_converted["metric_per_case"][i]['metrics'] = \
            {label_or_region_to_key(k): results["metric_per_case"][i]['metrics'][k]
             for k in results["metric_per_case"][i]['metrics'].keys()}
    # sort_keys=True will make foreground_mean the first entry and thus easy to spot
    save_json(results_converted, output_file, sort_keys=True)


def load_summary_json(filename: str):
    results = load_json(filename)
    # convert keys in mean metrics
    results['mean'] = {key_to_label_or_region(k): results['mean'][k] for k in results['mean'].keys()}
    # convert metric_per_case
    for i in range(len(results["metric_per_case"])):
        results["metric_per_case"][i]['metrics'] = \
            {key_to_label_or_region(k): results["metric_per_case"][i]['metrics'][k]
             for k in results["metric_per_case"][i]['metrics'].keys()}
    return results


def labels_to_list_of_regions(labels: List[int]):
    return [(i,) for i in labels]


def region_or_label_to_mask(segmentation: np.ndarray, region_or_label: Union[int, Tuple[int, ...]]) -> np.ndarray:
    if np.isscalar(region_or_label):
        return segmentation == region_or_label
    else:
        mask = np.zeros_like(segmentation, dtype=bool)
        for r in region_or_label:
            mask[segmentation == r] = True
    return mask


def compute_tp_fp_fn_tn(mask_ref: np.ndarray, mask_pred: np.ndarray, ignore_mask: np.ndarray = None):
    if ignore_mask is None:
        use_mask = np.ones_like(mask_ref, dtype=bool)
    else:
        use_mask = ~ignore_mask
    tp = np.sum((mask_ref & mask_pred) & use_mask)
    fp = np.sum(((~mask_ref) & mask_pred) & use_mask)
    fn = np.sum((mask_ref & (~mask_pred)) & use_mask)
    tn = np.sum(((~mask_ref) & (~mask_pred)) & use_mask)
    return tp, fp, fn, tn


def compute_metrics(reference_file: str, prediction_file: str, image_reader_writer: BaseReaderWriter,
                    labels_or_regions: Union[List[int], List[Union[int, Tuple[int, ...]]]],
                    ignore_label: int = None) -> dict:
    # load images
    seg_ref, seg_ref_dict = image_reader_writer.read_seg(reference_file)
    seg_pred, seg_pred_dict = image_reader_writer.read_seg(prediction_file)

    ignore_mask = seg_ref == ignore_label if ignore_label is not None else None

    results = {}
    results['reference_file'] = reference_file
    results['prediction_file'] = prediction_file
    results['metrics'] = {}
    for r in labels_or_regions:
        results['metrics'][r] = {}
        mask_ref = region_or_label_to_mask(seg_ref, r)
        mask_pred = region_or_label_to_mask(seg_pred, r)
        tp, fp, fn, tn = compute_tp_fp_fn_tn(mask_ref, mask_pred, ignore_mask)
        if tp + fp + fn == 0:
            results['metrics'][r]['Dice'] = np.nan
            results['metrics'][r]['IoU'] = np.nan
        else:
            results['metrics'][r]['Dice'] = 2 * tp / (2 * tp + fp + fn)
            results['metrics'][r]['IoU'] = tp / (tp + fp + fn)
        results['metrics'][r]['FP'] = fp
        results['metrics'][r]['TP'] = tp
        results['metrics'][r]['FN'] = fn
        results['metrics'][r]['TN'] = tn
        results['metrics'][r]['n_pred'] = fp + tp
        results['metrics'][r]['n_ref'] = fn + tp
    return results


def compute_metrics_on_folder(folder_ref: str, folder_pred: str, output_file: str,
                              image_reader_writer: BaseReaderWriter,
                              file_ending: str,
                              regions_or_labels: Union[List[int], List[Union[int, Tuple[int, ...]]]],
                              ignore_label: int = None,
                              num_processes: int = 6,
                              chill: bool = True) -> dict:
    """
    output_file must end with .json; can be None
    """
    if output_file is not None:
        assert output_file.endswith('.json'), 'output_file should end with .json'
    print('empezando a leer archivos')
    files_pred = subfiles(folder_pred, suffix=file_ending, join=False)
    files_ref = subfiles(folder_ref, suffix=file_ending, join=False)
    print('archivos leidos')
    if not chill:
        present = [isfile(join(folder_pred, i)) for i in files_ref]
        assert all(present), "Not all files in folder_ref exist in folder_pred"
    print('formando paths')
    files_ref = [join(folder_ref, i) for i in files_ref]
    files_pred = [join(folder_pred, i) for i in files_pred]
    print('paths formandos')
    results = []
    contador = 0
    print('Calculando métricas')
    for ref, pred in zip(files_ref, files_pred):
      result = compute_metrics(ref, pred, image_reader_writer, regions_or_labels, ignore_label)
      results.append(result)
      contador += 1
      print(contador)

    # mean metric per class
    metric_list = list(results[0]['metrics'][regions_or_labels[0]].keys())
    means = {}
    contador = 0
    print('Calculando medias')
    for r in regions_or_labels:
        means[r] = {}
        for m in metric_list:
            means[r][m] = np.nanmean([i['metrics'][r][m] for i in results])

        contador += 1
        print(contador)

    # foreground mean
    foreground_mean = {}
    contador = 0
    print('Calculando foreground')
    for m in metric_list:
        values = []
        for k in means.keys():
            if k == 0 or k == '0':
                continue
            values.append(means[k][m])
        foreground_mean[m] = np.mean(values)
        contador += 1
        print(contador)

    [recursive_fix_for_json_export(i) for i in results]
    recursive_fix_for_json_export(means)
    recursive_fix_for_json_export(foreground_mean)
    result = {'metric_per_case': results, 'mean': means, 'foreground_mean': foreground_mean}
    if output_file is not None:
        save_summary_json(result, output_file)
    return result
    # print('DONE')

## Calculo

In [None]:
folder_ref = '/content/drive/MyDrive/TFG/Uformer/V2/dataV3/CholecSeg8k/test/y'
folder_pred = '/content/drive/MyDrive/TFG/Uformer/V2/logs/CholecSeg8k/Uformer_T_/results_no_color/png'
output_file = '/content/drive/MyDrive/TFG/Resultados/metrics/uformer_summary.json'
image_reader_writer = NaturalImage2DIO()
file_ending = '.png'
labels = [0,1,2,3,4,5,6,7,8,9,10,11,12]
ignore_label = None
num_processes = 6
compute_metrics_on_folder(folder_ref, folder_pred, output_file, image_reader_writer, file_ending, labels, ignore_label,
                          num_processes)

In [None]:
folder_ref = '/content/drive/MyDrive/TFG/Parte2/target_data/test/y'
folder_pred = '/content/drive/MyDrive/TFG/Parte2/FineTune/logs/CholecSeg8k/results_no_color/png'
output_file = '/content/drive/MyDrive/TFG/Resultados/metrics/reprec_summary.json'
image_reader_writer = NaturalImage2DIO()
file_ending = '.png'
labels = [0,1,2,3,4,5,6,7,8,9,10,11,12]
ignore_label = None
num_processes = 6
compute_metrics_on_folder(folder_ref, folder_pred, output_file, image_reader_writer, file_ending, labels, ignore_label,
                          num_processes)

# Matriz de confusion

In [None]:
from sklearn.metrics import confusion_matrix

def load_images_from_folder(folder):
    images = []
    for filename in sorted(os.listdir(folder)):
        img = io.imread(os.path.join(folder, filename))
        if img is not None:
            images.append(img)
    return images

def flatten_images(images):
    flattened_images = [img.flatten() for img in images]
    return np.concatenate(flattened_images)

# Carpetas con las imágenes de predicciones y etiquetas
predictions_folder = 'path/to/predictions'
labels_folder = 'path/to/labels'

# Cargar las imágenes de predicciones y etiquetas
predictions = load_images_from_folder(predictions_folder)
labels = load_images_from_folder(labels_folder)

# Asegurarse de que tienen el mismo número de imágenes
assert len(predictions) == len(labels), "El número de imágenes de predicciones y etiquetas debe ser igual."

# Aplanar las imágenes
predictions_flat = flatten_images(predictions)
labels_flat = flatten_images(labels)

# Calcular la matriz de confusión
conf_matrix = confusion_matrix(labels_flat, predictions_flat)

print("Matriz de confusión:")
print(conf_matrix)



# Segmentaciones generadas

In [None]:
def colorear_seg_maps(seg_maps):
  seg2 = np.zeros((512, 1024, 3))

  seg2[:, :, 0] = seg_maps
  seg2[:, :, 1] = seg_maps
  seg2[:, :, 2] = seg_maps

  pixeles_0 = np.all(seg2 == [0, 0, 0], axis=2)
  pixeles_1 = np.all(seg2 == [1, 1, 1], axis=2)
  pixeles_2 = np.all(seg2 == [2, 2, 2], axis=2)
  pixeles_3 = np.all(seg2 == [3, 3, 3], axis=2)
  pixeles_4 = np.all(seg2 == [4, 4, 4], axis=2)
  pixeles_5 = np.all(seg2 == [5, 5, 5], axis=2)
  pixeles_6 = np.all(seg2 == [6, 6, 6], axis=2)
  pixeles_7 = np.all(seg2 == [7, 7, 7], axis=2)
  pixeles_8 = np.all(seg2 == [8, 8, 8], axis=2)
  pixeles_9 = np.all(seg2 == [9, 9, 9], axis=2)
  pixeles_10 = np.all(seg2 == [10, 10, 10], axis=2)
  pixeles_11 = np.all(seg2 == [11, 11, 11], axis=2)
  pixeles_12 = np.all(seg2 == [12, 12, 12], axis=2)


  seg2[pixeles_0] = [127, 127, 127]
  seg2[pixeles_1] = [210, 140, 140]
  seg2[pixeles_2] = [255, 114, 114]
  seg2[pixeles_3] = [231, 70, 156]
  seg2[pixeles_4] = [186, 183, 75]
  seg2[pixeles_5] = [170, 255, 0]
  seg2[pixeles_6] = [255, 85, 0]
  seg2[pixeles_7] = [255, 0, 0]
  seg2[pixeles_8] = [255, 255, 0]
  seg2[pixeles_9] = [169, 255, 184]
  seg2[pixeles_10] = [255, 160, 165]
  seg2[pixeles_11] = [0, 50, 128]
  seg2[pixeles_12] = [111, 74, 0]

  return seg2

In [None]:
num = 7720

directorios = {'input': '/content/drive/MyDrive/TFG/V2/nnUNet_raw/Dataset212_CholecSeg8kV2/imagesTs',
               'groundtruth': '/content/drive/MyDrive/TFG/V2/evaluationV2/labelsTs_colored',
               'nnunet': '/content/drive/MyDrive/TFG/V2/evaluationV2/predictionPP_Colored',
               'uformer': '/content/drive/MyDrive/TFG/Uformer/V2/logs/CholecSeg8k/Uformer_T_/results/png',
               'reprec': '/content/drive/MyDrive/TFG/Parte2/FineTune/logs/CholecSeg8k/results/png',
               }

nombres_imagenes = {'input': f'CS_{num}_0000.png',
                    'groundtruth': f'CS_{num}.png',
                    'nnunet': f'CS_{num}.png'
                    'uformer': f'test_prediction-{num}.png',
                    'reprec': f'test_prediction-{num}.png',
                    }


fig, axs = plt.subplots(1, 5, figsize=(15, 5))

for i, (tipo, directorio) in enumerate(directorios.items()):
    imagen_path = os.path.join(directorio, nombres_imagenes[tipo])
    imagen = io.imread(imagen_path)
    if tipo == 'groundtruth':
      print(imagen.shape)
      imagen = colorear_seg_maps(imagen)
      imagen = imagen.astype(np.uint8)
      imagen = cv2.resize(imagen, (854, 480), interpolation = cv2.INTER_NEAREST)
      imagen = imagen.astype(np.uint8)
    axs[i].imshow(imagen)
    axs[i].set_title(tipo.capitalize())
    axs[i].axis('off')

plt.tight_layout()
plt.show()

In [None]:
import random
import os
from pathlib import Path
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from skimage import io


directorios = {
    'input': '/content/drive/MyDrive/TFG/V2/nnUNet_raw/Dataset212_CholecSeg8kV2/imagesTs',
    'groundtruth': '/content/drive/MyDrive/TFG/V2/evaluationV2/labelsTs_colored',
    'nnunet': '/content/drive/MyDrive/TFG/V2/evaluationV2/predictionPP_Colored',
    'uformer': '/content/drive/MyDrive/TFG/Uformer/V2/logs/CholecSeg8k/Uformer_T_/results/png',
    'reprec': '/content/drive/MyDrive/TFG/Parte2/FineTune/logs/CholecSeg8k/results/png',
}


n_examples = 10


input_dir = Path(directorios['input'])
all_files = list(input_dir.glob('CS_*.png'))
all_nums = [int(f.stem.split('_')[1]) for f in all_files]


random.seed(75)
selected_nums = random.sample(all_nums, n_examples)



fig, axs = plt.subplots(n_examples, 5, figsize=((854/225) * 5, (480/225)* n_examples))

titles = ['Input', 'Groundtruth', 'nnUnet', 'Uformer', 'RepRec']

for col, title in enumerate(titles):
    axs[0, col].set_title(title, fontsize=20)
    axs[0, col].axis('off')

for row, num in enumerate(selected_nums):
    nombres_imagenes = {
        'input': f'CS_{num:04d}_0000.png',
        'groundtruth': f'CS_{num:04d}.png',
        'nnunet': f'CS_{num:04d}.png',
        'uformer': f'test_prediction-{num:04d}.png',
        'reprec': f'test_prediction-{num:04d}.png',
    }

    for col, (tipo, directorio) in enumerate(directorios.items()):
        imagen_path = os.path.join(directorio, nombres_imagenes[tipo])
        imagen = io.imread(imagen_path)
        axs[row, col].imshow(imagen, aspect='auto')
        axs[row, col].axis('off')

# Guardar la figura como archivo PNG
output_path = '//content/drive/MyDrive/TFG/Resultados/segmentation_examples/10_examples_6.png'
#plt.tight_layout(pad=0.5, h_pad=-55)
plt.tight_layout()
plt.savefig(output_path, format='png', bbox_inches='tight', pad_inches=0)

# Mostrar la figura
plt.show()