In [None]:
# =========================
# IMAGE PREPROCESSING PDF REPORT (Strategy B) — Improved Charts + Simple Aug Modes + Expanded Steps
# =========================

import os, tempfile, datetime
from collections import Counter

import numpy as np
import torch

import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

from PIL import Image as PILImage

from reportlab.platypus import (
    SimpleDocTemplate, Paragraph, Spacer, Table, TableStyle, PageBreak,
    Image as RLImage
)
from reportlab.lib.pagesizes import A4
from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle
from reportlab.lib import colors
from reportlab.lib.units import cm


# -------------------------
# Styles (aligned with your sensor report look)
# -------------------------
styles = getSampleStyleSheet()

H1 = ParagraphStyle("H1", parent=styles["Heading1"], alignment=1, spaceAfter=10)
H2 = ParagraphStyle("H2", parent=styles["Heading2"], spaceBefore=6, spaceAfter=8)
body = ParagraphStyle("body", parent=styles["BodyText"], spaceAfter=6, leading=13)
small = ParagraphStyle("small", parent=styles["BodyText"], fontSize=9, leading=11, spaceAfter=6)
caption = ParagraphStyle("cap", parent=styles["BodyText"], fontSize=9, leading=11, alignment=1, spaceAfter=10)

cover_title = ParagraphStyle("cover_title", parent=styles["Title"], alignment=1, fontSize=24, spaceAfter=18, leading=30)
cover_subtitle = ParagraphStyle("cover_subtitle", parent=styles["Heading2"], alignment=1, fontSize=16, spaceAfter=12, leading=20)
cover_meta = ParagraphStyle("cover_meta", parent=styles["Normal"], alignment=1, fontSize=10, textColor=colors.gray, spaceAfter=24)
cover_desc = ParagraphStyle("cover_desc", parent=styles["Normal"], alignment=1, fontSize=12, leading=16, spaceAfter=0)


# -------------------------
# Plot helpers (publication-friendly)
# -------------------------
def _set_pub_rcparams():
    plt.rcParams.update({
        "font.size": 9,
        "axes.titlesize": 11,
        "axes.labelsize": 9,
        "xtick.labelsize": 8,
        "ytick.labelsize": 8,
        "legend.fontsize": 8,
        "figure.dpi": 150,
        "savefig.dpi": 300,
        "axes.linewidth": 0.8,
    })
_set_pub_rcparams()

def _style_axes(ax, grid_axis="y"):
    ax.set_axisbelow(True)
    if grid_axis in ("y", "both"):
        ax.grid(True, axis="y", linestyle="--", linewidth=0.6, alpha=0.35)
    if grid_axis in ("x", "both"):
        ax.grid(True, axis="x", linestyle="--", linewidth=0.6, alpha=0.25)
    for s in ("top", "right"):
        ax.spines[s].set_visible(False)
    for s in ("left", "bottom"):
        ax.spines[s].set_linewidth(0.8)
        ax.spines[s].set_alpha(0.7)
    ax.tick_params(axis="both", which="both", length=3, width=0.8)
    return ax

def _save_fig_to_png(fig) -> str:
    tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
    tmp.close()
    fig.savefig(tmp.name, dpi=300, bbox_inches="tight", pad_inches=0.06, facecolor="white")
    plt.close(fig)
    return tmp.name

def _fit_rl_image(img_path: str, max_w: float, max_h: float) -> RLImage:
    im = PILImage.open(img_path)
    w, h = im.size
    im.close()
    scale = min(max_w / float(w), max_h / float(h))
    return RLImage(img_path, width=w * scale, height=h * scale)

def _convert_logo_to_png(logo_path: str) -> str | None:
    try:
        img = PILImage.open(logo_path).convert("RGBA")
        tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
        tmp.close()
        img.save(tmp.name, format="PNG")
        return tmp.name
    except Exception:
        return None


# -------------------------
# Chart: improved class distribution (clean labels + stable ordering)
# -------------------------
def _chart_class_counts(class_counts: dict) -> str | None:
    if not class_counts:
        return None

    # Stable ordering:
    # - If keys look numeric ("0","1",...), sort numerically; else alphabetically
    keys = list(class_counts.keys())
    def _is_intlike(s):
        try:
            int(str(s))
            return True
        except Exception:
            return False

    if all(_is_intlike(k) for k in keys):
        items = sorted(class_counts.items(), key=lambda x: int(str(x[0])))
    else:
        items = sorted(class_counts.items(), key=lambda x: str(x[0]).lower())

    labels = [str(k) for k, _ in items]
    values = [int(v) for _, v in items]
    total = sum(values)
    if total <= 0:
        return None

    fig, ax = plt.subplots(figsize=(7.2, 3.8))
    x = np.arange(len(labels))
    bars = ax.bar(x, values, width=0.65)

    ax.set_title("Class distribution")
    ax.set_ylabel("Images")
    ax.set_xticks(x)

    # rotate only if labels are long
    max_len = max(len(l) for l in labels) if labels else 0
    rot = 0 if max_len <= 10 and len(labels) <= 12 else 30
    ax.set_xticklabels(labels, rotation=rot, ha="right" if rot else "center")

    _style_axes(ax, grid_axis="y")

    # Clean bar labels: count + percent
    try:
        ax.bar_label(
            bars,
            labels=[f"{v} ({(v/total)*100:.1f}%)" for v in values],
            padding=3,
            fontsize=8
        )
    except Exception:
        pass

    fig.tight_layout()
    return _save_fig_to_png(fig)


# -------------------------
# Resolution sample + smarter resolution chart
# - If constant resolution -> skip chart (and we'll print a sentence instead)
# - Else -> scatter of (W,H)
# -------------------------
def _sample_resolution_pairs_for_report(train_loader, max_images=256):
    """
    Extract approximate original sizes is hard after transforms.
    Instead, we visualize post-transform tensor shape, and report
    precomputed dataset stats (if available). Here we only add a small
    'post-transform' confirmation.
    """
    try:
        xb, _ = next(iter(train_loader))
        # xb shape: [B,C,H,W]
        h = int(xb.shape[2])
        w = int(xb.shape[3])
        return {"post_transform_shape": f"{w}×{h}"}
    except Exception:
        return {}


def _chart_resolution_scatter(res_stats: dict) -> str | None:
    """
    Expects res_stats to include either:
      - sample_pairs: list[(w,h)]  (best)
    Otherwise, will skip.
    """
    pairs = res_stats.get("sample_pairs", None)
    if not pairs:
        return None

    ws = np.array([p[0] for p in pairs], dtype=float)
    hs = np.array([p[1] for p in pairs], dtype=float)

    # If constant resolution, skip
    if ws.min() == ws.max() and hs.min() == hs.max():
        return None

    fig, ax = plt.subplots(figsize=(6.9, 4.0))
    ax.scatter(ws, hs, s=14, alpha=0.6)
    ax.set_title("Resolution scatter (sampled)")
    ax.set_xlabel("Width (px)")
    ax.set_ylabel("Height (px)")
    _style_axes(ax, grid_axis="both")

    # Put mean marker
    try:
        ax.scatter([ws.mean()], [hs.mean()], s=60, marker="x")
        ax.text(ws.mean(), hs.mean(), f"  mean≈{ws.mean():.0f}×{hs.mean():.0f}", va="center", fontsize=8)
    except Exception:
        pass

    fig.tight_layout()
    return _save_fig_to_png(fig)


# -------------------------
# Visual grid of augmented samples (train_loader already includes Strategy-B aug)
# -------------------------
def _make_aug_grid_png(train_loader, max_images=12) -> str | None:
    try:
        xb, yb = next(iter(train_loader))
    except Exception:
        return None

    # Unnormalize for display (ImageNet normalization used in your transforms)
    mean = torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1)
    std  = torch.tensor([0.229, 0.224, 0.225]).view(1,3,1,1)

    x = xb[:max_images].detach().cpu()
    if x.shape[1] == 3:
        x = x * std + mean
    x = x.clamp(0, 1)

    n = x.size(0)
    cols = 6
    rows = int(np.ceil(n / cols))

    fig = plt.figure(figsize=(cols*2.05, rows*2.05))
    for i in range(n):
        ax = plt.subplot(rows, cols, i+1)
        img = x[i].permute(1,2,0).numpy()
        ax.imshow(img)
        ax.set_title(str(int(yb[i])), fontsize=9)
        ax.axis("off")

    fig.suptitle("Augmented training samples (Strategy B)", fontsize=12)
    fig.tight_layout()
    return _save_fig_to_png(fig)


# -------------------------
# Augmentation explanation (simple, report-friendly)
# -------------------------
def _augmentation_explanation_paragraphs(strategy_b: dict) -> list[str]:
    """
    Returns short, simple explanation paragraphs for all modes + highlight chosen mode.
    """
    chosen_family = str(strategy_b.get("family", "")).lower()
    chosen_level = strategy_b.get("level", "")

    # Simple mode explanations (always included, like a mini glossary)
    mode_text = []

    mode_text.append(
        "<b>Augmentation modes (simple explanation):</b> "
        "Augmentations apply random, label-preserving changes to training images to reduce overfitting. "
        "They are <b>never</b> applied to validation/test to keep evaluation fair."
    )
    mode_text.append(
        "<b>• None:</b> No random changes. Useful when data is abundant or when you want maximum determinism, "
        "but it can overfit on small datasets."
    )
    mode_text.append(
        "<b>• Basic:</b> A light set of common transforms (typically horizontal flips and small random crops). "
        "Think “minor camera framing changes”."
    )
    mode_text.append(
        "<b>• TrivialAugmentWide:</b> Applies <b>one</b> randomly chosen transform (e.g., rotate, brightness, contrast) "
        "with random strength. Strong regularization without any tuning."
    )
    mode_text.append(
        "<b>• RandAugment:</b> Applies <b>N</b> random transforms sequentially, all with a shared strength <b>M</b>. "
        "This is stronger than Basic, but still cheap because it does not perform any search."
    )

    # Highlight chosen mode
    if chosen_family == "randaugment":
        mode_text.append(
            f"<b>Chosen in this run:</b> <b>RandAugment</b> (level=<b>{chosen_level}</b>, "
            f"N=<b>{strategy_b.get('randaugment_N')}</b>, M=<b>{strategy_b.get('randaugment_M')}</b>)."
        )
    elif chosen_family == "trivialaugment":
        mode_text.append(
            f"<b>Chosen in this run:</b> <b>TrivialAugmentWide</b> (level=<b>{chosen_level}</b>)."
        )
    elif chosen_family == "basic":
        mode_text.append(
            f"<b>Chosen in this run:</b> <b>Basic</b> (level=<b>{chosen_level}</b>)."
        )
    else:
        mode_text.append(
            f"<b>Chosen in this run:</b> <b>{strategy_b.get('family','')}</b> (level=<b>{chosen_level}</b>)."
        )

    return mode_text


# -------------------------
# Main: generate report
# -------------------------
def generate_image_preprocessing_report(
    report: dict,
    train_loader=None,
    path: str = "image_prep_report.pdf",
    project_name: str = "Automata AI - Image Preprocessing Report",
    logo_path: str | None = None
):
    doc = SimpleDocTemplate(
        path,
        pagesize=A4,
        rightMargin=2*cm,
        leftMargin=2*cm,
        topMargin=3.1*cm,
        bottomMargin=2.0*cm
    )

    story = []
    tmp_files = []

    # Logo handling
    logo_png = None
    if logo_path and os.path.exists(logo_path):
        logo_png = _convert_logo_to_png(logo_path)
        if logo_png:
            tmp_files.append(logo_png)

    def draw_header_footer(canvas, doc_):
        canvas.saveState()
        header_top = doc_.pagesize[1] - 1.0*cm
        header_bottom = doc_.pagesize[1] - doc_.topMargin + 0.25*cm

        if logo_png:
            lw, lh = (1.25*cm, 1.25*cm)
            x = doc_.leftMargin
            y = header_top - lh
            canvas.drawImage(logo_png, x, y, width=lw, height=lh, preserveAspectRatio=True, mask="auto")

        canvas.setFont("Helvetica-Bold", 10)
        canvas.drawRightString(
            doc_.pagesize[0] - doc_.rightMargin,
            doc_.pagesize[1] - 1.35*cm,
            project_name
        )

        canvas.setLineWidth(0.4)
        canvas.setStrokeColor(colors.grey)
        canvas.line(doc_.leftMargin, header_bottom, doc_.pagesize[0] - doc_.rightMargin, header_bottom)

        canvas.setFont("Helvetica", 8)
        canvas.setFillColor(colors.black)
        canvas.drawString(doc_.leftMargin, 1.15*cm, f"Page {doc_.page}")
        canvas.drawRightString(
            doc_.pagesize[0] - doc_.rightMargin,
            1.15*cm,
            f"© {datetime.datetime.now().year} Automata AI — All rights reserved"
        )
        canvas.restoreState()

    # -------------------------
    # Cover page
    # -------------------------
    story.append(Spacer(1, 2.6 * cm))

    if logo_png:
        cover_logo = _fit_rl_image(logo_png, max_w=6.5*cm, max_h=6.5*cm)
        cover_logo.hAlign = "CENTER"
        story.append(cover_logo)

    story.append(Spacer(1, 1.2 * cm))
    story.append(Paragraph(project_name, cover_title))
    story.append(Paragraph("Automated Preprocessing Report (Image Modality)", cover_subtitle))
    story.append(Paragraph(f"Generated on {report.get('timestamp','')}", cover_meta))
    story.append(Spacer(1, 1.2 * cm))
    story.append(Paragraph(
        "This report summarizes dataset characteristics, preprocessing decisions, and the resulting training-ready loaders for image classification.",
        cover_desc
    ))
    story.append(PageBreak())

    # -------------------------
    # 1) Dataset Overview
    # -------------------------
    story.append(Paragraph("1. Dataset Overview", H2))

    class_counts = report.get("class_counts", {}) or {}
    counts_only = list(class_counts.values()) if class_counts else []
    imbalance_ratio = report.get("imbalance_ratio", None)
    if imbalance_ratio is None and counts_only:
        imbalance_ratio = (max(counts_only) / max(1, min(counts_only)))

    res = report.get("resolution_stats", {}) or {}
    sb = report.get("strategy_b", {}) or {}
    splits = report.get("splits", {}) or {}
    loader_cfg = report.get("loader", {}) or {}

    # For readability
    dataset_mode = report.get("dataset_mode", "")
    data_root = report.get("data_root", "")
    n_images = report.get("num_images", "")
    n_classes = report.get("num_classes", "")

    overview_rows = [
        ["Dataset mode", str(dataset_mode)],
        ["Data root", str(data_root)],
        ["# Images", str(n_images)],
        ["# Classes", str(n_classes)],
        ["Imbalance ratio", f"{float(imbalance_ratio):.3f}" if imbalance_ratio is not None else ""],
        ["Chosen img_size", str(report.get("img_size", ""))],
        ["Normalization", str(report.get("normalization", ""))],
        ["Batch size", str(loader_cfg.get("batch_size", ""))],
        ["Num workers", str(loader_cfg.get("num_workers", ""))],
        ["Seed", str(splits.get("seed", report.get("seed", "")))],
    ]

    # Add resolution stats if available
    if "width_mean" in res:
        overview_rows += [
            ["Mean resolution (sampled)", f"{res.get('width_mean',0):.1f} × {res.get('height_mean',0):.1f}"],
            ["Median resolution (sampled)", f"{res.get('width_median',0):.1f} × {res.get('height_median',0):.1f}"],
            ["Resolution min/max (sampled)", f"{res.get('width_min','?')}–{res.get('width_max','?')} × {res.get('height_min','?')}–{res.get('height_max','?')}"],
            ["Unreadable in scan", f"{res.get('resolution_bad_count',0)} / {res.get('resolution_scan_limit',0)}"],
        ]

    # Post-transform confirmation (optional)
    if train_loader is not None:
        post_shape = _sample_resolution_pairs_for_report(train_loader)
        if post_shape.get("post_transform_shape"):
            overview_rows.append(["Post-transform tensor shape", post_shape["post_transform_shape"]])

    t = Table(overview_rows, colWidths=[7.5*cm, 8.5*cm])
    t.setStyle(TableStyle([
        ("GRID", (0,0), (-1,-1), 0.3, colors.grey),
        ("BACKGROUND", (0,0), (-1,0), colors.whitesmoke),
        ("FONTNAME", (0,0), (-1,0), "Helvetica-Bold"),
        ("VALIGN", (0,0), (-1,-1), "TOP"),
        ("ROWBACKGROUNDS", (0,1), (-1,-1), [colors.white, colors.Color(0.97,0.97,0.97)]),
        ("LEFTPADDING", (0,0), (-1,-1), 6),
        ("RIGHTPADDING", (0,0), (-1,-1), 6),
        ("TOPPADDING", (0,0), (-1,-1), 4),
        ("BOTTOMPADDING", (0,0), (-1,-1), 4),
    ]))
    story.append(t)
    story.append(Spacer(1, 0.4*cm))

    # Charts (better)
    chart_paths = []

    p1 = _chart_class_counts(class_counts)
    if p1:
        chart_paths.append(("Target class distribution", p1))

    # Optional resolution scatter if sample_pairs exists
    p2 = _chart_resolution_scatter(res)
    if p2:
        chart_paths.append(("Resolution scatter (sampled)", p2))

    for title, p in chart_paths:
        tmp_files.append(p)
        story.append(_fit_rl_image(p, max_w=doc.width, max_h=8.2*cm))
        story.append(Paragraph(title, caption))

    # If resolution chart skipped because constant, add a smart note
    if "width_min" in res and "width_max" in res and "height_min" in res and "height_max" in res:
        if res["width_min"] == res["width_max"] and res["height_min"] == res["height_max"]:
            story.append(Paragraph(
                f"All sampled images share the same resolution: <b>{res['width_min']}×{res['height_min']}</b>. "
                "A scatter plot is omitted because it would be uninformative.",
                small
            ))

    story.append(PageBreak())

    # -------------------------
    # Configuration Snapshot (like sensor report)
    # -------------------------
    story.append(Paragraph("Configuration Snapshot", H2))

    cfg_rows = [
        ["STRONG_AUG_THRESHOLD", str(sb.get("strong_aug_threshold", ""))],
        ["MODERATE_AUG_THRESHOLD", str(sb.get("moderate_aug_threshold", ""))],
        ["RandAugment N", str(sb.get("randaugment_N", ""))],
        ["RandAugment M", str(sb.get("randaugment_M", ""))],
        ["Augmentation family", str(sb.get("family", ""))],
        ["Augmentation level", str(sb.get("level", ""))],
        ["Validation split", str(splits.get("val_split", ""))],
        ["Train / Val / Test samples", f"{splits.get('train_samples','')} / {splits.get('val_samples','')} / {splits.get('test_samples','')}"],
        ["Batch size", str(loader_cfg.get("batch_size", ""))],
        ["Num workers", str(loader_cfg.get("num_workers", ""))],
        ["Pin memory", str(loader_cfg.get("pin_memory", ""))],
    ]

    tc = Table(cfg_rows, colWidths=[7.5*cm, 8.5*cm])
    tc.setStyle(TableStyle([
        ("GRID", (0,0), (-1,-1), 0.25, colors.grey),
        ("BACKGROUND", (0,0), (-1,0), colors.whitesmoke),
        ("FONTNAME", (0,0), (-1,0), "Helvetica-Bold"),
        ("ROWBACKGROUNDS", (0,1), (-1,-1), [colors.white, colors.Color(0.97,0.97,0.97)]),
        ("LEFTPADDING", (0,0), (-1,-1), 6),
        ("RIGHTPADDING", (0,0), (-1,-1), 6),
        ("TOPPADDING", (0,0), (-1,-1), 4),
        ("BOTTOMPADDING", (0,0), (-1,-1), 4),
    ]))
    story.append(tc)
    story.append(PageBreak())

    # -------------------------
    # 2) Preprocessing Steps Applied (expanded)
    # -------------------------
    story.append(Paragraph("2. Preprocessing Steps Applied", H2))

    story.append(Paragraph(
        "<b>• Input decoding & integrity:</b> Images are loaded and decoded into a consistent in-memory representation. "
        "For file-based datasets, unreadable images can silently break training; therefore, the pipeline optionally performs "
        "a sampled scan to detect decode failures early.",
        body
    ))

    story.append(Paragraph(
        f"<b>• Standardization (shape):</b> All images are resized to <b>{report.get('img_size','')}</b> × "
        f"<b>{report.get('img_size','')}</b>. This produces fixed tensor shapes, stable batching, and predictable compute cost "
        "across datasets—important when comparing architectures fairly in NAS.",
        body
    ))

    story.append(Paragraph(
        f"<b>• Tensor conversion & normalization:</b> Images are converted to floating-point tensors and normalized using "
        f"<b>{report.get('normalization','')}</b> statistics. Normalization stabilizes optimization and makes training behavior "
        "more consistent across datasets and architectures (especially when using pretrained backbones).",
        body
    ))

    # Strategy B explanation (simple + chosen mode highlight)
    for p in _augmentation_explanation_paragraphs(sb):
        story.append(Paragraph(p, body if p.startswith("<b>•") else small))

    # Split explanation
    story.append(Paragraph(
        f"<b>• Splitting & evaluation fairness:</b> The dataset is split into train/validation/test "
        f"(<b>{splits.get('train_samples','')}</b> / <b>{splits.get('val_samples','')}</b> / <b>{splits.get('test_samples','')}</b>). "
        "Augmentations are applied to training batches only. Validation and test pipelines remain deterministic to ensure "
        "fair comparison between candidate architectures.",
        body
    ))

    # Loader settings explanation
    story.append(Paragraph(
        f"<b>• DataLoader settings:</b> batch_size=<b>{loader_cfg.get('batch_size','')}</b>, "
        f"num_workers=<b>{loader_cfg.get('num_workers','')}</b>, pin_memory=<b>{loader_cfg.get('pin_memory','')}</b>. "
        "These settings control input throughput and help keep the GPU utilized during training.",
        body
    ))

    story.append(PageBreak())

    # -------------------------
    # 3) Visual Sanity Check
    # -------------------------
    story.append(Paragraph("3. Visual Sanity Check", H2))
    story.append(Paragraph(
        "The grid below shows a sample of training images after the selected Strategy B augmentation policy. "
        "This is a quick sanity check that augmentations are label-preserving and not overly destructive.",
        body
    ))

    if train_loader is not None:
        grid = _make_aug_grid_png(train_loader, max_images=12)
        if grid:
            tmp_files.append(grid)
            story.append(_fit_rl_image(grid, max_w=doc.width, max_h=16*cm))
            story.append(Paragraph("Augmented samples (train loader)", caption))
        else:
            story.append(Paragraph("Could not render augmented sample grid (train_loader unavailable or empty).", small))
    else:
        story.append(Paragraph("No train_loader provided; skipping sample grid.", small))

    story.append(Paragraph("End of report.", ParagraphStyle("end", fontSize=9, alignment=1)))

    # -------------------------
    # Build + cleanup
    # -------------------------
    doc.build(story, onFirstPage=draw_header_footer, onLaterPages=draw_header_footer)

    for f in tmp_files:
        try:
            os.remove(f)
        except Exception:
            pass

    print(f"[INFO] Image preprocessing report saved to {path}")


In [None]:
import os
import random, numpy as np
import torch
from torch.utils.data import DataLoader, random_split
import torchvision
import torchvision.transforms as T
from PIL import Image  # needed for ImageFolder resolution scan

# ----------------------------
# Reproducibility
# ----------------------------
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

# ----------------------------
# Dataset / Loader config
# ----------------------------
DATA_ROOT = "./data"     # CIFAR-10 will download here by default.
BATCH_SIZE = 128
NUM_WORKERS = 2
VAL_SPLIT = 0.1          # (train -> train/val) split
PIN_MEMORY = True

# If you want to use your own dataset (ImageFolder), set:
# DATASET_MODE = "imagefolder"
# and set DATA_ROOT to folder:
# DATA_ROOT/
#   class_a/xxx.jpg
#   class_b/yyy.jpg
DATASET_MODE = "cifar10"   # "cifar10" or "imagefolder"

# ----------------------------
# Strategy B thresholds & params
# ----------------------------
STRONG_AUG_THRESHOLD = 5_000
MODERATE_AUG_THRESHOLD = 50_000

# RandAugment knobs (torchvision scale)
RAND_N = 2
RAND_M_STRONG = 18
RAND_M_MODERATE = 10
RAND_M_LIGHT = 6

# Default (will be overwritten correctly once dataset is loaded)
NUM_CLASSES = 10

# Will be populated after calling get_dataloaders(...)
LAST_IMAGE_PREP_REPORT = None


def _decide_strategy_b(num_images: int):
    """
    Strategy B decision: returns (level, family, randN, randM)
    family ∈ {"trivialaugment", "randaugment", "basic", "none"}
    """
    if num_images < STRONG_AUG_THRESHOLD:
        level = "strong"
        family = "trivialaugment"     # best default for small datasets
        randM = RAND_M_STRONG
    elif num_images <= MODERATE_AUG_THRESHOLD:
        level = "moderate"
        family = "randaugment"
        randM = RAND_M_MODERATE
    else:
        level = "light"
        family = "basic"              # could set to "none" if you prefer
        randM = RAND_M_LIGHT

    return level, family, RAND_N, randM


def make_transforms(img_size: int, family: str, rand_n: int, rand_m: int):
    """
    Build train/eval transforms.
    IMPORTANT:
      - Train gets Strategy B augmentation.
      - Val/Test are deterministic (no random aug).
    """
    # ImageNet normalization (works well for pretrained backbones used later)
    normalize = T.Normalize(mean=[0.485, 0.456, 0.406],
                            std=[0.229, 0.224, 0.225])

    train_ops = [T.Resize((img_size, img_size))]

    if family == "trivialaugment":
        train_ops.append(T.TrivialAugmentWide())
    elif family == "randaugment":
        train_ops.append(T.RandAugment(num_ops=rand_n, magnitude=rand_m))
    elif family == "basic":
        train_ops.extend([
            T.RandomHorizontalFlip(p=0.5),
            T.RandomResizedCrop(img_size, scale=(0.85, 1.0)),
        ])
    elif family == "none":
        pass
    else:
        raise ValueError(f"Unknown Strategy B family: {family}")

    train_ops.extend([T.ToTensor(), normalize])

    eval_tf = T.Compose([
        T.Resize((img_size, img_size)),
        T.ToTensor(),
        normalize,
    ])

    return T.Compose(train_ops), eval_tf


def _class_counts_from_dataset(ds, num_classes: int):
    """
    Returns class counts as dict[str,int], using class names if available.
    """
    if hasattr(ds, "targets"):  # CIFAR10
        y = np.array(ds.targets, dtype=int)
        counts = np.bincount(y, minlength=num_classes)
        return {str(i): int(counts[i]) for i in range(len(counts))}
    if hasattr(ds, "samples"):  # ImageFolder
        y = np.array([lbl for _, lbl in ds.samples], dtype=int)
        counts = np.bincount(y, minlength=num_classes)
        if hasattr(ds, "classes"):
            return {str(ds.classes[i]): int(counts[i]) for i in range(len(counts))}
        return {str(i): int(counts[i]) for i in range(len(counts))}
    return {}


def _sample_resolution_stats(ds, limit: int = 600):
    """
    Quick resolution stats using a sample (for CIFAR10 reads from ds.data, for ImageFolder opens files).
    """
    rng = np.random.default_rng(SEED)
    n = len(ds)
    k = min(limit, n)
    idxs = rng.choice(n, size=k, replace=False)

    widths, heights, bad = [], [], 0

    if isinstance(ds, torchvision.datasets.CIFAR10):
        # ds.data: uint8 [N,32,32,3]
        for i in idxs:
            img = ds.data[int(i)]
            h, w = int(img.shape[0]), int(img.shape[1])
            widths.append(w)
            heights.append(h)
    else:
        # ImageFolder: open paths
        for i in idxs:
            path, _ = ds.samples[int(i)]
            try:
                with Image.open(path) as im:
                    w, h = im.size
                widths.append(int(w))
                heights.append(int(h))
            except Exception:
                bad += 1

    out = {
        "resolution_scan_limit": int(k),
        "resolution_bad_count": int(bad),
    }
    if widths:
        out.update({
            "width_mean": float(np.mean(widths)),
            "height_mean": float(np.mean(heights)),
            "width_median": float(np.median(widths)),
            "height_median": float(np.median(heights)),
            "width_min": int(np.min(widths)),
            "height_min": int(np.min(heights)),
            "width_max": int(np.max(widths)),
            "height_max": int(np.max(heights)),
        })
    return out


def get_dataloaders(img_size: int):
    """
    Returns: train_loader, val_loader, test_loader
    Also populates LAST_IMAGE_PREP_REPORT for PDF generation.
    """
    global NUM_CLASSES, LAST_IMAGE_PREP_REPORT

    # ----------------------------
    # Load dataset (NO transforms yet)
    # ----------------------------
    if DATASET_MODE.lower() == "cifar10":
        full_train = torchvision.datasets.CIFAR10(
            root=DATA_ROOT, train=True, download=True, transform=None
        )
        test_set = torchvision.datasets.CIFAR10(
            root=DATA_ROOT, train=False, download=True, transform=None
        )
        NUM_CLASSES = 10

    elif DATASET_MODE.lower() == "imagefolder":
        if not os.path.isdir(DATA_ROOT):
            raise FileNotFoundError(f"DATA_ROOT not found: {DATA_ROOT}")

        full_train = torchvision.datasets.ImageFolder(DATA_ROOT, transform=None)
        test_set = None  # will be created below
        NUM_CLASSES = len(full_train.classes)

    else:
        raise ValueError(f"Unknown DATASET_MODE: {DATASET_MODE} (use 'cifar10' or 'imagefolder')")

    # ----------------------------
    # Compute dataset stats for report
    # ----------------------------
    n_images = int(len(full_train))
    class_counts = _class_counts_from_dataset(full_train, NUM_CLASSES)
    counts_only = list(class_counts.values()) if class_counts else []
    imbalance_ratio = (max(counts_only) / max(1, min(counts_only))) if counts_only else 1.0
    res_stats = _sample_resolution_stats(full_train, limit=600)

    # ----------------------------
    # Strategy B decision based on dataset size
    # ----------------------------
    level, family, rand_n, rand_m = _decide_strategy_b(n_images)

    # Build transforms (train has Strategy B aug; eval is deterministic)
    train_tf, eval_tf = make_transforms(img_size, family, rand_n, rand_m)

    # Assign transforms AFTER decision
    full_train.transform = train_tf

    # ----------------------------
    # Splits
    # ----------------------------
    if DATASET_MODE.lower() == "cifar10":
        # CIFAR-10 already has a separate test split
        test_set.transform = eval_tf

        val_len = int(len(full_train) * VAL_SPLIT)
        train_len = len(full_train) - val_len
        train_set, val_set = random_split(
            full_train, [train_len, val_len],
            generator=torch.Generator().manual_seed(SEED)
        )

        # Val must be deterministic: create a second CIFAR10 dataset with eval transform
        val_base = torchvision.datasets.CIFAR10(
            root=DATA_ROOT, train=True, download=True, transform=eval_tf
        )
        val_set = torch.utils.data.Subset(val_base, val_set.indices)

    else:
        # ImageFolder: create train/val/test splits from the single dataset.
        # test split is fixed to 10% here; change if you want.
        TEST_SPLIT = 0.1
        test_len = max(1, int(TEST_SPLIT * len(full_train)))
        remain_len = len(full_train) - test_len

        remain_set, test_set = random_split(
            full_train, [remain_len, test_len],
            generator=torch.Generator().manual_seed(SEED)
        )

        val_len = max(1, int(remain_len * VAL_SPLIT))
        train_len = remain_len - val_len

        train_set, val_set = random_split(
            remain_set, [train_len, val_len],
            generator=torch.Generator().manual_seed(SEED)
        )

        # Make deterministic eval datasets for val/test with the same indices
        eval_base = torchvision.datasets.ImageFolder(DATA_ROOT, transform=eval_tf)
        val_set = torch.utils.data.Subset(eval_base, val_set.indices)
        test_set = torch.utils.data.Subset(eval_base, test_set.indices)

    # ----------------------------
    # DataLoaders
    # ----------------------------
    train_loader = DataLoader(
        train_set, batch_size=BATCH_SIZE, shuffle=True,
        num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY
    )
    val_loader = DataLoader(
        val_set, batch_size=BATCH_SIZE, shuffle=False,
        num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY
    )
    test_loader = DataLoader(
        test_set, batch_size=BATCH_SIZE, shuffle=False,
        num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY
    )

    # ----------------------------
    # Store report snapshot for PDF generation (sensor-like)
    # ----------------------------
    LAST_IMAGE_PREP_REPORT = {
        "timestamp": __import__("datetime").datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
        "dataset_mode": DATASET_MODE,
        "data_root": DATA_ROOT,
        "num_images": int(n_images),
        "num_classes": int(NUM_CLASSES),
        "class_counts": class_counts,
        "imbalance_ratio": float(imbalance_ratio),
        "img_size": int(img_size),
        "normalization": "imagenet",
        "strategy_b": {
            "level": level,
            "family": family,
            "randaugment_N": int(rand_n),
            "randaugment_M": int(rand_m),
            "strong_aug_threshold": int(STRONG_AUG_THRESHOLD),
            "moderate_aug_threshold": int(MODERATE_AUG_THRESHOLD),
        },
        "resolution_stats": res_stats,
        "splits": {
            "val_split": float(VAL_SPLIT),
            "train_samples": int(len(train_set)),
            "val_samples": int(len(val_set)),
            "test_samples": int(len(test_set)),
            "train_batches": int(len(train_loader)),
            "val_batches": int(len(val_loader)),
            "test_batches": int(len(test_loader)),
            "seed": int(SEED),
        },
        "loader": {
            "batch_size": int(BATCH_SIZE),
            "num_workers": int(NUM_WORKERS),
            "pin_memory": bool(PIN_MEMORY),
        }
    }

    # Quick, helpful log (won't break anything)
    print(f"[Strategy B] n={n_images} → level={level}, family={family}"
          + (f", RandAug(N={rand_n}, M={rand_m})" if family == "randaugment" else "")
          + f" | img_size={img_size} | NUM_CLASSES={NUM_CLASSES}")

    return train_loader, val_loader, test_loader


In [None]:
train_loader, val_loader, test_loader = get_dataloaders(64)

generate_image_preprocessing_report(
    LAST_IMAGE_PREP_REPORT,
    train_loader=train_loader,                  # enables augmented image grid
    path="image_prep_report.pdf",
    project_name="Automata AI - Preprocessing Report",
    logo_path=None  # put your logo path here if you want
)


In [None]:
import time
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# Sweep configs
SWEEP = [
    {"name": "mnetv3_small_unfreeze0_res160",         "backbone": "mnetv3_small",         "unfreeze_blocks": 0, "img_size": 160, "lr": 3e-3},
    {"name": "mnetv3_small_unfreeze1_res160",         "backbone": "mnetv3_small",         "unfreeze_blocks": 1, "img_size": 160, "lr": 2e-3},
    {"name": "shufflenetv2_x0_5_unfreeze1_res160",    "backbone": "shufflenetv2_x0_5",    "unfreeze_blocks": 1, "img_size": 160, "lr": 2e-3},
    {"name": "squeezenet1_1_unfreeze1_res160",        "backbone": "squeezenet1_1",        "unfreeze_blocks": 1, "img_size": 160, "lr": 2e-3},
]

EPOCHS_PER_TRIAL = 3
FINAL_EPOCHS = 5

# Model builders
def build_backbone(backbone: str, num_classes: int):
    if backbone == "mnetv3_small":
        weights = torchvision.models.MobileNet_V3_Small_Weights.DEFAULT
        model = torchvision.models.mobilenet_v3_small(weights=weights)
        in_features = model.classifier[-1].in_features
        model.classifier[-1] = nn.Linear(in_features, num_classes)
        return model

    elif backbone == "shufflenetv2_x0_5":
        weights = torchvision.models.ShuffleNet_V2_X0_5_Weights.DEFAULT
        model = torchvision.models.shufflenet_v2_x0_5(weights=weights)
        in_features = model.fc.in_features
        model.fc = nn.Linear(in_features, num_classes)
        return model

    elif backbone == "squeezenet1_1":
        weights = torchvision.models.SqueezeNet1_1_Weights.DEFAULT
        model = torchvision.models.squeezenet1_1(weights=weights)
        model.classifier[1] = nn.Conv2d(512, num_classes, kernel_size=1)
        return model

    else:
        raise ValueError(f"Unknown backbone: {backbone}")

# Freeze / unfreeze
def unfreeze_module(m: nn.Module):
    for p in m.parameters():
        p.requires_grad = True

def get_head_module(model: nn.Module) -> nn.Module:
    if hasattr(model, "classifier"):
        return model.classifier
    if hasattr(model, "fc"):
        return model.fc
    raise ValueError("Could not find classifier head (expected .classifier or .fc)")

def get_block_list(backbone: str, model: nn.Module):
    if backbone == "mnetv3_small":
        return list(model.features)
    if backbone == "shufflenetv2_x0_5":
        return [model.conv1, model.maxpool, model.stage2, model.stage3, model.stage4, model.conv5]
    if backbone == "squeezenet1_1":
        return list(model.features)
    raise ValueError(f"Unsupported backbone for block unfreezing: {backbone}")

def freeze_all_but_head(model: nn.Module):
    for p in model.parameters():
        p.requires_grad = False
    head = get_head_module(model)
    unfreeze_module(head)

def unfreeze_last_n_blocks(model: nn.Module, backbone: str, n_blocks: int):
    freeze_all_but_head(model)
    if n_blocks <= 0:
        return
    blocks = get_block_list(backbone, model)
    start = max(0, len(blocks) - n_blocks)
    for i in range(start, len(blocks)):
        unfreeze_module(blocks[i])

def count_trainable_params(model: nn.Module) -> int:
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# Train / eval
@torch.no_grad()
def evaluate(model, loader, criterion):
    model.eval()
    total_loss, correct, total = 0.0, 0, 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        loss = criterion(logits, y)
        total_loss += loss.item() * x.size(0)
        pred = logits.argmax(dim=1)
        correct += (pred == y).sum().item()
        total += x.size(0)
    return total_loss / total, correct / total

def train_one_epoch(model, loader, optimizer, criterion):
    model.train()
    total_loss, correct, total = 0.0, 0, 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)

        optimizer.zero_grad(set_to_none=True)
        logits = model(x)
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * x.size(0)
        pred = logits.argmax(dim=1)
        correct += (pred == y).sum().item()
        total += x.size(0)

    return total_loss / total, correct / total

def make_optimizer(model, lr: float):
    params = [p for p in model.parameters() if p.requires_grad]
    return optim.AdamW(params, lr=lr, weight_decay=1e-4)

# Sweep + final fine-tune
criterion = nn.CrossEntropyLoss()

results = []
best_cfg = None

for trial in SWEEP:
    print("\n" + "=" * 80)
    print("Trial:", trial["name"])
    print(trial)
    t0 = time.time()

    train_loader, val_loader, test_loader = get_dataloaders(trial["img_size"])

    model = build_backbone(trial["backbone"], num_classes=NUM_CLASSES).to(device)
    unfreeze_last_n_blocks(model, backbone=trial["backbone"], n_blocks=trial["unfreeze_blocks"])

    trainable = count_trainable_params(model)
    total_params = sum(p.numel() for p in model.parameters())
    print(f"Trainable params: {trainable:,} / {total_params:,} ({100*trainable/total_params:.2f}%)")

    optimizer = make_optimizer(model, lr=trial["lr"])

    best_val = 0.0
    for epoch in range(1, EPOCHS_PER_TRIAL + 1):
        tr_loss, tr_acc = train_one_epoch(model, train_loader, optimizer, criterion)
        va_loss, va_acc = evaluate(model, val_loader, criterion)
        best_val = max(best_val, va_acc)
        print(f"Epoch {epoch:02d}/{EPOCHS_PER_TRIAL} | "
              f"train loss {tr_loss:.4f} acc {tr_acc:.4f} | "
              f"val loss {va_loss:.4f} acc {va_acc:.4f}")

    te_loss, te_acc = evaluate(model, test_loader, criterion)

    elapsed = time.time() - t0
    out = {
        "name": trial["name"],
        "backbone": trial["backbone"],
        "unfreeze_blocks": trial["unfreeze_blocks"],
        "img_size": trial["img_size"],
        "lr": trial["lr"],
        "best_val_acc": float(best_val),
        "test_acc": float(te_acc),
        "seconds": float(elapsed),
    }
    results.append(out)

    print(f"Trial done in {elapsed/60:.2f} min | best val acc {best_val:.4f} | test acc {te_acc:.4f}")

    if (best_cfg is None) or (out["best_val_acc"] > best_cfg["best_val_acc"]):
        best_cfg = out

print("\n" + "#" * 80)
print("Sweep results (sorted by best_val_acc):")
results_sorted = sorted(results, key=lambda x: x["best_val_acc"], reverse=True)
for r in results_sorted:
    print(f"{r['name']:40s} | val {r['best_val_acc']:.4f} | test {r['test_acc']:.4f} | {r['seconds']:.0f}s")

print("\nBest config:", best_cfg)

# Final fine-tune on best config
print("\n" + "=" * 80)
print("Final fine-tune best config for a few more epochs...")

train_loader, val_loader, test_loader = get_dataloaders(best_cfg["img_size"])
best_model = build_backbone(best_cfg["backbone"], num_classes=NUM_CLASSES).to(device)
unfreeze_last_n_blocks(best_model, backbone=best_cfg["backbone"], n_blocks=best_cfg["unfreeze_blocks"])

optimizer = make_optimizer(best_model, lr=best_cfg["lr"])
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=FINAL_EPOCHS)

best_val = 0.0
best_state = None

for epoch in range(1, FINAL_EPOCHS + 1):
    tr_loss, tr_acc = train_one_epoch(best_model, train_loader, optimizer, criterion)
    va_loss, va_acc = evaluate(best_model, val_loader, criterion)
    scheduler.step()

    if va_acc > best_val:
        best_val = va_acc
        best_state = {k: v.detach().cpu().clone() for k, v in best_model.state_dict().items()}

    print(f"[Final] Epoch {epoch:02d}/{FINAL_EPOCHS} | "
          f"train loss {tr_loss:.4f} acc {tr_acc:.4f} | "
          f"val loss {va_loss:.4f} acc {va_acc:.4f}")

if best_state is not None:
    best_model.load_state_dict(best_state)

final_test_loss, final_test_acc = evaluate(best_model, test_loader, criterion)

print("\n" + "#" * 80)
print(f"FINAL Best Val Acc: {best_val:.4f}")
print(f"FINAL Test Acc:     {final_test_acc:.4f}")
print("#" * 80)

In [None]:
import os
import torch

save_path = "best_model.pth"
torch.save(best_model.state_dict(), save_path)

size_mb = os.path.getsize(save_path) / (1024 * 1024)

print("Saved:", save_path)
print(f"Best model test accuracy: {final_test_acc:.4f}")
print(f"Saved .pth size: {size_mb:.2f} MB")