In [7]:
# preprocessing Gary Thung trashnet.py dataset 
# --- Binary TrashNet dataloaders with minority-only extra aug & class imbalance handling ---
# --- BASIC SETUP: make sure Python path and required packages are ready ---
import sys, subprocess, pkgutil
print("Using Python:", sys.executable)

# Install/upgrade pip and make sure torch/torchvision/kagglehub are available in THIS Python
subprocess.check_call([sys.executable, "-m", "pip", "install", "--upgrade", "pip"])
subprocess.check_call([sys.executable, "-m", "pip", "install", "torch", "torchvision", "kagglehub"])

import torch
print("Torch:", torch.__version__, "| CUDA available:", torch.cuda.is_available())

# Install light deps if missing (safe in notebooks)
import sys, subprocess, importlib
def _ensure(pkg, pip_name=None):
    pip_name = pip_name or pkg
    try:
        importlib.import_module(pkg)
    except Exception:
        subprocess.check_call([sys.executable, "-m", "pip", "install", pip_name])

_ensure("datasets")
_ensure("PIL", "pillow")

import os, random
from dataclasses import dataclass
from typing import Dict, Tuple, Optional, List
from collections import Counter

import torch
from torch.utils.data import DataLoader, WeightedRandomSampler
import torch.nn as nn
from torchvision import transforms as T
from torchvision.transforms import InterpolationMode
from PIL import Image
from datasets import load_dataset, Dataset, DatasetDict, Features

# ---------------- Config ----------------
@dataclass
class Config:
    image_size: int = 224
    val_pct: float = 0.10
    test_pct: float = 0.10
    seed: int = 42
    batch_size: int = 32
    num_workers: int = 0
    pin_memory: bool = torch.cuda.is_available()
    persistent_workers: bool = False
    augment: bool = True
    minority_extra_aug: bool = True  # extra aug only for the minority class

IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD  = (0.229, 0.224, 0.225)

# ---------------- Utils ----------------
def _set_seed(seed: int):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def _pick_base_split(ds: DatasetDict) -> Dataset:
    return ds["train"] if "train" in ds else ds[next(iter(ds.keys()))]

def _stratified_splits(base: Dataset, val_pct: float, test_pct: float, seed: int) -> DatasetDict:
    assert 0 < val_pct < 1 and 0 < test_pct < 1 and val_pct + test_pct < 1
    holdout_pct = val_pct + test_pct
    tmp = base.train_test_split(test_size=holdout_pct, stratify_by_column="label", seed=seed)
    train, holdout = tmp["train"], tmp["test"]
    test_frac_of_holdout = test_pct / (val_pct + test_pct)
    hold = holdout.train_test_split(test_size=test_frac_of_holdout, stratify_by_column="label", seed=seed)
    return DatasetDict(train=train, val=hold["train"], test=hold["test"])

def _class_names_from_features(feats: Features) -> List[str]:
    if "label" in feats and hasattr(feats["label"], "names") and feats["label"].names:
        return list(feats["label"].names)
    return []

# ------------- 6 → 2 mapping -------------
RECYCLE_NAMES = {"glass", "paper", "cardboard", "plastic", "metal"}
WASTE_NAMES   = {"trash"}  # maps to class 0

def _to_binary(ds: DatasetDict, feats: Features) -> DatasetDict:
    orig = _class_names_from_features(feats)
    recycle_ids = {i for i, n in enumerate(orig) if n in RECYCLE_NAMES}
    waste_ids   = {i for i, n in enumerate(orig) if n in WASTE_NAMES}
    if not recycle_ids or not waste_ids:
        raise RuntimeError(f"Could not build mapping from names {orig}")

    def map_fn(ex):
        lid = int(ex["label"])
        ex["label"] = 1 if lid in recycle_ids else 0  # 1=recycling, 0=waste
        ex["orig_label"] = lid
        return ex

    out = DatasetDict()
    for k in ds.keys():
        out[k] = ds[k].map(map_fn, desc=f"Map 6->2 for {k}")
    return out

# ------------- Transforms (fixed order) -------------
def _build_transform_parts(image_size: int, augment: bool):
    # PIL → PIL
    pre = T.Compose([
        T.Lambda(lambda im: im.convert("RGB")),
        T.Resize((image_size, image_size), interpolation=InterpolationMode.BILINEAR),
    ])
    # works on PIL or tensor, we keep it while still PIL
    aug_common_pil = T.Compose([
        T.RandomHorizontalFlip(p=0.5),
        T.RandomApply([T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2)], p=0.5),
        T.RandomAffine(degrees=12, translate=(0.05, 0.05), scale=(0.95, 1.05)),
    ]) if augment else T.Compose([])

    # EXTRA aug for the minority class while still PIL
    extra_minority_pil = T.Compose([
        T.RandomApply([T.RandomPerspective(distortion_scale=0.3)], p=0.35),
        T.RandomApply([T.GaussianBlur(kernel_size=3)], p=0.35),
    ]) if augment else T.Compose([])

    # PIL → Tensor
    to_tensor = T.ToTensor()
    normalize = T.Normalize(IMAGENET_MEAN, IMAGENET_STD)

    # Tensor-only extra aug (RandomErasing MUST be after ToTensor)
    extra_minority_tensor = T.RandomErasing(p=0.25, scale=(0.02, 0.08), ratio=(0.3, 3.3), value="random")

    return pre, aug_common_pil, extra_minority_pil, to_tensor, normalize, extra_minority_tensor

def _set_transform_train(ds: Dataset, pre, aug_common_pil, extra_minority_pil,
                         to_tensor, normalize, extra_minority_tensor,
                         minority_label: Optional[int]):
    # HuggingFace set_transform: batch dicts with list values
    def _apply(batch):
        imgs = batch["image"]; labels = batch["label"]
        if not isinstance(imgs, list): imgs = [imgs]
        if not isinstance(labels, list): labels = [labels]
        out_imgs = []
        for im, lbl in zip(imgs, labels):
            x = pre(im)                  # PIL
            x = aug_common_pil(x)        # PIL
            if minority_label is not None and int(lbl) == int(minority_label):
                x = extra_minority_pil(x)  # PIL (minority only)
            x = to_tensor(x)             # Tensor
            x = normalize(x)             # Tensor
            if minority_label is not None and int(lbl) == int(minority_label):
                x = extra_minority_tensor(x)  # Tensor-only (minority only)
            out_imgs.append(x)
        batch["image"] = out_imgs
        return batch
    ds.set_transform(_apply)
    return ds

def _set_transform_eval(ds: Dataset, pre, to_tensor, normalize):
    def _apply(batch):
        imgs = batch["image"]
        if not isinstance(imgs, list): imgs = [imgs]
        processed = []
        for im in imgs:
            x = pre(im)
            x = to_tensor(x)
            x = normalize(x)
            processed.append(x)
        batch["image"] = processed
        return batch
    ds.set_transform(_apply)
    return ds

# ------------- Collate -------------
def _collate_batch(batch):
    imgs, labels = [], []
    for b in batch:
        img = b["image"]; lbl = b["label"]
        if isinstance(img, list): img = img[0]
        if isinstance(lbl, list): lbl = lbl[0]
        imgs.append(img)
        labels.append(int(lbl))
    return torch.stack(imgs, dim=0), torch.tensor(labels, dtype=torch.long)

# ------------- Balancing helpers -------------
def _minority_label_from_counts(counts: Dict[int,int]) -> int:
    return min(counts, key=counts.get)

def _make_weighted_sampler(train_split: Dataset) -> WeightedRandomSampler:
    labels = list(train_split["label"])
    c = Counter(labels)
    w_per_class = {cls: 1.0 / cnt for cls, cnt in c.items()}
    weights = torch.as_tensor([w_per_class[int(lbl)] for lbl in labels], dtype=torch.double)
    return WeightedRandomSampler(weights=weights, num_samples=len(weights), replacement=True)

def _loss_weights_from_counts(counts: Dict[int,int]) -> Dict[str, torch.Tensor]:
    n0, n1 = counts.get(0, 0), counts.get(1, 0)  # 0=waste, 1=recycling
    N = n0 + n1
    ce_w = torch.tensor([N/(2*max(n0,1)), N/(2*max(n1,1))], dtype=torch.float)
    # NOTE: For BCEWithLogitsLoss, "positive" is class 1 (recycling) here.
    pos_weight = torch.tensor([max(n0,1)/max(n1,1)], dtype=torch.float)
    return {"ce": ce_w, "bce_pos": pos_weight}

# ------------- Public API -------------
def build_dataloaders(
    batch_size: int = 32,
    image_size: int = 224,
    val_pct: float = 0.10,
    test_pct: float = 0.10,
    seed: int = 42,
    num_workers: Optional[int] = None,
    augment: bool = True,
    minority_extra_aug: bool = True,
):
    cfg = Config(
        image_size=image_size, val_pct=val_pct, test_pct=test_pct, seed=seed,
        batch_size=batch_size, num_workers=(num_workers if num_workers is not None else 0),
        augment=augment, minority_extra_aug=minority_extra_aug,
    )
    cfg.persistent_workers = bool(cfg.num_workers and cfg.num_workers > 0)
    _set_seed(cfg.seed)

    # 1) Load & map to binary
    ds_full = load_dataset("garythung/trashnet")
    feats = _pick_base_split(ds_full).features
    ds_bin = _to_binary(ds_full, feats)

    # 2) Stratified splits on binary label
    splits = _stratified_splits(_pick_base_split(ds_bin), cfg.val_pct, cfg.test_pct, cfg.seed)

    # 3) Balancing tools from TRAIN counts
    train_counts = Counter(splits["train"]["label"])
    sampler = _make_weighted_sampler(splits["train"])
    balance = _loss_weights_from_counts(train_counts)
    minority_lbl = _minority_label_from_counts(train_counts) if cfg.minority_extra_aug else None

    # 4) Transforms (with fixed PIL→Tensor order and minority-only extras)
    pre, aug_common_pil, extra_minority_pil, to_tensor, normalize, extra_minority_tensor = \
        _build_transform_parts(cfg.image_size, cfg.augment)

    splits["train"] = _set_transform_train(
        splits["train"], pre, aug_common_pil, extra_minority_pil, to_tensor, normalize,
        extra_minority_tensor, minority_lbl
    )
    splits["val"]   = _set_transform_eval(splits["val"], pre, to_tensor, normalize)
    splits["test"]  = _set_transform_eval(splits["test"], pre, to_tensor, normalize)

    # 5) DataLoaders
    loaders = {
        "train": DataLoader(
            splits["train"], batch_size=cfg.batch_size, sampler=sampler,
            num_workers=cfg.num_workers, pin_memory=cfg.pin_memory,
            persistent_workers=cfg.persistent_workers, collate_fn=_collate_batch
        ),
        "val": DataLoader(
            splits["val"], batch_size=cfg.batch_size, shuffle=False,
            num_workers=cfg.num_workers, pin_memory=cfg.pin_memory,
            persistent_workers=cfg.persistent_workers, collate_fn=_collate_batch
        ),
        "test": DataLoader(
            splits["test"], batch_size=cfg.batch_size, shuffle=False,
            num_workers=cfg.num_workers, pin_memory=cfg.pin_memory,
            persistent_workers=cfg.persistent_workers, collate_fn=_collate_batch
        ),
    }
    class_names = ["waste", "recycling"]  # 0, 1
    return loaders, class_names, balance, train_counts

# ---------------- Smoke test (run this cell) ----------------
loaders, classes, balance, counts = build_dataloaders(batch_size=64, image_size=224, num_workers=0)
print("Binary classes:", classes)
print("Train label counts (0=waste, 1=recycling):", dict(counts))
print("CrossEntropy weights:", balance["ce"].tolist(), " | BCE pos_weight (pos=1):", float(balance["bce_pos"]))
xb, yb = next(iter(loaders["train"]))
print("Batch shapes:", xb.shape, yb.shape)


    #This will 
    # (a) serve roughly 50/50 batches via the sampler
    # (b) weight the loss toward the rarer waste class
    # (c) add extra aug only to the minority during training. 


Using Python: c:\Users\61459\anaconda3\python.exe
Torch: 2.8.0+cpu | CUDA available: False
Binary classes: ['waste', 'recycling']
Train label counts (0=waste, 1=recycling): {1: 3824, 0: 219}
CrossEntropy weights: [9.23059368133545, 0.5286349654197693]  | BCE pos_weight (pos=1): 0.0572698749601841
Batch shapes: torch.Size([64, 3, 224, 224]) torch.Size([64])
