# Libraries

In [None]:
# Pandas for building and sorting the table
import pandas as pd
# Display helpers for rendering HTML
from IPython.display import HTML, display
# PIL - reading/resizing images
from PIL import Image
# Base64 encoding for embedding images into HTML
import base64
# Memory buffer for image data
from io import BytesIO
# System utilities
import os
import csv
from tqdm import tqdm

# Helpers

In [None]:
# Convert image to a small HTML thumbnail
def img_to_html(path, max_size=(128, 128)):
	try:
		img = Image.open(path)
		img.thumbnail(max_size)
		buffer = BytesIO()
		img.save(buffer, format="PNG")
		b64 = base64.b64encode(buffer.getvalue()).decode()
		return f'<img src="data:image/png;base64,{b64}"/>'
	except:
		return "(preview error)"

# Extract base id such as cb_002 from all filename formats
def extract_base(filename):
	core = filename.split(".")[0] # remove extension
	parts = core.split("_")

	# Case: cb_002.png (only 2 parts --> clean)
	if len(parts) == 2:
		return parts[0] + "_" + parts[1]
	# Case: cb_002_clean or cb_002_glazed...
	if len(parts) >= 3:
		return parts[0] + "_" + parts[1]

	return parts[0]

# Color rules for each perturbation type
def type_color(t):
	if t == "clean":
		return "#FFFFFF"  # white
	if t == "glazed":
		return "#E4DAFF"  # light glaze lavender
	if t == "shaded":
		return "#E0EBFF"  # light nightshade blue
	if t == "glazed_shaded":
		return "#EEFFEE"  # blended tint
	return "#FFFFFF"

# Dictionary grouping by type
***(clean + glazed + shaded + shaded_glazed)***

Images are organized based on their filename endings so downstream processing can handle each category separately.

In [None]:
# Resolve training directory
current_dir = os.getcwd()
root_dir = os.path.abspath(os.path.join(current_dir, "..", ".."))
training_data_dir = os.path.join(root_dir, "Adversarial-Perturbation-main", "training_data")

print("Current directory:	  ", current_dir)
print("Resolved training path: ", training_data_dir)

IMAGE_EXTENSIONS = (".png", ".jpg", ".jpeg")

# Final structure:
# data[style][datatype] = {
#	"images": [list of paths],
#	"metadata": { filename: text }
# }
data = {}

print("\nScanning training_data...\n")

# First pass: detect and load metadata.csv (inside train folder)
metadata_map = {}   # key = comic_clean, materials_glazed, etc.

for root, dirs, files in os.walk(training_data_dir):

	current_folder = os.path.basename(root).lower()
	parent_folder  = os.path.basename(os.path.dirname(root)).lower()

	# metadata is inside: style_datatype/train/metadata.csv
	if current_folder == "train" and "metadata.csv" in [f.lower() for f in files]:

		meta_path = os.path.join(root, "metadata.csv")
		folder_key = parent_folder  # comic_clean, materials_glazed, etc.

		metadata_map[folder_key] = {}

		with open(meta_path, "r", newline="", encoding="utf-8") as csvfile:
			reader = csv.DictReader(csvfile)
			for row in reader:
				fname = row["file_name"].strip()
				text  = row["text"].strip()
				metadata_map[folder_key][fname] = text


# Second pass: gather images using parent folder (style_datatype)
for root, dirs, files in os.walk(training_data_dir):

	parent_folder = os.path.basename(os.path.dirname(root)).lower()

	if parent_folder == "training_data":
		continue

	if "_" not in parent_folder:
		continue

	style, datatype = parent_folder.split("_", 1)

	if style not in data:
		data[style] = {}

	if datatype not in data[style]:
		data[style][datatype] = {
			"images": [],
			"metadata": metadata_map.get(parent_folder, {})
		}

	for fname in files:
		if fname.lower().endswith(IMAGE_EXTENSIONS):
			full_path = os.path.join(root, fname)
			data[style][datatype]["images"].append(full_path)

# Summary Print
print("\n=== IMAGE GROUP SUMMARY ===")
for style, typedict in data.items():
	print(f"\nStyle: {style}")
	for datatype, info in typedict.items():
		print(f"  {datatype:20s} --> {len(info['images'])} images, {len(info['metadata'])} metadata entries")

print("\nDone.\n")

# Table of Images 
***(clean + glazed + shaded + shaded_glazed)***

In [None]:
# Build table (preview + name + type + metadata)
ordered_types = ["clean", "glazed", "shaded", "glazed_shaded"]
rows = []

for style, typedict in data.items():
	for t in ordered_types:
		if t not in typedict:
			continue

		images = typedict[t]["images"]
		metadata_dict = typedict[t]["metadata"]

		for img_path in images:
			filename = os.path.basename(img_path)
			base = extract_base(filename)
			meta_text = metadata_dict.get(filename, "")

			rows.append({
				"base": base,
				"preview": img_to_html(img_path),
				"filename": filename,
				"style": style,
				"type": t,
				"metadata": meta_text
			})

df = pd.DataFrame(rows)

# Assign clean type when filename has no perturbation suffix
df.loc[df['filename'].str.count('_') == 1, 'type'] = 'clean'

df["type_order"] = df["type"].apply(lambda x: ordered_types.index(x))

# sort by base then type_order to guarantee grouping
df = df.sort_values(by=["base", "type_order"]).drop(columns=["type_order"])


# Build grouped HTML manually to clearly separate each base block
html_output = """
<style>
.group-header {
	background: #DDE3EE;
	font-weight: bold;
	padding: 8px;
	border-top: 3px solid #2F3A4A;
	font-family: Arial, sans-serif;
}
.rowtable {
	margin-bottom: 15px;
}
</style>
"""

current_base = None

for _, row in df.iterrows():
	base = row["base"]

	# New group header when base changes
	if base != current_base:
		html_output += f'<div class="group-header">{base}</div>'
		current_base = base

	# Add row for each perturbation
	bg = type_color(row['type'])
	
	html_output += f"""
	<div style="
		display:flex;
		align-items:center;
		gap:18px;
		padding:10px;
		background:{bg};
		border-bottom:1px solid #C8C8C8;
		font-family:Arial;
	">
		<div style="width:150px;">{row['preview']}</div>
		<div style="flex:1;">{row['filename']}</div>
		<div style="width:120px;">{row['style']}</div>
		<div style="width:140px;">{row['type']}</div>
		<div style="flex:2;">{row['metadata']}</div>
	</div>
	"""

display(HTML(html_output))

# Black-Box Analysis
This module performs black-box perturbation analysis for adversarial-art-protection methods (Glaze, Nightshade, and Glaze+Nightshade).
Since model internals are not accessible, we instead treat each perturbation method as a black-box transformation that converts a clean image into a protected one.

## Black-Box Explanation #1 - Occlusion Sensitivity (Sliding Window Occlusion)
A masking window systematically removes regions of the clean image, and we measure how much the output image changes:
- High change → region is strongly modified by the perturbation method
- Low change → region is less affected
- Difference maps reveal perturbation behavior without model introspection
  
**Expected results:**
1. Spatial perturbation intensity heatmaps
2. Comparison across methods (clean vs glaze vs nightshade vs glazed+shaded)
3. Region-level behavioral characterization
4. Metadata-aware alignment if captions differ regionally

In [None]:
import numpy as np
import cv2
print("OpenCV:", cv2.__version__)
import matplotlib.pyplot as plt
from PIL import Image

###############################################
# Utility Functions
###############################################

def load_img(path):
	"""Load image as float32 numpy array [0,1]."""
	img = np.array(Image.open(path).convert("RGB")).astype(np.float32) / 255.0
	return img

def apply_occlusion(img, x, y, w, h, color=(0,0,0)):
	"""
	Apply an occlusion patch to a copy of the image.
	"""
	occluded = img.copy()
	occluded[y:y+h, x:x+w] = color
	return occluded

def diff_metric(a, b):
	"""
	Compute pixel-level difference magnitude.
	You can replace with L2, SSIM, LPIPS later.
	"""
	return np.mean(np.abs(a - b))


###############################################
# Core Occlusion Sensitivity Engine
###############################################

def occlusion_map(clean_img, perturbed_img, window=32, stride=16):
	"""
	Generate a heatmap showing how much the protected image
	changes when local regions of the CLEAN image are occluded.
	"""
	H, W, _ = clean_img.shape
	heatmap = np.zeros((H, W))

	for y in range(0, H - window, stride):
		for x in range(0, W - window, stride):

			# 1. Mask the CLEAN image region
			clean_occ = apply_occlusion(clean_img, x, y, window, window)

			# 2. Compute diff against the PERTURBED image
			score = diff_metric(clean_occ, perturbed_img)

			# 3. Fill heatmap region with score
			heatmap[y:y+window, x:x+window] = score

	return heatmap


###############################################
# Visualization Helper (Includes Row Label)
###############################################

def show_occlusion_results(clean, perturbed, heatmap, perturb_label):
	"""
	Plot clean image, perturbed image, and heatmap with a row label.
	"""

	fig, axes = plt.subplots(
		1, 4, figsize=(18,5),
		gridspec_kw={'width_ratios':[0.15, 1, 1, 1]}
	)

	# Row label (vertical)
	axes[0].axis("off")
	axes[0].text(
		0.5, 0.5, perturb_label,
		fontsize=16, fontweight="bold",
		rotation=90, ha="center", va="center"
	)

	# Clean image
	axes[1].imshow(clean)
	axes[1].set_title("Clean Image")
	axes[1].axis("off")

	# Perturbed image
	axes[2].imshow(perturbed)
	axes[2].set_title(f"Perturbed Image\n({perturb_label})")
	axes[2].axis("off")

	# Heatmap
	axes[3].imshow(heatmap, cmap="hot")
	axes[3].set_title(f"Sensitivity Map\n({perturb_label})")
	axes[3].axis("off")

	plt.tight_layout()
	plt.show()


###############################################
# Retrieve correct file paths from df/data
###############################################

def get_paths_for_base(base_id):
	"""
	Return dictionary of:
		clean, glazed, shaded, glazed_shaded --> full file paths
	based on df and data dictionaries.
	"""
	subset = df[df["base"] == base_id]
	paths = {}

	for perturb_type in ["clean", "glazed", "shaded", "glazed_shaded"]:
		rows = subset[subset["type"] == perturb_type]

		if len(rows) == 1:
			fname = rows["filename"].iloc[0]
			style = rows["style"].iloc[0]
			perturb = rows["type"].iloc[0]

			# Retrieve actual stored paths
			candidates = data[style][perturb]["images"]
			match = [p for p in candidates if p.endswith(fname)]

			if len(match) == 1:
				paths[perturb_type] = match[0]

	return paths

In [None]:
# Progress
import time
from datetime import datetime, timedelta
#from IPython.display import clear_output

#############################################################
# Run Occlusion Sensitivity on ALL image groups (timed)
#############################################################

# Collect image groups
all_bases = sorted(df["base"].unique())
total_groups = len(all_bases)

# Time bookkeeping
start_time = time.time()
start_dt = datetime.now()

print("Starting occlusion sensitivity analysis.")
print(f"Found {total_groups} image groups.\n")

# Show start time in both formats
print(f"Start time: {start_dt.strftime('%Y-%m-%d %H:%M:%S')}  ({start_dt.strftime('%I:%M:%S %p')})")

# Estimate runtime based on typical occlusion cost
estimated_seconds_per_group = 6   # adjust if needed
estimated_total = estimated_seconds_per_group * total_groups
estimated_finish = start_dt + timedelta(seconds=estimated_total)

print(f"Approximate time required: {estimated_total:.1f} seconds")

# Show estimated finish time in both formats
print(f"Estimated finish time: {estimated_finish.strftime('%Y-%m-%d %H:%M:%S')}  ({estimated_finish.strftime('%I:%M:%S %p')})\n")

group_counter = 0

for base_id in all_bases:
    group_counter += 1

    print(f"\n===== Processing image group '{base_id}' ({group_counter}/{total_groups}) =====")

    # Retrieve correct image paths
    paths = get_paths_for_base(base_id)

    # Skip incomplete groups
    required = ["clean", "glazed", "shaded", "glazed_shaded"]
    if not all(pt in paths for pt in required):
        print(f"[WARNING] Missing some variants for {base_id}, skipping.\n")
        continue

    # Load variants
    clean_img   = load_img(paths["clean"])
    glazed_img  = load_img(paths["glazed"])
    shade_img   = load_img(paths["shaded"])
    both_img    = load_img(paths["glazed_shaded"])

    # Compute heatmaps
    H_glaze = occlusion_map(clean_img, glazed_img)
    H_shade = occlusion_map(clean_img, shade_img)
    H_both  = occlusion_map(clean_img, both_img)

    # Visualize all results inline
    show_occlusion_results(clean_img, glazed_img, H_glaze,
                           perturb_label=f"{base_id} — Glaze")

    show_occlusion_results(clean_img, shade_img, H_shade,
                           perturb_label=f"{base_id} — Nightshade")

    show_occlusion_results(clean_img, both_img, H_both,
                           perturb_label=f"{base_id} — Glaze + Nightshade")

    print(f"[COMPLETED] {base_id}\n")

# After ALL groups have been processed, clear output
# clear_output(wait=True)

end_time = time.time()
end_dt = datetime.now()
elapsed = end_time - start_time

print("[COMPLETED] Occlusion sensitivity maps completed for all image groups.\n")

# Print start/end in both formats
print(f"Start time: {start_dt.strftime('%Y-%m-%d %H:%M:%S')}  ({start_dt.strftime('%I:%M:%S %p')})")
print(f"End time:   {end_dt.strftime('%Y-%m-%d %H:%M:%S')}  ({end_dt.strftime('%I:%M:%S %p')})")
print(f"Total elapsed time: {elapsed:.2f} seconds\n")

## Black-Box Explanation #2 — Fourier / Frequency-Domain Perturbation Analysis

Pixel-space differences between Clean, Glazed, and Nightshade images are often
visually subtle. However, these perturbations leave **strong and measurable
fingerprints in the frequency domain**, which becomes visible once we analyze
the image using the **2D Fast Fourier Transform (FFT)**.

---

### What is the Frequency Domain?

Every image can be interpreted not only as a grid of pixel intensities, but also
as a combination of sinusoidal waves of different **frequencies**, **orientations**,  
and **magnitudes**.

In this representation:

* **Low frequencies**  
  Large-scale structures, smooth lighting changes, color gradients.

* **Mid frequencies**  
  Object contours, shading transitions, macro texture.

* **High frequencies**  
  Fine texture, stippling, dithering, micro-noise —  
  *precisely where many adversarial or protective perturbations operate.*

The **2D FFT** decomposes an image into these components. Instead of asking:

> “How does Glaze or Nightshade change the pixels?”

we instead ask:

> “Which *frequencies* did the perturbation strengthen or weaken?”

This completely avoids needing model gradients or internals —  
**perfect for black-box analysis.**

---

### Why Fourier Analysis?

Both perturbation systems alter the image's texture and structure in ways that
are *difficult to notice in pixel-space but obvious in frequency-space*:

* **Glaze**  
  Introduces controlled **high-frequency texture injections** designed to disrupt
  style-transfer feature alignment.

* **Nightshade**  
  Can shift **mid and low frequencies**, interfering with semantic alignment and
  feature extraction.

* **Glaze + Nightshade**  
  Produces **broadband alterations**, affecting multiple frequency bands at once.

These spectral behaviors act like **fingerprints**, and FFT plots reveal them
directly.


For each clean–perturbed pair, compute:

1. **FFT magnitude spectrum (Clean)**  
   Shows the natural frequency structure of the unmodified artwork.

2. **FFT magnitude spectrum (Perturbed)**  
   Reveals injected frequencies or suppressed structures.

3. **FFT Difference Map (|FFT_clean - FFT_perturbed|)**  
   Highlights frequency regions where energy was modified.

4. **Radial Frequency Profiles**  
   Collapses the 2D FFT into a 1D low→high curve, showing how perturbations  
   affect different frequency bands.

This set of views captures the entire spectral behavior of each perturbation.

---

### What the Fourier Figures Mean (Legend)

Each row of the Fourier analysis displays the following:

1. **Clean Image**  
   The original input image for the given group (e.g., `cb_001`).  
   Serves as the baseline.

2. **Perturbed Image (Glaze / Nightshade / Combined)**  
   The visual result after the perturbation method is applied.

3. **FFT — Clean**  
   Log-magnitude spectrum.  
   Bright areas = strong frequencies.  
   Repeated patterns in the artwork produce symmetric star-like structures.

4. **FFT — Perturbed**  
   Shows how the perturbation shifted or added energy into the frequency space.  
   Extra texture → more high-frequency brightness.

5. **FFT Difference Map**  
   Highlights the regions where the perturbation increased or decreased spectral
   energy.  
   Faint or softly textured difference maps indicate subtle perturbations;  
   strong hotspots reveal concentrated spectral alterations.

---

### Expected Revelations

FFT-based analysis is extremely sensitive, and you should see:

* Glaze increasing **fine-scale (high-frequency)** energy.
* Nightshade introducing **mid-frequency distortions**.
* Combined perturbations broadening the frequency response.
* Unique, stable spectral “signatures” for each method.

Because the FFT does not rely on a model, this method is fully **black-box**  
and applicable across any architecture or decoder.

---

**Expected Outputs:**
1. Clean vs Perturbed images (visual reference)  
2. Full FFT magnitude spectrum  
3. FFT difference heatmap  
4. Radial low/mid/high frequency distribution  
5. Per-method spectral fingerprints

**Overall: Computing the log-magnitude spectrum of the shifted 2D FFT to visualize frequency energy.**

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from numpy.fft import fft2, fftshift
from IPython.display import clear_output
from scipy.ndimage import gaussian_filter

###############################################
# Utility: Load + Normalize
###############################################
def load_img(path):
	img = np.array(Image.open(path).convert("RGB")).astype(np.float32) / 255.0
	return img

###############################################
# FFT Magnitude Spectrum
###############################################
def fft_magnitude(img):
	"""
	Computes log-magnitude FFT spectrum for visualization.
	Operates on grayscale for clarity.
	"""
	gray = img.mean(axis=2)
	F = fftshift(fft2(gray))
	mag = np.log(1 + np.abs(F))
	return mag

###############################################
# Radial Frequency Profile
###############################################
def radial_profile(mag):
	"""
	Computes radial average of FFT magnitude.
	Returns: (radii, profile)
	"""
	H, W = mag.shape
	cy, cx = H // 2, W // 2

	y, x = np.ogrid[:H, :W]
	r = np.sqrt((x - cx)**2 + (y - cy)**2)

	r = r.astype(int)
	max_r = r.max()

	radial = np.zeros(max_r + 1)
	counts = np.zeros(max_r + 1)

	for radius in range(max_r + 1):
		mask = (r == radius)
		radial[radius] = mag[mask].mean()
		counts[radius] = mask.sum()

	return np.arange(max_r + 1), radial

###############################################
# Plot Radial Spectrum
###############################################
def show_radial(clean_img, perturbed_img, base_id, label):
	r_clean, s_clean = radial_profile(fft_magnitude(clean_img))
	r_pert, s_pert   = radial_profile(fft_magnitude(perturbed_img))

	plt.figure(figsize=(8,5))
	plt.plot(r_clean, s_clean, label="Clean", linewidth=2)
	plt.plot(r_pert, s_pert, label=label, linewidth=2)

	plt.title(f"Radial Frequency Spectrum — {base_id}")
	plt.xlabel("Radius (Frequency)")
	plt.ylabel("Magnitude")
	plt.legend()
	plt.grid(True)
	plt.show()


###############################################
# Improved Fourier Analysis Helper
###############################################
def compute_fft_mag(img):
	"""
	Compute log-magnitude spectrum of an image.
	"""
	gray = np.mean(img, axis=2)
	F = np.fft.fft2(gray)
	Fshift = np.fft.fftshift(F)
	mag = np.log(np.abs(Fshift) + 1e-6)
	return mag

###############################################
# Visualization of FFT + Difference
###############################################
def show_fourier(clean, perturbed, base_id, label):
	"""
	Displays:
	- Clean image
	- Perturbed image
	- FFT magnitude (clean)
	- FFT magnitude (perturbed)
	- FFT difference map
	- True visual colorbar legends (magma + coolwarm)
	"""

	fft_clean = compute_fft_mag(clean)
	fft_pert  = compute_fft_mag(perturbed)
	fft_diff  = fft_pert - fft_clean  # signed difference

	fig, axes = plt.subplots(
		2, 4, figsize=(22,10),
		gridspec_kw={'height_ratios':[12, 1]}
	)

	# ------------------------------------------------------------
	# FIRST ROW: IMAGES + FFT MAPS
	# ------------------------------------------------------------

	# Clean image
	axes[0,0].imshow(clean)
	axes[0,0].set_title(f"{base_id}\nClean Image")
	axes[0,0].axis("off")

	# Perturbed image
	axes[0,1].imshow(perturbed)
	axes[0,1].set_title(f"{label}\nPerturbed Image")
	axes[0,1].axis("off")

	# FFT Clean
	im_fft_clean = axes[0,2].imshow(fft_clean, cmap="magma")
	axes[0,2].set_title("FFT — Clean")
	axes[0,2].axis("off")

	# FFT Difference
	im_fft_diff = axes[0,3].imshow(fft_diff, cmap="coolwarm")
	axes[0,3].set_title(f"FFT Difference\n{base_id} — {label}")
	axes[0,3].axis("off")

	# ------------------------------------------------------------
	# SECOND ROW: COLOR LEGENDS (VISUAL)
	# ------------------------------------------------------------

	# magma legend
	cbar1 = fig.colorbar(im_fft_clean, cax=axes[1,2], orientation="horizontal")
	cbar1.set_label("FFT Magnitude\n(dark = low, bright = high)", fontsize=10)

	# coolwarm legend
	cbar2 = fig.colorbar(im_fft_diff, cax=axes[1,3], orientation="horizontal")
	cbar2.set_label(
		"FFT Difference\nBlue = Clean > Perturbed | Red = Perturbed > Clean", 
		fontsize=10
	)

	# Remove lower-left empty axes
	axes[1,0].axis("off")
	axes[1,1].axis("off")

	plt.tight_layout()
	plt.show()

In [None]:
# Progress / timing
import time
from datetime import datetime, timedelta

#############################################################
# Run Fourier Analysis on ALL image groups (timed)
#############################################################

all_bases = sorted(df["base"].unique())
total_groups = len(all_bases)

start_time = time.time()
start_dt = datetime.now()

print("Starting Fourier frequency analysis.")
print(f"Found {total_groups} image groups.\n")

print(f"Start time: {start_dt.strftime('%Y-%m-%d %H:%M:%S')}  ({start_dt.strftime('%I:%M:%S %p')})")

# Estimated runtime (adjust if it feels slow/fast)
estimated_seconds_per_group = 3
estimated_total = estimated_seconds_per_group * total_groups
estimated_finish = start_dt + timedelta(seconds=estimated_total)

print(f"Approximate time required: {estimated_total:.1f} seconds")
print(f"Estimated finish time: {estimated_finish.strftime('%Y-%m-%d %H:%M:%S')}  ({estimated_finish.strftime('%I:%M:%S %p')})\n")

group_counter = 0

for base_id in all_bases:
	group_counter += 1

	print(f"\n===== Processing image group '{base_id}' ({group_counter}/{total_groups}) =====")

	paths = get_paths_for_base(base_id)

	required = ["clean", "glazed", "shaded", "glazed_shaded"]
	if not all(k in paths for k in required):
		print(f"[WARNING] Missing variants for {base_id}, skipping.")
		continue

	# Load images using the same loader as occlusion
	clean = load_img(paths["clean"])
	glaze = load_img(paths["glazed"])
	shade = load_img(paths["shaded"])
	both  = load_img(paths["glazed_shaded"])

	# FFT visualizations
	show_fourier(clean, glaze, base_id, "Glaze")
	show_fourier(clean, shade, base_id, "Nightshade")
	show_fourier(clean, both,  base_id, "Glaze+Nightshade")

	# Radial spectra
	show_radial(clean, glaze, base_id, "Glaze")
	show_radial(clean, shade, base_id, "Nightshade")
	show_radial(clean, both,  base_id, "Glaze+Nightshade")

	print(f"[COMPLETED] {base_id}")

end_time = time.time()
end_dt = datetime.now()
elapsed = end_time - start_time

print("\n[COMPLETED] Fourier analysis completed for all image groups.\n")
print(f"Start time: {start_dt.strftime('%Y-%m-%d %H:%M:%S')}  ({start_dt.strftime('%I:%M:%S %p')})")
print(f"End time:   {end_dt.strftime('%Y-%m-%d %H:%M:%S')}  ({end_dt.strftime('%I:%M:%S %p')})")
print(f"Total elapsed time: {elapsed:.2f} seconds\n")

## Black-Box Explanation #2.b — Greyscale FFT Analysis
This version converts images to **pure grayscale before computing FFT magnitude, difference maps,  
and radial spectra.** 

- Removing color channels emphasizes *structural, shading, and texture-level  
frequency changes* ==> Helps **distinguish perturbation signatures** (Glaze, Nightshade, combined)  
that may be obscured in RGB space.

In [None]:
# ================================================================
# Greyscale FFT Analysis — Helper Functions (Separate Pipeline)
# ================================================================

import numpy as np
import matplotlib.pyplot as plt
from numpy.fft import fft2, fftshift

###################################################################
# Greyscale image loader
##################################################################
def load_img_gray(path):
    """
    Load as pure greyscale, normalized to [0,1].
    Returns a 2D array (H, W).
    """
    img = np.array(Image.open(path).convert("L")).astype(np.float32) / 255.0
    return img


##################################################################
# FFT magnitude
##################################################################
def fft_magnitude_gray(img):
    """
    Grayscale FFT log-magnitude spectrum.
    """
    F = fftshift(fft2(img))
    mag = np.log(1 + np.abs(F))
    return mag


##################################################################
# Radial frequency profile
##################################################################
def radial_profile_gray(mag):
    """
    Radial average of magnitude spectrum (grayscale version).
    """
    H, W = mag.shape
    cy, cx = H // 2, W // 2

    y, x = np.ogrid[:H, :W]
    r = np.sqrt((x - cx)**2 + (y - cy)**2).astype(int)

    radial_vals = []
    radii = np.arange(r.max() + 1)

    for radius in radii:
        mask = (r == radius)
        radial_vals.append(mag[mask].mean())

    return radii, np.array(radial_vals)


##################################################################
# Visualization: FFT + Differences (grayscale)
##################################################################
def show_fourier_gray(clean, perturbed, base_id, label):
    """
    Grayscale Fourier visualization:
        - Clean image
        - Perturbed image
        - Clean FFT
        - Perturbed FFT
        - Difference FFT
    With color legends.
    """
    fft_clean = fft_magnitude_gray(clean)
    fft_pert  = fft_magnitude_gray(perturbed)
    fft_diff  = fft_pert - fft_clean

    fig, axes = plt.subplots(
        2, 4, figsize=(22,10),
        gridspec_kw={'height_ratios':[12, 1]}
    )

    # Clean image
    axes[0,0].imshow(clean, cmap="gray")
    axes[0,0].set_title(f"{base_id}\nClean (Grayscale)")
    axes[0,0].axis("off")

    # Perturbed image
    axes[0,1].imshow(perturbed, cmap="gray")
    axes[0,1].set_title(f"{label}\nPerturbed (Grayscale)")
    axes[0,1].axis("off")

    # FFT Clean
    im1 = axes[0,2].imshow(fft_clean, cmap="magma")
    axes[0,2].set_title("FFT — Clean")
    axes[0,2].axis("off")

    # FFT Difference
    im2 = axes[0,3].imshow(fft_diff, cmap="coolwarm")
    axes[0,3].set_title(f"FFT Difference\n{base_id} — {label}")
    axes[0,3].axis("off")

    # Legends (colorbars)
    cbar1 = plt.colorbar(im1, cax=axes[1,2], orientation="horizontal")
    cbar1.set_label("FFT Magnitude (dark = low, bright = high)")

    cbar2 = plt.colorbar(im2, cax=axes[1,3], orientation="horizontal")
    cbar2.set_label("FFT Difference (blue = clean>perturbed, red = perturbed>clean)")

    axes[1,0].axis("off")
    axes[1,1].axis("off")

    plt.tight_layout()
    plt.show()

##################################################################
# Visualization: Radial Spectrum (grayscale)
##################################################################
def show_radial_gray(clean, perturbed, base_id, label):
    r_clean, s_clean = radial_profile_gray(fft_magnitude_gray(clean))
    r_pert,  s_pert  = radial_profile_gray(fft_magnitude_gray(perturbed))

    plt.figure(figsize=(8,5))
    plt.plot(r_clean, s_clean, label="Clean (Grayscale)", linewidth=2)
    plt.plot(r_pert,  s_pert,  label=label, linewidth=2)

    plt.title(f"Radial Frequency Spectrum — {base_id}")
    plt.xlabel("Radius (Frequency)")
    plt.ylabel("Magnitude")
    plt.grid(True)
    plt.legend()
    plt.show()

In [None]:
import time
from datetime import datetime, timedelta

all_bases = sorted(df["base"].unique())
total_groups = len(all_bases)

start_time = time.time()
start_dt = datetime.now()

print("Starting GREYSCALE Fourier analysis.")
print(f"Found {total_groups} image groups.\n")

print(f"Start time: {start_dt.strftime('%Y-%m-%d %H:%M:%S')} ({start_dt.strftime('%I:%M:%S %p')})")

# Estimate time
estimated_seconds_per_group = 2.5
estimated_total = estimated_seconds_per_group * total_groups
estimated_finish = start_dt + timedelta(seconds=estimated_total)

print(f"Approximate time required: {estimated_total:.1f} seconds")
print(f"Estimated finish time: {estimated_finish.strftime('%Y-%m-%d %H:%M:%S')} ({estimated_finish.strftime('%I:%M:%S %p')})\n")

group_counter = 0

for base_id in all_bases:
    group_counter += 1
    print(f"\n===== Processing (Grayscale) '{base_id}' ({group_counter}/{total_groups}) =====")

    paths = get_paths_for_base(base_id)
    required = ["clean", "glazed", "shaded", "glazed_shaded"]

    if not all(k in paths for k in required):
        print(f"[WARNING] Missing variants for {base_id}, skipping.")
        continue

    # Load as GREYSCALE
    clean = load_img_gray(paths["clean"])
    glaze = load_img_gray(paths["glazed"])
    shade = load_img_gray(paths["shaded"])
    both  = load_img_gray(paths["glazed_shaded"])

    # Fourier visualizations
    show_fourier_gray(clean, glaze, base_id, "Glaze")
    show_fourier_gray(clean, shade, base_id, "Nightshade")
    show_fourier_gray(clean, both,  base_id, "Glaze+Nightshade")

    # Radial frequency spectra
    show_radial_gray(clean, glaze, base_id, "Glaze")
    show_radial_gray(clean, shade, base_id, "Nightshade")
    show_radial_gray(clean, both,  base_id, "Glaze+Nightshade")

    print(f"[COMPLETED] {base_id}")

end_time = time.time()
end_dt = datetime.now()
elapsed = end_time - start_time

print("\n[COMPLETED] Grayscale Fourier analysis finished.\n")
print(f"Start time: {start_dt.strftime('%Y-%m-%d %H:%M:%S')} ({start_dt.strftime('%I:%M:%S %p')})")
print(f"End time:   {end_dt.strftime('%Y-%m-%d %H:%M:%S')} ({end_dt.strftime('%I:%M:%S %p')})")
print(f"Total elapsed time: {elapsed:.2f} seconds\n")


## --