# Data exploration notebook

A quick overview of the land-cover dataset: counts, label distribution, channel stats, and a few visual examples.

In [None]:
from pathlib import Path
import random
import numpy as np
import matplotlib.pyplot as plt
import tifffile

repo_root = Path(__file__).resolve().parents[1]
ds_root = repo_root / "dataset"
images_dir = ds_root / "train" / "images"
masks_dir = ds_root / "train" / "masks"

label_map = {lab: lab - 1 for lab in range(1, 10)}  # raw -> 0..8
num_classes = len(label_map)

## Basic info

In [None]:
image_files = sorted(images_dir.glob("*.tif"))
mask_files = sorted(masks_dir.glob("*.tif"))
print("#images", len(image_files))
print("#masks ", len(mask_files))
print("First image path:", image_files[0] if image_files else "None")

## Label distribution (train masks)
Scans a subset (or all if small) to estimate class frequencies.

In [None]:
max_masks = min(500, len(mask_files))  # adjust if you want more/less
chosen_masks = mask_files[:max_masks]

freq = np.zeros(num_classes, dtype=np.int64)
for p in chosen_masks:
    m = tifffile.imread(p)
    for raw, dst in label_map.items():
        freq[dst] += (m == raw).sum()

total_pixels = freq.sum()
print("Scanned masks:", len(chosen_masks), "| total pixels:", total_pixels)
for cls_id, count in enumerate(freq):
    pct = 100 * count / max(total_pixels, 1)
    print(f"class {cls_id}: {count} pixels ({pct:.2f}%)")

plt.figure(figsize=(6,3))
plt.bar(range(num_classes), freq)
plt.xlabel("class id (0-based)")
plt.ylabel("pixel count")
plt.title("Class pixel distribution (subset)")
plt.tight_layout()

## Per-channel statistics (mean/std)
Computed on a small random subset of images.

In [None]:
sample_n = min(100, len(image_files))
sample_files = random.sample(image_files, sample_n) if sample_n else []

sum_c = None
sum_sq_c = None
count_pix = 0

for p in sample_files:
    img = tifffile.imread(p).astype(np.float64)  # H, W, C
    if sum_c is None:
        sum_c = np.zeros(img.shape[2], dtype=np.float64)
        sum_sq_c = np.zeros(img.shape[2], dtype=np.float64)
    sum_c += img.reshape(-1, img.shape[2]).sum(axis=0)
    sum_sq_c += (img.reshape(-1, img.shape[2]) ** 2).sum(axis=0)
    count_pix += img.shape[0] * img.shape[1]

mean = sum_c / max(count_pix, 1)
std = np.sqrt(sum_sq_c / max(count_pix, 1) - mean**2)

print("Samples used:", sample_n)
for c, (m, s) in enumerate(zip(mean, std)):
    print(f"channel {c}: mean={m:.4f}, std={s:.4f}")

## Qualitative samples
Show a few image/mask pairs.

In [None]:
def plot_sample(idx=0):
    if not image_files:
        print("No images found")
        return
    img_path = image_files[idx % len(image_files)]
    mask_path = masks_dir / img_path.name
    img = tifffile.imread(img_path)
    mask = tifffile.imread(mask_path)
    mapped = np.full(mask.shape, fill_value=-1, dtype=np.int64)
    for raw, dst in label_map.items():
        mapped[mask == raw] = dst

    # Display first 3 channels (if more than 3, drop the rest)
    img_disp = img[..., :3]
    img_disp = (img_disp - img_disp.min()) / (img_disp.ptp() + 1e-6)

    fig, axes = plt.subplots(1, 2, figsize=(10, 5))
    axes[0].imshow(img_disp)
    axes[0].set_title(f"Image: {img_path.name}")
    axes[0].axis("off")
    im = axes[1].imshow(mapped, cmap="tab10", vmin=0, vmax=num_classes-1)
    axes[1].set_title("Mask (mapped 0..8)")
    axes[1].axis("off")
    fig.colorbar(im, ax=axes[1], fraction=0.046, pad=0.04)
    plt.tight_layout()

plot_sample(0)