In [1]:
#pip install kagglehub

In [2]:
import os, random, time, io
from pathlib import Path
from typing import List

import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader, Subset, Dataset
from torchvision import datasets, transforms, models
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, f1_score, accuracy_score
import kagglehub





In [3]:
print(torch.__version__)          # e.g., 2.x.x+cu124
print("Built with CUDA:", torch.version.cuda)
print("CUDA available:", torch.cuda.is_available())

2.5.1+cu121
Built with CUDA: 12.1
CUDA available: True


In [4]:
# Reproducibility + device
SEED = 56
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

Device: cuda


In [5]:

# Check if a graphics card (GPU) is available for faster training
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device being used:", device)

Device being used: cuda


In [6]:
# Download dataset (Kaggle)
# Dataset: alistairking/recyclable-and-household-waste-classification
print("Downloading dataset...")
ds_path = kagglehub.dataset_download(
    "alistairking/recyclable-and-household-waste-classification"
)
print("Dataset downloaded to:", ds_path)


Downloading dataset...
Dataset downloaded to: C:\Users\angus\.cache\kagglehub\datasets\alistairking\recyclable-and-household-waste-classification\versions\1


In [7]:
# Some Kaggle datasets have extra folders like "images/images"
# This loop finds the real folder that contains all the image categories.
candidates = [
    Path(ds_path) / "images" / "images",
    Path(ds_path) / "images",
    Path(ds_path),
]

DATA_ROOT = None
for c in candidates:
    if c.exists() and any(d.is_dir() for d in c.iterdir()):
        DATA_ROOT = c
        break

if DATA_ROOT is None:
    raise FileNotFoundError(f"Could not find image folder in: {path}")

print("Image root folder found:", DATA_ROOT)

Image root folder found: C:\Users\angus\.cache\kagglehub\datasets\alistairking\recyclable-and-household-waste-classification\versions\1\images\images


In [8]:
# Build full ImageFolder (PNG-only filter)
png_only = lambda p: str(p).lower().endswith(".png")
full_ds = datasets.ImageFolder(root=str(DATA_ROOT), transform=None, is_valid_file=png_only)
class_names: List[str] = full_ds.classes
num_classes = len(class_names)
print(f"Classes ({num_classes}):", class_names[:10], "..." if num_classes > 10 else "")
print("Total PNG images:", len(full_ds.samples))

Classes (30): ['aerosol_cans', 'aluminum_food_cans', 'aluminum_soda_cans', 'cardboard_boxes', 'cardboard_packaging', 'clothing', 'coffee_grounds', 'disposable_plastic_cutlery', 'eggshells', 'food_waste'] ...
Total PNG images: 15000


In [9]:
# --- Check what categories (folders) exist ---
# Each folder inside the dataset represents one type of waste.
class_dirs = sorted([d.name for d in DATA_ROOT.iterdir() if d.is_dir()])
print("Number of categories found:", len(class_dirs))
print("Example categories:", class_dirs[:10])

Number of categories found: 30
Example categories: ['aerosol_cans', 'aluminum_food_cans', 'aluminum_soda_cans', 'cardboard_boxes', 'cardboard_packaging', 'clothing', 'coffee_grounds', 'disposable_plastic_cutlery', 'eggshells', 'food_waste']


In [10]:
# --- Define image transformations (resizing and data augmentation) ---
# These changes help prepare the photos before training the model.

IMG_SIZE = 224  # final image size (in pixels)

# Training transformations (adds small random changes for variety)
train_tf = transforms.Compose([
    transforms.Resize((256, 256)),             # make all images similar size
    transforms.RandomResizedCrop(IMG_SIZE,     # randomly crop and resize
                                 scale=(0.8, 1.0),
                                 ratio=(0.9, 1.1)),
    transforms.RandomHorizontalFlip(),         # randomly flip images left-right
    transforms.RandomRotation(10),             # rotate slightly (±10°)
    transforms.ColorJitter(                    # small colour adjustments
        brightness=0.10, contrast=0.10, saturation=0.10, hue=0.05),
    transforms.ToTensor(),                     # turn image into numeric array
    transforms.Normalize(mean=[0.485, 0.456, 0.406],  # standard colour scaling
                         std=[0.229, 0.224, 0.225]),
])


In [11]:
# Validation/testing transformations (simpler, no random changes)
eval_tf = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.CenterCrop(IMG_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

# Filter function: only include .png or .PNG images
png_only = lambda p: str(p).lower().endswith(".png")

In [12]:
# --- Load the dataset ---
# Organises images into labelled groups using their folder names.
full_ds = datasets.ImageFolder(
    root=str(DATA_ROOT),
    transform=None,          # we'll apply transforms later
    is_valid_file=png_only   # only include PNG files
)


In [13]:
# Print a quick summary
num_classes = len(full_ds.classes)
print(f"Total number of categories: {num_classes}")
print("First few category names:", full_ds.classes[:10])
print("Total number of images:", len(full_ds.samples))

Total number of categories: 30
First few category names: ['aerosol_cans', 'aluminum_food_cans', 'aluminum_soda_cans', 'cardboard_boxes', 'cardboard_packaging', 'clothing', 'coffee_grounds', 'disposable_plastic_cutlery', 'eggshells', 'food_waste']
Total number of images: 15000


In [14]:
#pip install datasets

In [15]:
# preprocess_trashnet.py
"""
Preprocessing & DataLoaders for Hugging Face dataset: garythung/trashnet

Fixes for the errors you hit:
- Handle HuggingFace `set_transform` batch semantics (the transform receives lists).
- Force RGB before resizing to avoid mixed channel counts.
- Custom collate_fn that tolerates single-item lists from `set_transform`.
- No Dataset.with_format('torch'); we stack tensors ourselves.
- CLI uses parse_known_args() so Jupyter's injected --f arg won't crash.

Quick use (Notebook/Python):
    from preprocess_trashnet import build_dataloaders
    loaders, classes = build_dataloaders(batch_size=32, image_size=224, num_workers=0)
    xb, yb = next(iter(loaders["train"]))
    print(xb.shape, yb.shape, classes)

CLI:
    python preprocess_trashnet.py --batch-size 64 --image-size 224 --num-workers 0
"""

from dataclasses import dataclass
from typing import Dict, Tuple, Optional, List
import os
import random

import torch
from torch.utils.data import DataLoader
from torchvision import transforms as T
from torchvision.transforms import InterpolationMode
from PIL import Image
from datasets import load_dataset, Dataset, DatasetDict, Features
from tqdm import tqdm


# ---------- Config ----------

@dataclass
class Config:
    image_size: int = 224
    val_pct: float = 0.10
    test_pct: float = 0.10
    seed: int = 42
    batch_size: int = 32
    num_workers: int = 0  # safest on Windows; bump if you want
    pin_memory: bool = torch.cuda.is_available()
    persistent_workers: bool = False  # will enable if num_workers > 0
    augment: bool = True
    export_dir: Optional[str] = None  # optional: export class-folder JPGs


IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD  = (0.229, 0.224, 0.225)


# ---------- Utils ----------

def _set_seed(seed: int):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def _pick_base_split(ds: DatasetDict) -> Dataset:
    return ds["train"] if "train" in ds else ds[next(iter(ds.keys()))]


def _stratified_splits(base: Dataset, val_pct: float, test_pct: float, seed: int) -> DatasetDict:
    assert 0 < val_pct < 1 and 0 < test_pct < 1 and val_pct + test_pct < 1, "Invalid val/test percentages."
    holdout_pct = val_pct + test_pct
    tmp = base.train_test_split(test_size=holdout_pct, stratify_by_column="label", seed=seed)
    train, holdout = tmp["train"], tmp["test"]
    test_frac_of_holdout = test_pct / (val_pct + test_pct)
    hold = holdout.train_test_split(test_size=test_frac_of_holdout, stratify_by_column="label", seed=seed)
    return DatasetDict(train=train, val=hold["train"], test=hold["test"])


def _maybe_make_splits(ds: DatasetDict, cfg: Config) -> DatasetDict:
    keys = set(ds.keys())
    if {"train", "validation", "test"}.issubset(keys):
        return DatasetDict(train=ds["train"], val=ds["validation"], test=ds["test"])
    if {"train", "val", "test"}.issubset(keys):
        return DatasetDict(train=ds["train"], val=ds["val"], test=ds["test"])
    return _stratified_splits(_pick_base_split(ds), cfg.val_pct, cfg.test_pct, cfg.seed)


def _build_transforms(cfg: Config):
    # Force RGB first (some images may be grayscale)
    common = [
        T.Lambda(lambda im: im.convert("RGB")),  # <- important
        T.Resize((cfg.image_size, cfg.image_size), interpolation=InterpolationMode.BILINEAR),
    ]
    train_list = common + (
        [
            T.RandomHorizontalFlip(p=0.5),
            T.RandomApply([T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2)], p=0.5),
            T.RandomAffine(degrees=12, translate=(0.05, 0.05), scale=(0.95, 1.05)),
        ] if cfg.augment else []
    ) + [
        T.ToTensor(),
        T.Normalize(IMAGENET_MEAN, IMAGENET_STD),
    ]
    eval_list = common + [
        T.ToTensor(),
        T.Normalize(IMAGENET_MEAN, IMAGENET_STD),
    ]
    return T.Compose(train_list), T.Compose(eval_list)


def _set_transform(ds: Dataset, tfm):
    """
    HuggingFace `set_transform` applies on *batches*: the dict values are lists.
    We must map the torchvision transform over each image in the list.
    """
    def _apply(batch):
        imgs = batch["image"]
        if not isinstance(imgs, list):  # be robust if it's a single example
            imgs = [imgs]
        batch["image"] = [tfm(img) for img in imgs]
        return batch
    ds.set_transform(_apply)
    return ds


def _class_names_from_features(feats: Features) -> List[str]:
    if "label" in feats and hasattr(feats["label"], "names") and feats["label"].names:
        return list(feats["label"].names)
    return []


def _export_split(split: Dataset, split_name: str, class_names: List[str], out_dir: str, image_size: int):
    os.makedirs(out_dir, exist_ok=True)
    export_tfm = T.Compose([
        T.Lambda(lambda im: im.convert("RGB")),
        T.Resize((image_size, image_size), interpolation=InterpolationMode.BILINEAR),
    ])
    for idx in tqdm(range(len(split)), desc=f"Exporting {split_name}"):
        ex = split[idx]
        img: Image.Image = ex["image"]
        label = int(ex["label"])
        cls = class_names[label] if class_names else str(label)
        cls_dir = os.path.join(out_dir, split_name, cls)
        os.makedirs(cls_dir, exist_ok=True)
        export_tfm(img).save(os.path.join(cls_dir, f"{idx:06d}.jpg"), format="JPEG", quality=92, optimize=True)


def _collate_batch(batch):
    """
    Each item `b` may come with b["image"] as a Tensor OR a single-item list [Tensor]
    depending on HF internals. Be tolerant.
    """
    imgs, labels = [], []
    for b in batch:
        img = b["image"]
        if isinstance(img, list):
            img = img[0]
        lbl = b["label"]
        if isinstance(lbl, list):
            lbl = lbl[0]
        imgs.append(img)
        labels.append(int(lbl))
    return torch.stack(imgs, dim=0), torch.tensor(labels, dtype=torch.long)


# ---------- Public API ----------

def build_dataloaders(
    batch_size: int = 32,
    image_size: int = 224,
    val_pct: float = 0.10,
    test_pct: float = 0.10,
    seed: int = 42,
    num_workers: Optional[int] = None,
    augment: bool = True,
    export_dir: Optional[str] = None,
) -> Tuple[Dict[str, DataLoader], List[str]]:
    cfg = Config(
        image_size=image_size,
        val_pct=val_pct,
        test_pct=test_pct,
        seed=seed,
        batch_size=batch_size,
        num_workers=(num_workers if num_workers is not None else 0),
        augment=augment,
        export_dir=export_dir,
    )
    cfg.persistent_workers = bool(cfg.num_workers and cfg.num_workers > 0)

    _set_seed(cfg.seed)

    # 1) Load dataset
    ds = load_dataset("garythung/trashnet")

    # 2) Standardize splits
    splits = _maybe_make_splits(ds, cfg)

    # 3) Class names
    feats = _pick_base_split(ds).features
    class_names = _class_names_from_features(feats)

    # 4) Transforms
    train_tfm, eval_tfm = _build_transforms(cfg)

    # 5) Attach transforms (batch-aware)
    splits["train"] = _set_transform(splits["train"], train_tfm)
    splits["val"]   = _set_transform(splits["val"],   eval_tfm)
    splits["test"]  = _set_transform(splits["test"],  eval_tfm)

    # 6) Optional export (use a fresh view without transforms)
    if cfg.export_dir:
        raw = _maybe_make_splits(ds, cfg)
        for part in ["train", "val", "test"]:
            _export_split(raw[part], part, class_names, cfg.export_dir, cfg.image_size)

    # 7) DataLoaders
    loaders: Dict[str, DataLoader] = {
        "train": DataLoader(
            splits["train"],
            batch_size=cfg.batch_size,
            shuffle=True,
            num_workers=cfg.num_workers,
            pin_memory=cfg.pin_memory,
            persistent_workers=cfg.persistent_workers,
            collate_fn=_collate_batch,
        ),
        "val": DataLoader(
            splits["val"],
            batch_size=cfg.batch_size,
            shuffle=False,
            num_workers=cfg.num_workers,
            pin_memory=cfg.pin_memory,
            persistent_workers=cfg.persistent_workers,
            collate_fn=_collate_batch,
        ),
        "test": DataLoader(
            splits["test"],
            batch_size=cfg.batch_size,
            shuffle=False,
            num_workers=cfg.num_workers,
            pin_memory=cfg.pin_memory,
            persistent_workers=cfg.persistent_workers,
            collate_fn=_collate_batch,
        ),
    }

    return loaders, class_names


# ---------- CLI ----------

if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser(description="Preprocess garythung/trashnet and build loaders / optional export.")
    parser.add_argument("--batch-size", type=int, default=32)
    parser.add_argument("--image-size", type=int, default=224)
    parser.add_argument("--val-pct", type=float, default=0.10)
    parser.add_argument("--test-pct", type=float, default=0.10)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--num-workers", type=int, default=0)
    parser.add_argument("--no-augment", action="store_true", help="Disable training-time data augmentation.")
    parser.add_argument("--export-dir", type=str, default=None, help="If set, saves resized JPGs to this folder.")
    args, _unknown = parser.parse_known_args()

    loaders, class_names = build_dataloaders(
        batch_size=args.batch_size,
        image_size=args.image_size,
        val_pct=args.val_pct,
        test_pct=args.test_pct,
        seed=args.seed,
        num_workers=args.num_workers,
        augment=not args.no_augment,
        export_dir=args.export_dir,
    )

    n_train = len(loaders["train"].dataset)
    n_val   = len(loaders["val"].dataset)
    n_test  = len(loaders["test"].dataset)
    print(f"Splits -> train: {n_train}, val: {n_val}, test: {n_test}")
    if class_names:
        print(f"Classes ({len(class_names)}): {class_names}")

    xb, yb = next(iter(loaders["train"]))
    print(f"Batch shapes -> images: {tuple(xb.shape)}, labels: {tuple(yb.shape)}")


Splits -> train: 4043, val: 505, test: 506
Classes (6): ['cardboard', 'glass', 'metal', 'paper', 'plastic', 'trash']
Batch shapes -> images: (32, 3, 224, 224), labels: (32,)
