In [10]:
import numpy as np
import pandas as pd
import os
import cv2
from pathlib import Path
import random
from tqdm import tqdm
from skimage.feature import hog
from skimage import color
from pathlib import Path




In [11]:
BASE_DIR = Path.cwd()
DATASET_PATH = BASE_DIR / 'dataset'
CATEGORIES = ['cardboard', 'glass', 'metal', 'paper', 'plastic', 'trash']
TARGET_COUNT = 500


In [12]:
def getImages(categoryPath):
    return list(categoryPath.glob('*.jpg'))


def printCategoryCounts():
    print("\nImage counts per category:")
    print("-" * 30)
    for categoryName in CATEGORIES:
        categoryPath = DATASET_PATH / categoryName
        imageCount = len(getImages(categoryPath))
        print(f"{categoryName:15s}: {imageCount} images")


def validateAndCleanDataset():
    if not DATASET_PATH.exists():
        raise FileNotFoundError(f"Dataset not found at: {DATASET_PATH}")
    print("Validating dataset for corrupted images...")
    categoryDirectories = [directory for directory in DATASET_PATH.iterdir() if directory.is_dir()]
    print(f"Found {len(categoryDirectories)} categories: {[category.name for category in categoryDirectories]}\n")
    totalRemovedImages = 0
    for categoryDirectory in sorted(categoryDirectories):
        imageFiles = [file for file in categoryDirectory.iterdir() if file.is_file()]
        removedCount = 0
        for imagePath in imageFiles:
            if cv2.imread(str(imagePath)) is None:
                print(f"Removing: {imagePath.name} as it is corrupted.")
                try:
                    os.remove(imagePath)
                    removedCount += 1
                except Exception as error:
                    print(f"Warning: Could not remove: {error}")
        numberOfValidImages = len(imageFiles) - removedCount
        print(f"{categoryDirectory.name}: {numberOfValidImages}/{len(imageFiles)} valid images")
        totalRemovedImages += removedCount
    print(f"\nValidation complete. Removed {totalRemovedImages} corrupted file(s).")
    printCategoryCounts()


def augmentImage(image):
    height, width = image.shape[:2]
    center = (width / 2, height / 2)
    augmentations = [
        cv2.flip(image, 1),  # Horizontal flip
        cv2.flip(image, 0),  # Vertical flip
        cv2.warpAffine(image, cv2.getRotationMatrix2D(center, 90, 1.0), (width, height)), # 90 degrees
        cv2.warpAffine(image, cv2.getRotationMatrix2D(center, 180, 1.0), (width, height)), # 180 degrees
        cv2.warpAffine(image, cv2.getRotationMatrix2D(center, 270, 1.0), (width, height)), # 270 degrees
        cv2.convertScaleAbs(image, alpha=1.3, beta=30),   # Brightness +30
        cv2.convertScaleAbs(image, alpha=0.7, beta=-30),  # Brightness -30
        cv2.GaussianBlur(image, (5, 5), 0),               # Gaussian blur 5x5
    ]
    # Zoom crop
    scale = 1.2
    newHeight, newWidth = int(height * scale), int(width * scale)
    resized = cv2.resize(image, (newWidth, newHeight))
    startHeight, startWidth = (newHeight - height) // 2, (newWidth - width) // 2
    augmentations.append(resized[startHeight:startHeight + height, startWidth:startWidth + width])
    return augmentations


def augmentDataset():
    print(f"\nAugmenting images to reach {TARGET_COUNT} per category...")
    for categoryName in CATEGORIES:
        categoryPath = DATASET_PATH / categoryName
        imagePaths = getImages(categoryPath)
        currentImageCount = len(imagePaths)
        imagesNeeded = TARGET_COUNT - currentImageCount
        print(f"\n{categoryName}: {currentImageCount} images", end="")
        if imagesNeeded <= 0:
            print(" - Already sufficient")
            continue
        print(f" (Need {imagesNeeded} more)")
        # Load original images using tqdm for progress
        originalImages = [(cv2.imread(str(imageFile)), imageFile.stem) for imageFile in tqdm(imagePaths, desc="Loading")]
        originalImages = [(image, name) for image, name in originalImages]
        # Generate augmented images
        generatedCount = 0
        while generatedCount < imagesNeeded:
            image, imageName = random.choice(originalImages)
            for augmentationIndex, augmentedImage in enumerate(augmentImage(image)):
                if generatedCount >= imagesNeeded:
                    break
                savePath = categoryPath / f"{imageName}_augmented_{generatedCount}_{augmentationIndex}.jpg"
                cv2.imwrite(str(savePath), augmentedImage)
                generatedCount += 1
        print(f"Generated {generatedCount} augmented images")
    print("\n" + "=" * 50)
    print("Augmentation complete.")
    printCategoryCounts()



In [13]:
validateAndCleanDataset()


Validating dataset for corrupted images...
Found 6 categories: ['cardboard', 'glass', 'metal', 'paper', 'plastic', 'trash']

Removing: 2ec9d19b-8027-4c77-a13f-5eee033b9868.jpg as it is corrupted.
Removing: 31381a44-38d6-4a44-9384-7690727801bc.jpg as it is corrupted.
Removing: 345bdb67-4190-4235-a16f-b60c1556a28d.jpg as it is corrupted.
Removing: 38b3e4da-738c-4694-a946-55101b25ad53.jpg as it is corrupted.
Removing: 4840d678-7af4-4a2d-bda1-338c2f2a59c5.jpg as it is corrupted.
Removing: 509251d8-4e3a-4f1e-aabc-4d034b0f2455.jpg as it is corrupted.
Removing: 5b7da318-c2ab-4c29-8ace-19895a890840.jpg as it is corrupted.
Removing: 8617221e-dc90-48fe-a116-46350b5f814e.jpg as it is corrupted.
Removing: 88ce5fbf-e9c7-40ad-87a6-deffe95d8ee8.jpg as it is corrupted.
Removing: bff223bf-1a84-4d38-a486-c3f4c9bfef5e.jpg as it is corrupted.
Removing: ce8a4c3d-2a08-4e78-9e5a-16b69719e505.jpg as it is corrupted.
Removing: d5856b01-c157-4e34-b921-80f29252976a.jpg as it is corrupted.
cardboard: 247/259 vali

In [14]:
augmentDataset()



Augmenting images to reach 500 per category...

cardboard: 247 images (Need 253 more)


Loading: 100%|██████████| 247/247 [00:00<00:00, 726.11it/s]


Generated 253 augmented images

glass: 385 images (Need 115 more)


Loading: 100%|██████████| 385/385 [00:00<00:00, 535.29it/s]


Generated 115 augmented images

metal: 315 images (Need 185 more)


Loading: 100%|██████████| 315/315 [00:00<00:00, 558.79it/s]


Generated 185 augmented images

paper: 449 images (Need 51 more)


Loading: 100%|██████████| 449/449 [00:00<00:00, 623.01it/s]


Generated 51 augmented images

plastic: 363 images (Need 137 more)


Loading: 100%|██████████| 363/363 [00:00<00:00, 416.37it/s]


Generated 137 augmented images

trash: 106 images (Need 394 more)


Loading: 100%|██████████| 106/106 [00:00<00:00, 417.34it/s]


Generated 394 augmented images

Augmentation complete.

Image counts per category:
------------------------------
cardboard      : 500 images
glass          : 500 images
metal          : 500 images
paper          : 500 images
plastic        : 500 images
trash          : 500 images


In [16]:
# Feature Extraction Step
ROWS = []
y = []
for category_dir in DATASET_PATH.iterdir(): # For each category directory
    if category_dir.is_dir(): # Ensure it's a directory
        category = category_dir.name # Get category name
        for img in category_dir.iterdir(): # For each image in the category directory
            image = cv2.imread(str(img)) # Read the image
            image = cv2.resize(image,(128,128)) 
            gray = cv2.cvtColor(image,cv2.COLOR_BGR2GRAY)
            hog_features, _ = hog(gray, pixels_per_cell=(8,8),
                              cells_per_block=(2,2), visualize=True)
            hist = cv2.calcHist([image], [0,1,2], None, [8,8,8], [0,256,0,256,0,256])
            hist = cv2.normalize(hist, hist).flatten()
            features = np.hstack([hog_features, hist])

            ROWS.append(features)
            y.append(category)
            




X = np.array(ROWS)
y = np.array(y)

print("Feature matrix shape:", X.shape)
print("Labels shape:", y.shape)

Feature matrix shape: (3000, 8612)
Labels shape: (3000,)


In [18]:
# Convert features to DataFrame
df = pd.DataFrame(ROWS)

# Add a label column
df['label'] = y

df.head()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,8603,8604,8605,8606,8607,8608,8609,8610,8611,label
0,0.220498,0.195286,0.157054,0.220153,0.153531,0.058062,0.021371,0.124173,0.075496,0.220498,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.001472,0.015653,cardboard
1,0.236059,0.131783,0.167829,0.100275,0.236059,0.063842,0.206441,0.057195,0.018632,0.236059,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,cardboard
2,0.221509,0.081774,0.092141,0.114815,0.221509,0.221509,0.221509,0.167012,0.048867,0.201401,...,0.000221,0.0,0.0,0.0,0.0,0.0,0.0,0.006841,0.095337,cardboard
3,0.234457,0.0,0.083137,0.106276,0.234457,0.185383,0.065328,0.12652,0.223676,0.234457,...,0.0,0.0,0.0,0.0,0.000295,0.003093,0.002946,0.00486,0.015906,cardboard
4,0.241221,0.12142,0.078683,0.141422,0.235001,0.148732,0.125599,0.098984,0.084878,0.241221,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.004539,cardboard
