In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

# [SETTING WARNINGS]
import warnings
warnings.simplefilter(action='ignore', category=Warning)

# import modules and components
from FEXT.commons.utils.dataloader.serializer import DataSerializer, get_images_path
from FEXT.commons.utils.process.splitting import DataSplit
from FEXT.commons.constants import CONFIG, IMG_DATA_PATH, RESULTS_PATH

In [None]:
serializer = DataSerializer(CONFIG) 
images_paths = get_images_path(IMG_DATA_PATH, CONFIG, sample_size=None)
images = [np.asarray(serializer.load_image(pt), dtype=np.float16) for pt in tqdm(images_paths)]

# Evaluate image dataset

---

### Pixel intensity

The pixel intensity distribution of the entire image dataset is evaluated, plotting the histogram with the mean value of pixel intensity for all images

In [None]:
pixel_intensities = np.concatenate([image.flatten() for image in tqdm(images)], dtype=np.float16)
plt.figure(figsize=(16, 14)) 
plt.hist(pixel_intensities, bins='auto', alpha=0.7, color='blue', label='Dataset')
plt.title('Pixel Intensity Histogram', fontsize=16)
plt.xlabel('Pixel Intensity', fontsize=12)
plt.ylabel('Frequency', fontsize=12)
plt.legend()
plt.tight_layout()        

## Compare train and validation datasets

---

In [None]:
splitter = DataSplit(images_paths, CONFIG)     
train_data, validation_data = splitter.split_train_and_validation()
print(f'Number of train samples: {len(train_data)}')
print(f'Number of validation samples: {len(validation_data)}')

train_images = [np.asarray(serializer.load_image(pt), dtype=np.float16) for pt in tqdm(train_data)] if train_data else []
validation_images = [np.asarray(serializer.load_image(pt), dtype=np.float16) for pt in tqdm(validation_data)] if validation_data else []        
datasets = {'train': train_images, 'validation': validation_images}

### Pixel intensity

The pixel intensity distribution is now used to compare train and validation datasets, plotting the overlapping histograms with the mean value of pixel intensity for all images

In [None]:
plt.figure(figsize=(16, 14)) 
for name, image_set in datasets.items():    
    pixel_intensities = np.concatenate([image.flatten() for image in tqdm(image_set)], dtype=np.float16)
    plt.hist(pixel_intensities, bins='auto', alpha=0.5, label=name)
plt.title('Pixel Intensity Histogram', fontsize=16)
plt.xlabel('Pixel Intensity', fontsize=12)
plt.ylabel('Frequency', fontsize=12)
plt.legend()
plt.tight_layout()     