In [None]:
# ============================================================================
# INSTALLS & IMPORTS
# ============================================================================

!pip install torch torchvision --quiet

import os
from typing import Dict, List, Tuple

import numpy as np
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
from tqdm.auto import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


Using device: cuda


In [None]:
# ============================================================================
# 2. CONFIGURATION (PATHS & HYPERPARAMS)
# ============================================================================

CELEBA_ROOT     = "/content/data/celeba"
VGGFACE2_ROOT   = "/content/data/vggface2"
GENDERFACE_ROOT = "/content/data/genderface"
LFW_ROOT        = "/content/data/lfw"

DATA_ROOTS = {
    "celeba": CELEBA_ROOT,
    "vggface2": VGGFACE2_ROOT,
    "genderface": GENDERFACE_ROOT,
    "lfw": LFW_ROOT,
}

for name, p in DATA_ROOTS.items():
    print(f"{name:10s} -> {p}")

# Federated training hyperparameters
NUM_ROUNDS = 3          # Number of federated rounds
LOCAL_EPOCHS = 1        # Local epochs per round
BATCH_SIZE = 32
LEARNING_RATE = 1e-4
NUM_WORKERS = 2


celeba     -> /content/data/celeba
vggface2   -> /content/data/vggface2
genderface -> /content/data/genderface
lfw        -> /content/data/lfw


In [None]:
# ============================================================================
# DOWNLOAD & PREPARE ALL FOUR DATASETS FOR FEDERATED TRAINING
# ============================================================================

import os, json, shutil, gc, zipfile, random
from pathlib import Path

import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from tqdm.auto import tqdm


!pip install -q kaggle gender-guesser

import gender_guesser.detector as gender

# --------------------------------------------------------------------------
# 0. KAGGLE CREDENTIALS (EDIT THIS ONCE)
# --------------------------------------------------------------------------

kaggle_creds = {
    "username": "yassineyahyaouii",
    "key": "6e0ebf67d37bb1c91dc21f7861a27c77"
}

kaggle_dir = Path.home() / ".kaggle"
kaggle_dir.mkdir(exist_ok=True)
kaggle_json_path = kaggle_dir / "kaggle.json"
with open(kaggle_json_path, "w") as f:
    json.dump(kaggle_creds, f)
os.chmod(kaggle_json_path, 0o600)
print("✓ Kaggle credentials configured")

# Make sure base output dirs exist
BASE_DATA_DIR = Path("/content/data")
BASE_DATA_DIR.mkdir(parents=True, exist_ok=True)

CELEBA_ROOT     = BASE_DATA_DIR / "celeba"
VGGFACE2_ROOT   = BASE_DATA_DIR / "vggface2"
GENDERFACE_ROOT = BASE_DATA_DIR / "genderface"
LFW_ROOT        = BASE_DATA_DIR / "lfw"

print("Data roots:")
print("  CELEBA_ROOT    :", CELEBA_ROOT)
print("  VGGFACE2_ROOT  :", VGGFACE2_ROOT)
print("  GENDERFACE_ROOT:", GENDERFACE_ROOT)
print("  LFW_ROOT       :", LFW_ROOT)


# ============================================================================
# 1. CELEBA (Kaggle: jessicali9530/celeba-dataset)
# ============================================================================

SAMPLE_SIZE = 2000

print("\n" + "="*70)
print("CELEBA: DOWNLOAD + SAMPLE + TRAIN/VAL/TEST SPLIT")
print("="*70 + "\n")

# Download
!kaggle datasets download -d jessicali9530/celeba-dataset --force

# Extract everything into celeba_raw/
!mkdir -p celeba_raw
!unzip -q celeba-dataset.zip -d celeba_raw

def extract_sampled_celeba(raw_dir: str, sample_size: int = None) -> Path:
    """Your CelebA sampling + split logic (slightly adapted)."""
    output_dir = Path("celeba_sampled")
    output_dir.mkdir(exist_ok=True, parents=True)

    for split in ["train", "val", "test"]:
        for gender_name in ["male", "female"]:
            (output_dir / split / gender_name).mkdir(exist_ok=True, parents=True)

    raw_root = Path(raw_dir) / "celeba_raw"
    df_part = pd.read_csv(raw_root / "list_eval_partition.csv")
    df_attr = pd.read_csv(raw_root / "list_attr_celeba.csv")
    df = df_part.merge(df_attr[["image_id", "Male"]], on="image_id")

    partition_map = {0: "train", 1: "val", 2: "test"}
    df["split"] = df["partition"].map(partition_map)
    df["gender"] = df["Male"].apply(lambda x: "male" if x == 1 else "female")

    img_root = raw_root / "img_align_celeba" / "img_align_celeba"

    male_train   = df[(df["split"] == "train") & (df["gender"] == "male")].reset_index(drop=True)
    female_train = df[(df["split"] == "train") & (df["gender"] == "female")].reset_index(drop=True)
    male_val     = df[(df["split"] == "val")   & (df["gender"] == "male")].reset_index(drop=True)
    female_val   = df[(df["split"] == "val")   & (df["gender"] == "female")].reset_index(drop=True)
    male_test    = df[(df["split"] == "test")  & (df["gender"] == "male")].reset_index(drop=True)
    female_test  = df[(df["split"] == "test")  & (df["gender"] == "female")].reset_index(drop=True)

    print(f"Found {len(male_train)} male train images")
    print(f"Found {len(female_train)} female train images")

    if sample_size is not None:
        train_per_gender = sample_size // 2
        val_per_gender   = int(sample_size * 0.15)
        test_per_gender  = int(sample_size * 0.15)

        male_train   = male_train.sample(n=min(train_per_gender, len(male_train)), random_state=42)
        female_train = female_train.sample(n=min(train_per_gender, len(female_train)), random_state=42)
        male_val     = male_val.sample(n=min(val_per_gender, len(male_val)), random_state=42)
        female_val   = female_val.sample(n=min(val_per_gender, len(female_val)), random_state=42)
        male_test    = male_test.sample(n=min(test_per_gender, len(male_test)), random_state=42)
        female_test  = female_test.sample(n=min(test_per_gender, len(female_test)), random_state=42)

        print(f"Sampling ~{sample_size} images total")

    samples = pd.concat([male_train, female_train, male_val, female_val, male_test, female_test])

    counts = {"train": {"male": 0, "female": 0},
              "val":   {"male": 0, "female": 0},
              "test":  {"male": 0, "female": 0}}

    for _, row in tqdm(samples.iterrows(), total=len(samples), desc="Extracting CelebA images"):
        img_id = row["image_id"]
        split  = row["split"]
        gender_name = row["gender"]

        src = img_root / img_id
        if not src.exists():
            continue
        dst = output_dir / split / gender_name / img_id
        shutil.copy2(src, dst)
        counts[split][gender_name] += 1

    print(f"\n✓ CelebA extraction complete:")
    print(f"  - Train: {counts['train']['male']} male, {counts['train']['female']} female")
    print(f"  - Val:   {counts['val']['male']} male, {counts['val']['female']} female")
    print(f"  - Test:  {counts['test']['male']} male, {counts['test']['female']} female")

    gc.collect()
    return output_dir

celeba_sampled_dir = extract_sampled_celeba(".", sample_size=SAMPLE_SIZE)

# Move to CELEBA_ROOT
if CELEBA_ROOT.exists():
    shutil.rmtree(CELEBA_ROOT)
shutil.move(str(celeba_sampled_dir), CELEBA_ROOT)
print("✓ CELEBA ready at", CELEBA_ROOT)


# ============================================================================
# 2. LFW (Kaggle: jessicali9530/lfw-dataset)
# ============================================================================

print("\n" + "="*70)
print("LFW: DOWNLOAD + NAME-BASED GENDER + TRAIN/VAL/TEST SPLIT")
print("="*70 + "\n")

# Download and extract
!kaggle datasets download -d jessicali9530/lfw-dataset --force
!unzip -q lfw-dataset.zip -d lfw_data
print("✓ LFW dataset downloaded and extracted")

def get_gender_from_name(name: str):
    """Predict gender from first name using gender_guesser."""
    d = gender.Detector()
    first_name = name.split("_")[0]
    result = d.get_gender(first_name)

    if result in ["male", "mostly_male"]:
        return "male"
    elif result in ["female", "mostly_female"]:
        return "female"
    else:
        return None

def organize_lfw():
    source_dir = Path("lfw_data/lfw-deepfunneled/lfw-deepfunneled")
    dest_dir   = Path("lfw_organized")
    dest_dir.mkdir(exist_ok=True)

    male_dir    = dest_dir / "male"
    female_dir  = dest_dir / "female"
    unknown_dir = dest_dir / "unknown"
    male_dir.mkdir(exist_ok=True)
    female_dir.mkdir(exist_ok=True)
    unknown_dir.mkdir(exist_ok=True)

    male_count = female_count = unknown_count = 0

    print("Processing LFW images and predicting gender from names...")
    person_folders = list(source_dir.iterdir())
    for person_folder in tqdm(person_folders, desc="LFW folders"):
        if not person_folder.is_dir():
            continue

        person_name = person_folder.name
        gender_pred = get_gender_from_name(person_name)

        if gender_pred is None:
            target_dir = unknown_dir
        elif gender_pred == "male":
            target_dir = male_dir
        else:
            target_dir = female_dir

        imgs = list(person_folder.glob("*.jpg"))
        if gender_pred is None:
            unknown_count += len(imgs)
        elif gender_pred == "male":
            male_count += len(imgs)
        else:
            female_count += len(imgs)

        for img_file in imgs:
            target_file = target_dir / f"{person_name}_{img_file.name}"
            shutil.copy2(img_file, target_file)

    print(f"\n✓ LFW organized:")
    print(f"  - Male images:    {male_count}")
    print(f"  - Female images:  {female_count}")
    print(f"  - Unknown gender: {unknown_count} (in 'unknown/')")

    return dest_dir

lfw_organized_dir = organize_lfw()

def create_splits_lfw(data_dir: Path, out_dir: Path, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15):
    assert abs(train_ratio + val_ratio + test_ratio - 1.0) < 1e-6

    if out_dir.exists():
        shutil.rmtree(out_dir)
    out_dir.mkdir(parents=True)

    for split in ["train", "val", "test"]:
        for gender_name in ["male", "female"]:
            (out_dir / split / gender_name).mkdir(parents=True, exist_ok=True)

    for gender_name in ["male", "female"]:
        gender_dir = data_dir / gender_name
        images = list(gender_dir.glob("*.jpg"))
        if len(images) == 0:
            print(f"⚠ No LFW images found for {gender_name}")
            continue

        train_imgs, temp_imgs = train_test_split(images, train_size=train_ratio, random_state=42)
        val_imgs, test_imgs = train_test_split(
            temp_imgs,
            train_size=val_ratio / (val_ratio + test_ratio),
            random_state=42,
        )

        for img_list, split in [(train_imgs, "train"), (val_imgs, "val"), (test_imgs, "test")]:
            for img in img_list:
                target = out_dir / split / gender_name / img.name
                shutil.copy2(img, target)

        print(f"LFW {gender_name.capitalize()}: {len(train_imgs)} train, {len(val_imgs)} val, {len(test_imgs)} test")

lfw_splits_dir = LFW_ROOT
create_splits_lfw(lfw_organized_dir, lfw_splits_dir)
print("✓ LFW ready at", LFW_ROOT)


# ============================================================================
# 3. VGGFACE2 (Kaggle: hearfool/vggface2) + sampling with mapping
# ============================================================================

print("\n" + "="*70)
print("VGGFACE2: DOWNLOAD + GENDER MAPPING + SAMPLING + SPLIT")
print("="*70 + "\n")

SAMPLE_SIZE = 2000  # as in your notebook (per dataset)

!kaggle datasets download -d hearfool/vggface2 --force

gender_mapping_path = Path("/content/gender_mapping.txt")

def extract_sampled_vggface2(zip_path, sample_size=None, gender_mapping_file_path=None):
    output_dir = Path("vggface2_sampled")
    output_dir.mkdir(exist_ok=True, parents=True)

    male_dir   = output_dir / "male"
    female_dir = output_dir / "female"
    male_dir.mkdir(exist_ok=True, parents=True)
    female_dir.mkdir(exist_ok=True, parents=True)

    print("\nVGGFace2: scanning and pre-classifying...")

    male_ids = set()
    female_ids = set()

    if gender_mapping_file_path and gender_mapping_file_path.exists():
        print(f"Loading gender mappings from {gender_mapping_file_path}...")
        with open(gender_mapping_file_path, "r") as f:
            for line in f:
                line = line.strip()
                if not line:
                    continue
                try:
                    image_path_str, label_str = line.split("	")
                    person_id = Path(image_path_str).parts[0]  # "n000002"
                    label = int(label_str)
                    if label == 0:
                        female_ids.add(person_id)
                    elif label == 1:
                        male_ids.add(person_id)
                except Exception as e:
                    print(f"Warning: could not parse line: {line} ({e})")
        print(f"  -> {len(male_ids)} male IDs, {len(female_ids)} female IDs")
    else:
        print("⚠ No gender mapping file found; cannot classify VGGFace2 by gender.")
        return output_dir

    male_image_paths = []
    female_image_paths = []

    with zipfile.ZipFile(zip_path, "r") as zf:
        all_candidate_files = [
            f for f in zf.namelist()
            if f.lower().endswith((".jpg", ".jpeg", ".png")) and f.startswith("train/")
        ]

        for file_path in tqdm(all_candidate_files, desc="Pre-classifying VGGFace2"):
            parts = Path(file_path).parts
            if len(parts) < 3:
                continue
            person_id = parts[1]

            if person_id in male_ids:
                male_image_paths.append(file_path)
            elif person_id in female_ids:
                female_image_paths.append(file_path)

    print(f"Found {len(male_image_paths)} male images, {len(female_image_paths)} female images in archive.")

    sampled_files_with_gender = []

    if sample_size:
        # Balanced sampling if possible
        target_male   = min(sample_size // 2, len(male_image_paths))
        target_female = min(sample_size // 2, len(female_image_paths))

        remaining = sample_size - (target_male + target_female)
        if remaining > 0:
            # If one side has more capacity, fill it
            extra_m = min(remaining, len(male_image_paths) - target_male)
            target_male += extra_m
            remaining -= extra_m
        if remaining > 0:
            extra_f = min(remaining, len(female_image_paths) - target_female)
            target_female += extra_f

        if target_male > 0:
            sampled_files_with_gender.extend(
                [(f, "male") for f in random.sample(male_image_paths, target_male)]
            )
        if target_female > 0:
            sampled_files_with_gender.extend(
                [(f, "female") for f in random.sample(female_image_paths, target_female)]
            )

        random.shuffle(sampled_files_with_gender)
        print(f"Sampling {len(sampled_files_with_gender)} VGGFace2 images "
              f"(male: {target_male}, female: {target_female}).")
    else:
        sampled_files_with_gender.extend([(f, "male") for f in male_image_paths])
        sampled_files_with_gender.extend([(f, "female") for f in female_image_paths])
        print(f"Using all classified images: {len(sampled_files_with_gender)}")

    male_extracted = female_extracted = 0

    if not sampled_files_with_gender:
        print("No VGGFace2 images to extract after sampling.")
        return output_dir

    with zipfile.ZipFile(zip_path, "r") as zf:
        for file_path, gender_label in tqdm(sampled_files_with_gender, desc="Extracting VGGFace2"):
            try:
                parts = Path(file_path).parts
                person_id = parts[1]
                filename = parts[-1]

                target_dir = male_dir if gender_label == "male" else female_dir
                target_file = target_dir / f"{person_id}_{filename}"

                with zf.open(file_path) as src, open(target_file, "wb") as dst:
                    dst.write(src.read())

                if gender_label == "male":
                    male_extracted += 1
                else:
                    female_extracted += 1
            except Exception as e:
                print(f"Error extracting {file_path}: {e}")
                continue

    print(f"\n✓ VGGFace2 extracted: {male_extracted} male, {female_extracted} female")
    gc.collect()
    return output_dir

vgg_sampled_dir = extract_sampled_vggface2("vggface2.zip", sample_size=SAMPLE_SIZE, gender_mapping_file_path=gender_mapping_path)

def create_splits_generic(src_dir: Path, out_dir: Path, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15):
    assert abs(train_ratio + val_ratio + test_ratio - 1.0) < 1e-6

    if out_dir.exists():
        shutil.rmtree(out_dir)
    out_dir.mkdir(parents=True)

    for split in ["train", "val", "test"]:
        for gender_name in ["male", "female"]:
            (out_dir / split / gender_name).mkdir(parents=True, exist_ok=True)

    for gender_name in ["male", "female"]:
        gender_dir = src_dir / gender_name
        images = list(gender_dir.glob("*.*"))
        if len(images) == 0:
            print(f"⚠ No images for {gender_name} in {src_dir}")
            continue

        train_imgs, temp_imgs = train_test_split(images, train_size=train_ratio, random_state=42)
        val_imgs, test_imgs = train_test_split(
            temp_imgs,
            train_size=val_ratio / (val_ratio + test_ratio),
            random_state=42,
        )

        for img_list, split in [(train_imgs, "train"), (val_imgs, "val"), (test_imgs, "test")]:
            for img in img_list:
                target = out_dir / split / gender_name / img.name
                shutil.copy2(img, target)

        print(f"{src_dir.name} {gender_name}: {len(train_imgs)} train, {len(val_imgs)} val, {len(test_imgs)} test")

create_splits_generic(vgg_sampled_dir, VGGFACE2_ROOT)
print("✓ VGGFace2 ready at", VGGFACE2_ROOT)


# ============================================================================
# 4. BIGGEST GENDERFACE DATASET (Kaggle: maciejgronczynski/biggest-genderface-recognition-dataset)
# ============================================================================

print("\n" + "="*70)
print("GENDERFACE: DOWNLOAD + AUTO-DISCOVER GENDER FOLDERS + SPLIT")
print("="*70 + "\n")

!kaggle datasets download -d maciejgronczynski/biggest-genderface-recognition-dataset --force
!unzip -q biggest-genderface-recognition-dataset.zip -d genderface_raw
print("✓ Genderface dataset downloaded and extracted")

def find_gender_labeled_images(root: Path):
    """Walk dataset and return list of (filepath, label) where label in {'male','female'} inferred from directory names."""
    files_with_labels = []
    for p in root.rglob("*.*"):
        if not p.is_file():
            continue
        lower_parts = [part.lower() for part in p.parts]
        if any(ext in p.suffix.lower() for ext in [".jpg", ".jpeg", ".png"]):
            label = None
            if any("female" in x or "woman" in x or "girl" in x for x in lower_parts):
                label = "female"
            elif any("male" in x or "man" in x or "boy" in x for x in lower_parts):
                label = "male"
            if label is not None:
                files_with_labels.append((p, label))
    return files_with_labels

genderface_raw_root = Path("genderface_raw")
genderface_files = find_gender_labeled_images(genderface_raw_root)
print(f"Found {len(genderface_files)} labeled Genderface images.")

# Build DataFrame and split per gender
if GENDERFACE_ROOT.exists():
    shutil.rmtree(GENDERFACE_ROOT)
GENDERFACE_ROOT.mkdir(parents=True)

for split in ["train", "val", "test"]:
    for gender_name in ["male", "female"]:
        (GENDERFACE_ROOT / split / gender_name).mkdir(parents=True, exist_ok=True)

df_gf = pd.DataFrame(genderface_files, columns=["path", "label"])

for gender_name in ["male", "female"]:
    sub = df_gf[df_gf["label"] == gender_name]
    paths = sub["path"].tolist()
    if len(paths) == 0:
        print(f"⚠ No Genderface images for {gender_name}")
        continue

    train_paths, temp_paths = train_test_split(paths, train_size=0.7, random_state=42)
    val_paths, test_paths = train_test_split(temp_paths, train_size=0.5, random_state=42)

    for img_list, split in [(train_paths, "train"), (val_paths, "val"), (test_paths, "test")]:
        for img in img_list:
            target = GENDERFACE_ROOT / split / gender_name / img.name
            shutil.copy2(img, target)

    print(f"Genderface {gender_name}: {len(train_paths)} train, {len(val_paths)} val, {len(test_paths)} test")

print("✓ Genderface ready at", GENDERFACE_ROOT)

gc.collect()
print("\nAll datasets are prepared under /content/data/")


[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/379.3 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m379.3/379.3 kB[0m [31m11.7 MB/s[0m eta [36m0:00:00[0m
[?25h✓ Kaggle credentials configured
Data roots:
  CELEBA_ROOT    : /content/data/celeba
  VGGFACE2_ROOT  : /content/data/vggface2
  GENDERFACE_ROOT: /content/data/genderface
  LFW_ROOT       : /content/data/lfw

CELEBA: DOWNLOAD + SAMPLE + TRAIN/VAL/TEST SPLIT

Dataset URL: https://www.kaggle.com/datasets/jessicali9530/celeba-dataset
License(s): other
Downloading celeba-dataset.zip to /content
 99% 1.31G/1.33G [00:14<00:00, 176MB/s]
100% 1.33G/1.33G [00:14<00:00, 97.2MB/s]
Found 68261 male train images
Found 94509 female train images
Sampling ~2000 images total


Extracting CelebA images:   0%|          | 0/3200 [00:00<?, ?it/s]


✓ CelebA extraction complete:
  - Train: 1000 male, 1000 female
  - Val:   300 male, 300 female
  - Test:  300 male, 300 female
✓ CELEBA ready at /content/data/celeba

LFW: DOWNLOAD + NAME-BASED GENDER + TRAIN/VAL/TEST SPLIT

Dataset URL: https://www.kaggle.com/datasets/jessicali9530/lfw-dataset
License(s): other
Downloading lfw-dataset.zip to /content
  0% 0.00/112M [00:00<?, ?B/s]
100% 112M/112M [00:00<00:00, 1.42GB/s]
✓ LFW dataset downloaded and extracted
Processing LFW images and predicting gender from names...


LFW folders:   0%|          | 0/5749 [00:00<?, ?it/s]


✓ LFW organized:
  - Male images:    9256
  - Female images:  2818
  - Unknown gender: 1159 (in 'unknown/')
LFW Male: 6479 train, 1388 val, 1389 test
LFW Female: 1972 train, 423 val, 423 test
✓ LFW ready at /content/data/lfw

VGGFACE2: DOWNLOAD + GENDER MAPPING + SAMPLING + SPLIT

Dataset URL: https://www.kaggle.com/datasets/hearfool/vggface2
License(s): unknown
Downloading vggface2.zip to /content
 99% 2.30G/2.32G [00:35<00:00, 58.5MB/s]
100% 2.32G/2.32G [00:36<00:00, 68.9MB/s]

VGGFace2: scanning and pre-classifying...
Loading gender mappings from /content/gender_mapping.txt...
  -> 3567 male IDs, 2433 female IDs


Pre-classifying VGGFace2:   0%|          | 0/176398 [00:00<?, ?it/s]

Found 79324 male images, 69594 female images in archive.
Sampling 2000 VGGFace2 images (male: 1000, female: 1000).


Extracting VGGFace2:   0%|          | 0/2000 [00:00<?, ?it/s]


✓ VGGFace2 extracted: 1000 male, 1000 female
vggface2_sampled male: 700 train, 150 val, 150 test
vggface2_sampled female: 700 train, 150 val, 150 test
✓ VGGFace2 ready at /content/data/vggface2

GENDERFACE: DOWNLOAD + AUTO-DISCOVER GENDER FOLDERS + SPLIT

Dataset URL: https://www.kaggle.com/datasets/maciejgronczynski/biggest-genderface-recognition-dataset
License(s): CC0-1.0
Downloading biggest-genderface-recognition-dataset.zip to /content
 97% 424M/439M [00:01<00:00, 304MB/s]
100% 439M/439M [00:01<00:00, 339MB/s]
✓ Genderface dataset downloaded and extracted
Found 27167 labeled Genderface images.
Genderface male: 12374 train, 2652 val, 2652 test
Genderface female: 6642 train, 1423 val, 1424 test
✓ Genderface ready at /content/data/genderface

All datasets are prepared under /content/data/


In [None]:
# ============================================================================
# 3. DATA TRANSFORMS & DATALOADER FACTORY
# ============================================================================

# Standard ImageNet-like transforms (since we use ResNet18)
train_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    ),
])

eval_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    ),
])


def make_dataloaders(root: str, batch_size: int = 32) -> Tuple[DataLoader, DataLoader, DataLoader]:
    """
    root/
      train/male, train/female
      val/male,   val/female
      test/male,  test/female
    """
    train_dir = os.path.join(root, "train")
    val_dir   = os.path.join(root, "val")
    test_dir  = os.path.join(root, "test")

    train_dataset = datasets.ImageFolder(train_dir, transform=train_transform)
    val_dataset   = datasets.ImageFolder(val_dir,   transform=eval_transform)
    test_dataset  = datasets.ImageFolder(test_dir,  transform=eval_transform)

    train_loader = DataLoader(train_dataset, batch_size=batch_size,
                              shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
    val_loader   = DataLoader(val_dataset,   batch_size=batch_size,
                              shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
    test_loader  = DataLoader(test_dataset,  batch_size=batch_size,
                              shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

    print(f"Loaded from {root}: "
          f"{len(train_dataset)} train, {len(val_dataset)} val, {len(test_dataset)} test images.")

    return train_loader, val_loader, test_loader


In [None]:
# ============================================================================
# 4. MODEL ARCHITECTURE (SHARED BY ALL CLIENTS)
# ============================================================================

def create_model(num_classes: int = 2) -> nn.Module:
    """
    ResNet18 backbone with a small classifier head for 2 classes (male/female).
    """
    model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
    in_features = model.fc.in_features
    model.fc = nn.Sequential(
        nn.Linear(in_features, 128),
        nn.ReLU(inplace=True),
        nn.Dropout(0.3),
        nn.Linear(128, num_classes),
    )
    return model.to(device)


def create_optimizer(model: nn.Module, lr: float = LEARNING_RATE):
    return optim.Adam(model.parameters(), lr=lr)


criterion = nn.CrossEntropyLoss()
print("Model and loss function ready.")


Model and loss function ready.


In [None]:
# ============================================================================
# 5. TRAINING & EVALUATION HELPERS
# ============================================================================

def train_one_epoch(
    model: nn.Module,
    dataloader: DataLoader,
    optimizer: optim.Optimizer,
    criterion: nn.Module,
    device: torch.device,
) -> Tuple[float, float]:
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for images, labels in tqdm(dataloader, leave=False):
        images = images.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * labels.size(0)

        _, preds = torch.max(outputs, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    epoch_loss = running_loss / total if total > 0 else 0.0
    epoch_acc = correct / total if total > 0 else 0.0
    return epoch_loss, epoch_acc


def evaluate(
    model: nn.Module,
    dataloader: DataLoader,
    criterion: nn.Module,
    device: torch.device,
) -> Tuple[float, float]:
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    with torch.inference_mode():
        for images, labels in dataloader:
            images = images.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)

            outputs = model(images)
            loss = criterion(outputs, labels)

            running_loss += loss.item() * labels.size(0)
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    epoch_loss = running_loss / total if total > 0 else 0.0
    epoch_acc = correct / total if total > 0 else 0.0
    return epoch_loss, epoch_acc


In [None]:
# ============================================================================
# 6. CREATE DATALOADERS & LOCAL MODELS FOR 4 CLIENTS
# ============================================================================

class Client:
    def __init__(self, name: str, root: str):
        self.name = name
        self.root = root

        # Load data
        self.train_loader, self.val_loader, self.test_loader = make_dataloaders(root, BATCH_SIZE)

        # Local model & optimizer
        self.model = create_model(num_classes=2)
        self.optimizer = create_optimizer(self.model, LEARNING_RATE)

    def set_weights(self, global_state_dict: Dict[str, torch.Tensor]):
        self.model.load_state_dict(global_state_dict, strict=True)

    def get_weights(self) -> Dict[str, torch.Tensor]:
        return {k: v.cpu().clone() for k, v in self.model.state_dict().items()}

    def num_training_samples(self) -> int:
        return len(self.train_loader.dataset)

    def local_train(self, local_epochs: int = 1) -> Dict[str, float]:
        logs = {}
        for _ in range(local_epochs):
            train_loss, train_acc = train_one_epoch(
                self.model, self.train_loader, self.optimizer, criterion, device
            )
        val_loss, val_acc = evaluate(self.model, self.val_loader, criterion, device)
        logs["train_loss"] = train_loss
        logs["train_acc"] = train_acc
        logs["val_loss"] = val_loss
        logs["val_acc"] = val_acc
        return logs

    def local_test(self) -> Tuple[float, float]:
        test_loss, test_acc = evaluate(self.model, self.test_loader, criterion, device)
        return test_loss, test_acc


# Instantiate the 4 clients
celeba_client     = Client("celeba",     CELEBA_ROOT)
vggface2_client   = Client("vggface2",   VGGFACE2_ROOT)
genderface_client = Client("genderface", GENDERFACE_ROOT)
lfw_client        = Client("lfw",        LFW_ROOT)

clients = [celeba_client, vggface2_client, genderface_client, lfw_client]
print("\nClients initialized:", [c.name for c in clients])


Loaded from /content/data/celeba: 2000 train, 600 val, 600 test images.
Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


100%|██████████| 44.7M/44.7M [00:00<00:00, 199MB/s]


Loaded from /content/data/vggface2: 1400 train, 300 val, 300 test images.
Loaded from /content/data/genderface: 19016 train, 4075 val, 4076 test images.
Loaded from /content/data/lfw: 8451 train, 1811 val, 1812 test images.

Clients initialized: ['celeba', 'vggface2', 'genderface', 'lfw']


In [None]:
# ============================================================================
# 7. SERVER: GLOBAL MODEL & FEDAVG UTILITIES
# ============================================================================

def average_state_dicts(
    state_dicts_and_sizes: List[Tuple[Dict[str, torch.Tensor], int]]
) -> Dict[str, torch.Tensor]:
    """
    Weighted average of model parameters from multiple clients.
    state_dicts_and_sizes: list of (state_dict, num_samples)
    """
    # Use CPU for aggregation
    avg_state_dict: Dict[str, torch.Tensor] = {}

    total_samples = sum(n for _, n in state_dicts_and_sizes)
    if total_samples == 0:
        raise ValueError("No samples across clients!")

    first_state_dict = state_dicts_and_sizes[0][0]

    for key in first_state_dict.keys():
        # Start with zeros, same shape
        avg_param = torch.zeros_like(first_state_dict[key], dtype=torch.float32)
        for state_dict, num_samples in state_dicts_and_sizes:
            weight = num_samples / total_samples
            avg_param += state_dict[key].float() * weight
        avg_state_dict[key] = avg_param

    return avg_state_dict


# Initialize global model
global_model = create_model(num_classes=2)
global_state = global_model.state_dict()
print("Global model initialized with", len(global_state), "parameters.")


Global model initialized with 124 parameters.


In [None]:
# ============================================================================
# 8. FEDERATED TRAINING LOOP (SERVER + 4 CLIENTS)
# ============================================================================

history = {
    "round": [],
    "global_test_acc_mean": [],
}

for rnd in range(1, NUM_ROUNDS + 1):
    print(f"\n{'='*30} ROUND {rnd} / {NUM_ROUNDS} {'='*30}")

    # 1. Broadcast global weights to all clients
    for client in clients:
        client.set_weights(global_state)

    # 2. Each client trains locally
    client_states_and_sizes = []
    for client in clients:
        print(f"\n--- Client: {client.name} local training ---")
        logs = client.local_train(local_epochs=LOCAL_EPOCHS)
        print(f"Client {client.name} - "
              f"train_loss: {logs['train_loss']:.4f}, train_acc: {logs['train_acc']:.4f}, "
              f"val_loss: {logs['val_loss']:.4f}, val_acc: {logs['val_acc']:.4f}")

        local_state = client.get_weights()
        num_samples = client.num_training_samples()
        client_states_and_sizes.append((local_state, num_samples))

    # 3. Server aggregates (FedAvg)
    global_state = average_state_dicts(client_states_and_sizes)
    global_model.load_state_dict(global_state, strict=True)

    # 4. Evaluate global model on each client's test set
    test_accs = []
    for client in clients:
        # use client's test set but global weights
        client.set_weights(global_state)
        test_loss, test_acc = client.local_test()
        test_accs.append(test_acc)
        print(f"Client {client.name} TEST - loss: {test_loss:.4f}, acc: {test_acc:.4f}")

    mean_test_acc = float(np.mean(test_accs))
    print(f"\n>>> ROUND {rnd} - Mean test accuracy across clients: {mean_test_acc:.4f}")

    history["round"].append(rnd)
    history["global_test_acc_mean"].append(mean_test_acc)

print("\nFederated training complete!")




--- Client: celeba local training ---


  0%|          | 0/63 [00:00<?, ?it/s]

Client celeba - train_loss: 0.3663, train_acc: 0.8340, val_loss: 0.1149, val_acc: 0.9583

--- Client: vggface2 local training ---


  0%|          | 0/44 [00:00<?, ?it/s]

Client vggface2 - train_loss: 0.4128, train_acc: 0.8171, val_loss: 0.2982, val_acc: 0.8533

--- Client: genderface local training ---


  0%|          | 0/595 [00:00<?, ?it/s]

Client genderface - train_loss: 0.2419, train_acc: 0.8962, val_loss: 0.1081, val_acc: 0.9617

--- Client: lfw local training ---


  0%|          | 0/265 [00:00<?, ?it/s]

Client lfw - train_loss: 0.2796, train_acc: 0.8920, val_loss: 0.1718, val_acc: 0.9431
Client celeba TEST - loss: 0.1224, acc: 0.9450
Client vggface2 TEST - loss: 0.1060, acc: 0.9733
Client genderface TEST - loss: 0.1234, acc: 0.9588
Client lfw TEST - loss: 0.1672, acc: 0.9531

>>> ROUND 1 - Mean test accuracy across clients: 0.9576


--- Client: celeba local training ---


  0%|          | 0/63 [00:00<?, ?it/s]

Client celeba - train_loss: 0.1745, train_acc: 0.9270, val_loss: 0.0740, val_acc: 0.9700

--- Client: vggface2 local training ---


  0%|          | 0/44 [00:00<?, ?it/s]

Client vggface2 - train_loss: 0.1915, train_acc: 0.9164, val_loss: 0.0532, val_acc: 0.9900

--- Client: genderface local training ---


  0%|          | 0/595 [00:00<?, ?it/s]

Client genderface - train_loss: 0.1860, train_acc: 0.9263, val_loss: 0.1100, val_acc: 0.9637

--- Client: lfw local training ---


  0%|          | 0/265 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x783df0135c60>
Traceback (most recent call last):
Exception ignored in:   File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
<function _MultiProcessingDataLoaderIter.__del__ at 0x783df0135c60>    self._shutdown_workers()

  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1637, in _shutdown_workers
Traceback (most recent call last):
    if w.is_alive():
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
      self._shutdown_workers()
     File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1637, in _shutdown_workers
      if w.is_alive():^
^ ^^ ^^   ^ ^ ^^^^^^^^
^  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
^    ^assert self._parent_pid == os.getpid(), 'can only test a child process'^
^ ^ ^^ 
  File "/usr/lib/pyt

Client lfw - train_loss: 0.1953, train_acc: 0.9355, val_loss: 0.1342, val_acc: 0.9636
Client celeba TEST - loss: 0.1088, acc: 0.9567
Client vggface2 TEST - loss: 0.0867, acc: 0.9733
Client genderface TEST - loss: 0.1151, acc: 0.9598
Client lfw TEST - loss: 0.1534, acc: 0.9647

>>> ROUND 2 - Mean test accuracy across clients: 0.9636


--- Client: celeba local training ---


  0%|          | 0/63 [00:00<?, ?it/s]

Client celeba - train_loss: 0.1601, train_acc: 0.9345, val_loss: 0.0553, val_acc: 0.9817

--- Client: vggface2 local training ---


  0%|          | 0/44 [00:00<?, ?it/s]

Client vggface2 - train_loss: 0.1728, train_acc: 0.9286, val_loss: 0.0458, val_acc: 0.9867

--- Client: genderface local training ---


  0%|          | 0/595 [00:00<?, ?it/s]

Client genderface - train_loss: 0.1603, train_acc: 0.9377, val_loss: 0.1205, val_acc: 0.9558

--- Client: lfw local training ---


  0%|          | 0/265 [00:00<?, ?it/s]

Client lfw - train_loss: 0.1767, train_acc: 0.9432, val_loss: 0.1476, val_acc: 0.9536
Client celeba TEST - loss: 0.1108, acc: 0.9583
Client vggface2 TEST - loss: 0.0810, acc: 0.9700
Client genderface TEST - loss: 0.1032, acc: 0.9679
Client lfw TEST - loss: 0.1581, acc: 0.9636

>>> ROUND 3 - Mean test accuracy across clients: 0.9649

Federated training complete!


In [None]:
# ============================================================================
# 9. SAVE FINAL GLOBAL MODEL
# ============================================================================

save_path = "/content/federated_global_model.pth"
torch.save(global_state, save_path)
print("Saved global model to:", save_path)


Saved global model to: /content/federated_global_model.pth
