# Test model: food classes (splits_new_v2) and non-food (food_not_food_ds)

This notebook loads the fine-tuned EfficientNet-B3 checkpoint and runs inference on:

- All food classes from `d:\\VSC FILES\\testtrain\\splits_new_v2` (test folders)
- Non-food test images from `d:\\VSC FILES\\testtrain\\food_not_food_ds\\not_food\\test`

Edit variables in the first code cell (ckpt_path, dataset roots, batch_size, threshold) before running.

In [1]:
# Section 1 — Environment & Dependencies

# (Optional) create a minimal requirements file. Edit as needed.
requirements = '''torch torchvision pillow torch-directml numpy'''
with open('requirements.txt', 'w') as f:
    f.write(requirements)

# Example: install packages (uncomment to run). On Windows with bash shell you can run these cells.
# %%bash
# pip install -r requirements.txt

print('requirements.txt written. Review and install packages if needed.')

requirements.txt written. Review and install packages if needed.


In [2]:
# Section 2 — Project structure (show top-level files/folders)
import os
root = r'd:\VSC FILES\testtrain'
print('Workspace root:', root)
for entry in sorted(os.listdir(root))[:80]:
    print(entry)

print('\nNotebook file: d:\\VSC FILES\\testtrain\\test_inference_nonfood_and_food.ipynb')

Workspace root: d:\VSC FILES\testtrain
HF_ready.ipynb
bad_images_by_class.json
class_names.json
export_to_onxx.ipynb
fine_tune.ipynb
fine_tune2.ipynb
fine_tune2_.ipynb
fine_tune2__.ipynb
food-101-split
food_not_food_ds
hf_package
model.onnx
new_model.onnx
new_splits
new_train.ipynb
nutrisight_model_training
onxx
openimages_downloads
requirements.txt
runs
runs copy
splits_new
splits_new_v2
test.ipynb
test_inference_nonfood_and_food.ipynb
test_model.ipynb
test_model_finetune.ipynb
test_model_finetune2.ipynb
thesis-model-resnet50-20250824-230424
thesis_datasets
thesis_datasets_cleaned
thesis_datasets_cleaned_split
thesis_datasets_split
train.ipynb
train_efficientnet_b3_optimized_food_not_food.ipynb

Notebook file: d:\VSC FILES\testtrain\test_inference_nonfood_and_food.ipynb


In [3]:
# Section 3 — Import libraries
import os
import json
import torch
import torch.nn as nn
from torchvision import transforms, models
from PIL import Image
import numpy as np
from torch.utils.data import Dataset, DataLoader

# Optional import for DirectML
try:
    import torch_directml
    has_directml = True
except Exception:
    has_directml = False

print('torch:', torch.__version__)
print('torch_directml available:', has_directml)

torch: 2.4.1+cpu
torch_directml available: True


In [4]:
# Section 4 — Device selection
try:
    if has_directml:
        dml = torch_directml.device()
        device = dml
    else:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
except Exception as e:
    print('Device selection error, falling back to CPU:', e)
    device = torch.device('cpu')

print('Using device:', device)

Using device: privateuseone:0


In [5]:
# Section 5 — Paths and checkpoint
# Edit these paths if needed
CKPT_PATH = r'D:\VSC FILES\testtrain\runs\efficientnet_b3_optimized-20251028-231113\best_efficientnet_b3.pth'
SPLITS_DIR = r'D:\VSC FILES\testtrain\splits_new_v2'
# Use the non-food root (contains train/val/test). We'll combine all folds.
NON_FOOD_DIR = r'D:\VSC FILES\testtrain\food_not_food_ds\not_food'
OUTPUT_DIR = r'd:\VSC FILES\testtrain\inference_outputs'
os.makedirs(OUTPUT_DIR, exist_ok=True)

print('Checkpoint exists?', os.path.exists(CKPT_PATH))

ckpt = torch.load(CKPT_PATH, map_location='cpu')
# Expected keys: 'classes', 'model_state', optionally 'mean','std','image_size'
classes = ckpt.get('classes') or ckpt.get('class_names')
if classes is None:
    raise RuntimeError('Checkpoint does not contain class list under "classes" or "class_names"')

mean = ckpt.get('mean', [0.485, 0.456, 0.406])
std = ckpt.get('std', [0.229, 0.224, 0.225])
image_size = ckpt.get('image_size', 300)  # default used during training
print('Loaded checkpoint. Model classes:', len(classes))

Checkpoint exists? True


  ckpt = torch.load(CKPT_PATH, map_location='cpu')


Loaded checkpoint. Model classes: 2
 2


In [6]:
# Section 6 — Build model
num_classes = len(classes)
model = models.efficientnet_b3(weights=None)
# Replace classifier to match training code pattern
if hasattr(model, 'classifier'):
    # EfficientNet classifier is usually model.classifier[1]
    try:
        in_feats = model.classifier[1].in_features
        model.classifier[1] = nn.Linear(in_feats, num_classes)
    except Exception:
        # fallback
        model.classifier = nn.Linear(model.classifier.in_features, num_classes)
else:
    # Generic fallback - replace last linear if found
    for name, module in reversed(list(model.named_modules())):
        if isinstance(module, nn.Linear):
            parent = model
            break

# Load weights (support different key names)
state_key = 'model_state' if 'model_state' in ckpt else 'state_dict' if 'state_dict' in ckpt else None
if state_key is None:
    raise RuntimeError('Checkpoint missing model_state/state_dict')
model.load_state_dict(ckpt[state_key])
model = model.to(device)
model.eval()
print('Model loaded and set to eval() on device', device)

Model loaded and set to eval() on device privateuseone:0


In [7]:
# Section 7 — Transforms and Datasets
from torchvision import transforms

eval_tfms = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(image_size),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std),
])

# Helper to detect binary food/not_food checkpoint
def _detect_binary_food_ckpt(ckpt_classes):
    lower = [c.lower().replace(' ', '_').replace('-', '_') for c in ckpt_classes]
    food_idx = None
    not_food_idx = None
    for i, c in enumerate(lower):
        if c == 'food':
            food_idx = i
        if c in ('not_food', 'notfood', 'not_foods', 'not_food', 'nonfood', 'non_food', 'not food'):
            not_food_idx = i
    # also allow when only 'food' and 'not_food' are present in some variant
    if food_idx is not None and not_food_idx is not None:
        return True, food_idx, not_food_idx
    return False, food_idx, not_food_idx

class ImageFolderFlat(Dataset):
    """Collect images from <splits_root>/<class>/(train|val|test) and map labels to indices in ckpt['classes'] so indices align with model outputs.

    If the checkpoint is a binary food/not_food model, this class will treat all images under every class in `splits_root` as `food` (label index = index of 'food' in checkpoint)."""
    def __init__(self, splits_root, ckpt_classes, folds=None, transform=None):
        self.paths = []
        self.labels = []
        self.transform = transform
        self.ckpt_classes = list(ckpt_classes)
        # default folds to combine
        if folds is None:
            folds = ['train', 'val', 'test']
        # detect binary ckpt
        binary_mode, food_idx, not_food_idx = _detect_binary_food_ckpt(self.ckpt_classes)
        available = [d for d in os.listdir(splits_root) if os.path.isdir(os.path.join(splits_root, d))]
        available = set(available)
        skipped = []
        if binary_mode:
            # collect ALL images under every class folder and label them as 'food'
            for cls_dir_name in sorted(available):
                for fold in folds:
                    cls_dir = os.path.join(splits_root, cls_dir_name, fold)
                    if not os.path.isdir(cls_dir):
                        continue
                    for fn in os.listdir(cls_dir):
                        if fn.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp', '.webp')):
                            p = os.path.join(cls_dir, fn)
                            self.paths.append(p)
                            self.labels.append(food_idx)
            print(f"Binary checkpoint detected. Collected {len(self.paths)} images from splits and labeled them as 'food' (idx={food_idx}).")
        else:
            for cls_name in self.ckpt_classes:
                if cls_name not in available:
                    skipped.append(cls_name)
                    continue
                for fold in folds:
                    cls_dir = os.path.join(splits_root, cls_name, fold)
                    if not os.path.isdir(cls_dir):
                        continue
                    for fn in os.listdir(cls_dir):
                        if fn.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp', '.webp')):
                            p = os.path.join(cls_dir, fn)
                            self.paths.append(p)
                            # label index is index in ckpt_classes
                            self.labels.append(self.ckpt_classes.index(cls_name))
            if skipped:
                print(f"Skipped {len(skipped)} classes not found in splits directory: {skipped[:10]}")

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        p = self.paths[idx]
        img = Image.open(p).convert('RGB')
        if self.transform:
            img = self.transform(img)
        return img, self.labels[idx], p

class NonFoodDataset(Dataset):
    """Collect images from root_dir/(train|val|test) combining all folds."""
    def __init__(self, root_dir, folds=None, transform=None, ckpt_classes=None):
        self.paths = []
        if folds is None:
            folds = ['train', 'val', 'test']
        for fold in folds:
            fold_dir = os.path.join(root_dir, fold)
            if not os.path.isdir(fold_dir):
                continue
            for fn in os.listdir(fold_dir):
                if fn.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp', '.webp')):
                    self.paths.append(os.path.join(fold_dir, fn))
        self.transform = transform
        # if ckpt is binary, record indices for mapping when saving/analysis
        self.ckpt_classes = list(ckpt_classes) if ckpt_classes is not None else None

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        p = self.paths[idx]
        img = Image.open(p).convert('RGB')
        if self.transform:
            img = self.transform(img)
        return img, p

In [8]:
# Section 8 — Build datasets and dataloaders
BATCH_SIZE = 32
NUM_WORKERS = 0

# Combine train/val/test for food dataset using folds
FOLDS = ['train', 'val', 'test']
food_dataset = ImageFolderFlat(SPLITS_DIR, classes, folds=FOLDS, transform=eval_tfms)
food_loader = DataLoader(food_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)
print('Food dataset size (combined folds):', len(food_dataset))

# Non-food dataset: combine all folds under NON_FOOD_DIR
nonfood_dataset = NonFoodDataset(NON_FOOD_DIR, folds=FOLDS, transform=eval_tfms)
nonfood_loader = DataLoader(nonfood_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)
print('Non-food dataset size (combined folds):', len(nonfood_dataset))

Binary checkpoint detected. Collected 43400 images from splits and labeled them as 'food' (idx=0).
Food dataset size (combined folds): 43400
Non-food dataset size (combined folds): 5012


In [9]:
# Section 9 — Inference helper
import math

# Determine if checkpoint is binary and get food index
binary_mode, food_idx, not_food_idx = _detect_binary_food_ckpt(classes)
if binary_mode:
    print(f"Binary model detected. food_idx={food_idx}, not_food_idx={not_food_idx}")

def run_inference(model, loader, device, topk=5):
    """Run inference.

    For binary models (food vs not_food) returns per-image dict with 'path', 'gt_idx', 'food_prob' and 'pred_label'.
    For multi-class models returns 'topk' as before.
    """
    results = []
    model.to(device)
    model.eval()
    with torch.no_grad():
        for batch in loader:
            # batch may be (imgs, labels, paths) or (imgs, paths)
            if len(batch) == 3:
                imgs, labels, paths = batch
            else:
                imgs, paths = batch
                labels = [None] * len(paths)
            imgs = imgs.to(device)
            logits = model(imgs)
            probs = torch.softmax(logits, dim=1).cpu().numpy()
            for i in range(probs.shape[0]):
                row = probs[i]
                if binary_mode:
                    # use food index probability
                    food_probability = float(row[food_idx])
                    pred_label = classes[int(row.argmax())]
                    results.append({
                        'path': paths[i],
                        'gt_idx': int(labels[i]) if labels is not None and labels[i] is not None else None,
                        'food_prob': food_probability,
                        'pred_label': pred_label,
                    })
                else:
                    top_idx = row.argsort()[-topk:][::-1]
                    top_probs = row[top_idx]
                    top_classes = [classes[j] for j in top_idx]
                    results.append({
                        'path': paths[i],
                        'gt_idx': int(labels[i]) if labels is not None and labels[i] is not None else None,
                        'topk': list(zip(top_classes, [float(x) for x in top_probs]))
                    })
    return results


Binary model detected. food_idx=0, not_food_idx=1


In [10]:
# Section 10 — Run inference on food dataset and compute metrics
CONF_THRESH = 0.7

food_results = run_inference(model, food_loader, device, topk=5)

if binary_mode:
    # For binary model we expect all food_dataset samples to be positive (we labeled them as food_idx)
    total_food = len(food_results)
    detected_food = 0
    missed_food = []
    food_probs = []
    for r in food_results:
        prob = r.get('food_prob', 0.0)
        food_probs.append(prob)
        if prob >= CONF_THRESH:
            detected_food += 1
        else:
            missed_food.append(r['path'])
    tpr = detected_food / total_food if total_food else 0.0
    print(f'Food detection on combined folds: {total_food} images. Detected (>= {CONF_THRESH}): {detected_food}. TPR: {tpr:.4f}')
    print(f'Collected {len(missed_food)} missed food images (prob < {CONF_THRESH})')
    # Prepare bad_images_by_class-like structure for compatibility (single key 'food')
    bad_images_by_class = {'food': missed_food}
    # Optionally save distribution stats
    import statistics
    print('Food prob mean:', statistics.mean(food_probs) if food_probs else 0.0)
    print('Food prob median:', statistics.median(food_probs) if food_probs else 0.0)
else:
    # Multiclass path (unchanged)
    correct1 = 0
    correct5 = 0
    bad_images_by_class = {}
    for r in food_results:
        gt = r['gt_idx']
        topk = r['topk']  # list of (class_name, prob)
        top_classes = [t[0] for t in topk]
        top_probs = [t[1] for t in topk]
        # top1
        if gt is not None and classes[gt] == top_classes[0]:
            correct1 += 1
        # top5
        if gt is not None and classes[gt] in top_classes:
            correct5 += 1
            idx = top_classes.index(classes[gt])
            if top_probs[idx] < CONF_THRESH:
                bad_images_by_class.setdefault(classes[gt], []).append(r['path'])
        else:
            # GT not in top5
            bad_images_by_class.setdefault(classes[gt] if gt is not None else 'UNKNOWN', []).append(r['path'])

    n = len(food_results)
    acc1 = correct1 / n if n else 0.0
    acc5 = correct5 / n if n else 0.0
    print(f'Food dataset: {n} images. Top-1 acc: {acc1:.4f}, Top-5 acc: {acc5:.4f}')
    print('Bad images (per-class counts):')
    for k,v in bad_images_by_class.items():
        print(k, len(v))


Food detection on combined folds: 43400 images. Detected (>= 0.7): 42046. TPR: 0.9688
Collected 1354 missed food images (prob < 0.7)
Food prob mean: 0.9023973787442914
Food prob median: 0.9271599650382996


In [11]:
# Section 11 — Run inference on non-food and flag confident predictions
import shutil

non_food_flagged = []
model.to(device)
model.eval()
with torch.no_grad():
    for imgs, paths in nonfood_loader:
        imgs = imgs.to(device)
        logits = model(imgs)
        probs = torch.softmax(logits, dim=1).cpu().numpy()
        for i in range(probs.shape[0]):
            row = probs[i]
            if binary_mode:
                food_probability = float(row[food_idx])
                pred_label = classes[int(row.argmax())]
                # All images in this dataset are expected to be NOT food (ground-truth = not_food_idx)
                gt_idx = not_food_idx
                if food_probability >= CONF_THRESH:
                    # false positive: non-food image predicted as food with high confidence
                    non_food_flagged.append({
                        'path': paths[i],
                        'food_prob': food_probability,
                        'pred_label': pred_label,
                        'gt_idx': gt_idx,
                    })
            else:
                max_idx = int(row.argmax())
                max_prob = float(row[max_idx])
                if max_prob >= CONF_THRESH:
                    non_food_flagged.append({
                        'path': paths[i],
                        'pred_class': classes[max_idx],
                        'prob': max_prob,
                        'gt_idx': None,
                    })

# Compute non-food metrics: total, false positives, true negatives, FPR, specificity
total_nonfood = len(nonfood_dataset)
false_positives = len(non_food_flagged)
true_negatives = total_nonfood - false_positives
fpr = false_positives / total_nonfood if total_nonfood else 0.0
specificity = true_negatives / total_nonfood if total_nonfood else 0.0
print('Non-food images total:', total_nonfood)
print('Non-food images flagged (confident predictions / false positives):', false_positives)
print(f'False positive rate (FPR): {fpr:.4f}  | Specificity (TN / N): {specificity:.4f}')


Non-food images total: 5012
Non-food images flagged (confident predictions / false positives): 1
False positive rate (FPR): 0.0002  | Specificity (TN / N): 0.9998


In [12]:
# Section 12 — Save results and optionally copy flagged images
bad_json = os.path.join(OUTPUT_DIR, 'bad_images_by_class.json')
nonfood_json = os.path.join(OUTPUT_DIR, 'non_food_confident.json')

with open(bad_json, 'w') as f:
    json.dump(bad_images_by_class, f, indent=2)
with open(nonfood_json, 'w') as f:
    json.dump(non_food_flagged, f, indent=2)

print('Saved:', bad_json)
print('Saved:', nonfood_json)

COPY_FLAGGED = True
if COPY_FLAGGED:
    bad_out_dir = os.path.join(OUTPUT_DIR, 'bad_images_by_class')
    nonfood_out_dir = os.path.join(OUTPUT_DIR, 'non_food_confident')
    os.makedirs(bad_out_dir, exist_ok=True)
    os.makedirs(nonfood_out_dir, exist_ok=True)
    # copy non-food
    for item in non_food_flagged:
        dst = os.path.join(nonfood_out_dir, os.path.basename(item['path']))
        try:
            shutil.copy(item['path'], dst)
        except Exception as e:
            print('Copy error:', e)
    # copy bad images by class
    for cls_name, paths in bad_images_by_class.items():
        cls_dir = os.path.join(bad_out_dir, cls_name)
        os.makedirs(cls_dir, exist_ok=True)
        for p in paths:
            try:
                shutil.copy(p, os.path.join(cls_dir, os.path.basename(p)))
            except Exception as e:
                print('Copy error:', e)
    print('Flagged images copied to:', bad_out_dir, 'and', nonfood_out_dir)

Saved: d:\VSC FILES\testtrain\inference_outputs\bad_images_by_class.json
Saved: d:\VSC FILES\testtrain\inference_outputs\non_food_confident.json
Flagged images copied to: d:\VSC FILES\testtrain\inference_outputs\bad_images_by_class and d:\VSC FILES\testtrain\inference_outputs\non_food_confident
Flagged images copied to: d:\VSC FILES\testtrain\inference_outputs\bad_images_by_class and d:\VSC FILES\testtrain\inference_outputs\non_food_confident


In [13]:
# Section 13 — Summary & quick checks
print('\nSummary:')
print('Food images total:', len(food_dataset))
print('Non-food images total:', len(nonfood_dataset))

# Print appropriate metrics depending on model type
if binary_mode:
    if 'tpr' in globals():
        print(f'Detected (TPR) at threshold {CONF_THRESH}:', f"{detected_food}/{total_food} = {tpr:.4f}")
    else:
        print('Run Section 10 to compute detection metrics (TPR).')
    # Non-food metrics (if computed)
    if 'total_nonfood' in globals():
        print(f'Non-food: Total={total_nonfood}, False positives={false_positives}, FPR={fpr:.4f}, Specificity={specificity:.4f}')
    else:
        print('Run Section 11/11.5 to compute non-food false positive metrics.')
else:
    # multiclass metrics
    if 'acc1' in globals() and 'acc5' in globals():
        print('Top-1 acc:', acc1)
        print('Top-5 acc:', acc5)
    else:
        print('Run Section 10 to compute multiclass accuracy metrics.')

print('Bad images saved to:', bad_json)
print('Non-food flagged saved to:', nonfood_json)

# Quick sanity checks
if not os.path.exists(bad_json):
    print('Warning: bad_images JSON missing')
if not os.path.exists(nonfood_json):
    print('Warning: non-food JSON missing')

print('\nDone.')


Summary:
Food images total: 43400
Non-food images total: 5012
Detected (TPR) at threshold 0.7: 42046/43400 = 0.9688
Non-food: Total=5012, False positives=1, FPR=0.0002, Specificity=0.9998
Bad images saved to: d:\VSC FILES\testtrain\inference_outputs\bad_images_by_class.json
Non-food flagged saved to: d:\VSC FILES\testtrain\inference_outputs\non_food_confident.json

Done.


In [14]:
# Section 12.5 — List images below confidence threshold
# Shows all food images (from splits) where food probability < CONF_THRESH
print(f"Listing images with food probability < {CONF_THRESH}:\n")
if binary_mode:
    low_conf = bad_images_by_class.get('food', [])
    print(f'Found {len(low_conf)} food images below {CONF_THRESH}\n')
    # print full paths; if many, show first 200 to avoid flooding
    max_show = 10000
    for i, p in enumerate(low_conf):
        if i >= max_show:
            print(f"... ({len(low_conf)-max_show} more) ...")
            break
        print(p)
else:
    # For multiclass models, bad_images_by_class already collects low-confidence/misranked images
    total = 0
    for cls, paths in bad_images_by_class.items():
        print(f"Class: {cls} -> {len(paths)} images")
        total += len(paths)
        for p in paths[:200]:
            print(p)
    print(f"Total low-confidence images across classes: {total}")

Listing images with food probability < 0.7:

Found 1354 food images below 0.7

D:\VSC FILES\testtrain\splits_new_v2\adobong_pusit\train\adobong_pusit_160.jpg
D:\VSC FILES\testtrain\splits_new_v2\adobong_pusit\train\adobong_pusit_45.jpg
D:\VSC FILES\testtrain\splits_new_v2\adobong_pusit\val\adobong_pusit_3.jpg
D:\VSC FILES\testtrain\splits_new_v2\adobong_pusit\test\adobong_pusit_175.jpg
D:\VSC FILES\testtrain\splits_new_v2\apple\train\apple_112.jpg
D:\VSC FILES\testtrain\splits_new_v2\apple\train\apple_121.jpg
D:\VSC FILES\testtrain\splits_new_v2\apple\train\apple_141.jpg
D:\VSC FILES\testtrain\splits_new_v2\apple\train\apple_148.jpg
D:\VSC FILES\testtrain\splits_new_v2\apple\train\apple_160.jpg
D:\VSC FILES\testtrain\splits_new_v2\apple\train\apple_169.jpg
D:\VSC FILES\testtrain\splits_new_v2\apple\train\apple_198.jpg
D:\VSC FILES\testtrain\splits_new_v2\apple\train\apple_229.jpg
D:\VSC FILES\testtrain\splits_new_v2\apple\train\apple_275.jpg
D:\VSC FILES\testtrain\splits_new_v2\apple\t

In [15]:
# Section 11.5 — Ensure non-food inference ran and list flagged non-food images
import csv

# If non_food_flagged isn't defined, run the non-food inference loop now
if 'non_food_flagged' not in globals():
    non_food_flagged = []
    model.to(device)
    model.eval()
    with torch.no_grad():
        for imgs, paths in nonfood_loader:
            imgs = imgs.to(device)
            logits = model(imgs)
            probs = torch.softmax(logits, dim=1).cpu().numpy()
            for i in range(probs.shape[0]):
                row = probs[i]
                if binary_mode:
                    food_probability = float(row[food_idx])
                    pred_label = classes[int(row.argmax())]
                    gt_idx = not_food_idx
                    if food_probability >= CONF_THRESH:
                        non_food_flagged.append({
                            'path': paths[i],
                            'food_prob': food_probability,
                            'pred_label': pred_label,
                            'gt_idx': gt_idx,
                        })
                else:
                    max_idx = int(row.argmax())
                    max_prob = float(row[max_idx])
                    if max_prob >= CONF_THRESH:
                        non_food_flagged.append({
                            'path': paths[i],
                            'pred_class': classes[max_idx],
                            'prob': max_prob,
                            'gt_idx': None,
                        })

# Recompute non-food summary metrics if they don't exist
total_nonfood = len(nonfood_dataset)
false_positives = len(non_food_flagged)
true_negatives = total_nonfood - false_positives
fpr = false_positives / total_nonfood if total_nonfood else 0.0
specificity = true_negatives / total_nonfood if total_nonfood else 0.0
print('Non-food images flagged (confident predictions):', false_positives)

# Save CSV with results for easier review
csv_path = os.path.join(OUTPUT_DIR, 'non_food_confident.csv')
with open(csv_path, 'w', newline='', encoding='utf-8') as cf:
    if binary_mode:
        fieldnames = ['path','food_prob','pred_label','gt_idx']
    else:
        fieldnames = ['path','pred_class','prob','gt_idx']
    writer = csv.DictWriter(cf, fieldnames=fieldnames)
    writer.writeheader()
    for item in non_food_flagged:
        writer.writerow(item)
print('Saved non-food flagged CSV to:', csv_path)

# Print first 200 flagged examples (or fewer) for quick inspection
for i, item in enumerate(non_food_flagged[:200]):
    if binary_mode:
        print(i+1, item['path'], f"food_prob={item['food_prob']:.3f}", 'pred=', item.get('pred_label'), 'gt_idx=', item.get('gt_idx'))
    else:
        print(i+1, item['path'], item.get('prob'), item.get('pred_class'))

print(f'Non-food summary -> Total: {total_nonfood}, False positives: {false_positives}, FPR: {fpr:.4f}, Specificity: {specificity:.4f}')

# Optionally copy these flagged non-food images to output folder (already done in Section 12 if COPY_FLAGGED)
if not COPY_FLAGGED:
    print('\nTo copy flagged images to disk, set COPY_FLAGGED = True and re-run Section 12.')

Non-food images flagged (confident predictions): 1
Saved non-food flagged CSV to: d:\VSC FILES\testtrain\inference_outputs\non_food_confident.csv
1 D:\VSC FILES\testtrain\food_not_food_ds\not_food\test\000017_123_da47b22102623a9f809aa3a98987717c.jpg.jpg food_prob=0.821 pred= food gt_idx= 1
Non-food summary -> Total: 5012, False positives: 1, FPR: 0.0002, Specificity: 0.9998
