In [None]:
import os
import cv2
import numpy as np
import imageio
from collections import defaultdict
import matplotlib.pyplot as plt


def read_pfm(file):
    with open(file, "rb") as f:
        header = f.readline().decode('utf-8').rstrip()
        color = header == 'PF'
        dims = f.readline().decode('utf-8').rstrip()
        width, height = map(int, dims.split())
        scale = float(f.readline().decode('utf-8').rstrip())
        endian = '<' if scale < 0 else '>'
        data = np.fromfile(f, endian + 'f')
        shape = (height, width, 3) if color else (height, width)
        return np.reshape(data, shape)


def get_image_metadata(path):
    ext = os.path.splitext(path)[-1].lower()
    try:
        if ext in [".png", ".jpg", ".jpeg", ".tif", ".tiff"]:
            img = imageio.imread(path)
        elif ext == ".pfm":
            img = read_pfm(path)
        else:
            return {"file": path, "status": "unsupported format"}

        shape = img.shape
        channels = 1 if len(shape) == 2 else shape[2]
        bit_depth = img.dtype

        return {
            "file": path,
            "status": "ok",
            "format": ext,
            "size": (shape[1], shape[0]),  # width, height
            "channels": channels,
            "bit_depth": str(bit_depth)
        }
    except Exception as e:
        return {"file": path, "status": f"error: {str(e)}"}


def spectral_profile(img):
    channels = 1 if len(img.shape) == 2 else img.shape[2]
    profiles = {}
    if channels == 1:
        profiles["LWIR"] = {"min": float(np.min(img)), "max": float(np.max(img))}
    else:
        for c in range(channels):
            profiles[f"channel_{c}"] = {
                "min": float(np.min(img[:,:,c])),
                "max": float(np.max(img[:,:,c]))
            }
    return profiles


def check_data_quality(img):
    report = {}
    report["dead_pixels"] = int(np.sum(img == 0))
    report["nan_pixels"] = int(np.sum(np.isnan(img)))
    report["mean"] = float(np.mean(img))
    report["std"] = float(np.std(img))
    return report


def dataset_structure(root_dir):
    structure = {}
    for split in ["train", "val", "test"]:
        split_path = os.path.join(root_dir, split)
        if os.path.exists(split_path):
            structure[split] = len([
                f for f in os.listdir(split_path)
                if os.path.isfile(os.path.join(split_path, f))
            ])
    if not structure:
        structure["unsplit"] = len([
            f for f in os.listdir(root_dir)
            if os.path.isfile(os.path.join(root_dir, f))
        ])
    return structure


def dataset_statistics(imgs):
    stats = defaultdict(list)
    for img in imgs:
        stats["mean"].append(np.mean(img))
        stats["std"].append(np.std(img))
    return {
        "mean_global": float(np.mean(stats["mean"])),
        "std_global": float(np.mean(stats["std"]))
    }


def detect_modality(metadata_list):
    channels = [m["channels"] for m in metadata_list if m["status"] == "ok"]
    if all(c == 1 for c in channels):
        return "LWIR only"
    elif any(c == 3 for c in channels) and any(c == 1 for c in channels):
        return "RGB + LWIR"
    elif any(c > 3 for c in channels):
        return "Multispectral"
    else:
        return "Unknown"


def profile_dataset(root_dir, sample_limit=20):
    report = {}
    files = [
        os.path.join(root_dir, f) for f in os.listdir(root_dir)
        if os.path.isfile(os.path.join(root_dir, f))
    ]
    files = files[:sample_limit]  

    metadata_list = [get_image_metadata(f) for f in files]
    report["metadata"] = metadata_list

    imgs = []
    for m in metadata_list:
        if m["status"] == "ok":
            ext = m["format"]
            if ext in [".png", ".jpg", ".jpeg", ".tif", ".tiff"]:
                img = imageio.imread(m["file"])
            elif ext == ".pfm":
                img = read_pfm(m["file"])
            else:
                continue
            imgs.append(img)

    if imgs:
        report["spectral"] = [spectral_profile(img) for img in imgs]
        report["quality"] = [check_data_quality(img) for img in imgs]
        report["stats"] = dataset_statistics(imgs)

    report["structure"] = dataset_structure(root_dir)
    report["modality"] = detect_modality(metadata_list)

    return report


In [None]:
def print_report(report):
    print("\n================ DATASET PROFILE REPORT ================\n")
    
    print("📂 Dataset Structure:")
    for k, v in report["structure"].items():
        print(f"  - {k}: {v} files")

    print(f"\n🎨 Detected Modality: {report['modality']}")

    print("\n📑 Sample Metadata (first 5 files):")
    for m in report["metadata"][:5]:
        print(f"  - File: {os.path.basename(m['file'])}")
        print(f"    Format: {m.get('format','?')}, Size: {m.get('size','?')}, "
              f"Channels: {m.get('channels','?')}, Bit Depth: {m.get('bit_depth','?')}, Status: {m['status']}")

    if "stats" in report:
        print("\n📊 Dataset Statistics:")
        print(f"  - Global Mean: {report['stats']['mean_global']:.3f}")
        print(f"  - Global Std : {report['stats']['std_global']:.3f}")

    if "spectral" in report:
        print("\n🌈 Spectral Profile (sample images):")
        for i, sp in enumerate(report["spectral"][:3]):  # فقط ۳ نمونه نشون بدیم
            print(f"  Image {i+1}:")
            for ch, vals in sp.items():
                print(f"    {ch}: min={vals['min']:.2f}, max={vals['max']:.2f}")

    if "quality" in report:
        print("\n🔎 Data Quality (sample images):")
        for i, q in enumerate(report["quality"][:3]):  # فقط ۳ نمونه
            print(f"  Image {i+1}: dead_pixels={q['dead_pixels']}, nan_pixels={q['nan_pixels']}, "
                  f"mean={q['mean']:.2f}, std={q['std']:.2f}")

    print("\n========================================================\n")


In [None]:
report = profile_dataset("", sample_limit=1000)
print_report(report)