# 📸 EDA on Train / Test / Pretrain Images

<img src='../assets/stranger-sections-2.png'>

## 📚 Libraries

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import random
import matplotlib.pyplot as plt
from PIL import Image, ImageEnhance
import plotly.graph_objs as go
from plotly.subplots import make_subplots
import plotly.express as px
import numpy as np
import cv2
from glob import glob
from tqdm.notebook import tqdm
import torch

## 🔬 Analysis

In [None]:
# Chemin vers le dossier contenant les classes
folder_path = "../data/raw/**/*.JPG"

shapes = []
for image_path in tqdm(glob(folder_path, recursive=True)):
    image = Image.open(image_path).convert("L")
    # Convertir l'image en tableau numpy
    shapes.append(np.array(image).shape)
    
    # Calculer l'histogramme des valeurs de pixels de l'image et l'ajouter à l'histogramme de la classe
values, counts = np.unique(shapes, return_counts=True, axis=0)
values = [str(shape) for shape in values]
px.bar(x=values, y=counts)

## 🎞️ Data visualisation

In [None]:
# Chemin du dossier contenant les images JPG et les fichiers npy
image_folder_path = "../data/raw/train/image"
label_folder_path = "../data/raw/train/label"
labels_name = ['Inertinite', 'Vitrinite', 'Liptinite']
# Liste des fichiers JPG dans le dossier
jpg_files = [file for file in os.listdir(image_folder_path) if file.endswith(".JPG")]

# for i, jpg_file in enumerate(jpg_files[:2]):
jpg_file = random.choice(jpg_files)
# images avec plusieurs labels
# jpg_file = 'grqhu2.JPG'
# jpg_file = 'tya5k0.JPG'
# jpg_file = 'tpb83i.JPG'
# jpg_file = 'hsa12q.JPG'
# Chargement de l'image JPG
jpg_image = np.asarray(Image.open(os.path.join(image_folder_path, jpg_file)))

# Chargement du fichier npy correspondant
npy_file = jpg_file.replace(".JPG", "_gt.npy")
npy_data = np.load(os.path.join(label_folder_path, npy_file))
label_idxs = np.unique(npy_data)[1:] - 1
# print(jpg_file)
fig, ax = plt.subplots(1, 3, figsize=(15, 5))
ax[0].imshow(jpg_image)
ax[0].set_title(f'{" and ".join(labels_name[label_idx] for label_idx in label_idxs)} Image')
ax[0].axis('off')
ax[1].imshow(npy_data)
ax[1].set_title(f'{" and ".join(labels_name[label_idx] for label_idx in label_idxs)} Label')
ax[1].axis('off')
ax[2].imshow(jpg_image)
ax[2].set_title(f'{" and ".join(labels_name[label_idx] for label_idx in label_idxs)} Image with Label')
ax[2].axis('off')
ax[2].imshow(npy_data, alpha=0.5)

plt.show()


In [None]:
# Chemin du dossier contenant les images JPG et les fichiers npy
pretrain_folder_path = "../data/raw/pretrain"
# Liste des fichiers JPG dans le dossier
jpg_files = [file for file in os.listdir(pretrain_folder_path) if file.endswith(".jpg")]

k = 9
# for i, jpg_file in enumerate(jpg_files[:2]):
jpg_files = random.sample(jpg_files, k=k)

# Chargement de l'image JPG
jpg_images = [np.asarray(Image.open(os.path.join(pretrain_folder_path, jpg_file))) for jpg_file in jpg_files]

fig, axs = plt.subplots(3, 3, figsize=(15, 15))
for i in range(int(np.sqrt(k))):
    for j in range(int(np.sqrt(k))):
        axs[i][j].imshow(jpg_images[i*(int(np.sqrt(k))) + j])
        axs[i][j].set_title(f'{jpg_files[i*(int(np.sqrt(k))) + j]}')
        axs[i][j].axis('off')

plt.show()


## 🎨 Color Histogram

In [None]:
def get_image_path_by_label():
    # Chemins vers les dossiers d'images et de labels
    image_folder = "../data/raw/train/image"
    label_folder = "../data/raw/train/label"

    # Initialiser les listes pour stocker les chemins des images
    image_paths_1 = []
    image_paths_2 = []
    image_paths_3 = []

    # Parcourir les fichiers dans le dossier des images
    for image_file in tqdm(os.listdir(image_folder)):
        # Vérifier si le fichier est une image PNG
        if image_file.endswith(".JPG"):
            # Chemin complet vers l'image
            image_path = os.path.join(image_folder, image_file)
            
            # Chemin complet vers le fichier label correspondant
            label_file = image_file.replace(".JPG", "_gt.npy")
            label_path = os.path.join(label_folder, label_file)
            
            # Charger le tableau numpy du label
            label_array = np.load(label_path)
            
            # Vérifier si le label contient les chiffres 1, 2 ou 3
            unique = np.unique(label_array)
            if len(unique) > 2:
                print(image_file)
            
            if 1 in label_array:
                image_paths_1.append(image_path)
            if 2 in label_array:
                image_paths_2.append(image_path)
            if 3 in label_array:
                image_paths_3.append(image_path)
    
    return image_paths_1, image_paths_2, image_paths_3


def compute_pixel_distribution(list_image_path):
    # Initialiser un vecteur pour stocker la distribution des valeurs de pixel
    pixel_distribution = np.zeros((256, 256, 256), dtype=np.int32)  # Taille pour couvrir toutes les combinaisons de valeurs de pixels
    # Parcourir chaque image
    for image_path in tqdm(list_image_path):
        image = Image.open(image_path)
        image = np.array(image).reshape(-1, 3)
        # Compter le nombre d'occurrences de chaque tuple de valeurs de pixel
        unique, counts = np.unique(image, axis=0, return_counts=True)
        pixel_distribution[unique[:, 0], unique[:, 1], unique[:, 2]] += counts
        
    return pixel_distribution


inertinite_image_path, vitrinite_image_path, liptinite_image_path = get_image_path_by_label()

In [None]:
def reduce_rgb_cube_size(rgb_values, reduction=16):
    rgb_values = torch.from_numpy(rgb_values)
    # Créer une instance de Conv3d avec un noyau de taille 3x3x3
    conv3d = torch.nn.Conv3d(1, 1, kernel_size=reduction, bias=False, stride=reduction)

    # Initialiser les poids du noyau à 1
    with torch.no_grad():
        conv3d.weight.fill_(1)

    # Convertir les données en un format approprié pour la convolution 3D (ajouter les dimensions de canal et de lot)
    rgb_values = rgb_values.unsqueeze(0).unsqueeze(0).float()

    # Appliquer la convolution 3D
    output = conv3d(rgb_values)
    
    return output.squeeze().numpy(force=True)

def display_rgb_distribution(rgb_values: np.ndarray, class_name, reduction=16, nb_bar=10):
    rgb_values = reduce_rgb_cube_size(rgb_values, reduction)
    
    y = rgb_values.flatten()
    x_coords = np.arange(rgb_values.shape[0])
    y_coords = np.arange(rgb_values.shape[1])
    z_coords = np.arange(rgb_values.shape[2])
    
    def index_to_rgb(idx):
        return idx * reduction + reduction // 2
    
    rgb_colors = np.array([
        f'rgb({index_to_rgb(x)}, {index_to_rgb(y)}, {index_to_rgb(z)})' 
        for x in x_coords
        for y in y_coords
        for z in z_coords
    ])
    
    partitioned_indexes = np.argpartition(y, -nb_bar)[-nb_bar:]
    sorted_indexes = partitioned_indexes[np.argsort(-y[partitioned_indexes])]
    rgb_colors = rgb_colors[sorted_indexes]
    y = y[sorted_indexes]
    
    # Créer la figure Plotly
    fig = go.Figure(data=go.Bar(
        x=rgb_colors,
        y=y,
        marker=dict(
            color=rgb_colors
        )
    ))

    # Définir les axes et le titre
    fig.update_layout(
        scene=dict(
            xaxis=dict(title='X'),
            yaxis=dict(title='Y'),
            zaxis=dict(title='Z'),
        ),
        title=f'Color Histogram for {class_name}',
    )

    # Afficher la figure
    fig.show()

In [None]:
liptinite_distribution = compute_pixel_distribution(liptinite_image_path)
display_rgb_distribution(liptinite_distribution, 'liptinite', reduction=32, nb_bar=15)

In [None]:
inertinite_distribution = compute_pixel_distribution(inertinite_image_path)
display_rgb_distribution(inertinite_distribution, 'inertinite', reduction=32, nb_bar=15)

In [None]:
vitrinite_distribution = compute_pixel_distribution(vitrinite_image_path)
display_rgb_distribution(vitrinite_distribution, 'vitrinite', reduction=32, nb_bar=15)

In [None]:
display_rgb_distribution(liptinite_distribution + vitrinite_distribution + inertinite_distribution, 'Train data', reduction=32, nb_bar=15)

In [None]:
test_distribution = compute_pixel_distribution(glob('../data/raw/test/image/*.JPG'))
display_rgb_distribution(test_distribution, 'Test data', reduction=32, nb_bar=15)

In [None]:
pretrain_distribution = compute_pixel_distribution(glob('../data/raw/pretrain/*.jpg'))
display_rgb_distribution(pretrain_distribution, 'Pretrain data', reduction=32, nb_bar=15)

## 🎯 Heatmap

In [None]:
def get_heatmap_by_class():
    # Chemins vers les dossiers d'images et de labels
    label_folder = "../data/raw/train/label"

    # Initialiser les listes pour stocker les chemins des images
    heatmap_1 = np.zeros((1024, 1360), np.int32)
    heatmap_2 = np.zeros((1024, 1360), np.int32)
    heatmap_3 = np.zeros((1024, 1360), np.int32)

    # Parcourir les fichiers dans le dossier des images
    for label_file in tqdm(os.listdir(label_folder)):
        # Vérifier si le fichier est une image PNG
        if label_file.endswith(".npy"):
            # Chemin complet vers l'image
            label_path = os.path.join(label_folder, label_file)
            
            # Charger le tableau numpy du label
            label_array = np.load(label_path)
            
            heatmap_1 += np.where(label_array == 1, label_array, 0)
            heatmap_2 += np.where(label_array == 2, label_array, 0)
            heatmap_3 += np.where(label_array == 3, label_array, 0)
    
    return heatmap_1, heatmap_2, heatmap_3

def plot_heatmap(data, class_name):
        # Calculer le rapport hauteur/largeur
    aspect_ratio = data.shape[0] / data.shape[1]

    # Définir la taille de la figure en fonction du rapport
    fig_height = 600  # Choisissez une hauteur arbitraire
    fig_width = int(fig_height / aspect_ratio)
    fig = go.Figure(data=go.Heatmap(z=data, hoverinfo='none'))
    fig = fig.update_layout(
        title=f'{class_name} Heatmap',
        width=fig_width,
        height=fig_height,
    )
    
    fig.show()


inertinite_heatmap, vitrinite_heatmap, liptinite_heatmap = get_heatmap_by_class()

In [None]:
plot_heatmap(inertinite_heatmap, 'Inertinite')

In [None]:
plot_heatmap(vitrinite_heatmap, 'Vitrinite')

In [None]:
plot_heatmap(liptinite_heatmap, 'Liptinite')