In [None]:
# === Config ===
USE_DCT_ATTENTION = 0   # 🔄 flip to True later if you want attention
IMG_SIZE = 224

In [None]:
import os
import glob

"""
the function below just loads picks up the right paths from the test or the validation directory and
arranges them in a dictionary
The directory structure is that say we pickup a random directory X... under it we will have one or multiple clean images 
and under sub-dir named distortion under which all the distorted images are present
"""

def collect_image_paths(root_dir):
    person_dict = {}

    for person_folder in sorted(os.listdir(root_dir)):
        folder_path = os.path.join(root_dir, person_folder)
        if not os.path.isdir(folder_path):
            continue

        # Undistorted images (exclude anything in subfolders)
        all_jpgs = glob.glob(os.path.join(folder_path, "*.jpg"))
        clean_imgs = [f for f in all_jpgs if "distortion" not in f]

        # Distorted images
        distortion_dir = os.path.join(folder_path, "distortion")
        distortion_imgs = []
        if os.path.exists(distortion_dir):
            distortion_imgs = glob.glob(os.path.join(distortion_dir, "*.jpg"))

        if clean_imgs:
            person_dict[person_folder] = {
                "clean": clean_imgs,  # store list, not just one
                "distorted": distortion_imgs
            }

    return person_dict

In [None]:
train_dir = f"/kaggle/input/facecom/Comys_Hackathon5/Task_B/train"
person_dict = collect_image_paths(train_dir)

In [None]:
validation_dir = f"/kaggle/input/facecom/Comys_Hackathon5/Task_B/val"
val_dict = collect_image_paths(validation_dir)

In [None]:
import random

def generate_balanced_augmented_pairs(person_dict, min_pos_per_id=28, num_neg_per_pos=3, seed=42):
    """
    Generates positive and negative pairs for face verification task.
    
    - Ensures at least `min_pos_per_id` positive pairs per identity (if possible).
    - For each positive, generates `num_neg_per_pos` negative pairs from other identities.
    - Random sampling ensures balance and coverage of identities.

    Args:
        person_dict (dict): Dictionary with keys as person IDs and values as dicts with 'clean' and 'distorted' paths.
        min_pos_per_id (int): Minimum number of positive pairs per identity.
        num_neg_per_pos (int): Number of negative pairs to generate per positive pair.
        seed (int): Random seed for reproducibility.

    Returns:
        List of (img1_path, img2_path, label) tuples.
    """
    random.seed(seed)
    all_ids = list(person_dict.keys())
    positive_pairs = []
    negative_pairs = []

    for person_id in all_ids:
        images = []
        # Gather all available images for positive pairing
        if isinstance(person_dict[person_id]['clean'], list):
            images += person_dict[person_id]['clean']
        elif person_dict[person_id]['clean']:
            images.append(person_dict[person_id]['clean'])

        images += person_dict[person_id]['distorted']

        # Ensure at least 2 images to form a pair
        if len(images) < 2:
            continue

        # Generate all possible positive pairs (combinations of 2)
        all_pos_pairs = [(a, b) for idx, a in enumerate(images) for b in images[idx+1:] if a != b]

        # Sample up to min_pos_per_id (or fewer if not enough pairs exist)
        selected_pos_pairs = random.sample(all_pos_pairs, min(min_pos_per_id, len(all_pos_pairs)))

        for img1, img2 in selected_pos_pairs:
            positive_pairs.append((img1, img2, 1))

            # Generate negatives for each positive
            other_ids = [pid for pid in all_ids if pid != person_id]
            sampled_neg_ids = random.sample(other_ids, min(num_neg_per_pos, len(other_ids)))

            for neg_id in sampled_neg_ids:
                neg_candidates = person_dict[neg_id]['distorted'] or [person_dict[neg_id]['clean']]
                if not neg_candidates:
                    continue
                neg_img = random.choice(neg_candidates)
                # Use img1 from the positive pair to form the negative pair
                negative_pairs.append((img1, neg_img, 0))

    # Combine and shuffle
    all_pairs = positive_pairs + negative_pairs
    random.shuffle(all_pairs)
    return all_pairs, positive_pairs, negative_pairs

In [None]:
all_pairs, positive_pairs, negative_pairs = generate_balanced_augmented_pairs(
    person_dict=person_dict, 
    min_pos_per_id=28, 
    num_neg_per_pos=1
)

In [None]:
all_val_pairs, val_pos , val_neg = generate_balanced_augmented_pairs(val_dict,20,1)

In [None]:
# Step 3: Stats
print(f"✅ Total pairs: {len(all_pairs)}")
print(f"🔵 Positive pairs: {len(positive_pairs)}")
print(f"🔴 Negative pairs: {len(negative_pairs)}")
print("🧾 Sample pairs:", all_pairs[:3])

# Step 4: Visualization
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

pair_stats = {"identity": [], "label": []}
for img1, img2, label in all_pairs:
    identity = os.path.basename(os.path.dirname(img1))
    pair_stats["identity"].append(identity)
    pair_stats["label"].append(label)

df_pairs = pd.DataFrame(pair_stats)

plt.figure(figsize=(10, 5))
sns.countplot(data=df_pairs, x="label", palette="Set2")
plt.xticks([0, 1], ["Negative", "Positive"])
plt.title("Distribution of Positive vs Negative Pairs")
plt.xlabel("Pair Type")
plt.ylabel("Count")
plt.grid(True, axis='y')
plt.show()

In [None]:
# Step 3: Stats
print(f"✅ Total pairs: {len(all_val_pairs)}")
print(f"🔵 Positive pairs: {len(val_pos)}")
print(f"🔴 Negative pairs: {len(val_neg)}")
print("🧾 Sample pairs:", all_pairs[:3])

# Step 4: Visualization
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

pair_stats = {"identity": [], "label": []}
for img1, img2, label in all_val_pairs:
    identity = os.path.basename(os.path.dirname(img1))
    pair_stats["identity"].append(identity)
    pair_stats["label"].append(label)

df_pairs = pd.DataFrame(pair_stats)

plt.figure(figsize=(10, 5))
sns.countplot(data=df_pairs, x="label", palette="Set2")
plt.xticks([0, 1], ["Negative", "Positive"])
plt.title("Distribution of Positive vs Negative Pairs")
plt.xlabel("Pair Type")
plt.ylabel("Count")
plt.grid(True, axis='y')
plt.show()

In [None]:
import torch
import torchvision.transforms as T
from PIL import Image

to_tensor = T.Compose([
    T.Grayscale(),
    T.Resize((224, 224)),
    T.ToTensor()
])

def compute_fft_attention_batch(batch1, batch2):
    f1 = torch.fft.fft2(batch1)             # [B, H, W]
    f2 = torch.fft.fft2(batch2)
    diff = torch.abs(f1 - f2)
    attn_maps = torch.fft.ifft2(diff).real  # [B, H, W]
    attn_maps -= attn_maps.amin(dim=(1, 2), keepdim=True)
    attn_maps /= (attn_maps.amax(dim=(1, 2), keepdim=True) + 1e-8)
    attn_maps = 1.0 - attn_maps
    return attn_maps.unsqueeze(1)  # [B, 1, H, W]

In [None]:
import os, hashlib
from torchvision.transforms import Compose, Grayscale, Resize, ToTensor
from PIL import Image
from tqdm.notebook import tqdm
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

resize = Resize((224, 224))
attention_root = "/kaggle/working/fft_attention_maps"
os.makedirs(attention_root, exist_ok=True)

to_tensor = Compose([Grayscale(), resize, ToTensor()])

In [None]:
def _cache_key(path1, path2):
    a, b = sorted([os.path.abspath(path1), os.path.abspath(path2)])
    key = hashlib.md5(f"{a}|{b}".encode()).hexdigest()
    return os.path.join(attention_root, f"{key}.pt")

In [None]:
def batch_cache_attention(pairs, batch_size=32):
    uncached = []
    paths = []

    for p1, p2, _ in pairs:
        out_path = _cache_key(p1, p2)
        if not os.path.exists(out_path):
            uncached.append((p1, p2))
            paths.append(out_path)

    if not uncached:
        print("✅ All attention maps already cached.")
        return

    for i in tqdm(range(0, len(uncached), batch_size), desc="⚡ Batch-caching attention maps"):
        batch = uncached[i:i + batch_size]
        batch_paths = paths[i:i + batch_size]

        imgs1 = []
        imgs2 = []

        for p1, p2 in batch:
            try:
                imgs1.append(to_tensor(Image.open(p1).convert("RGB")))
                imgs2.append(to_tensor(Image.open(p2).convert("RGB")))
            except Exception as e:
                print(f"❌ Failed to load: {p1} or {p2} — {e}")

        if not imgs1 or not imgs2:
            continue

        t1 = torch.stack(imgs1).squeeze(1).to(device)  # [B, H, W]
        t2 = torch.stack(imgs2).squeeze(1).to(device)

        attn_batch = compute_fft_attention_batch(t1, t2).cpu()  # [B, 1, H, W]

        for attn_map, out_path in zip(attn_batch, batch_paths):
            torch.save(attn_map.half(), out_path, pickle_protocol=4)

In [None]:
# Assuming: all_to_cache = positive_pairs + negative_pairs
if USE_DCT_ATTENTION:
    batch_cache_attention(all_pairs, batch_size=32)
else:
    print("⚡ Skipping FFT-attention caching (USE_DCT_ATTENTION = False)")

In [None]:
# Assuming: all_to_cache = positive_pairs + negative_pairs
if USE_DCT_ATTENTION:
    batch_cache_attention(all_val_pairs, batch_size=32)
else:
    print("⚡ Skipping FFT-attention caching (USE_DCT_ATTENTION = False)")

In [None]:
import os
import hashlib

def reverse_lookup_path(attn_path, all_pairs):
    target_hash = os.path.basename(attn_path).replace(".pt", "")

    for path1, path2, _ in all_pairs:
        a, b = sorted([os.path.abspath(path1), os.path.abspath(path2)])
        key_raw = f"{a}|{b}"
        key_hash = hashlib.md5(key_raw.encode()).hexdigest()
        if key_hash == target_hash:
            return path1, path2
    return None, None

from PIL import Image
import numpy as np
from scipy.fftpack import dct, idct
import torch

def apply_2d_dct(img):
    return dct(dct(img.T, norm='ortho').T, norm='ortho')

def apply_2d_idct(coeffs):
    return idct(idct(coeffs.T, norm='ortho').T, norm='ortho')

def compute_dct_attention(img1, img2):
    img1_np = np.array(img1.convert("L"), dtype=np.float32)
    img2_np = np.array(img2.convert("L"), dtype=np.float32)
    dct1 = apply_2d_dct(img1_np)
    dct2 = apply_2d_dct(img2_np)
    diff = np.abs(dct1 - dct2)
    attn_map = apply_2d_idct(diff)
    attn_map -= attn_map.min()
    attn_map /= (attn_map.max() + 1e-8)
    attn_map = 1.0 - attn_map
    return torch.tensor(attn_map).float()

In [None]:
import matplotlib.pyplot as plt

# Load cached attention map
attn_cached = torch.load(attn_path, weights_only=False).squeeze()

# Lookup original images
img1_path, img2_path = reverse_lookup_path(attn_path, all_pairs)
assert img1_path and img2_path, "❌ Original paths not found!"

# Recompute DCT attention map
img1 = Image.open(img1_path).convert("RGB")
img2 = Image.open(img2_path).convert("RGB")
attn_dct = compute_dct_attention(img1, img2)

# Resize to match if needed
if attn_cached.shape != attn_dct.shape:
    from torchvision.transforms import Resize
    attn_dct = Resize(attn_cached.shape)(attn_dct.unsqueeze(0)).squeeze(0)

# Show both
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.imshow(attn_cached.cpu().numpy(), cmap='viridis')
plt.title("Cached Attention Map")
plt.axis("off")

plt.subplot(1, 2, 2)
plt.imshow(attn_dct.cpu().numpy(), cmap='viridis')
plt.title("Recomputed DCT Attention")
plt.axis("off")

plt.suptitle("Sanity Check: Cached vs DCT Recomputed", fontsize=14)
plt.show()

In [None]:
import os, hashlib
from torchvision.transforms import Compose, Grayscale, Resize, ToTensor
from PIL import Image
from tqdm.notebook import tqdm
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

resize = Resize((224, 224))
attention_root = "/kaggle/working/fft_attention_maps"
os.makedirs(attention_root, exist_ok=True)

to_tensor = Compose([Grayscale(), resize, ToTensor()])

In [None]:
from torch.utils.data import Dataset
class FacePairDataset(Dataset):
    def __init__(self, pairs, transform=None):
        self.pairs = pairs
        self.transform = transform or T.Compose([
            T.Resize((IMG_SIZE, IMG_SIZE)), T.ToTensor()
        ])

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        p1, p2, label = self.pairs[idx]
        img1 = Image.open(p1).convert("RGB")
        img2 = Image.open(p2).convert("RGB")
        img1_t = self.transform(img1)
        img2_t = self.transform(img2)
        label_t = torch.tensor(label, dtype=torch.float32)

        if USE_DCT_ATTENTION:
            attn = torch.load(_cache_key(p1, p2), weights_only=False)
            return {"img1": img1_t, "img2": img2_t, "attn": attn, "label": label_t}
        else:
            return {"img1": img1_t, "img2": img2_t, "label": label_t}

In [None]:
train_dataset = FacePairDataset(all_pairs)
print("Total samples:", len(train_dataset))

# View one sample
sample = train_dataset[0]
img1  = sample["img1"]
img2  = sample["img2"]
attn  = sample.get("attn", None)  # safe even if attention is off
label = sample["label"]

print(f"Image 1 shape       : {img1.shape}")
print(f"Image 2 shape       : {img2.shape}")
if attn is not None:
    print(f"Attention map shape : {attn.shape}")
else:
    print("Attention map       : ❌ Not used (USE_DCT_ATTENTION = False)")
print(f"Label               : {label}")

In [None]:
val_dataset = FacePairDataset(all_val_pairs)
print("Total samples:", len(val_dataset))

# View one sample
sample = val_dataset[0]
img1  = sample["img1"]
img2  = sample["img2"]
attn  = sample.get("attn", None)  # safe even if attention is off
label = sample["label"]

print(f"Image 1 shape       : {img1.shape}")
print(f"Image 2 shape       : {img2.shape}")
if attn is not None:
    print(f"Attention map shape : {attn.shape}")
else:
    print("Attention map       : ❌ Not used (USE_DCT_ATTENTION = False)")
print(f"Label               : {label}")

In [None]:
import torch
import torch.nn as nn
import torchvision.models as models

# global toggle (make sure this exists in a config cell)
# USE_DCT_ATTENTION = False

class SiameseNet(nn.Module):          # renamed (name is up to you)
    def __init__(self, backbone="resnet18", pretrained=True):
        super().__init__()

        # backbone (easily switchable)
        base = getattr(models, backbone)(pretrained=pretrained)
        self.feature_extractor = nn.Sequential(*list(base.children())[:-1])  # -> [B, 512, 1, 1]

        # head
        self.fc = nn.Sequential(
            nn.Linear(512 * 2, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 1)
        )

    def forward(self, img1, img2, attn_map=None):
        """
        img1, img2: [B, 3, H, W]
        attn_map   : [B, 1, H, W] or None
        """
        if USE_DCT_ATTENTION and attn_map is not None:
            attn_map = attn_map.expand(-1, 3, -1, -1)  # broadcast to RGB
            img1, img2 = img1 * attn_map, img2 * attn_map

        f1 = self.feature_extractor(img1).view(img1.size(0), -1)  # [B, 512]
        f2 = self.feature_extractor(img2).view(img2.size(0), -1)  # [B, 512]

        out = self.fc(torch.cat([f1, f2], dim=1))  # [B, 1]  (logits)
        return out

In [None]:
from torch.cuda.amp import autocast, GradScaler  # ✅ modern AMP usage
from tqdm.auto import tqdm
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from sklearn.metrics import precision_score, recall_score, f1_score
from torch.utils.data import DataLoader
import torch.nn as nn
import pandas as pd
import torch
import os

def train_siamese_model_amp(
    model,
    train_dataset,
    val_dataset,
    epochs=20,
    batch_size=32,
    lr=1e-4,
    save_dir='/kaggle/working/',
    use_amp=True,
    patience=5,
    use_dct_attention=True
):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=5, T_mult=2)
    criterion = nn.BCEWithLogitsLoss()
    scaler = GradScaler() if use_amp else None

    history = {
        "epoch": [], "train_loss": [], "train_acc": [],
        "val_loss": [], "val_acc": [],
        "precision": [], "recall": [], "f1": []
    }

    best_val_acc = 0.0
    bad_epochs = 0
    best_model_path = None

    for epoch in range(epochs):
        model.train()
        train_loss, train_correct, total_train = 0.0, 0, 0

        print(f"\n📚 Epoch {epoch+1}/{epochs}")
        loop = tqdm(train_loader, desc="🔁 Training", leave=False)

        for batch in loop:
            img1 = batch["img1"].to(device)
            img2 = batch["img2"].to(device)
            labels = batch["label"].to(device)
            attn = batch["attn"].to(device) if use_dct_attention and "attn" in batch else None

            optimizer.zero_grad()

            with autocast(device_type='cuda', enabled=use_amp):
                outputs = model(img1, img2, attn).squeeze(1)
                loss = criterion(outputs, labels)

            if use_amp:
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                loss.backward()
                optimizer.step()

            preds = (torch.sigmoid(outputs) > 0.5).float()
            train_correct += (preds == labels).sum().item()
            train_loss += loss.item() * labels.size(0)
            total_train += labels.size(0)

            loop.set_postfix(loss=loss.item(), acc=train_correct / total_train)

        train_loss = train_loss / total_train
        train_acc = train_correct / total_train
        scheduler.step()

        # --- Validation ---
        model.eval()
        val_loss, val_correct, total_val = 0.0, 0, 0
        all_preds, all_labels = [], []

        with torch.no_grad():
            for batch in val_loader:
                img1 = batch["img1"].to(device)
                img2 = batch["img2"].to(device)
                labels = batch["label"].to(device)
                attn = batch["attn"].to(device) if use_dct_attention and "attn" in batch else None

                outputs = model(img1, img2, attn).squeeze(1)
                loss = criterion(outputs, labels)

                probs = torch.sigmoid(outputs).cpu().numpy()
                preds = (probs > 0.5).astype(float)
                labels_np = labels.cpu().numpy()

                all_preds.extend(preds)
                all_labels.extend(labels_np)

                val_correct += (preds == labels_np).sum()
                val_loss += loss.item() * labels.size(0)
                total_val += labels.size(0)

        val_acc = val_correct / total_val
        val_loss = val_loss / total_val
        precision = precision_score(all_labels, all_preds, zero_division=0)
        recall = recall_score(all_labels, all_preds, zero_division=0)
        f1 = f1_score(all_labels, all_preds, zero_division=0)

        print(f"📣 Epoch {epoch+1}: "
              f"Train Loss: {train_loss:.4f}, Acc: {train_acc:.4f} | "
              f"Val Loss: {val_loss:.4f}, Acc: {val_acc:.4f} | "
              f"Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}")

        history["epoch"].append(epoch + 1)
        history["train_loss"].append(train_loss)
        history["train_acc"].append(train_acc)
        history["val_loss"].append(val_loss)
        history["val_acc"].append(val_acc)
        history["precision"].append(precision)
        history["recall"].append(recall)
        history["f1"].append(f1)

        if val_acc > best_val_acc:
            if best_model_path and os.path.exists(best_model_path):
                os.remove(best_model_path)
            best_val_acc = val_acc
            bad_epochs = 0
            model_name = f"siamese_best_epoch{epoch+1}_acc{val_acc:.4f}.pt"
            best_model_path = os.path.join(save_dir, model_name)
            torch.save(model.state_dict(), best_model_path)
            print(f"✅ New best model saved: {model_name}")
        else:
            bad_epochs += 1
            print(f"⚠️ No improvement. Bad epochs: {bad_epochs}/{patience}")
            if bad_epochs >= patience:
                print("🛑 Early stopping triggered.")
                break

    df = pd.DataFrame(history)
    df.to_csv(os.path.join(save_dir, "training_metrics_cycle1.csv"), index=False)
    print("📊 Training history saved to training_metrics_cycle1.csv")


In [None]:
train_siamese_model_amp(
    model=model,
    train_dataset=train_dataset,
    val_dataset=val_dataset,
    epochs=25,
    batch_size=64,
    save_dir='/kaggle/working/',
    use_amp=True,
    patience=5,
    use_dct_attention=True  # or False if not using DCT
)


In [None]:
import pandas as pd
import matplotlib.pyplot as plt

# Load the training metrics
df = pd.read_csv('/kaggle/working/training_metrics_cycle1.csv')

# Plotting
plt.figure(figsize=(16, 10))

# Loss
plt.subplot(2, 2, 1)
plt.plot(df['epoch'], df['train_loss'], label='Train Loss', marker='o')
plt.plot(df['epoch'], df['val_loss'], label='Val Loss', marker='o')
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Loss per Epoch")
plt.legend()
plt.grid(True)

# Accuracy
plt.subplot(2, 2, 2)
plt.plot(df['epoch'], df['train_acc'], label='Train Accuracy', marker='o')
plt.plot(df['epoch'], df['val_acc'], label='Val Accuracy', marker='o')
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.title("Accuracy per Epoch")
plt.legend()
plt.grid(True)

# Precision and Recall
plt.subplot(2, 2, 3)
plt.plot(df['epoch'], df['precision'], label='Precision', marker='o')
plt.plot(df['epoch'], df['recall'], label='Recall', marker='o')
plt.xlabel("Epoch")
plt.ylabel("Score")
plt.title("Precision & Recall per Epoch")
plt.legend()
plt.grid(True)

# F1 Score
plt.subplot(2, 2, 4)
plt.plot(df['epoch'], df['f1'], label='F1 Score', color='green', marker='o')
plt.xlabel("Epoch")
plt.ylabel("F1 Score")
plt.title("F1 Score per Epoch")
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()

# Model Testing

In [None]:
import os
import torch
import torchvision.transforms as T
from PIL import Image
from tqdm.notebook import tqdm
from sklearn.metrics import accuracy_score, f1_score
import glob
import random

# --- Constants ---
IMG_SIZE = 224
USE_DCT_ATTENTION = True  # True because you're using FFT attention
TEST_MODEL_PATH = "/kaggle/input/siamese_facecom/pytorch/default/1/siamese_best_epoch4_acc0.8825.pt"
TEST_DIR = "/kaggle/input/facecom/Comys_Hackathon5/Task_B/val"
BATCH_SIZE = 32

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- Load model ---
test_model = SiameseNet(backbone="resnet18", pretrained=False)
test_model.load_state_dict(torch.load(TEST_MODEL_PATH, map_location=device))
test_model.to(device).eval()

# --- Preprocessing ---
image_transform = T.Compose([
    T.Resize((IMG_SIZE, IMG_SIZE)),
    T.ToTensor()
])

fft_tensor = T.Compose([
    T.Grayscale(),
    T.Resize((IMG_SIZE, IMG_SIZE)),
    T.ToTensor()
])

def compute_fft_attention_batch(batch1, batch2):
    f1 = torch.fft.fft2(batch1)
    f2 = torch.fft.fft2(batch2)
    diff = torch.abs(f1 - f2)
    attn_maps = torch.fft.ifft2(diff).real
    attn_maps -= attn_maps.amin(dim=(1, 2), keepdim=True)
    attn_maps /= (attn_maps.amax(dim=(1, 2), keepdim=True) + 1e-8)
    attn_maps = 1.0 - attn_maps
    return attn_maps.unsqueeze(1)  # [B, 1, H, W]

# --- Load image from path ---
def load_image_tensor(path, for_fft=False):
    img = Image.open(path).convert("RGB")
    return (fft_tensor if for_fft else image_transform)(img)

In [None]:
import random
from tqdm.notebook import tqdm
from sklearn.metrics import accuracy_score, f1_score

def evaluate_few_shot_siamese(model, support_dict, batch_size=5):
    support_classes = list(support_dict.keys())
    tqdm.write(f"🔢 Total support classes: {len(support_classes)}")

    # Step 1: Select 1 clean image per class
    support_images = []
    support_fft_imgs = []
    for c in support_classes:
        clean_list = support_dict[c]['clean']
        chosen_path = random.choice(clean_list) if isinstance(clean_list, list) else clean_list
        try:
            img = load_image_tensor(chosen_path).to(device)         # [3, H, W]
            fft_img = load_image_tensor(chosen_path, for_fft=True).to(device)  # [H, W] or [1, H, W]
            support_images.append(img)
            support_fft_imgs.append(fft_img.squeeze(0) if fft_img.dim() == 3 else fft_img)
        except Exception as e:
            tqdm.write(f"❌ Failed to load support image for class {c}: {e}")
            raise

    # Step 2: Collect query images and labels
    query_paths = []
    query_labels = []
    for cls in support_classes:
        for q in support_dict[cls]['distorted']:
            query_paths.append(q)
            query_labels.append(cls)

    tqdm.write(f"🔍 Total query images: {len(query_paths)}")
    predicted_labels = []

    for start in tqdm(range(0, len(query_paths), batch_size), desc="🧪 Inference"):
        end = min(start + batch_size, len(query_paths))
        batch_paths = query_paths[start:end]
        batch_labels = query_labels[start:end]
        tqdm.write(f"🧾 Processing batch: {start}-{end}")

        try:
            query_imgs = torch.stack([load_image_tensor(p) for p in batch_paths]).to(device)  # [B, 3, H, W]
            query_fft_imgs = torch.stack([load_image_tensor(p, for_fft=True) for p in batch_paths]).to(device)  # [B, H, W] or [B, 1, H, W]
        except Exception as e:
            tqdm.write(f"❌ Error loading query batch: {e}")
            raise

        # Ensure query_fft_imgs is [B, H, W]
        if query_fft_imgs.dim() == 4 and query_fft_imgs.shape[1] == 1:
            query_fft_imgs = query_fft_imgs.squeeze(1)

            B = query_imgs.shape[0]
            S = len(support_classes)
        
            # Make sure query_fft_imgs is [B, H, W]
            if query_fft_imgs.dim() == 4 and query_fft_imgs.shape[1] == 1:
                query_fft_imgs = query_fft_imgs.squeeze(1)
        
            support_imgs = torch.stack(support_images).to(device)        # [S, 3, H, W]
            support_ffts = torch.stack(support_fft_imgs).to(device)      # [S, H, W]
        
            scores = []
        
            with torch.no_grad():
                for i in range(S):
                    support_img = support_imgs[i].unsqueeze(0).repeat(B, 1, 1, 1)  # [B, 3, H, W]
                    support_fft = support_ffts[i].unsqueeze(0).repeat(B, 1, 1)     # [B, H, W]
                    attn = compute_fft_attention_batch(support_fft, query_fft_imgs)  # [B, 1, H, W]
                    logit = model(support_img, query_imgs, attn).squeeze(1)         # [B]
                    scores.append(torch.sigmoid(logit))  # [B]
        
                scores = torch.stack(scores, dim=1)  # [B, S]
                preds = torch.argmax(scores, dim=1).cpu().numpy()
                batch_pred_labels = [support_classes[p] for p in preds]
                predicted_labels.extend(batch_pred_labels)


    acc = accuracy_score(query_labels, predicted_labels)
    f1 = f1_score(query_labels, predicted_labels, average='macro')
    return acc, f1

In [None]:
# --- Run Evaluation ---
val_dict = collect_image_paths(TEST_DIR)
acc, f1 = evaluate_few_shot_siamese(test_model, val_dict)
print(f"✅ Few-shot Classification Accuracy: {acc:.4f}")
print(f"✅ Macro-Averaged F1 Score: {f1:.4f}")