In [45]:
import os
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np


In [46]:
dataset = "Tea Leaf Disease Dataset"

In [47]:
from pathlib import Path
import os, csv
from collections import Counter
from sklearn.model_selection import train_test_split

IMAGE_EXTS = (".jpg", ".png")  # what you kept after cleaning

def list_images_with_labels(root_dir):
    paths, labels = [], []
    for cls in sorted(os.listdir(root_dir)):
        cdir = os.path.join(root_dir, cls)
        if not os.path.isdir(cdir) or cls.startswith("_"):
            continue
        for f in sorted(os.listdir(cdir)):
            if f.lower().endswith(IMAGE_EXTS):
                # store relative path (to dataset root)
                rel = os.path.join(cls, f)
                paths.append(rel)
                labels.append(cls)
    return paths, labels

def write_index_csv(rows, out_csv):
    os.makedirs(os.path.dirname(out_csv) or ".", exist_ok=True)
    with open(out_csv, "w", newline="", encoding="utf-8") as f:
        w = csv.writer(f); w.writerow(["path","label"]); w.writerows(rows)

def stratified_split_index(data_dir, out_dir, test_size=0.20, val_size=0.10, random_state=42):
    X, y = list_images_with_labels(data_dir)
    if not X:
        raise FileNotFoundError(f"No images found under {data_dir} with {IMAGE_EXTS}")
    counts = Counter(y)
    if any(v < 2 for v in counts.values()):
        raise ValueError(f"Some classes too small: {counts}")

    X_tr, X_tmp, y_tr, y_tmp = train_test_split(
        X, y, test_size=test_size+val_size, stratify=y, random_state=random_state
    )
    rel_val = val_size / (test_size+val_size)
    X_val, X_te, y_val, y_te = train_test_split(
        X_tmp, y_tmp, test_size=1-rel_val, stratify=y_tmp, random_state=random_state
    )

    write_index_csv(list(zip(X_tr,y_tr)), os.path.join(out_dir,"train_index.csv"))
    write_index_csv(list(zip(X_val,y_val)), os.path.join(out_dir,"val_index.csv"))
    write_index_csv(list(zip(X_te,y_te)), os.path.join(out_dir,"test_index.csv"))

    print(f"Train: {len(X_tr)}, Val: {len(X_val)}, Test: {len(X_te)}")

# --- Resolve dataset path from Notebooks/ ---
cwd = Path.cwd()
candidates = [
    cwd / "Tea Leaf Disease Dataset",
    cwd.parent / "Tea Leaf Disease Dataset",                # <â€” your case
    cwd / "Data" / "Raw" / "Tea Leaf Disease Dataset",
    cwd.parent / "Data" / "Raw" / "Tea Leaf Disease Dataset",
]
dataset_dir = next((p for p in candidates if p.is_dir()), None)
if dataset_dir is None:
    raise FileNotFoundError("Couldn't find dataset. Tried:\n" + "\n".join(str(p) for p in candidates))
print("Using dataset at:", dataset_dir)

# Run
stratified_split_index(str(dataset_dir), "splits", test_size=0.20, val_size=0.10)

Using dataset at: /Users/tharukakumarasiri/Desktop/AI_ML_Project/Notebooks/Tea Leaf Disease Dataset
Train: 22694, Val: 3242, Test: 6485


In [49]:
# === TEA DATASET QUALITY CONTROL & CLEAN (final) ===
# Works from project root OR from Notebooks/. Saves a CSV log.

import os, csv, hashlib
from pathlib import Path
from collections import defaultdict, Counter
from PIL import Image, UnidentifiedImageError, ImageFile

# ---------------- CONFIG ----------------
# Where is the dataset? These two cover both typical run locations:
CANDIDATES = [
    Path("Tea Leaf Disease Dataset"),
    Path("../Tea Leaf Disease Dataset"),
    Path("Data/Raw/Tea Leaf Disease Dataset"),
    Path("../Data/Raw/Tea Leaf Disease Dataset"),
]

# Formats you will actually use for training
ALLOWED_EXTS   = (".jpg", ".png")
# What we scan for (so we can catch disallowed ones)
SCAN_EXTS      = (".jpg", ".jpeg", ".png", ".bmp", ".webp", ".tif", ".tiff")
# Treat these as "corrupt/disallowed" on sight
DISALLOWED_EXTS = (".bmp", ".webp", ".tif", ".tiff", ".jpeg")

# Structural/image-quality constraints
MIN_SIDE   = 64           # remove if min(height,width) < MIN_SIDE
MAX_ASPECT = 3.0          # remove if max/min side ratio > MAX_ASPECT
REQUIRE_RGB = True        # remove non-RGB (e.g., RGBA, L, CMYK)

# Duplicates
CHECK_EXACT_DUPLICATES   = True   # MD5 exact files
CHECK_NEAR_DUP_SAME_PHASH = True  # remove if SAME perceptual hash bucket (aggressive)
PHASH_SIZE = 8                     # aHash grid size (8x8 -> 64 bits)

# Action
DRY_RUN  = True                    # True: report only; False: apply ACTION
ACTION   = "quarantine"            # "quarantine" or "delete"
# ----------------------------------------

# Optional: allow truncated files to load (usually keep False)
# ImageFile.LOAD_TRUNCATED_IMAGES = True

# --------- Path resolve ----------
def resolve_dataset_dir() -> Path:
    for p in CANDIDATES:
        if p.is_dir():
            return p.resolve()
    tried = "\n - ".join(str(p.resolve()) for p in CANDIDATES)
    raise FileNotFoundError(
        "Couldn't find the dataset folder.\n"
        f"CWD: {Path.cwd()}\nTried:\n - {tried}\n"
        "Tip: if running from Notebooks/, '../Tea Leaf Disease Dataset' is typical."
    )

DATASET_DIR = resolve_dataset_dir()
QUARANTINE_DIR = DATASET_DIR / "_quarantine_corrupt"
LOG_CSV = DATASET_DIR / "_clean_log.csv"
if ACTION == "quarantine":
    QUARANTINE_DIR.mkdir(parents=True, exist_ok=True)

# --------- Helpers ----------
def move_or_delete(src: Path):
    if DRY_RUN:
        return None
    if ACTION == "delete":
        src.unlink(missing_ok=True)
        return None
    # quarantine (preserve class subfolders)
    rel = src.relative_to(DATASET_DIR)
    dst = QUARANTINE_DIR / rel
    dst.parent.mkdir(parents=True, exist_ok=True)
    base, ext = os.path.splitext(dst)
    k, final = 1, dst
    while Path(final).exists():
        final = Path(f"{base}__dup{k}{ext}")
        k += 1
    src.replace(final)
    return str(final)

def file_md5(path: Path, chunk=8192) -> str:
    m = hashlib.md5()
    with path.open("rb") as f:
        for blk in iter(lambda: f.read(chunk), b""):
            m.update(blk)
    return m.hexdigest()

def phash_hex(path: Path, hash_size=PHASH_SIZE) -> str:
    # Simple average hash (aHash) as hex
    with Image.open(path) as im:
        im = im.convert("L").resize((hash_size, hash_size), Image.LANCZOS)
        pixels = list(im.getdata())
        avg = sum(pixels) / len(pixels)
        bits = ''.join('1' if p > avg else '0' for p in pixels)
        return f"{int(bits, 2):016x}"

# --------- Main scan ----------
issues = []                  # [class, path, reason, action]
before_counts = {}
after_counts  = {}
per_class_errors = defaultdict(int)

all_valid_files = []         # survivors of pass 1
class_of = {}                # map path -> class

for cls in sorted(os.listdir(DATASET_DIR)):
    cls_path = DATASET_DIR / cls
    if not cls_path.is_dir() or cls.startswith("_"):
        continue

    files = [f for f in os.listdir(cls_path) if f.lower().endswith(SCAN_EXTS)]
    before_counts[cls] = len(files)
    valid = 0

    for fname in files:
        fpath = cls_path / fname
        ext = fpath.suffix.lower()

        # 0) disallowed extension
        if ext in DISALLOWED_EXTS:
            per_class_errors[cls] += 1
            dst = move_or_delete(fpath)
            issues.append([cls, str(fpath), f"Disallowed extension: {ext}", dst or "DRY_RUN"])
            continue

        # 1) zero-byte
        if fpath.stat().st_size == 0:
            per_class_errors[cls] += 1
            dst = move_or_delete(fpath)
            issues.append([cls, str(fpath), "Zero-byte file", dst or "DRY_RUN"])
            continue

        # 2) readability + full decode
        try:
            with Image.open(fpath) as im:
                im.verify()
            with Image.open(fpath) as im:
                im.load()
                mode = im.mode
                w, h = im.size
        except UnidentifiedImageError:
            per_class_errors[cls] += 1
            dst = move_or_delete(fpath)
            issues.append([cls, str(fpath), "Unidentified/unsupported image", dst or "DRY_RUN"])
            continue
        except OSError as e:
            per_class_errors[cls] += 1
            dst = move_or_delete(fpath)
            issues.append([cls, str(fpath), f"OSError: {e}", dst or "DRY_RUN"])
            continue
        except Exception as e:
            per_class_errors[cls] += 1
            dst = move_or_delete(fpath)
            issues.append([cls, str(fpath), f"Other error: {type(e).__name__}: {e}", dst or "DRY_RUN"])
            continue

        # 3) mode rule
        if REQUIRE_RGB and mode != "RGB":
            per_class_errors[cls] += 1
            dst = move_or_delete(fpath)
            issues.append([cls, str(fpath), f"Invalid mode: {mode} (need RGB)", dst or "DRY_RUN"])
            continue

        # 4) size/aspect rules
        if min(w, h) < MIN_SIDE:
            per_class_errors[cls] += 1
            dst = move_or_delete(fpath)
            issues.append([cls, str(fpath), f"Too small: {w}x{h} (min<{MIN_SIDE})", dst or "DRY_RUN"])
            continue

        ar = max(w, h) / max(1, min(w, h))
        if ar > MAX_ASPECT:
            per_class_errors[cls] += 1
            dst = move_or_delete(fpath)
            issues.append([cls, str(fpath), f"Extreme aspect ratio: {w}x{h} (>{MAX_ASPECT}:1)", dst or "DRY_RUN"])
            continue

        # Survived pass 1
        valid += 1
        all_valid_files.append(fpath)
        class_of[fpath] = cls

    after_counts[cls] = valid

# --------- Duplicates ----------
# Exact duplicates by MD5 (within class)
if CHECK_EXACT_DUPLICATES and all_valid_files:
    by_class = defaultdict(list)
    for p in all_valid_files:
        if p.exists():
            by_class[class_of[p]].append(p)
    for cls, paths in by_class.items():
        md5_map = defaultdict(list)
        for p in paths:
            try:
                md5_map[file_md5(p)].append(p)
            except Exception as e:
                per_class_errors[cls] += 1
                dst = move_or_delete(p)
                issues.append([cls, str(p), f"MD5 failed: {e}", dst or "DRY_RUN"])
                after_counts[cls] -= 1

        for group in md5_map.values():
            if len(group) > 1:
                keep = group[0]
                for g in group[1:]:
                    if not g.exists():
                        continue
                    per_class_errors[cls] += 1
                    dst = move_or_delete(g)
                    issues.append([cls, str(g), f"Exact duplicate of {keep.relative_to(DATASET_DIR)}", dst or "DRY_RUN"])
                    after_counts[cls] -= 1

# Near-duplicates by identical pHash (aggressive; within class)
if CHECK_NEAR_DUP_SAME_PHASH and all_valid_files:
    by_class = defaultdict(list)
    for p in all_valid_files:
        if p.exists():
            by_class[class_of[p]].append(p)

    for cls, paths in by_class.items():
        buckets = defaultdict(list)
        for p in paths:
            try:
                h = phash_hex(p)
                buckets[h].append(p)
            except Exception as e:
                per_class_errors[cls] += 1
                dst = move_or_delete(p)
                issues.append([cls, str(p), f"pHash failed: {e}", dst or "DRY_RUN"])
                after_counts[cls] -= 1

        for bucket_paths in buckets.values():
            if len(bucket_paths) <= 1:
                continue
            keep = bucket_paths[0]
            for g in bucket_paths[1:]:
                if not g.exists():
                    continue
                per_class_errors[cls] += 1
                dst = move_or_delete(g)
                issues.append([cls, str(g), f"Near-duplicate of {keep.relative_to(DATASET_DIR)} (same pHash)", dst or "DRY_RUN"])
                after_counts[cls] -= 1

# --------- Report ----------
removed_total = sum(before_counts[c] - after_counts[c] for c in before_counts)

print("Dataset root:", DATASET_DIR)
print("\n=== Summary per class ===")
print("{:<30} {:>8} {:>8} {:>8}".format("Class", "Before", "Valid", "Removed"))
print("-"*60)
for cls in sorted(before_counts):
    removed = before_counts[cls] - after_counts[cls]
    print("{:<30} {:>8} {:>8} {:>8}".format(cls, before_counts[cls], after_counts[cls], removed))

print("\nTotal removed:", removed_total)
print("Mode:", "DRY RUN (no files moved/deleted)" if DRY_RUN else f"ACTION = {ACTION.upper()}")
if ACTION == "quarantine":
    print("Quarantine folder:", QUARANTINE_DIR)

# Save CSV audit log
with LOG_CSV.open("w", newline="", encoding="utf-8") as f:
    w = csv.writer(f)
    w.writerow(["class", "path", "reason", "action"])
    for cls, path, reason, action in issues:
        w.writerow([cls, path, reason, action])
print("Log saved to:", LOG_CSV)

Dataset root: /Users/tharukakumarasiri/Desktop/AI_ML_Project/Notebooks/Tea Leaf Disease Dataset

=== Summary per class ===
Class                            Before    Valid  Removed
------------------------------------------------------------
algal_spot                         5497     5135      362
brown_blight                       4908     4810       98
gray_blight                        5537     5121      416
healthy                            5492     4990      502
helopeltis                         5482     5254      228
red_spot                           5505     5266      239

Total removed: 1845
Mode: DRY RUN (no files moved/deleted)
Quarantine folder: /Users/tharukakumarasiri/Desktop/AI_ML_Project/Notebooks/Tea Leaf Disease Dataset/_quarantine_corrupt
Log saved to: /Users/tharukakumarasiri/Desktop/AI_ML_Project/Notebooks/Tea Leaf Disease Dataset/_clean_log.csv
