In [3]:
import os, random, time, io
from pathlib import Path
from typing import List

import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader, Subset, Dataset
from torchvision import datasets, transforms, models
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, f1_score, accuracy_score
import kagglehub






ModuleNotFoundError: No module named 'torch'

In [None]:
# Reproducibility + device
SEED = 56
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

In [None]:

# Check if a graphics card (GPU) is available for faster training
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device being used:", device)

In [None]:
# Download dataset (Kaggle)
# Dataset: alistairking/recyclable-and-household-waste-classification
print("Downloading dataset...")
ds_path = kagglehub.dataset_download(
    "alistairking/recyclable-and-household-waste-classification"
)
print("Dataset downloaded to:", ds_path)


In [None]:
# Some Kaggle datasets have extra folders like "images/images"
# This loop finds the real folder that contains all the image categories.
candidates = [
    Path(path) / "images" / "images",
    Path(path) / "images",
    Path(path),
]

DATA_ROOT = None
for c in candidates:
    if c.exists() and any(d.is_dir() for d in c.iterdir()):
        DATA_ROOT = c
        break

if DATA_ROOT is None:
    raise FileNotFoundError(f"Could not find image folder in: {path}")

print("Image root folder found:", DATA_ROOT)

In [None]:
# Build full ImageFolder (PNG-only filter)
png_only = lambda p: str(p).lower().endswith(".png")
full_ds = datasets.ImageFolder(root=str(DATA_ROOT), transform=None, is_valid_file=png_only)
class_names: List[str] = full_ds.classes
num_classes = len(class_names)
print(f"Classes ({num_classes}):", class_names[:10], "..." if num_classes > 10 else "")
print("Total PNG images:", len(full_ds.samples))

In [None]:
# --- Check what categories (folders) exist ---
# Each folder inside the dataset represents one type of waste.
class_dirs = sorted([d.name for d in DATA_ROOT.iterdir() if d.is_dir()])
print("Number of categories found:", len(class_dirs))
print("Example categories:", class_dirs[:10])

In [None]:
# --- Define image transformations (resizing and data augmentation) ---
# These changes help prepare the photos before training the model.

IMG_SIZE = 224  # final image size (in pixels)

# Training transformations (adds small random changes for variety)
train_tf = transforms.Compose([
    transforms.Resize((256, 256)),             # make all images similar size
    transforms.RandomResizedCrop(IMG_SIZE,     # randomly crop and resize
                                 scale=(0.8, 1.0),
                                 ratio=(0.9, 1.1)),
    transforms.RandomHorizontalFlip(),         # randomly flip images left-right
    transforms.RandomRotation(10),             # rotate slightly (±10°)
    transforms.ColorJitter(                    # small colour adjustments
        brightness=0.10, contrast=0.10, saturation=0.10, hue=0.05),
    transforms.ToTensor(),                     # turn image into numeric array
    transforms.Normalize(mean=[0.485, 0.456, 0.406],  # standard colour scaling
                         std=[0.229, 0.224, 0.225]),
])


In [None]:
# Validation/testing transformations (simpler, no random changes)
eval_tf = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.CenterCrop(IMG_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

# Filter function: only include .png or .PNG images
png_only = lambda p: str(p).lower().endswith(".png")

In [None]:
# --- Load the dataset ---
# Organises images into labelled groups using their folder names.
full_ds = datasets.ImageFolder(
    root=str(DATA_ROOT),
    transform=None,          # we'll apply transforms later
    is_valid_file=png_only   # only include PNG files
)


In [None]:
# Print a quick summary
num_classes = len(full_ds.classes)
print(f"Total number of categories: {num_classes}")
print("First few category names:", full_ds.classes[:10])
print("Total number of images:", len(full_ds.samples))