In [None]:
# === Config ===
USE_DCT_ATTENTION = True   # 🔄 Controls whether to use attention maps based on DCT/FFT for training (can be toggled)
IMG_SIZE = 224             # 🔄 Defines the size (224x224) to which images will be resized for model input

In [None]:
import os
import glob

"""
the function below just loads picks up the right paths from the test or the validation directory and
arranges them in a dictionary
The directory structure is that say we pickup a random directory X... under it we will have one or multiple clean images 
and under sub-dir named distortion under which all the distorted images are present
"""
# 📌 Function to collect image paths from a directory structure, organizing them by person ID
def collect_image_paths(root_dir):
    person_dict = {}  # 📌 Dictionary to store paths: {person_id: {"clean": [paths], "distorted": [paths]}}

    # 📌 Iterate through each folder in the root directory (each folder represents a person)
    for person_folder in sorted(os.listdir(root_dir)):
        folder_path = os.path.join(root_dir, person_folder)
        # 📌 Skip if not a directory
        if not os.path.isdir(folder_path):
            continue

        # 📌 Collect all .jpg files directly in the person folder (clean images, excluding subfolders)
        all_jpgs = glob.glob(os.path.join(folder_path, "*.jpg"))
        clean_imgs = [f for f in all_jpgs if "distortion" not in f]

        # 📌 Collect distorted images from the 'distortion' subfolder
        distortion_dir = os.path.join(folder_path, "distortion")
        distortion_imgs = []
        if os.path.exists(distortion_dir):
            distortion_imgs = glob.glob(os.path.join(distortion_dir, "*.jpg"))

        # 📌 Only include person in dict if they have at least one clean image
        if clean_imgs:
            person_dict[person_folder] = {
                "clean": clean_imgs,  # store list, not just one
                "distorted": distortion_imgs
            }

    return person_dict

In [None]:
# 📌 Load training data from the specified directory
train_dir = f"/kaggle/input/facecom/Comys_Hackathon5/Task_B/train"
person_dict = collect_image_paths(train_dir)  # 📌 Create dictionary of image paths for training data

In [None]:
# 📌 Load validation data from the specified directory
validation_dir = f"/kaggle/input/facecom/Comys_Hackathon5/Task_B/val"
val_dict = collect_image_paths(validation_dir)  # 📌 Create dictionary of image paths for validation data

In [None]:
# 📌 Debugging: Print the structure of the person_dict to verify data loading
print(f"Found {len(person_dict)} person folders")
print("Available keys:", list(person_dict.keys())[:5])  # 📌 Show first 5 person IDs

# FIX: Access dictionary correctly
if person_dict:
    # 📌 Access the first person's data for inspection
    first_person = list(person_dict.keys())[0]
    print(f"\nFirst person '{first_person}':")
    print(person_dict[first_person])  # 📌 Print clean and distorted image paths for the first person

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

# 📌 Convert person_dict to a DataFrame for statistical analysis
records = []
for identity, imgs in person_dict.items():
    records.append({
        "identity": identity,
        "num_clean": len(imgs["clean"]),  # 📌 Count of clean images per person
        "num_distorted": len(imgs["distorted"])  # 📌 Count of distorted images per person
    })

df = pd.DataFrame(records)

# 📊 Histogram: Distorted Images
plt.figure(figsize=(10, 6))
sns.histplot(df["num_distorted"], bins=30, color="skyblue")
plt.title("Number of Distorted Images per Identity")
plt.xlabel("Distorted Images")
plt.ylabel("Frequency")
plt.grid(True)
plt.show()

# 📊 Histogram: Undistorted Images
plt.figure(figsize=(8, 5))
sns.countplot(data=df, x="num_clean", palette="Set2")
plt.title("Number of Undistorted Images per Identity")
plt.xlabel("Undistorted Images")
plt.ylabel("Number of Identities")
plt.grid(True, axis='y')
plt.show()

# 📋 Basic Stats
print(f"📌 Total Identities: {df.shape[0]}")  # 📌 Total number of unique persons
print(f"📸 Avg distorted per identity: {df['num_distorted'].mean():.2f}")  # 📌 Average number of distorted images
print(f"🧼 Avg undistorted per identity: {df['num_clean'].mean():.2f}")  # 📌 Average number of clean images
print(f"🧼 Min undistorted: {df['num_clean'].min()}, Max undistorted: {df['num_clean'].max()}")  # 📌 Min/max clean images

In [None]:
# 📌 Calculate total possible positive pairs from clean images (combinations of clean images per person)
df["pos_pairs"] = df["num_clean"] * (df["num_clean"] - 1) // 2  # 📌 Formula: n*(n-1)/2 for combinations
total_positive_pairs = df["pos_pairs"].sum()

print(f"✅ Total possible positive pairs from undistorted images: {total_positive_pairs}")

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

# 📌 Repeat the same analysis for validation data
records_val = []
for identity, imgs in val_dict.items():
    records_val.append({
        "identity": identity,
        "num_clean": len(imgs["clean"]),
        "num_distorted": len(imgs["distorted"])
    })

df = pd.DataFrame(records_val)

# 📊 Histogram: Distorted Images
plt.figure(figsize=(10, 6))
sns.histplot(df["num_distorted"], bins=30, color="skyblue")
plt.title("Number of Distorted Images per Identity")
plt.xlabel("Distorted Images")
plt.ylabel("Frequency")
plt.grid(True)
plt.show()

# 📊 Histogram: Undistorted Images
plt.figure(figsize=(8, 5))
sns.countplot(data=df, x="num_clean", palette="Set2")
plt.title("Number of Undistorted Images per Identity")
plt.xlabel("Undistorted Images")
plt.ylabel("Number of Identities")
plt.grid(True, axis='y')
plt.show()

# 📋 Basic Stats
print(f"📌 Total Identities: {df.shape[0]}")
print(f"📸 Avg distorted per identity: {df['num_distorted'].mean():.2f}")
print(f"🧼 Avg undistorted per identity: {df['num_clean'].mean():.2f}")
print(f"🧼 Min undistorted: {df['num_clean'].min()}, Max undistorted: {df['num_clean'].max()}")

In [None]:
import random

# 📌 Function to generate balanced positive and negative pairs for face verification
def generate_balanced_augmented_pairs(person_dict, min_pos_per_id=28, num_neg_per_pos=3, seed=42):
    """
    Generates positive and negative pairs for face verification task.
    
    - Ensures at least `min_pos_per_id` positive pairs per identity (if possible).
    - For each positive, generates `num_neg_per_pos` negative pairs from other identities.
    - Random sampling ensures balance and coverage of identities.

    Args:
        person_dict (dict): Dictionary with keys as person IDs and values as dicts with 'clean' and 'distorted' paths.
        min_pos_per_id (int): Minimum number of positive pairs per identity.
        num_neg_per_pos (int): Number of negative pairs to generate per positive pair.
        seed (int): Random seed for reproducibility.

    Returns:
        List of (img1_path, img2_path, label) tuples.
    """
    random.seed(seed)  # 📌 Set seed for reproducible random sampling
    all_ids = list(person_dict.keys())  # 📌 List of all person IDs
    positive_pairs = []  # 📌 Store positive pairs (same person)
    negative_pairs = []  # 📌 Store negative pairs (different people)

    for person_id in all_ids:
        images = []
        # 📌 Collect all images (clean and distorted) for the current person
        if isinstance(person_dict[person_id]['clean'], list):
            images += person_dict[person_id]['clean']
        elif person_dict[person_id]['clean']:
            images.append(person_dict[person_id]['clean'])

        images += person_dict[person_id]['distorted']

        # 📌 Skip if fewer than 2 images (can't form a pair)
        if len(images) < 2:
            continue

        # 📌 Generate all possible positive pairs (combinations of 2 images)
        all_pos_pairs = [(a, b) for idx, a in enumerate(images) for b in images[idx+1:] if a != b]

        # 📌 Sample up to min_pos_per_id pairs (or fewer if not enough)
        selected_pos_pairs = random.sample(all_pos_pairs, min(min_pos_per_id, len(all_pos_pairs)))

        for img1, img2 in selected_pos_pairs:
            positive_pairs.append((img1, img2, 1))  # 📌 Label 1 for positive pair

            # 📌 Generate negative pairs for each positive pair
            other_ids = [pid for pid in all_ids if pid != person_id]
            sampled_neg_ids = random.sample(other_ids, min(num_neg_per_pos, len(other_ids)))

            for neg_id in sampled_neg_ids:
                neg_candidates = person_dict[neg_id]['distorted'] or [person_dict[neg_id]['clean']]
                if not neg_candidates:
                    continue
                neg_img = random.choice(neg_candidates)  # 📌 Randomly select a negative image
                negative_pairs.append((img1, neg_img, 0))  # 📌 Label 0 for negative pair

    # 📌 Combine and shuffle all pairs
    all_pairs = positive_pairs + negative_pairs
    random.shuffle(all_pairs)
    return all_pairs, positive_pairs, negative_pairs

In [None]:
# 📌 Generate pairs for training data
all_pairs, positive_pairs, negative_pairs = generate_balanced_augmented_pairs(
    person_dict=person_dict, 
    min_pos_per_id=28, 
    num_neg_per_pos=1
)

In [None]:
# 📌 Generate pairs for validation data
all_val_pairs, val_pos , val_neg = generate_balanced_augmented_pairs(val_dict,20,1)

In [None]:
# 📌 Print statistics about the generated pairs
print(f"✅ Total pairs: {len(all_pairs)}")
print(f"🔵 Positive pairs: {len(positive_pairs)}")
print(f"🔴 Negative pairs: {len(negative_pairs)}")
print("🧾 Sample pairs:", all_pairs[:3])

# 📌 Visualize distribution of positive vs negative pairs
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

pair_stats = {"identity": [], "label": []}
for img1, img2, label in all_pairs:
    identity = os.path.basename(os.path.dirname(img1))
    pair_stats["identity"].append(identity)
    pair_stats["label"].append(label)

df_pairs = pd.DataFrame(pair_stats)

plt.figure(figsize=(10, 5))
sns.countplot(data=df_pairs, x="label", palette="Set2")
plt.xticks([0, 1], ["Negative", "Positive"])
plt.title("Distribution of Positive vs Negative Pairs")
plt.xlabel("Pair Type")
plt.ylabel("Count")
plt.grid(True, axis='y')
plt.show()

In [None]:
# 📌 Repeat pair statistics and visualization for validation data
print(f"✅ Total pairs: {len(all_val_pairs)}")
print(f"🔵 Positive pairs: {len(val_pos)}")
print(f"🔴 Negative pairs: {len(val_neg)}")
print("🧾 Sample pairs:", all_pairs[:3])

pair_stats = {"identity": [], "label": []}
for img1, img2, label in all_val_pairs:
    identity = os.path.basename(os.path.dirname(img1))
    pair_stats["identity"].append(identity)
    pair_stats["label"].append(label)

df_pairs = pd.DataFrame(pair_stats)

plt.figure(figsize=(10, 5))
sns.countplot(data=df_pairs, x="label", palette="Set2")
plt.xticks([0, 1], ["Negative", "Positive"])
plt.title("Distribution of Positive vs Negative Pairs")
plt.xlabel("Pair Type")
plt.ylabel("Count")
plt.grid(True, axis='y')
plt.show()

In [None]:
from PIL import Image

# 📌 Function to display positive and negative image pairs for visual inspection
def show_image_pairs(pairs, title, num=5):
    plt.figure(figsize=(15, 3 * num))
    for i in range(num):
        img1 = Image.open(pairs[i][0])
        img2 = Image.open(pairs[i][1])
        
        # 📌 Display first image of the pair
        plt.subplot(num, 2, 2 * i + 1)
        plt.imshow(img1)
        plt.axis("off")
        plt.title(f"{title} Pair {i+1} - Img1")

        # 📌 Display second image of the pair
        plt.subplot(num, 2, 2 * i + 2)
        plt.imshow(img2)
        plt.axis("off")
        plt.title(f"{title} Pair {i+1} - Img2")

    plt.tight_layout()
    plt.show()

# 📌 Show 5 positive and 5 negative pairs for visual verification
show_image_pairs(positive_pairs, "Positive", num=5)
show_image_pairs(negative_pairs, "Negative", num=5)

In [None]:
import torch
import torchvision.transforms as T
from PIL import Image

# 📌 Define transformation pipeline for converting images to tensors
to_tensor = T.Compose([
    T.Grayscale(),  # 📌 Convert images to grayscale
    T.Resize((224, 224)),  # 📌 Resize to 224x224
    T.ToTensor()  # 📌 Convert to PyTorch tensor
])

# 📌 Function to compute FFT-based attention maps for a batch of image pairs
def compute_fft_attention_batch(batch1, batch2):
    f1 = torch.fft.fft2(batch1)             # 📌 Compute 2D FFT on first batch of images
    f2 = torch.fft.fft2(batch2)             # 📌 Compute 2D FFT on second batch of images
    diff = torch.abs(f1 - f2)               # 📌 Compute absolute difference in frequency domain
    attn_maps = torch.fft.ifft2(diff).real  # 📌 Inverse FFT to get attention map
    attn_maps -= attn_maps.amin(dim=(1, 2), keepdim=True)  # 📌 Normalize: subtract min
    attn_maps /= (attn_maps.amax(dim=(1, 2), keepdim=True) + 1e-8)  # 📌 Normalize: divide by max
    attn_maps = 1.0 - attn_maps  # 📌 Invert attention map
    return attn_maps.unsqueeze(1)  # 📌 Add channel dimension: [B, 1, H, W]

In [None]:
import os, hashlib
from torchvision.transforms import Compose, Grayscale, Resize, ToTensor
from PIL import Image
from tqdm.notebook import tqdm
import torch

# 📌 Set device (GPU if available, else CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 📌 Define paths and transformations for attention map caching
resize = Resize((224, 224))
attention_root = "/kaggle/working/fft_attention_maps"
os.makedirs(attention_root, exist_ok=True)  # 📌 Create directory for caching attention maps

to_tensor = Compose([Grayscale(), resize, ToTensor()])

In [None]:
# 📌 Generate a unique cache key for each image pair using MD5 hash
def _cache_key(path1, path2):
    a, b = sorted([os.path.abspath(path1), os.path.abspath(path2)])
    key = hashlib.md5(f"{a}|{b}".encode()).hexdigest()
    return os.path.join(attention_root, f"{key}.pt")

In [None]:
# 📌 Function to compute and cache FFT attention maps for image pairs in batches
def batch_cache_attention(pairs, batch_size=32):
    uncached = []
    paths = []

    # 📌 Identify pairs that need attention maps computed (not already cached)
    for p1, p2, _ in pairs:
        out_path = _cache_key(p1, p2)
        if not os.path.exists(out_path):
            uncached.append((p1, p2))
            paths.append(out_path)

    if not uncached:
        print("✅ All attention maps already cached.")
        return

    # 📌 Process pairs in batches
    for i in tqdm(range(0, len(uncached), batch_size), desc="⚡ Batch-caching attention maps"):
        batch = uncached[i:i + batch_size]
        batch_paths = paths[i:i + batch_size]

        imgs1 = []
        imgs2 = []

        # 📌 Load and preprocess images
        for p1, p2 in batch:
            try:
                imgs1.append(to_tensor(Image.open(p1).convert("RGB")))
                imgs2.append(to_tensor(Image.open(p2).convert("RGB")))
            except Exception as e:
                print(f"❌ Failed to load: {p1} or {p2} — {e}")

        if not imgs1 or not imgs2:
            continue

        t1 = torch.stack(imgs1).squeeze(1).to(device)  # 📌 Stack images into batch: [B, H, W]
        t2 = torch.stack(imgs2).squeeze(1).to(device)

        attn_batch = compute_fft_attention_batch(t1, t2).cpu()  # 📌 Compute attention maps

        # 📌 Save each attention map to disk
        for attn_map, out_path in zip(attn_batch, batch_paths):
            torch.save(attn_map.half(), out_path, pickle_protocol=4)

In [None]:
# 📌 Cache attention maps for training pairs if enabled
if USE_DCT_ATTENTION:
    batch_cache_attention(all_pairs, batch_size=32)
else:
    print("⚡ Skipping FFT-attention caching (USE_DCT_ATTENTION = False)")

In [None]:
# 📌 Cache attention maps for validation pairs if enabled
if USE_DCT_ATTENTION:
    batch_cache_attention(all_val_pairs, batch_size=32)
else:
    print("⚡ Skipping FFT-attention caching (USE_DCT_ATTENTION = False)")

In [None]:
import torch
import matplotlib.pyplot as plt
import os

# 📌 Load and visualize a sample attention map
attn_path = "/kaggle/working/fft_attention_maps/08a14f70e2e319f02c48cb5d13f9e37b.pt"
attn_tensor = torch.load(attn_path, weights_only=False).squeeze()  # 📌 Load and remove channel dimension

plt.figure(figsize=(6, 6))
plt.imshow(attn_tensor.cpu().numpy(), cmap='viridis')  # 📌 Display as heatmap
plt.colorbar(label="Attention Intensity")
plt.title("DCT/FFT-based Attention Map")
plt.axis("off")
plt.show()

In [None]:
import os
import hashlib

# 📌 Function to find original image paths from a cached attention map
def reverse_lookup_path(attn_path, all_pairs):
    target_hash = os.path.basename(attn_path).replace(".pt", "")

    for path1, path2, _ in all_pairs:
        a, b = sorted([os.path.abspath(path1), os.path.abspath(path2)])
        key_raw = f"{a}|{b}"
        key_hash = hashlib.md5(key_raw.encode()).hexdigest()
        if key_hash == target_hash:
            return path1, path2
    return None, None

from PIL import Image
import numpy as np
from scipy.fftpack import dct, idct
import torch

# 📌 Compute 2D DCT for an image
def apply_2d_dct(img):
    return dct(dct(img.T, norm='ortho').T, norm='ortho')

# 📌 Compute 2D inverse DCT
def apply_2d_idct(coeffs):
    return idct(idct(coeffs.T, norm='ortho').T, norm='ortho')

# 📌 Compute DCT-based attention map for a pair of images
def compute_dct_attention(img1, img2):
    img1_np = np.array(img1.convert("L"), dtype=np.float32)
    img2_np = np.array(img2.convert("L"), dtype=np.float32)
    dct1 = apply_2d_dct(img1_np)
    dct2 = apply_2d_dct(img2_np)
    diff = np.abs(dct1 - dct2)
    attn_map = apply_2d_idct(diff)
    attn_map -= attn_map.min()
    attn_map /= (attn_map.max() + 1e-8)
    attn_map = 1.0 - attn_map
    return torch.tensor(attn_map).float()

In [None]:
import matplotlib.pyplot as plt

# 📌 Load cached attention map for comparison
attn_cached = torch.load(attn_path, weights_only=False).squeeze()

# 📌 Find original images for the attention map
img1_path, img2_path = reverse_lookup_path(attn_path, all_pairs)
assert img1_path and img2_path, "❌ Original paths not found!"

# 📌 Recompute DCT attention map
img1 = Image.open(img1_path).convert("RGB")
img2 = Image.open(img2_path).convert("RGB")
attn_dct = compute_dct_attention(img1, img2)

# 📌 Resize if shapes don't match
if attn_cached.shape != attn_dct.shape:
    from torchvision.transforms import Resize
    attn_dct = Resize(attn_cached.shape)(attn_dct.unsqueeze(0)).squeeze(0)

# 📌 Visualize cached vs recomputed attention maps
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.imshow(attn_cached.cpu().numpy(), cmap='viridis')
plt.title("Cached Attention Map")
plt.axis("off")

plt.subplot(1, 2, 2)
plt.imshow(attn_dct.cpu().numpy(), cmap='viridis')
plt.title("Recomputed DCT Attention")
plt.axis("off")

plt.suptitle("Sanity Check: Cached vs DCT Recomputed", fontsize=14)
plt.show()

In [None]:
import os, hashlib
from torchvision.transforms import Compose, Grayscale, Resize, ToTensor
from PIL import Image
from tqdm.notebook import tqdm
import torch

# 📌 Reinitialize device and attention map caching setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
resize = Resize((224, 224))
attention_root = "/kaggle/working/fft_attention_maps"
os.makedirs(attention_root, exist_ok=True)
to_tensor = Compose([Grayscale(), resize, ToTensor()])

In [None]:
from torch.utils.data import Dataset
# 📌 Custom Dataset class for loading image pairs and their attention maps
class FacePairDataset(Dataset):
    def __init__(self, pairs, transform=None):
        self.pairs = pairs
        self.transform = transform or T.Compose([
            T.Resize((IMG_SIZE, IMG_SIZE)), T.ToTensor()
        ])

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        p1, p2, label = self.pairs[idx]
        img1 = Image.open(p1).convert("RGB")
        img2 = Image.open(p2).convert("RGB")
        img1_t = self.transform(img1)
        img2_t = self.transform(img2)
        label_t = torch.tensor(label, dtype=torch.float32)

        if USE_DCT_ATTENTION:
            attn = torch.load(_cache_key(p1, p2), weights_only=False)
            return {"img1": img1_t, "img2": img2_t, "attn": attn, "label": label_t}
        else:
            return {"img1": img1_t, "img2": img2_t, "label": label_t}

In [None]:
# 📌 Create training dataset and inspect a sample
train_dataset = FacePairDataset(all_pairs)
print("Total samples:", len(train_dataset))

sample = train_dataset[0]
img1  = sample["img1"]
img2  = sample["img2"]
attn  = sample.get("attn", None)  # 📌 Safe access in case attention is disabled
label = sample["label"]

print(f"Image 1 shape       : {img1.shape}")
print(f"Image 2 shape       : {img2.shape}")
if attn is not None:
    print(f"Attention map shape : {attn.shape}")
else:
    print("Attention map       : ❌ Not used (USE_DCT_ATTENTION = False)")
print(f"Label               : {label}")

In [None]:
# 📌 Create validation dataset and inspect a sample
val_dataset = FacePairDataset(all_val_pairs)
print("Total samples:", len(val_dataset))

sample = val_dataset[0]
img1  = sample["img1"]
img2  = sample["img2"]
attn  = sample.get("attn", None)
label = sample["label"]

print(f"Image 1 shape       : {img1.shape}")
print(f"Image 2 shape       : {img2.shape}")
if attn is not None:
    print(f"Attention map shape : {attn.shape}")
else:
    print("Attention map       : ❌ Not used (USE_DCT_ATTENTION = False)")
print(f"Label               : {label}")

In [None]:
import torch
import torch.nn as nn
import torchvision.models as models

# 📌 Define Siamese Network for face verification
class SiameseNet(nn.Module):
    def __init__(self, backbone="resnet18", pretrained=True):
        super().__init__()

        # 📌 Load pre-trained backbone (e.g., ResNet18) and remove the final layer
        base = getattr(models, backbone)(pretrained=pretrained)
        self.feature_extractor = nn.Sequential(*list(base.children())[:-1])

        # 📌 Define fully connected head for classification
        self.fc = nn.Sequential(
            nn.Linear(512 * 2, 512),  # 📌 Concatenate features from both images
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 1)  # 📌 Output a single logit for binary classification
        )

    def forward(self, img1, img2, attn_map=None):
        """
        img1, img2: [B, 3, H, W]
        attn_map   : [B, 1, H, W] or None
        """
        # 📌 Apply attention map to images if enabled
        if USE_DCT_ATTENTION and attn_map is not None:
            attn_map = attn_map.expand(-1, 3, -1, -1)  # 📌 Broadcast to RGB channels
            img1, img2 = img1 * attn_map, img2 * attn_map

        # 📌 Extract features for both images
        f1 = self.feature_extractor(img1).view(img1.size(0), -1)
        f2 = self.feature_extractor(img2).view(img2.size(0), -1)

        # 📌 Concatenate features and pass through head
        out = self.fc(torch.cat([f1, f2], dim=1))
        return out

In [None]:
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm
from sklearn.metrics import precision_score, recall_score, f1_score
import pandas as pd

# 📌 Function to train the Siamese Network
def train_siamese_model(
    model, 
    train_dataset, 
    val_dataset, 
    epochs=10, 
    batch_size=32, 
    lr=1e-4,
    save_dir='/kaggle/working/'
):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)

    # 📌 Create data loaders for training and validation
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)

    criterion = nn.BCEWithLogitsLoss()  # 📌 Binary cross-entropy loss with logits
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)  # 📌 Adam optimizer

    # 📌 Track training metrics
    history = {
        "epoch": [],
        "train_loss": [],
        "train_acc": [],
        "val_loss": [],
        "val_acc": [],
        "precision": [],
        "recall": [],
        "f1": []
    }

    best_val_acc = 0.0
    best_model_path = None

    for epoch in range(epochs):
        model.train()
        train_loss, train_correct, total_train = 0.0, 0, 0
        loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}", leave=False)

        # 📌 Training loop
        for batch in loop:
            img1 = batch["img1"].to(device)
            img2 = batch["img2"].to(device)
            labels = batch["label"].to(device)

            attn = batch["attn"].to(device) if USE_DCT_ATTENTION else None
            outputs = model(img1, img2, attn).squeeze(1)

            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            preds = (torch.sigmoid(outputs) > 0.5).float()
            train_correct += (preds == labels).sum().item()
            train_loss += loss.item() * labels.size(0)
            total_train += labels.size(0)
            loop.set_postfix(loss=loss.item())

        train_acc = train_correct / total_train
        train_loss = train_loss / total_train

        # 📌 Validation loop
        model.eval()
        val_loss, val_correct, total_val = 0.0, 0, 0
        all_preds = []
        all_labels = []

        with torch.no_grad():
            for batch in val_loader:
                img1 = batch["img1"].to(device)
                img2 = batch["img2"].to(device)
                labels = batch["label"].to(device)
                attn = batch["attn"].to(device) if USE_DCT_ATTENTION else None

                outputs = model(img1, img2, attn).squeeze(1)
                loss = criterion(outputs, labels)

                probs = torch.sigmoid(outputs).cpu().numpy()
                preds = (probs > 0.5).astype(float)
                labels_np = labels.cpu().numpy()

                all_preds.extend(preds)
                all_labels.extend(labels_np)

                val_correct += (preds == labels_np).sum().item()
                val_loss += loss.item() * labels.size(0)
                total_val += labels.size(0)

        val_acc = val_correct / total_val
        val_loss = val_loss / total_val
        precision = precision_score(all_labels, all_preds, zero_division=0)
        recall = recall_score(all_labels, all_preds, zero_division=0)
        f1 = f1_score(all_labels, all_preds, zero_division=0)

        print(f"📣 Epoch {epoch+1}: "
              f"Train Loss: {train_loss:.4f}, Acc: {train_acc:.4f} | "
              f"Val Loss: {val_loss:.4f}, Acc: {val_acc:.4f} | "
              f"Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}")

        # 📌 Save best model based on validation accuracy
        if val_acc > best_val_acc:
            if best_model_path and os.path.exists(best_model_path):
                os.remove(best_model_path)
            model_name = f"siamese_best_epoch{epoch+1}_acc{val_acc:.4f}.pt"
            best_model_path = os.path.join(save_dir, model_name)
            torch.save(model.state_dict(), best_model_path)
            best_val_acc = val_acc
            print(f"✅ New best model saved: {model_name}")

        # 📌 Update metrics history
        history["epoch"].append(epoch + 1)
        history["train_loss"].append(train_loss)
        history["train_acc"].append(train_acc)
        history["val_loss"].append(val_loss)
        history["val_acc"].append(val_acc)
        history["precision"].append(precision)
        history["recall"].append(recall)
        history["f1"].append(f1)

    # 📌 Save training metrics to CSV
    history_df = pd.DataFrame(history)
    history_df.to_csv(os.path.join(save_dir, "training_metrics_cycle1.csv"), index=False)
    print("📈 Training metrics saved to training_metrics.csv")

In [None]:
# 📌 Initialize Siamese Network with ResNet18 backbone
model_backbone_r18 = SiameseNet(backbone="resnet18", pretrained=True)

In [None]:
# 📌 Check GPU status
!nvidia-smi

In [None]:
# 📌 Train the Siamese Network
train_siamese_model(
    model=model_backbone_r18,
    train_dataset=train_dataset,
    val_dataset=val_dataset,
    epochs=5,                        # 📌 Number of training epochs
    batch_size=32,                   # 📌 Batch size for training
    lr=1e-4,                         # 📌 Learning rate
    save_dir="/kaggle/working"
)