In [98]:
from jupyter_core.version import parts

OUTPUT_BASE_DIR = "./output"
DATA_DIR = "./data"
SAMPLE_IMAGES_PER_GROUP = 20
SEED = 42
IMAGE_EXTS = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp'}

In [99]:
import argparse
import os
from pathlib import Path
from collections import Counter, defaultdict
import random
import csv
import json
from PIL import Image, UnidentifiedImageError
import numpy as np
import pandas as pd
from tqdm import tqdm
import imagehash
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split

In [100]:
def find_images(root_dir):
    root = Path(root_dir)
    items = []
    for p in root.rglob('*'):
        if p.suffix.lower() in IMAGE_EXTS:
            items.append(p)
    return sorted(items)

In [101]:
def infer_labels_from_path(path: Path, root_dir: Path):
    # Get relative path from root directory
    rel = path.relative_to(root_dir)

    # The folder name we need is the immediate parent directory
    folder_name = rel.parent.name  # This gets "Shelled Healthy Multiple Tamarind"

    # Convert to lowercase for case-insensitive matching
    folder_lower = folder_name.lower()

    # Extract class (healthy/unhealthy)
    if 'unhealthy' in folder_lower or 'diseased' in folder_lower or 'sick' in folder_lower:
        cls = 'unhealthy'
    elif 'healthy' in folder_lower:
        cls = 'healthy'
    else:
        cls = 'unknown'

    # Extract domain (shelled/unshelled)
    if 'unshelled' in folder_lower:
        domain = 'unshelled'
    elif 'shelled' in folder_lower:
        domain = 'shelled'
    else:
        domain = 'mixed'

    # Extract multiplicity (single/multiple)
    if 'single' in folder_lower:
        multiplicity = 'single'
    elif 'multiple' in folder_lower:
        multiplicity = 'multiple'
    else:
        multiplicity = 'unknown'

    return cls, domain, multiplicity, folder_name

In [102]:
def is_image_valid(path):
    try:
        with Image.open(path) as im:
            im.verify()  # verify does not load full image in memory
        # Reopen to get size safely
        with Image.open(path) as im:
            im.convert('RGB')
        return True, None
    except UnidentifiedImageError as e:
        return False, f"UnidentifiedImageError: {e}"
    except Exception as e:
        return False, str(e)

In [103]:
def compute_phash(path, hash_size=16):
    try:
        with Image.open(path) as im:
            ph = imagehash.phash(im, hash_size=hash_size)
        return str(ph)
    except Exception:
        return None

In [104]:
random.seed(SEED)
out = Path(OUTPUT_BASE_DIR)
out.mkdir(parents=True, exist_ok=True)
data_dir = Path(DATA_DIR)

In [105]:
print("Finding images...")
images = find_images(data_dir)
print(f"Found {len(images)} images.")

Finding images...
Found 8432 images.


In [106]:
records = []
invalid = []
counts = Counter()
domain_counts = Counter()
size_counts = Counter()
phash_map = defaultdict(list)

for p in tqdm(images, desc="Scanning images"):
    # print("Processing", p)
    # print("type >>> ", type(p))
    valid, err = is_image_valid(p)
    if not valid:
        invalid.append({'path': str(p), 'error': err})
        continue
    cls, domain, multiplicity, folder = infer_labels_from_path(p, data_dir)
    # open to get size
    try:
        with Image.open(p) as im:
            w,h = im.size
    except Exception as e:
        invalid.append({'path': str(p), 'error': f"open_error:{e}"})
        continue
    counts[cls] += 1
    domain_counts[domain] += 1
    size_counts[(w,h)] += 1
    ph = compute_phash(p)
    if ph is not None:
        phash_map[ph].append(str(p))
    records.append({
        'path': str(p),
        'class': cls,
        'domain': domain,
        'multiplicity': multiplicity,
        'folder': folder,
        'width': w,
        'height': h,
        'phash': ph
    })

df = pd.DataFrame.from_records(records)
print("Summary counts:")
print(counts)
df.to_csv(out / 'tamarind_metadata.csv', index=False)
with open(out / 'invalid_images.json', 'w') as f:
    json.dump(invalid, f, indent=2)

Scanning images: 100%|██████████| 8432/8432 [00:42<00:00, 198.35it/s]

Summary counts:
Counter({'healthy': 6408, 'unhealthy': 2024})





In [115]:
# class distribution bar plot
plt.figure(figsize=(8,5))
sns.barplot(x=list(counts.keys()), y=list(counts.values()))
plt.title('Image counts per class')
plt.ylabel('count')
plt.xlabel('class')
plt.tight_layout()
# plt.show()
plt.savefig(out / 'class_distribution.png', dpi=150)
plt.close()

In [116]:
# domain distribution
plt.figure(figsize=(8,5))
sns.barplot(x=list(domain_counts.keys()), y=list(domain_counts.values()))
plt.title('Image counts per domain (inferred)')
plt.ylabel('count')
plt.xlabel('domain')
plt.tight_layout()
# plt.show()
plt.savefig(out / 'domain_distribution.png', dpi=150)
plt.close()

In [117]:
# image size distribution (top N)
sizes_sorted = sorted(size_counts.items(), key=lambda x: -x[1])[:10]
fig, ax = plt.subplots(1,1,figsize=(8,5))
labels = [f"{w}x{h}" for (w,h),c in sizes_sorted]
vals = [c for (k,c) in sizes_sorted]
sns.barplot(x=labels, y=vals, ax=ax)
ax.set_title('Top image sizes')
plt.xticks(rotation=45)
plt.tight_layout()
# plt.show()
plt.savefig(out / 'image_sizes_top10.png', dpi=150)
plt.close()

In [110]:
# duplicate detection via phash: groups with >1 entry are potential duplicates
dup_groups = {k:v for k,v in phash_map.items() if len(v) > 1}
with open(out / 'duplicate_groups.json', 'w') as f:
    json.dump(dup_groups, f, indent=2)

In [111]:
# sample images per (class, domain) for quick visual inspection
grouped = df.groupby(['class','domain'])
sample_dir = out / 'samples'
sample_dir.mkdir(exist_ok=True)
sample_index = []
for (cls, dom), g in grouped:
    g2 = g.sample(min(len(g), SAMPLE_IMAGES_PER_GROUP), random_state=SEED)
    fig, axes = plt.subplots(1, len(g2), figsize=(len(g2)*3,3))
    if len(g2)==1:
        axes = [axes]
    for ax, (_, row) in zip(axes, g2.iterrows()):
        try:
            img = Image.open(row['path']).convert('RGB')
            ax.imshow(img)
            ax.axis('off')
        except Exception as e:
            ax.text(0.5,0.5,f"err:{e}",ha='center')
    title = f"{cls} | {dom} | n={len(g)}"
    fig.suptitle(title)
    fname = f"sample_{cls}_{dom}.png".replace(' ','_')
    fig.savefig(sample_dir / fname, dpi=150, bbox_inches='tight')
    plt.close(fig)
    sample_index.append({'class': cls, 'domain': dom, 'count': len(g), 'sample_image': str(sample_dir / fname)})

pd.DataFrame(sample_index).to_csv(out / 'samples_index.csv', index=False)

# Save a small CSV list of items for further experiments and stratified split
df['combined_label'] = df['class'].astype(str) + '|' + df['domain'].astype(str)
df.to_csv(out / 'tamarind_metadata_full.csv', index=False)

In [113]:
# Stratified train/val/test split (70/15/15)
items = df.to_dict('records')
labels = [r['combined_label'] for r in items]
train, temp = train_test_split(items, stratify=labels, test_size=0.30, random_state=SEED)
temp_labels = [r['combined_label'] for r in temp]
val, test = train_test_split(temp, stratify=temp_labels, test_size=0.5, random_state=SEED)
pd.DataFrame(train).to_csv(out / 'split_train.csv', index=False)
pd.DataFrame(val).to_csv(out / 'split_val.csv', index=False)
pd.DataFrame(test).to_csv(out / 'split_test.csv', index=False)
print("Saved metadata and splits to:", out)

Saved metadata and splits to: output


In [114]:
# Quick metrics summary file
summary = {
    'total_images': len(records),
    'invalid_images': len(invalid),
    'class_counts': dict(counts),
    'domain_counts': dict(domain_counts),
    'duplicate_groups_count': len(dup_groups)
}
with open(out / 'summary.json', 'w') as f:
    json.dump(summary, f, indent=2)
print("EDA complete. Check the output folder for images and CSVs.")

EDA complete. Check the output folder for images and CSVs.
