# **Import Libraries**

In [None]:
import json
import os
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import cv2
from tqdm import tqdm

sns.set_theme(style="whitegrid")

# **Configuration**

In [None]:

# Dataset in the same directory
current_dir = os.getcwd()
DATASET_ROOT = os.path.join(current_dir, "ZJU-Leaper")

# Define names and paths
GROUP_JSON_FILE = "group1.json"
SOURCE_IMAGES_DIR = os.path.join(DATASET_ROOT, "images")
SOURCE_LABELS_DIR = os.path.join(DATASET_ROOT, "Label")

# JSON file path
json_path = os.path.join(DATASET_ROOT, 'ImageSets', 'Groups', GROUP_JSON_FILE)

print("Configuration set.")

# **Load Data & Get Statistics**

In [None]:
# Load the JSON file
with open(json_path, 'r') as f:
    data = json.load(f)

# Get the file
train_normal_stems = data['normal']['train']
train_defect_stems = data['defect']['train']
val_normal_stems = data['normal']['test']
val_defect_stems = data['defect']['test']

# Get the count
counts = {
    'train_normal': len(train_normal_stems),
    'train_defect': len(train_defect_stems),
    'val_normal': len(val_normal_stems),
    'val_defect': len(val_defect_stems),
}

counts['total_train'] = counts['train_normal'] + counts['train_defect']
counts['total_val'] = counts['val_normal'] + counts['val_defect']
counts['total_normal'] = counts['train_normal'] + counts['val_normal']
counts['total_defect'] = counts['train_defect'] + counts['val_defect']
counts['total_images'] = counts['total_train'] + counts['total_val']

# Print summary Statistics Table
print("--- Dataset Summary Statistics (Group 1) ---")
print(f"Total Images:   {counts['total_images']}")
print(f"  Training Set:   {counts['total_train']} images")
print(f"  Validation Set: {counts['total_val']} images")
print("\n--- Class Distribution ---")
print(f"Total Normal:   {counts['total_normal']} images ({counts['total_normal']/counts['total_images']:.1%})")
print(f"Total Defect:   {counts['total_defect']} images ({counts['total_defect']/counts['total_images']:.1%})")

# Create a DataFrame for plotting
df_class = pd.DataFrame({
    'Class': ['Normal', 'Defect'],
    'Count': [counts['total_normal'], counts['total_defect']]
})

df_split = pd.DataFrame([
    {'Set': 'Train', 'Class': 'Normal', 'Count': counts['train_normal']},
    {'Set': 'Train', 'Class': 'Defect', 'Count': counts['train_defect']},
    {'Set': 'Validation', 'Class': 'Normal', 'Count': counts['val_normal']},
    {'Set': 'Validation', 'Class': 'Defect', 'Count': counts['val_defect']},
])

# **Plot 1 - Class Distribution (Histogram)**

In [None]:
plt.figure(figsize=(8, 6))
ax = sns.barplot(x='Class', y='Count', data=df_class)
ax.set_title('Overall Class Distribution (Group 1)', fontsize=16)
ax.set_xlabel('Sample Type', fontsize=12)
ax.set_ylabel('Number of Images', fontsize=12)

for p in ax.patches:
    ax.annotate(f'{int(p.get_height())}',
                (p.get_x() + p.get_width() / 2., p.get_height()),
                ha='center', va='center',
                xytext=(0, 9),
                textcoords='offset points')
plt.show()

# **Plot 2 - Train/Validation Split**

In [None]:
plt.figure(figsize=(10, 7))
ax = sns.barplot(x='Set', y='Count', hue='Class', data=df_split, palette={'Normal': 'g', 'Defect': 'r'})
ax.set_title('Train/Validation Split by Class', fontsize=16)
ax.set_xlabel('Dataset Split', fontsize=12)
ax.set_ylabel('Number of Images', fontsize=12)
plt.legend(title='Class')
plt.show()

# **Plot 3 - Visual EDA (Sample Images)**

In [None]:
import random

# Get random defect and normal images
sample_defect_stems = random.sample(train_defect_stems, 2)
sample_normal_stems = random.sample(train_normal_stems, 2)

fig, axes = plt.subplots(2, 2, figsize=(12, 12))
fig.suptitle('Exploratory Data Analysis: Image Samples', fontsize=20)

# Plot Defect Samples
for i, stem in enumerate(sample_defect_stems):
    img_path = os.path.join(SOURCE_IMAGES_DIR, f"{stem}.jpg")
    lbl_path = os.path.join(SOURCE_LABELS_DIR, f"{stem}.txt")

    img = cv2.imread(img_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    h, w, _ = img.shape

    # Read the label and draw the box
    with open(lbl_path, 'r') as f:
        line = f.readline().strip().split()
        class_id, x_c, y_c, w_n, h_n = [float(p) for p in line]

        xmin = int((x_c - w_n/2) * w)
        xmax = int((x_c + w_n/2) * w)
        ymin = int((y_c - h_n/2) * h)
        ymax = int((y_c + h_n/2) * h)

        cv2.rectangle(img, (xmin, ymin), (xmax, ymax), (0, 255, 0), 2)

    axes[0, i].imshow(img)
    axes[0, i].set_title(f'Defect Sample: {stem}.jpg', color='red')
    axes[0, i].axis('off')

# Plot Normal Samples
for i, stem in enumerate(sample_normal_stems):
    img_path = os.path.join(SOURCE_IMAGES_DIR, f"{stem}.jpg")
    img = cv2.imread(img_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

    axes[1, i].imshow(img)
    axes[1, i].set_title(f'Normal Sample: {stem}.jpg', color='green')
    axes[1, i].axis('off')

plt.tight_layout(rect=[0, 0.03, 1, 0.95])
plt.show()