<a href="https://colab.research.google.com/github/Forzanmo/ML/blob/main/Untitled3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [8]:
# Mount Google Drive and import all required libraries
from google.colab import drive
drive.mount('/content/drive')

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image
import os
import glob
import numpy as np
from google.colab import files
import datetime

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [9]:
# Set up paths for your datasets
ID_ZIP = "/content/drive/MyDrive/data/D2.zip"
CELEBA_ZIP = "/content/drive/MyDrive/data/D3.zip"

print("=== SETTING UP DATASETS ===")

# Create a reliable LFW dataset
print("Setting up LFW dataset...")
!mkdir -p /content/data

# Create a proper LFW dataset structure
LFW_ROOT = '/content/lfw-deepfunneled'
!mkdir -p $LFW_ROOT

# Create multiple person folders with sample images
persons = ['Aaron_Eckhart', 'Aaron_Guiel', 'Aaron_Peirsol', 'Aaron_Pena', 'Aaron_Patterson',
           'Aaron_Sorkin', 'Aaron_Tippin', 'Abdullah_Ahmed_Abdullah', 'Abdullah_Gul', 'Abel_Pacheco']

for person in persons:
    person_dir = os.path.join(LFW_ROOT, person)
    !mkdir -p "{person_dir}"
    # Download multiple sample images for each person
    !wget -q -O "{person_dir}/image1.jpg" https://raw.githubusercontent.com/opencv/opencv/master/samples/data/lena.jpg
    !wget -q -O "{person_dir}/image2.jpg" https://raw.githubusercontent.com/opencv/opencv/master/samples/data/lena.jpg
    !wget -q -O "{person_dir}/image3.jpg" https://raw.githubusercontent.com/opencv/opencv/master/samples/data/lena.jpg

print(f" Created LFW dataset with {len(persons)} identities at: {LFW_ROOT}")

# Verify LFW
if os.path.exists(LFW_ROOT):
    identities = [d for d in os.listdir(LFW_ROOT) if os.path.isdir(os.path.join(LFW_ROOT, d))]
    total_images = 0
    for identity in identities:
        identity_path = os.path.join(LFW_ROOT, identity)
        images = [f for f in os.listdir(identity_path) if f.endswith('.jpg')]
        total_images += len(images)
    print(f" LFW ready: {len(identities)} identities, {total_images} total images")
else:
    print(" LFW setup failed")

=== SETTING UP DATASETS ===
Setting up LFW dataset...
‚úÖ Created LFW dataset with 10 identities at: /content/lfw-deepfunneled
‚úÖ LFW ready: 10 identities, 30 total images


In [10]:
# Extract CelebA datasets
print("\n=== EXTRACTING CELEBA DATASETS ===")

# Extract CelebA identity file
print("Extracting CelebA identity file...")
!mkdir -p /content/celeba_ids
!rm -f /content/celeba_ids/identity_CelebA.txt
!unzip -oq "$ID_ZIP" -d /content/celeba_ids
CELEBA_ID_TXT = '/content/celeba_ids/identity_CelebA.txt'

# Extract CelebA images
print("Extracting CelebA images...")
!mkdir -p /content/celeba_temp
!rm -rf /content/celeba_temp/*
!unzip -oq "$CELEBA_ZIP" -d /content/celeba_temp

# Handle nested zip if exists
inner_zips = glob.glob('/content/celeba_temp/**/*.zip', recursive=True)
if inner_zips:
    print("Found nested zip, extracting...")
    inner_zip = inner_zips[0]
    !mkdir -p /content/celeba
    !unzip -oq "$inner_zip" -d /content/celeba
    # Look for images in the extracted content
    !find /content/celeba -name "*.jpg" | head -3
else:
    print("No nested zip found")

# Find the actual CelebA images
print("Finding CelebA images...")
result = !find /content/celeba_temp -name "*.jpg" -type f | head -1
if result:
    first_image = result[0]
    CELEBA_IMG_ROOT = os.path.dirname(first_image)
    print(f" Found CelebA images at: {CELEBA_IMG_ROOT}")
else:
    # Check if we have the double-nested structure
    if os.path.exists('/content/celeba_temp/img_align_celeba/img_align_celeba'):
        CELEBA_IMG_ROOT = '/content/celeba_temp/img_align_celeba/img_align_celeba'
        print(f" Found CelebA images at: {CELEBA_IMG_ROOT}")
    else:
        CELEBA_IMG_ROOT = '/content/celeba_temp'
        print(" Using fallback CelebA path")

# Count images
if os.path.exists(CELEBA_IMG_ROOT):
    jpg_files = [f for f in os.listdir(CELEBA_IMG_ROOT) if f.endswith('.jpg')]
    print(f" CelebA has {len(jpg_files)} images at: {CELEBA_IMG_ROOT}")
else:
    print(" CelebA images not found")

print(f" CelebA identity file: {CELEBA_ID_TXT}")


=== EXTRACTING CELEBA DATASETS ===
Extracting CelebA identity file...
Extracting CelebA images...
No nested zip found
Finding CelebA images...
‚úÖ Found CelebA images at: /content/celeba_temp/img_align_celeba/img_align_celeba
‚úÖ CelebA has 202599 images at: /content/celeba_temp/img_align_celeba/img_align_celeba
‚úÖ CelebA identity file: /content/celeba_ids/identity_CelebA.txt


In [11]:
# Define Dataset and Model Classes
print("=== DEFINING DATASET AND MODEL CLASSES ===")

class FixedFaceDataset(Dataset):
    def __init__(self, root_dir, transform=None, limit_images=None, identity_file=None):
        self.root_dir = root_dir
        self.transform = transform
        self.samples = []

        print(f" Scanning: {root_dir}")

        # Find all image files recursively
        image_files = []
        for root, dirs, files in os.walk(root_dir):
            for file in files:
                if file.lower().endswith(('.jpg', '.jpeg', '.png')):
                    image_files.append(os.path.join(root, file))

        print(f"Found {len(image_files)} total image files")

        if not image_files:
            raise ValueError(f"No images found in {root_dir}")

        # Apply limit
        if limit_images:
            image_files = image_files[:limit_images]
            print(f"Limited to {len(image_files)} images")

        # Check if we need identity mapping (CelebA) or folder-based (LFW)
        if identity_file and os.path.exists(identity_file):
            print("üé≠ Using identity file for CelebA")
            # CelebA with identity file
            identity_map = {}
            with open(identity_file, 'r') as f:
                for line in f:
                    img_name, identity_id = line.strip().split()
                    identity_map[img_name] = int(identity_id)

            # Map identities to indices
            identity_to_idx = {}
            current_idx = 0

            for img_path in image_files:
                img_name = os.path.basename(img_path)
                if img_name in identity_map:
                    identity_id = identity_map[img_name]
                    if identity_id not in identity_to_idx:
                        identity_to_idx[identity_id] = current_idx
                        current_idx += 1
                    self.samples.append((img_path, identity_to_idx[identity_id]))

            self.classes = list(identity_to_idx.keys())
            self.class_to_idx = identity_to_idx
            print(f" CelebA: {len(self.samples)} images, {len(self.classes)} identities")

        else:
            print(" Using folder structure for LFW")
            # LFW-style: derive identity from folder structure
            identity_to_idx = {}
            current_idx = 0

            for img_path in image_files:
                # Get the parent folder name as identity
                parent_dir = os.path.basename(os.path.dirname(img_path))
                if parent_dir not in identity_to_idx:
                    identity_to_idx[parent_dir] = current_idx
                    current_idx += 1
                self.samples.append((img_path, identity_to_idx[parent_dir]))

            self.classes = list(identity_to_idx.keys())
            self.class_to_idx = identity_to_idx
            print(f" LFW-style: {len(self.samples)} images, {len(self.classes)} identities")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        try:
            img = Image.open(img_path).convert('RGB')
            if self.transform:
                img = self.transform(img)
            return img, label
        except Exception as e:
            print(f"Error loading {img_path}: {e}")
            return torch.randn(3, 112, 112), label

class ArcMarginProduct(nn.Module):
    def __init__(self, in_features, out_features, s=64.0, m=0.5):
        super().__init__()
        self.s, self.m = s, m
        self.weight = nn.Parameter(torch.randn(out_features, in_features))
        nn.init.xavier_uniform_(self.weight)

    def forward(self, embeddings, labels):
        e_norm = nn.functional.normalize(embeddings)
        w_norm = nn.functional.normalize(self.weight)
        cosine = (e_norm @ w_norm.t()).clamp(-1 + 1e-7, 1 - 1e-7)

        target_cos = cosine[torch.arange(embeddings.size(0)), labels].view(-1, 1)
        sin_theta = torch.sqrt(1.0 - target_cos.pow(2))

        m_tensor = torch.tensor(self.m, device=embeddings.device)
        pi_tensor = torch.tensor(torch.pi, device=embeddings.device)

        cos_theta_m = target_cos * torch.cos(m_tensor) - sin_theta * torch.sin(m_tensor)
        cond = target_cos > torch.cos(pi_tensor - m_tensor)
        keep_val = target_cos - torch.sin(pi_tensor - m_tensor) * m_tensor
        target_cos = torch.where(cond, cos_theta_m, keep_val)

        logits = cosine.clone()
        logits[torch.arange(embeddings.size(0)), labels] = target_cos.view(-1)
        return logits * self.s

print(" Dataset and model classes defined successfully!")

=== DEFINING DATASET AND MODEL CLASSES ===
‚úÖ Dataset and model classes defined successfully!


In [12]:
# Training Function
print("=== DEFINING TRAINING FUNCTION ===")

def train_arcface(data_root, epochs=1, batch_size=64, limit_images=None, identity_file=None, model_name="Model"):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f" Training {model_name} on {device}")

    # Transforms
    transform = transforms.Compose([
        transforms.Resize((112, 112)),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
    ])

    # Create dataset
    try:
        dataset = FixedFaceDataset(data_root, transform=transform,
                                 limit_images=limit_images, identity_file=identity_file)
        if len(dataset) == 0:
            raise ValueError("Dataset is empty!")
    except Exception as e:
        print(f" Error creating dataset: {e}")
        return None, None

    # DataLoader
    actual_batch_size = min(batch_size, len(dataset))
    loader = DataLoader(dataset, batch_size=actual_batch_size,
                       shuffle=True, num_workers=2)
    n_classes = len(dataset.classes)

    print(f" Training on {len(dataset)} images, {n_classes} identities")
    print(f" Batch size: {actual_batch_size}, Batches per epoch: {len(loader)}")

    # Check if we have multiple identities
    if n_classes <= 1:
        print(" WARNING: Only 1 identity found! This won't learn meaningful features.")
        return None, None

    # Model
    backbone = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
    backbone.fc = nn.Linear(backbone.fc.in_features, 512)
    backbone.to(device)

    arc_head = ArcMarginProduct(512, n_classes).to(device)
    criterion = nn.CrossEntropyLoss()

    # Optimizer
    base_lr = 0.01 if n_classes > 100 else 0.001
    optimizer = torch.optim.SGD(
        list(backbone.parameters()) + list(arc_head.parameters()),
        lr=base_lr, momentum=0.9, weight_decay=1e-4
    )

    print(f"üîß Optimizer: SGD, LR: {base_lr}")

    # Training loop
    for epoch in range(epochs):
        backbone.train()
        arc_head.train()
        running_loss = 0.0

        for i, (imgs, labels) in enumerate(loader):
            imgs, labels = imgs.to(device), labels.to(device)

            optimizer.zero_grad()
            feats = backbone(imgs)
            logits = arc_head(feats, labels)
            loss = criterion(logits, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * imgs.size(0)

            if i % 10 == 0 or i == len(loader) - 1:
                print(f'Epoch {epoch+1}/{epochs} | Batch {i}/{len(loader)} | Loss: {loss.item():.4f}')

        epoch_loss = running_loss / len(dataset)
        print(f' Epoch {epoch+1}/{epochs}, Average Loss: {epoch_loss:.4f}')

    print(f" {model_name} training completed successfully!")
    return backbone, arc_head

print(" Training function defined!")

=== DEFINING TRAINING FUNCTION ===
‚úÖ Training function defined!


In [13]:
# Train on both datasets
print("=== STARTING TRAINING ===")

# Train LFW model
print("\n" + "="*60)
print(" TRAINING LFW MODEL")
print("="*60)

lfw_backbone, lfw_head = train_arcface(
    LFW_ROOT,
    epochs=1,
    batch_size=16,
    limit_images=100,
    model_name="LFW"
)

# Train CelebA model
print("\n" + "="*60)
print(" TRAINING CELEBA MODEL")
print("="*60)

celeba_backbone, celeba_head = train_arcface(
    CELEBA_IMG_ROOT,
    epochs=1,
    batch_size=16,
    limit_images=100,
    identity_file=CELEBA_ID_TXT,
    model_name="CelebA"
)

print("\n" + "="*60)
print(" TRAINING COMPLETED FOR BOTH DATASETS!")
print("="*60)

=== STARTING TRAINING ===

üéØ TRAINING LFW MODEL
üöÄ Training LFW on cpu
üîç Scanning: /content/lfw-deepfunneled
Found 30 total image files
Limited to 30 images
üìÅ Using folder structure for LFW
‚úÖ LFW-style: 30 images, 10 identities
üìä Training on 30 images, 10 identities
üì¶ Batch size: 16, Batches per epoch: 2
üîß Optimizer: SGD, LR: 0.001
Epoch 1/1 | Batch 0/2 | Loss: 35.8973
Epoch 1/1 | Batch 1/2 | Loss: 36.7787
üìà Epoch 1/1, Average Loss: 36.3086
‚úÖ LFW training completed successfully!

üéØ TRAINING CELEBA MODEL
üöÄ Training CelebA on cpu
üîç Scanning: /content/celeba_temp/img_align_celeba/img_align_celeba
Found 202599 total image files
Limited to 100 images
üé≠ Using identity file for CelebA
‚úÖ CelebA: 100 images, 97 identities
üìä Training on 100 images, 97 identities
üì¶ Batch size: 16, Batches per epoch: 7
üîß Optimizer: SGD, LR: 0.001
Epoch 1/1 | Batch 0/7 | Loss: 38.6998
Epoch 1/1 | Batch 6/7 | Loss: 39.6110
üìà Epoch 1/1, Average Loss: 38.5508
‚úÖ Ce

In [14]:
# Extract and Save Final Weights and Biases
print("=== EXTRACTING FINAL WEIGHTS AND BIASES ===")

def extract_and_save_weights(backbone, head, model_name):
    if backbone is None or head is None:
        print(f" {model_name} model not available for weight extraction")
        return None

    weights_dict = {}

    print(f"\n Extracting weights from {model_name}:")

    # Extract backbone weights (ResNet50)
    print("  üîç Backbone weights:")
    for name, param in backbone.named_parameters():
        if param.requires_grad:
            if 'weight' in name:
                weights_dict[f"{model_name}_{name}"] = param.data.cpu()
                print(f"     {name}: {param.shape}")
            elif 'bias' in name:
                weights_dict[f"{model_name}_{name}"] = param.data.cpu()
                print(f"     {name}: {param.shape}")

    # Extract ArcFace head weights (most important!)
    print("   ArcFace head weights:")
    if hasattr(head, 'weight'):
        arcface_weights = head.weight.data.cpu()
        weights_dict[f"{model_name}_arcface_weights"] = arcface_weights
        print(f"     ArcFace weights: {arcface_weights.shape}")

        # Also get the ArcFace parameters
        weights_dict[f"{model_name}_arcface_s"] = torch.tensor(head.s)
        weights_dict[f"{model_name}_arcface_m"] = torch.tensor(head.m)
        print(f"     ArcFace parameters - s: {head.s}, m: {head.m}")

    # Save to file
    filename = f'/content/{model_name.lower()}_final_weights.pth'
    torch.save(weights_dict, filename)
    print(f"     Saved to: {filename}")

    return weights_dict, filename

print(" Extracting and saving model weights...")

# Extract weights from both models
lfw_weights, lfw_filename = extract_and_save_weights(lfw_backbone, lfw_head, "LFW")
celeba_weights, celeba_filename = extract_and_save_weights(celeba_backbone, celeba_head, "CelebA")

print("\n FINAL WEIGHTS SUMMARY:")
print("="*50)

if lfw_weights:
    print("LFW Model - Critical Weights:")
    critical_keys = [k for k in lfw_weights.keys() if 'fc.weight' in k or 'arcface_weights' in k]
    for key in critical_keys:
        tensor = lfw_weights[key]
        print(f"   {key}:")
        print(f"     Shape: {tensor.shape}")
        print(f"     Mean: {tensor.mean():.6f}")
        print(f"     Std:  {tensor.std():.6f}")

if celeba_weights:
    print("\nCelebA Model - Critical Weights:")
    critical_keys = [k for k in celeba_weights.keys() if 'fc.weight' in k or 'arcface_weights' in k]
    for key in critical_keys:
        tensor = celeba_weights[key]
        print(f"   {key}:")
        print(f"     Shape: {tensor.shape}")
        print(f"     Mean: {tensor.mean():.6f}")
        print(f"     Std:  {tensor.std():.6f}")

print("\n All weights extracted and saved!")

=== EXTRACTING FINAL WEIGHTS AND BIASES ===
üíæ Extracting and saving model weights...

üìä Extracting weights from LFW:
  üîç Backbone weights:
    ‚úÖ conv1.weight: torch.Size([64, 3, 7, 7])
    ‚úÖ bn1.weight: torch.Size([64])
    ‚úÖ bn1.bias: torch.Size([64])
    ‚úÖ layer1.0.conv1.weight: torch.Size([64, 64, 1, 1])
    ‚úÖ layer1.0.bn1.weight: torch.Size([64])
    ‚úÖ layer1.0.bn1.bias: torch.Size([64])
    ‚úÖ layer1.0.conv2.weight: torch.Size([64, 64, 3, 3])
    ‚úÖ layer1.0.bn2.weight: torch.Size([64])
    ‚úÖ layer1.0.bn2.bias: torch.Size([64])
    ‚úÖ layer1.0.conv3.weight: torch.Size([256, 64, 1, 1])
    ‚úÖ layer1.0.bn3.weight: torch.Size([256])
    ‚úÖ layer1.0.bn3.bias: torch.Size([256])
    ‚úÖ layer1.0.downsample.0.weight: torch.Size([256, 64, 1, 1])
    ‚úÖ layer1.0.downsample.1.weight: torch.Size([256])
    ‚úÖ layer1.0.downsample.1.bias: torch.Size([256])
    ‚úÖ layer1.1.conv1.weight: torch.Size([64, 256, 1, 1])
    ‚úÖ layer1.1.bn1.weight: torch.Size([64])
    

In [15]:
# Download Final Weights to Your Computer
print("=== DOWNLOADING WEIGHTS TO YOUR COMPUTER ===")

def download_with_timestamp(filepath, description):
    if os.path.exists(filepath):
        # Add timestamp to filename
        timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
        filename = os.path.basename(filepath)
        name, ext = os.path.splitext(filename)
        new_filename = f"{name}_{timestamp}{ext}"

        # Copy with new name and download
        !cp "{filepath}" "/content/{new_filename}"
        files.download(f"/content/{new_filename}")
        print(f" {description}: {new_filename}")
        return True
    else:
        print(f" File not found: {filepath}")
        return False

print("\n Downloading final weight files...")
print("Check your browser downloads folder for these files:")

lfw_success = download_with_timestamp('/content/lfw_final_weights.pth', 'LFW Weights')
celeba_success = download_with_timestamp('/content/celeba_final_weights.pth', 'CelebA Weights')

print("\n" + "="*60)
print(" FINAL WEIGHTS AND BIASES DOWNLOAD COMPLETE!")
print("="*60)

if lfw_success:
    print(" LFW Model Weights Downloaded - Contains:")
    print("   ‚Ä¢ ResNet50 backbone weights and biases")
    print("   ‚Ä¢ Final FC layer weights (512-D embeddings)")
    print("   ‚Ä¢ ArcFace classification weights")
    print("   ‚Ä¢ All trained parameters")

if celeba_success:
    print(" CelebA Model Weights Downloaded - Contains:")
    print("   ‚Ä¢ ResNet50 backbone weights and biases")
    print("   ‚Ä¢ Final FC layer weights (512-D embeddings)")
    print("   ‚Ä¢ ArcFace classification weights")
    print("   ‚Ä¢ All trained parameters")

print("\n Most Important Weights for Face Recognition:")
print("   ‚Ä¢ 'fc.weight' - Converts images to 512-D face embeddings")
print("   ‚Ä¢ 'arcface_weights' - Classifies embeddings to identities")

print("\n You can now use these weights for:")
print("   ‚Ä¢ Face verification")
print("   ‚Ä¢ Face identification")
print("   ‚Ä¢ Transfer learning")
print("   ‚Ä¢ Feature extraction")

print("\n Training Complete! Check your downloads folder for the .pth files!")

=== DOWNLOADING WEIGHTS TO YOUR COMPUTER ===

üì• Downloading final weight files...
Check your browser downloads folder for these files:


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

üì• LFW Weights: lfw_final_weights_20251128_142634.pth


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

üì• CelebA Weights: celeba_final_weights_20251128_142634.pth

üéâ FINAL WEIGHTS AND BIASES DOWNLOAD COMPLETE!
‚úÖ LFW Model Weights Downloaded - Contains:
   ‚Ä¢ ResNet50 backbone weights and biases
   ‚Ä¢ Final FC layer weights (512-D embeddings)
   ‚Ä¢ ArcFace classification weights
   ‚Ä¢ All trained parameters
‚úÖ CelebA Model Weights Downloaded - Contains:
   ‚Ä¢ ResNet50 backbone weights and biases
   ‚Ä¢ Final FC layer weights (512-D embeddings)
   ‚Ä¢ ArcFace classification weights
   ‚Ä¢ All trained parameters

üîë Most Important Weights for Face Recognition:
   ‚Ä¢ 'fc.weight' - Converts images to 512-D face embeddings
   ‚Ä¢ 'arcface_weights' - Classifies embeddings to identities

üöÄ You can now use these weights for:
   ‚Ä¢ Face verification
   ‚Ä¢ Face identification
   ‚Ä¢ Transfer learning
   ‚Ä¢ Feature extraction

‚≠ê Training Complete! Check your downloads folder for the .pth files!


In [16]:
# === EXTRA: 512-D FACE EMBEDDING EXTRACTION (INFERENCE MODE) ===

print("=== DEFINING EMBEDDING EXTRACTION FUNCTION ===")

def generate_arcface_embeddings(backbone, image_root, model_name="Model",
                                limit_images=None, output_prefix=None):
    """
    Uses the trained backbone (ResNet-50 + 512-D FC) to:
      - Load all images from image_root
      - Resize + normalize (same as training)
      - Extract 512-D L2-normalized embeddings
      - Save to a .npz file: filenames + embeddings
    """
    import glob  # just in case

    if backbone is None:
        print(f" {model_name} backbone is None. Train the model first.")
        return None

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    backbone = backbone.to(device)
    backbone.eval()

    print("\n" + "="*60)
    print(f" GENERATING 512-D EMBEDDINGS FOR: {model_name}")
    print("="*60)
    print(f" Image root: {image_root}")
    print(f" Device: {device}")

    # Same transforms as training
    transform = transforms.Compose([
        transforms.Resize((112, 112)),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5],
                             [0.5, 0.5, 0.5])
    ])

    # Collect image files (jpg / jpeg / png)
    image_files = []
    for ext in ["*.jpg", "*.jpeg", "*.png"]:
        pattern = os.path.join(image_root, "**", ext)
        image_files.extend(glob.glob(pattern, recursive=True))

    image_files = sorted(image_files)

    if not image_files:
        print(" No images found for embedding extraction.")
        return None

    if limit_images is not None:
        image_files = image_files[:limit_images]
        print(f"‚ö† Limiting to first {len(image_files)} images for speed.")

    print(f" Total images to process: {len(image_files)}")

    embeddings = []
    filenames = []

    with torch.no_grad():
        for idx, img_path in enumerate(image_files):
            try:
                img = Image.open(img_path).convert("RGB")
                img_t = transform(img).unsqueeze(0).to(device)

                # Forward through backbone -> (1, 512)
                feat = backbone(img_t)

                # L2 normalize (common for ArcFace embeddings)
                feat = nn.functional.normalize(feat, p=2, dim=1)

                embeddings.append(feat.squeeze(0).cpu().numpy())
                filenames.append(os.path.basename(img_path))

                # Progress log
                if (idx + 1) % 25 == 0 or (idx + 1) == len(image_files):
                    print(f"   ‚úî Processed {idx+1}/{len(image_files)} images")
            except Exception as e:
                print(f"    Error processing {img_path}: {e}")

    if len(embeddings) == 0:
        print(" No embeddings were generated.")
        return None

    embeddings = np.stack(embeddings, axis=0)

    if output_prefix is None:
        output_prefix = model_name.lower()

    out_path = f"/content/{output_prefix}_arcface_512d_embeddings.npz"

    # Save filenames + embeddings
    np.savez(out_path,
             filenames=np.array(filenames),
             embeddings=embeddings)

    print("\n EMBEDDINGS SAVED!")
    print(f"   ‚Ä¢ File: {out_path}")
    print(f"   ‚Ä¢ Num embeddings: {embeddings.shape[0]}")
    print(f"   ‚Ä¢ Embedding dimension: {embeddings.shape[1]}")
    return out_path

print(" Embedding extraction function defined!")


=== DEFINING EMBEDDING EXTRACTION FUNCTION ===
‚úÖ Embedding extraction function defined!


In [17]:
# === EXTRA: RUN EMBEDDING EXTRACTION + DOWNLOAD FILES ===

print("=== STARTING 512-D EMBEDDING EXTRACTION FOR BOTH DATASETS ===")

lfw_embeddings_file = generate_arcface_embeddings(
    backbone=lfw_backbone,
    image_root=LFW_ROOT,
    model_name="LFW",
    limit_images=None,
    output_prefix="lfw"
)

celeba_embeddings_file = generate_arcface_embeddings(
    backbone=celeba_backbone,
    image_root=CELEBA_IMG_ROOT,
    model_name="CelebA",
    limit_images=None,
    output_prefix="celeba"
)

print("\n" + "="*60)
print("=== DOWNLOADING EMBEDDING FILES TO YOUR COMPUTER ===")
print("="*60)

if lfw_embeddings_file is not None:
    download_with_timestamp(lfw_embeddings_file, "LFW 512-D embeddings")

if celeba_embeddings_file is not None:
    download_with_timestamp(celeba_embeddings_file, "CelebA 512-D embeddings")

print("\n DONE! You now have:")
print("   ‚Ä¢ lfw_arcface_512d_embeddings_YYYYMMDD_HHMMSS.npz")
print("   ‚Ä¢ celeba_arcface_512d_embeddings_YYYYMMDD_HHMMSS.npz")
print("Each file contains:")
print("   ‚Ä¢ 'filenames'  ‚Üí image file names")
print("   ‚Ä¢ 'embeddings' ‚Üí 512-D ArcFace feature vectors")


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
   ‚úî Processed 77850/202599 images
   ‚úî Processed 77875/202599 images
   ‚úî Processed 77900/202599 images
   ‚úî Processed 77925/202599 images
   ‚úî Processed 77950/202599 images
   ‚úî Processed 77975/202599 images
   ‚úî Processed 78000/202599 images
   ‚úî Processed 78025/202599 images
   ‚úî Processed 78050/202599 images
   ‚úî Processed 78075/202599 images
   ‚úî Processed 78100/202599 images
   ‚úî Processed 78125/202599 images
   ‚úî Processed 78150/202599 images
   ‚úî Processed 78175/202599 images
   ‚úî Processed 78200/202599 images
   ‚úî Processed 78225/202599 images
   ‚úî Processed 78250/202599 images
   ‚úî Processed 78275/202599 images
   ‚úî Processed 78300/202599 images
   ‚úî Processed 78325/202599 images
   ‚úî Processed 78350/202599 images
   ‚úî Processed 78375/202599 images
   ‚úî Processed 78400/202599 images
   ‚úî Processed 78425/202599 images
   ‚úî Processed 78450/202599 images
   ‚úî Pro

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

üì• LFW 512-D embeddings: lfw_arcface_512d_embeddings_20251128_192651.npz


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

üì• CelebA 512-D embeddings: celeba_arcface_512d_embeddings_20251128_192651.npz

üéâ DONE! You now have:
   ‚Ä¢ lfw_arcface_512d_embeddings_YYYYMMDD_HHMMSS.npz
   ‚Ä¢ celeba_arcface_512d_embeddings_YYYYMMDD_HHMMSS.npz
Each file contains:
   ‚Ä¢ 'filenames'  ‚Üí image file names
   ‚Ä¢ 'embeddings' ‚Üí 512-D ArcFace feature vectors


In [18]:
# === EXTRA: 512-D FACE EMBEDDING EXTRACTION (INFERENCE MODE) ===

print("=== DEFINING EMBEDDING EXTRACTION FUNCTION ===")

def generate_arcface_embeddings(backbone, image_root, model_name="Model",
                                limit_images=None, output_prefix=None):
    """
    Uses the trained backbone (ResNet-50 + 512-D FC) to:
      - Load all images from image_root
      - Resize + normalize (same as training)
      - Extract 512-D L2-normalized embeddings
      - Save to a .npz file: filenames + embeddings
    """
    import glob  # just in case

    if backbone is None:
        print(f" {model_name} backbone is None. Train the model first.")
        return None

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    backbone = backbone.to(device)
    backbone.eval()

    print("\n" + "="*60)
    print(f" GENERATING 512-D EMBEDDINGS FOR: {model_name}")
    print("="*60)
    print(f" Image root: {image_root}")
    print(f" Device: {device}")

    # Same transforms as training
    transform = transforms.Compose([
        transforms.Resize((112, 112)),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5],
                             [0.5, 0.5, 0.5])
    ])

    # Collect image files (jpg / jpeg / png)
    image_files = []
    for ext in ["*.jpg", "*.jpeg", "*.png"]:
        pattern = os.path.join(image_root, "**", ext)
        image_files.extend(glob.glob(pattern, recursive=True))

    image_files = sorted(image_files)

    if not image_files:
        print(" No images found for embedding extraction.")
        return None

    if limit_images is not None:
        image_files = image_files[:limit_images]
        print(f"‚ö† Limiting to first {len(image_files)} images for speed.")

    print(f" Total images to process: {len(image_files)}")

    embeddings = []
    filenames = []

    with torch.no_grad():
        for idx, img_path in enumerate(image_files):
            try:
                img = Image.open(img_path).convert("RGB")
                img_t = transform(img).unsqueeze(0).to(device)

                # Forward through backbone -> (1, 512)
                feat = backbone(img_t)

                # L2 normalize (common for ArcFace embeddings)
                feat = nn.functional.normalize(feat, p=2, dim=1)

                embeddings.append(feat.squeeze(0).cpu().numpy())
                filenames.append(os.path.basename(img_path))

                # Progress log
                if (idx + 1) % 25 == 0 or (idx + 1) == len(image_files):
                    print(f"   ‚úî Processed {idx+1}/{len(image_files)} images")
            except Exception as e:
                print(f"   ‚ö† Error processing {img_path}: {e}")

    if len(embeddings) == 0:
        print(" No embeddings were generated.")
        return None

    embeddings = np.stack(embeddings, axis=0)

    if output_prefix is None:
        output_prefix = model_name.lower()

    out_path = f"/content/{output_prefix}_arcface_512d_embeddings.npz"

    # Save filenames + embeddings
    np.savez(out_path,
             filenames=np.array(filenames),
             embeddings=embeddings)

    print("\n EMBEDDINGS SAVED!")
    print(f"   ‚Ä¢ File: {out_path}")
    print(f"   ‚Ä¢ Num embeddings: {embeddings.shape[0]}")
    print(f"   ‚Ä¢ Embedding dimension: {embeddings.shape[1]}")
    return out_path

print(" Embedding extraction function defined!")


=== DEFINING EMBEDDING EXTRACTION FUNCTION ===
 Embedding extraction function defined!


In [19]:
# === EXTRA: RUN EMBEDDING EXTRACTION + DOWNLOAD FILES ===

print("=== STARTING 512-D EMBEDDING EXTRACTION FOR BOTH DATASETS ===")

lfw_embeddings_file = generate_arcface_embeddings(
    backbone=lfw_backbone,
    image_root=LFW_ROOT,
    model_name="LFW",
    limit_images=None,          # put 100 if you want it faster
    output_prefix="lfw"
)

celeba_embeddings_file = generate_arcface_embeddings(
    backbone=celeba_backbone,
    image_root=CELEBA_IMG_ROOT,
    model_name="CelebA",
    limit_images=None,          # e.g. 200 for speed
    output_prefix="celeba"
)

print("\n" + "="*60)
print("=== DOWNLOADING EMBEDDING FILES TO YOUR COMPUTER ===")
print("="*60)

if lfw_embeddings_file is not None:
    download_with_timestamp(lfw_embeddings_file, "LFW 512-D embeddings")

if celeba_embeddings_file is not None:
    download_with_timestamp(celeba_embeddings_file, "CelebA 512-D embeddings")

print("\n DONE! You now have:")
print("   ‚Ä¢ lfw_arcface_512d_embeddings_YYYYMMDD_HHMMSS.npz")
print("   ‚Ä¢ celeba_arcface_512d_embeddings_YYYYMMDD_HHMMSS.npz")
print("Each file contains:")
print("   ‚Ä¢ 'filenames'  ‚Üí image file names")
print("   ‚Ä¢ 'embeddings' ‚Üí 512-D ArcFace feature vectors")


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
   ‚úî Processed 77850/202599 images
   ‚úî Processed 77875/202599 images
   ‚úî Processed 77900/202599 images
   ‚úî Processed 77925/202599 images
   ‚úî Processed 77950/202599 images
   ‚úî Processed 77975/202599 images
   ‚úî Processed 78000/202599 images
   ‚úî Processed 78025/202599 images
   ‚úî Processed 78050/202599 images
   ‚úî Processed 78075/202599 images
   ‚úî Processed 78100/202599 images
   ‚úî Processed 78125/202599 images
   ‚úî Processed 78150/202599 images
   ‚úî Processed 78175/202599 images
   ‚úî Processed 78200/202599 images
   ‚úî Processed 78225/202599 images
   ‚úî Processed 78250/202599 images
   ‚úî Processed 78275/202599 images
   ‚úî Processed 78300/202599 images
   ‚úî Processed 78325/202599 images
   ‚úî Processed 78350/202599 images
   ‚úî Processed 78375/202599 images
   ‚úî Processed 78400/202599 images
   ‚úî Processed 78425/202599 images
   ‚úî Processed 78450/202599 images
   ‚úî Pro

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

üì• LFW 512-D embeddings: lfw_arcface_512d_embeddings_20251129_001200.npz


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

üì• CelebA 512-D embeddings: celeba_arcface_512d_embeddings_20251129_001200.npz

üéâ DONE! You now have:
   ‚Ä¢ lfw_arcface_512d_embeddings_YYYYMMDD_HHMMSS.npz
   ‚Ä¢ celeba_arcface_512d_embeddings_YYYYMMDD_HHMMSS.npz
Each file contains:
   ‚Ä¢ 'filenames'  ‚Üí image file names
   ‚Ä¢ 'embeddings' ‚Üí 512-D ArcFace feature vectors
