In [None]:
import sys
sys.path.append('../src')

In [None]:
import sys
sys.path.append('../src')

import numpy as np
import matplotlib.pyplot as plt
import cv2
from pathlib import Path
import seaborn as sns

from utils.config import *
from data.preprocessing import create_segmentation_mask

%matplotlib inline
sns.set_style('whitegrid')

"""
## 1. Load and Visualize Images
"""

In [None]:
# Get all images
image_files = list(RAW_DATA_DIR.glob('*.jpg')) + list(RAW_DATA_DIR.glob('*.png'))
print(f"Found {len(image_files)} images")

# Display sample images
fig, axes = plt.subplots(2, 4, figsize=(16, 8))
axes = axes.flatten()

for i, img_path in enumerate(image_files[:8]):
    img = cv2.imread(str(img_path))
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    axes[i].imshow(img)
    axes[i].set_title(img_path.name)
    axes[i].axis('off')

plt.tight_layout()
plt.show()

"""
## 2. Analyze Image Statistics
"""

In [None]:
# Calculate statistics
heights, widths, channels = [], [], []

for img_path in image_files:
    img = cv2.imread(str(img_path))
    h, w, c = img.shape
    heights.append(h)
    widths.append(w)
    channels.append(c)

print(f"Image dimensions:")
print(f"  Height: min={min(heights)}, max={max(heights)}, mean={np.mean(heights):.1f}")
print(f"  Width: min={min(widths)}, max={max(widths)}, mean={np.mean(widths):.1f}")
print(f"  Channels: {set(channels)}")

"""
## 3. Visualize Segmentation Masks
"""

In [None]:
# Create and visualize masks
test_image_path = image_files[0]
image = cv2.imread(str(test_image_path))
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image_resized = cv2.resize(image, (IMG_WIDTH, IMG_HEIGHT))

# Create masks using different methods
mask_kmeans = create_segmentation_mask(image_resized, method='kmeans', n_clusters=5)
mask_watershed = create_segmentation_mask(image_resized, method='watershed', n_clusters=5)

# Visualize
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

axes[0].imshow(image_resized)
axes[0].set_title('Original Image', fontsize=14, fontweight='bold')
axes[0].axis('off')

axes[1].imshow(mask_kmeans, cmap='tab10')
axes[1].set_title('K-Means Segmentation', fontsize=14, fontweight='bold')
axes[1].axis('off')

axes[2].imshow(mask_watershed, cmap='tab10')
axes[2].set_title('Watershed Segmentation', fontsize=14, fontweight='bold')
axes[2].axis('off')

plt.tight_layout()
plt.show()

"""
## 4. Color Distribution Analysis
"""

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

colors = ['red', 'green', 'blue']
for i, (ax, color) in enumerate(zip(axes, colors)):
    ax.hist(image_resized[:,:,i].ravel(), bins=50, color=color, alpha=0.7)
    ax.set_title(f'{color.capitalize()} Channel Distribution', fontsize=12, fontweight='bold')
    ax.set_xlabel('Pixel Value')
    ax.set_ylabel('Frequency')
    ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

"""
## 5. Segment Distribution
"""

In [None]:
# Analyze segment distribution
unique, counts = np.unique(mask_kmeans, return_counts=True)

plt.figure(figsize=(10, 6))
plt.bar(unique, counts, color='skyblue', edgecolor='navy', linewidth=2)
plt.xlabel('Segment ID', fontsize=12)
plt.ylabel('Number of Pixels', fontsize=12)
plt.title('Pixel Distribution Across Segments', fontsize=14, fontweight='bold')
plt.grid(axis='y', alpha=0.3)
plt.show()

# Print statistics
print("\nSegment Statistics:")
for seg_id, count in zip(unique, counts):
    percentage = (count / mask_kmeans.size) * 100
    print(f"  Segment {seg_id}: {count} pixels ({percentage:.2f}%)")