In [None]:
!pip install matplotlib

In [None]:
import os
import matplotlib.pyplot as plt
from PIL import Image, UnidentifiedImageError
import numpy as np
from collections import defaultdict

In [None]:
data_dir = "C:/Users/OWNER/Downloads/SDS-CP028-smart-leaf/submissions/team-members/Samsudeen/CropDiseaseClasses"
shapes = []
class_counts = defaultdict(int)
removed = []

In [None]:
print("🔍 Verifying image shapes and removing corrupted files...\n")

for class_name in sorted(os.listdir(data_dir)):
    class_path = os.path.join(data_dir, class_name)
    if not os.path.isdir(class_path):
        continue

    for filename in os.listdir(class_path):
        filepath = os.path.join(class_path, filename)

        try:
            with Image.open(filepath) as img:
                img.verify()  # check for corruption

            with Image.open(filepath) as img_check:
                img_check = img_check.convert("RGB")  # ensure RGB
                np_img = np.array(img_check)
                shape = np_img.shape  # (height, width, channels)
                shapes.append((filepath, shape))
                print(f"{filepath} — Shape: {shape}")  # example: (224, 224, 3)

                class_counts[class_name] += 1

        except (UnidentifiedImageError, OSError, ValueError):
            print(f"❌ Removing corrupted file: {filepath}")
            os.remove(filepath)
            removed.append(filepath)


In [None]:
print("\n✅ Done checking.")
print(f"🧹 Removed corrupted files: {len(removed)}")

# Optional: print shape summary
print("\n📏 Unique image shapes found:")
unique_shapes = set(shape for _, shape in shapes)
for shape in unique_shapes:
    count = sum(1 for _, s in shapes if s == shape)
    print(f"{shape}: {count} images")

In [None]:
# Optional: display class counts
print("\n📊 Class image counts:")
for cls, count in class_counts.items():
    print(f"{cls}: {count}")

In [None]:
# Plot class distribution
plt.figure(figsize=(12, 6))
plt.bar(class_counts.keys(), class_counts.values(), color='teal')
plt.title("Crop Disease Distribution")
plt.xticks(rotation=45)
plt.ylabel("Number of Images")
plt.tight_layout()
plt.show()


In [None]:
# Save the figure
# Plot class distribution
plt.figure(figsize=(12, 6))
plt.bar(class_counts.keys(), class_counts.values(), color='teal')
plt.title("Crop Disease Distribution")
plt.xticks(rotation=45)
plt.ylabel("Number of Images")
#plt.tight_layout()
#plt.show()
plt.savefig("crop_disease_distribution.png", dpi=300) 
plt.close()  # Close the figure to avoid displaying it in notebooks