## Setup & Installation

In [None]:
# Install necessary libraries (Commented out potentially conflicting versions)
# !pip install fiftyone wandb open-clip-torch
# !pip install ftfy regex tqdm torchmetrics torch-fidelity
# !pip install git+https://github.com/openai/CLIP.git

# NOTE: The following line tries to downgrade PyTorch to 1.7.1 which conflicts with our current setup.
# !pip install --yes -c pytorch pytorch=1.7.1 torchvision cudatoolkit=11.0

## Part 1: Image Generation and Emebdding Extraction

In [None]:
# cd drive/MyDrive/Hands-on-CV3  # Commented out Google Colab specific path

In [None]:
import torch
import sys
import os

# Import our new utility module
import utils
from utils import ClassicLeNet5, CustomTorchImageDataset, train_epoch, val_epoch, evaluate_idk_performance

print(f"Torch version: {torch.__version__}")
print("Utils imported successfully.")

In [None]:
import clip

clip.available_models()
clip_model, clip_preprocess = clip.load("ViT-B/32")
clip_model.eval()
CLIP_FEATURES = 512

Sanity Check of the model

In [None]:
import matplotlib.pyplot as plt

N = 200     # Based on ressource limiations
BATCH_SIZE = 8
IMG_SIZE = 32
IMG_CH = 3
INPUT_SIZE = (IMG_CH, IMG_SIZE, IMG_SIZE)
B_start = 0.0001
B_end = 0.02
NUM_STEPS = 400

# Define paths
project_root = Path.cwd()
sys.path.append(str(Path(".").resolve()))
OUT_DIR = Path("generated_flowers")
OUT_DIR.mkdir(parents=True, exist_ok=True)

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- 2. MODEL SETUP ---

model = UNet_utils.UNet(
    T=400, img_ch=3, img_size=32, down_chs=(256, 256, 512),
    t_embed_dim=8, c_embed_dim=512
).to(device)

# Ensure the path is correct relative to your environment
weights_path = '/content/drive/MyDrive/Hands-on-CV3/flowerDiff.pth'
model.load_state_dict(torch.load(weights_path, map_location=device))
model.eval()

# 1. Define Test Prompts
test_prompts = [
    "A photo of a red rose",
    "A photo of a white daisy",
    "A photo of a yellow sunflower"
]

print("Generating 3 test images with Guidance Scale w=4.0...")

# 2. Generate (Manually calling the sampler to be sure)
#    Note: We explicitly pass w=4.0 here.
text_tokens = clip.tokenize(test_prompts).to(device)
c = clip_model.encode_text(text_tokens).float()

B = torch.linspace(B_start, B_end, NUM_STEPS).to(device)
ddpm = ddpm_utils.DDPM(B, device)

def _to_01(x: torch.Tensor) -> torch.Tensor:
    if x.min() < 0:
        x = (x + 1) / 2.0
    return x.clamp(0, 1)


# Sample
x_test, _ = ddpm_utils.sample_w(
    model,
    ddpm,
    INPUT_SIZE,
    NUM_STEPS,
    c,
    device,
    w_tests=[1] #
)
# 3. Visualize Results
plt.figure(figsize=(15, 5))
for i in range(len(test_prompts)):
    # Convert from [C, H, W] to [H, W, C] for plotting
    img_tensor = _to_01(x_test[i]).cpu()
    img_np = img_tensor.permute(1, 2, 0).numpy()

    plt.subplot(1, 3, i+1)
    plt.imshow(img_np)
    plt.title(f"Prompt: {test_prompts[i]}\n(Guidance w=4.0)")
    plt.axis('off')

plt.show()

In [None]:
import sys
from pathlib import Path
import numpy as np
import torch
from torchvision.utils import save_image

# --- 1. CONFIGURATION
N = 200     # Based on ressource limiations
BATCH_SIZE = 8
IMG_SIZE = 32
IMG_CH = 3
INPUT_SIZE = (IMG_CH, IMG_SIZE, IMG_SIZE)
B_start = 0.0001
B_end = 0.02
NUM_STEPS = 400

# Define paths
project_root = Path.cwd()
sys.path.append(str(Path(".").resolve()))
OUT_DIR = Path("generated_flowers")
OUT_DIR.mkdir(parents=True, exist_ok=True)

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- 2. MODEL SETUP ---

model = UNet_utils.UNet(
    T=400, img_ch=3, img_size=32, down_chs=(256, 256, 512),
    t_embed_dim=8, c_embed_dim=512
).to(device)

# Ensure the path is correct relative to your environment
weights_path = '/content/drive/MyDrive/Hands-on-CV3/flowerDiff.pth'
model.load_state_dict(torch.load(weights_path, map_location=device))
model.eval()

# DDPM Setup
B = torch.linspace(B_start, B_end, NUM_STEPS).to(device)
ddpm = ddpm_utils.DDPM(B, device)

# --- 3. HOOKS FOR EMBEDDINGS ---
embeddings_storage = {}

def get_embedding_hook(name):
    def hook(model, input, output):
        embeddings_storage[name] = output.detach()
    return hook

# Register hook on down2
_ = model.down2.register_forward_hook(get_embedding_hook('down2'))

# --- 4. HELPER FUNCTIONS ---
def _to_01(x: torch.Tensor) -> torch.Tensor:
    if x.min() < 0:
        x = (x + 1) / 2.0
    return x.clamp(0, 1)

def cycle_prompts(prompts, n):
    """Cycles through the prompt list to match the requested number of images."""
    return (prompts * ((n + len(prompts) - 1) // len(prompts)))[:n]

@torch.no_grad()
def sample_flowers_with_embeddings(prompt_list):
    embeddings_storage.clear()

    text_tokens = clip.tokenize(prompt_list).to(device)
    c = clip_model.encode_text(text_tokens).float()

    x_gen, _ = ddpm_utils.sample_w(
        model,
        ddpm,
        INPUT_SIZE,
        NUM_STEPS,
        c,
        device,
        w_tests=[2]
    )

    down2 = embeddings_storage["down2"]          # [B, C, H, W]
    down2_vec = down2.mean(dim=(2, 3))           # [B, C] Global Average Pooling

    x_gen = x_gen[:len(prompt_list)]
    down2_vec = down2_vec[:len(prompt_list)]

    return x_gen, down2_vec

# --- 5. MAIN GENERATION LOOP ---
text_prompts = [
    "A photo of a red rose",
    "A photo of a white daisy",
    "A photo of a yellow sunflower"
]

text_prompts_seed = text_prompts
all_prompts = cycle_prompts(text_prompts_seed, N)

image_paths = []
prompt_per_image = []
unet_embs = []

print(f"Starting generation of {N} images...")

idx = 0
while idx < N:
    # Slice the prompts for this batch
    batch_prompts = all_prompts[idx : idx + BATCH_SIZE]

    # Generate
    x_gen, emb_vec = sample_flowers_with_embeddings(batch_prompts)

    # Process results
    x01 = _to_01(x_gen).cpu()
    emb_np = emb_vec.detach().cpu().numpy()

    for j in range(len(batch_prompts)):
        current_id = idx + j
        fp = OUT_DIR / f"gen_{current_id:06d}.png"

        save_image(x01[j], fp)

        image_paths.append(str(fp))
        prompt_per_image.append(batch_prompts[j])
        unet_embs.append(emb_np[j].astype(np.float32))

    idx += len(batch_prompts)
    print(f"Generated {idx}/{N}")

unet_embs_np = np.stack(unet_embs, axis=0)  # [N, C]
print("Done!")
print("Generated:", len(image_paths), "Embeddings:", unet_embs_np.shape)

## Part 2: Evaluation with CLIP Score and Frechet Inception Distance

First download the data for comparson

In [None]:
import tarfile
import urllib.request
from pathlib import Path
import shutil

# 1. Define URL and Paths
dataset_url = "http://download.tensorflow.org/example_images/flower_photos.tgz"
archive_path = Path("flower_photos.tgz")
data_dir = Path("flower_photos")  # This is where images will be extracted

# 2. Download
if not archive_path.exists():
    print(f"Downloading TF-Flowers from {dataset_url}...")
    urllib.request.urlretrieve(dataset_url, archive_path)
    print("Download complete.")

# 3. Extract
if not data_dir.exists():
    print("Extracting images...")
    with tarfile.open(archive_path, "r:gz") as tar:
        tar.extractall()
    print(f"Extracted to {data_dir.resolve()}")
else:
    print(f"Data already exists at {data_dir.resolve()}")

# 4. Quick verification
jpg_count = len(list(data_dir.glob("**/*.jpg")))
print(f"Found {jpg_count} images in dataset.")

In [None]:
# OLD
# real_data_dir = Path("flower_data/train")

# NEW (Points to the downloaded TF-Flowers)
real_data_dir = Path("flower_photos")

In [None]:
import open_clip
import torch
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
from scipy.linalg import sqrtm
from pathlib import Path

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ==========================================
# 1. CLIP Score Evaluation
# ==========================================
print("--- Starting CLIP Evaluation ---")

# FIX 1: Load model ONCE outside the loop
# Note: 'ViT-B-32' is standard, but check if you need 'ViT-L-14' for better accuracy if VRAM allows.
clip_model_name = "ViT-B-32"
clip_pretrained = "laion2b_s34b_b79k"

try:
    model_clip, _, preprocess_clip = open_clip.create_model_and_transforms(clip_model_name, pretrained=clip_pretrained)
    model_clip = model_clip.to(device).eval()
    tokenizer = open_clip.get_tokenizer(clip_model_name)
except Exception as e:
    print(f"Error loading OpenCLIP: {e}. Make sure open_clip_torch is installed.")

@torch.no_grad()
def calculate_single_clip_score(image_path, text_prompt):
    # Load and Preprocess
    image = preprocess_clip(Image.open(image_path).convert("RGB")).unsqueeze(0).to(device)
    text = tokenizer([text_prompt]).to(device)

    # Encode
    image_features = model_clip.encode_image(image)
    text_features = model_clip.encode_text(text)

    # Normalize
    image_features /= image_features.norm(dim=-1, keepdim=True)
    text_features /= text_features.norm(dim=-1, keepdim=True)

    # Dot product
    return float((image_features @ text_features.T).item())

# FIX 2: Use 'prompt_per_image' (length 200) instead of 'text_prompts' (length 3)
# Ensure image_paths and prompt_per_image exist from the previous step
if 'prompt_per_image' not in locals():
    print("Warning: prompt_per_image not found. Using text_prompts (this limits eval to 3 images).")
    prompt_list_to_use = text_prompts
else:
    prompt_list_to_use = prompt_per_image

# Calculate scores
clip_scores = []
for i, (p, t) in enumerate(zip(image_paths, prompt_list_to_use)):
    score = calculate_single_clip_score(p, t)
    clip_scores.append(score)
    if i % 50 == 0: print(f"Evaluated {i} images...")

print("Mean CLIP Score:", float(np.mean(clip_scores)))


# ==========================================
# 2. FID (Frechet Inception Distance)
# ==========================================
print("\n--- Starting FID Evaluation ---")

# FIX 3: Add Normalization. Inception expects ImageNet mean/std.
inception_transform = transforms.Compose([
    transforms.Resize((299, 299)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load Inception V3
inception = models.inception_v3(weights=models.Inception_V3_Weights.IMAGENET1K_V1, transform_input=False)
inception.fc = torch.nn.Identity()  # Remove classification layer
inception = inception.to(device).eval()

@torch.no_grad()
def get_inception_embeddings(image_paths, batch_size=32):
    embs = []
    # Batch processing for speed
    for i in range(0, len(image_paths), batch_size):
        batch_paths = image_paths[i:i+batch_size]

        imgs = []
        valid_batch = True
        for p in batch_paths:
            try:
                imgs.append(inception_transform(Image.open(p).convert("RGB")))
            except Exception as e:
                print(f"Error reading {p}: {e}")
                valid_batch = False

        if not valid_batch or len(imgs) == 0: continue

        x = torch.stack(imgs).to(device)
        y = inception(x)  # [B, 2048]
        embs.append(y.cpu().numpy())

    if len(embs) > 0:
        return np.concatenate(embs, axis=0)
    else:
        return np.array([])

def calculate_fid(real_embeddings, gen_embeddings):
    # Safety check for small N
    if len(real_embeddings) == 0 or len(gen_embeddings) == 0:
        return float('inf')

    mu1, sigma1 = real_embeddings.mean(axis=0), np.cov(real_embeddings, rowvar=False)
    mu2, sigma2 = gen_embeddings.mean(axis=0), np.cov(gen_embeddings, rowvar=False)

    ssdiff = np.sum((mu1 - mu2) ** 2)

    # Calculate sqrt of product of covariances
    covmean = sqrtm(sigma1.dot(sigma2))

    # Check for numerical instability
    if np.iscomplexobj(covmean):
        covmean = covmean.real

    fid = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean)
    return float(fid)


In [None]:
# --- Define Paths ---
gen_dir = Path("generated_flowers")

# Point to the extracted gt F-Flowers folder
real_data_dir = Path("flower_photos")

# Collect paths (TF-Flowers has subfolders)
gen_paths  = sorted([str(p) for p in gen_dir.glob("*.png")])
real_paths = sorted([str(p) for p in real_data_dir.glob("**/*.jpg")])

# Limit real images to match the number of generated images (fair comparison)
# If we generated 200 images, we select the first 200 real images.
if len(real_paths) > len(gen_paths):
    real_paths = real_paths[:len(gen_paths)]

print(f"Computing embeddings... Real: {len(real_paths)}, Gen: {len(gen_paths)}")

if len(real_paths) > 0 and len(gen_paths) > 0:
    real_emb = get_inception_embeddings(real_paths)
    gen_emb  = get_inception_embeddings(gen_paths)

    fid_value = calculate_fid(real_emb, gen_emb)
    print(f"FID Score: {fid_value:.4f}")
else:
    print("Skipping FID: Real or Generated image lists are empty. Check your paths.")

## Part 3: Embedding Analysis with FiftyOne Brain

In [None]:
import fiftyone as fo
import fiftyone.brain as fob
import numpy as np

print("--- Starting FiftyOne Setup ---")

# 1. Clean up previous runs
dataset_name = "generated_flowers_with_embeddings"
if dataset_name in fo.list_datasets():
    fo.delete_dataset(dataset_name)

dataset = fo.Dataset(name=dataset_name)

# 2. Add Samples
samples = []

for fp, prompt, score, emb in zip(image_paths, prompt_per_image, clip_scores, unet_embs_np):
    s = fo.Sample(filepath=fp)

    # Store metadata
    s["prompt"] = fo.Classification(label=prompt)
    s["clip_score"] = float(score)
    s["unet_embedding"] = emb.tolist()  # FiftyOne expects lists, not numpy arrays

    samples.append(s)

dataset.add_samples(samples)
dataset.save()
print(f"Created dataset with {len(dataset)} samples.")

# 3. Brain Computations (Uniqueness & Representativeness)
print("Computing uniqueness...")
fob.compute_uniqueness(dataset)

print("Computing representativeness...")
fob.compute_representativeness(dataset, embeddings="unet_embedding")

# 4. Visualization (UMAP)
print("Computing UMAP visualization...")
# This generates a 2D scatter plot of your embeddings in the App
fob.compute_visualization(
    dataset,
    embeddings="unet_embedding",
    method="umap",
    brain_key="umap_vis"
)

# 5. Launch App
session = fo.launch_app(dataset)


The results look quite nice. While they are are a bit noisy, it is clear that using the guidance scale, every image is clearly dividable in one of the 3 possible classes!ðŸŒ¹ðŸŒ»ðŸŒ¼


In [None]:
## Part 4 - W&B logging

In [None]:
import wandb

print("--- Starting WandB Logging ---")

# 1. Login
wandb.login()

# 2. Initialize Run
run = wandb.init(
    project="Hands-on-CV-Project3",
    name="flower_generation_run",
    config={
        "num_steps": NUM_STEPS,
        "beta_start": B_start,
        "beta_end": B_end,
        "num_prompts": len(text_prompts),
        "total_images": N,
        "model_architecture": "UNet_32x32"
    },
)

# 3. Log Scalar Metrics (Summary)
run.log({
    "global_clip_mean": float(np.mean(clip_scores)),
    "global_fid_score": fid_value,
})

# 4. Create Rich Table
# We iterate over the FiftyOne dataset to ensure we get the computed scores
table = wandb.Table(columns=[
    "generated_image",
    "prompt",
    "clip_score",
    "uniqueness_score",
    "representativeness_score"
])

print("Populating WandB Table...")

for s in dataset:
    # Safely get brain scores (default to 0.0 if calculation failed)
    u_score = s["uniqueness"] if "uniqueness" in s else 0.0
    r_score = s["representativeness"] if "representativeness" in s else 0.0

    table.add_data(
        wandb.Image(s.filepath),
        s["prompt"].label,
        s["clip_score"],
        u_score,
        r_score
    )

# 5. Log Table and Finish
run.log({"generation_results": table})
run.finish()

print("WandB logging complete ðŸš€ðŸš€ðŸš€ Check your dashboard!")

Publish the data on HuggingFace

In [None]:
!pip install huggingface_hub

import fiftyone as fo
from huggingface_hub import login
# --- SETUP ---
my_hf_REDACTED = "hf_REDACTED" # PUT YOUR TOKEN HERE
login(token=my_hf_REDACTED)

# --- LOAD DATASET ---
print("1. Attempting to load dataset...", flush=True)

# Ensure we aren't using a cached variable by accident
if 'dataset' in locals():
    del dataset

dataset = fo.load_dataset("generated_flowers_with_embeddings")
print("   Dataset loaded successfully.", flush=True)

# --- UPLOAD ---
hf_REDACTED_name = "Consscht/FlowerDiff"

print(f"2. Preparing to upload to: {hf_REDACTED_name}...", flush=True)
print("   (This may take a moment while FiftyOne prepares the files...)", flush=True)

dataset.push_to_hub(
    repo_id=hf_REDACTED_name,
    private=False,
    dataset_type="image"
)

print("3. Success! Your dataset is now published ðŸš€", flush=True)
print(f"View it here: https://huggingface.co/datasets/{hf_REDACTED_name}")

In [None]:
!pip install huggingface_hub

import fiftyone as fo
from huggingface_hub import login


# !pip install huggingface_hub  # Uncomment if you need to install it

import fiftyone as fo
from huggingface_hub import login

# --- 1. SETUP ---
# Paste your token starting with "hf_..." inside the quotes below
my_hf_REDACTED = "hf_REDACTED"

# Log in automatically (no pop-up box)
login(token=my_hf_REDACTED)

# --- 2. LOAD DATASET ---
# Reloads the dataset you created in the previous step
if 'dataset' not in locals():
    dataset = fo.load_dataset("generated_flowers_with_embeddings")

# --- 3. UPLOAD TO HUGGING FACE ---
# Defines your repository name (Username/DatasetName)
# Based on your previous attempt, I assume your username is 'Consscht'
hf_REDACTED_name = "Consscht/FlowerDiff"

print(f"ðŸš€ Uploading dataset to: {hf_REDACTED_name} ...")

dataset.push_to_hub(
    repo_id=hf_REDACTED_name,
    private=False,       # Set to True if you want it hidden
    dataset_type="image"
)

print("âœ… Success! Your dataset is now published.")
print(f"View it here: https://huggingface.co/datasets/{hf_REDACTED_name}")

## Bonus -- "MNIST" idk classifier

In [None]:
import torch
import torch.nn.functional as F
import numpy as np
from pathlib import Path
from torchvision.utils import save_image
import fiftyone as fo
from PIL import Image

# Import utils (ensure utils.py is in the same directory)
import utils
from utils import ClassicLeNet5 

# Note: UNet_utils and ddpm_utils must be available from earlier validation cells
# If they are not in the path, ensure the corresponding files exist or cells are run.
# from utils import UNet_utils, ddpm_utils  <-- These are not in utils.py

# --- 1. SETUP & CONFIG ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
THRESHOLD = 0.8         # (Optional) Confidence threshold if we wanted to enforce high confidence
N_SAMPLES = 50          # How many images to generate
OUT_DIR = Path("generated_mnist_bonus")
OUT_DIR.mkdir(parents=True, exist_ok=True)

# --- 2. LOAD MODELS ---

# A. Load the Trained IDK Classifier
# We use the 11-class model we trained in IDK_Model_Training.ipynb
classifier = ClassicLeNet5(num_classes=11)
classifier_path = "best_lenet_idk.pth"

if Path(classifier_path).exists():
    classifier.load_state_dict(torch.load(classifier_path, map_location=device))
    print(f"Loaded IDK Classifier from {classifier_path}")
else:
    print(f"WARNING: Classifier checkpoint '{classifier_path}' not found!")

classifier.to(device)
classifier.eval()

# B. Load the MNIST Generator (U-Net)
# IMPORTANT: Use img_ch=1 and img_size=28 for MNIST!
# Assuming UNet_utils is defined in the workspace or imported previously
try:
    if 'UNet_utils' in globals():
       unet_class = UNet_utils.UNet
    else:
       # Fallback: hope it's in the scope or accessible
       import UNet_utils
       unet_class = UNet_utils.UNet
       
    model = unet_class(
        T=400, img_ch=1, img_size=28, down_chs=(256, 256, 512),
        t_embed_dim=8, c_embed_dim=512
    ).to(device)
    
    # Check for DDPM Utils
    if 'ddpm_utils' in globals():
        ddpm_class = ddpm_utils.DDPM
    else:
        import ddpm_utils
        ddpm_class = ddpm_utils.DDPM
        
except ImportError:
    print("Warning: UNet_utils or ddpm_utils not found. Ensure they are imported or defined.")
    # Stop execution if important utils are missing
    # raise

# For this bonus part, you likely need a different checkpoint for the generator (trained on MNIST)
# If you don't have one, this part might just generate noise.
# model.load_state_dict(torch.load('my_mnist_generator.pth', map_location=device))
model.eval()

# Setup DDPM for generation
B = torch.linspace(0.0001, 0.02, 400).to(device)
try:
    ddpm = ddpm_class(B, device)
except:
    print("Could not initialize DDPM. Skipping generation setup.")

# --- 3. DEFINE PREDICTION FUNCTION ---
def predict_with_idk(image, model, threshold):
    """
    Predicts using the 11-class model.
    Index 0-9: Digits
    Index 10: IDK
    """
    with torch.inference_mode():
        logits = model(image)
        probs = F.softmax(logits, dim=1)
        max_prob, pred_idx = torch.max(probs, dim=1)
        
        idx = pred_idx.item()
        
        # If model explicitly predicts the IDK class (Index 10)
        if idx == 10:
            return "IDK", max_prob.item()
        
        # Optional: You could still force IDK if confidence of a digit is too low
        # if max_prob.item() < threshold:
        #     return "IDK", max_prob.item()
            
        return str(idx), max_prob.item()

# --- 4. GENERATE, CLASSIFY & BUILD DATASET ---

# Initialize FiftyOne Dataset
dataset_name = "mnist_idk_experiment"
if dataset_name in fo.list_datasets():
    fo.delete_dataset(dataset_name)
dataset = fo.Dataset(name=dataset_name)

print(f"Generating {N_SAMPLES} digits and classifying...")

samples = []

# Ensure we can run generation
if 'ddpm' in locals() and 'model' in locals():
    for i in range(N_SAMPLES):
        # A. Generate Image
        # Start with random noise [1, 1, 28, 28]
        xi = torch.randn(1, 1, 28, 28).to(device)
        
        try:
             x_gen, _ = ddpm_utils.sample_w(model, ddpm, (1, 28, 28), 400, xi, device)
        except Exception as e:
            print(f"Generation failed: {e}")
            break

        # B. Classify
        # x_gen is [-1, 1], ClassicLeNet5 expects normalized but close to [0,1] or standard normalization
        # Our LeNet was trained on images normalized with mean=0.1307, std=0.3081
        # Generator output [-1, 1].
        # First map to [0, 1]
        img_01 = (x_gen.clamp(-1, 1) + 1) / 2.0
        
        # Then normalize for LeNet
        # (x - mean) / std
        img_norm = (img_01 - 0.1307) / 0.3081
        
        label, confidence = predict_with_idk(img_norm, classifier, THRESHOLD)

        # C. Save Image to Disk (required for FiftyOne)
        # Save the [0,1] version for viewing
        fp = OUT_DIR / f"mnist_{i:04d}.png"
        save_image(img_01, fp)

        # D. Create FiftyOne Sample
        sample = fo.Sample(filepath=str(fp))

        # Store the prediction
        sample["prediction"] = fo.Classification(
            label=label,
            confidence=confidence
        )

        # Tag it based on result
        if label == "IDK":
            sample.tags.append("idk_predicted")
        else:
            sample.tags.append("digit_predicted")

        samples.append(sample)

    # Add samples to dataset
    if samples:
        dataset.add_samples(samples)
        dataset.save()
        print(f"Done! Created dataset '{dataset_name}' with {len(samples)} samples.")
        
        # --- 5. VISUALIZE ---
        # This opens the App inside the notebook
        session = fo.launch_app(dataset)
else:
    print("Skipping generation loop because model/ddpm not initialized.")