In [11]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision
import torchvision.transforms as transforms

import numpy as np
import random


In [12]:
def seed_everything(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

seed_everything()

In [13]:
def load_mnist_datasets():
    transform = transforms.ToTensor()

    full_train = torchvision.datasets.MNIST(
        root="../data", train=True, download=True, transform=transform
    )
    full_test = torchvision.datasets.MNIST(
        root="../data", train=False, download=True, transform=transform
    )

    # Collect indices per class
    class_indices = [[] for _ in range(10)]
    for idx, (img, label) in enumerate(full_train):
        class_indices[label].append(idx)

    # Build subsets per digit class
    class_datasets = []
    for digit in range(10):
        subset_indices = class_indices[digit]
        class_subset = torch.utils.data.Subset(full_train, subset_indices)
        class_datasets.append(class_subset)

    return class_datasets, full_test


In [14]:
def fourier_domain_adaptation(
    source_img: torch.Tensor,
    alpha: float = 0.1
) -> torch.Tensor:
    """
    Simple demonstration for FDA: we add a bit of random phase noise in the Fourier domain.
    source_img: (1, 28, 28)  # for MNIST, single channel
    alpha: how strong the style perturbation should be
    """
    # Convert to frequency domain
    fft_source = torch.fft.fft2(source_img)
    fft_source_shifted = torch.fft.fftshift(fft_source)

    # Create random phase noise
    _, h, w = source_img.shape
    noise = torch.exp(1j * 2 * np.pi * torch.rand((h, w)))

    # Blend with alpha
    # We'll blend the magnitude of source with random phase
    magnitude = torch.abs(fft_source_shifted)
    phase = torch.angle(fft_source_shifted)
    random_phase = torch.angle(noise)

    new_phase = (1 - alpha) * phase + alpha * random_phase
    new_fft = magnitude * torch.exp(1j * new_phase)

    # Inverse shift and inverse FFT
    new_fft_ishifted = torch.fft.ifftshift(new_fft)
    perturbed_img = torch.fft.ifft2(new_fft_ishifted).real

    return perturbed_img


In [15]:
class FDADataset(Dataset):
    """
    Wraps a Subset of MNIST, applying Fourier Domain Adaptation (FDA) to each sample.
    """
    def __init__(self, subset, alpha=0.1):
        self.subset = subset
        self.alpha = alpha

    def __len__(self):
        return len(self.subset)

    def __getitem__(self, idx):
        img, label = self.subset[idx]
        # img shape: [1, 28, 28] (already a tensor from transforms.ToTensor())
        # Apply FDA
        perturbed_img = fourier_domain_adaptation(img, alpha=self.alpha)
        return perturbed_img, label


In [16]:
class SimpleCNN(nn.Module):
    def __init__(self, num_classes=10):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x


In [17]:
class ClientNode:
    """
    Represents a federated client responsible for training on its local dataset
    (which in this case is one digit class + some FDA-based perturbation).
    """

    def __init__(self, client_id, dataset, batch_size=64, lr=0.01, device='cpu'):
        self.client_id = client_id
        self.dataset = dataset
        self.device = device

        self.dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

        self.lr = lr

    def local_train(self, global_model: nn.Module, epochs=1):
        """
        Train a copy of the global model on the local dataset for a few epochs,
        and return the updated state_dict.
        """
        model = SimpleCNN().to(self.device)
        model.load_state_dict(global_model.state_dict())  # copy global weights

        optimizer = optim.SGD(model.parameters(), lr=self.lr)
        criterion = nn.CrossEntropyLoss()

        model.train()
        for epoch in range(epochs):
            for images, labels in self.dataloader:
                images, labels = images.to(self.device), labels.to(self.device)

                optimizer.zero_grad()
                outputs = model(images)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()

        return model.state_dict()


In [18]:
class FedServer:
    """
    Central server orchestrating federated training:
    - Initializes a global model
    - Sends it to each client
    - Averages updates
    - Repeats for multiple rounds
    """

    def __init__(self, num_classes=10, device='cpu'):
        self.global_model = SimpleCNN(num_classes).to(device)
        self.device = device

    def aggregate_weights(self, client_state_dicts):
        """
        Perform a simple FedAvg (weighted average if needed, or just an average).
        """
        # Create a new state dict as the average
        global_state_dict = {}
        # Initialize with zeros
        for key in client_state_dicts[0].keys():
            global_state_dict[key] = torch.zeros_like(client_state_dicts[0][key])

        # Sum up all client weights
        for state_dict in client_state_dicts:
            for key in state_dict.keys():
                global_state_dict[key] += state_dict[key]

        # Average
        for key in global_state_dict.keys():
            global_state_dict[key] = global_state_dict[key] / len(client_state_dicts)

        return global_state_dict

    def update_global_model(self, global_state_dict):
        self.global_model.load_state_dict(global_state_dict)

    def evaluate(self, test_loader):
        """
        Evaluate the global model on a test set.
        """
        self.global_model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for images, labels in test_loader:
                images, labels = images.to(self.device), labels.to(self.device)
                outputs = self.global_model(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        return 100.0 * correct / total


In [19]:
def main_federated_sdfa_mnist(
    rounds=5, local_epochs=1, alpha=0.1, batch_size=64, lr=0.01, device='cpu'
):
    # 1. Load the data
    class_datasets, full_test = load_mnist_datasets()

    # 2. Create FDADatasets for each digit class
    #    (each "client" sees images of a single digit, plus FDA perturbation)
    fda_clients = []
    for i, subset in enumerate(class_datasets):
        client_dataset = FDADataset(subset, alpha=alpha)
        fda_clients.append(
            ClientNode(client_id=i, dataset=client_dataset, batch_size=batch_size, lr=lr, device=device)
        )

    # 3. Global server
    server = FedServer(num_classes=10, device=device)

    # 4. Test loader
    test_loader = DataLoader(full_test, batch_size=256, shuffle=False)

    # 5. Federated Rounds
    for r in range(rounds):
        client_weights = []
        # Each client trains locally
        for client in fda_clients:
            local_state_dict = client.local_train(global_model=server.global_model, epochs=local_epochs)
            client_weights.append(local_state_dict)

        # Server aggregates
        aggregated = server.aggregate_weights(client_weights)
        server.update_global_model(aggregated)

        # Evaluate after this round
        acc = server.evaluate(test_loader)
        print(f"[Round {r+1}] Global Test Accuracy: {acc:.2f}%")

    print("Federated training with SFDA on MNIST is complete.")


In [20]:
if __name__ == "__main__":
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    main_federated_sdfa_mnist(
        rounds=10,        # Number of federated rounds
        local_epochs=1,  # Local epochs per round
        alpha=0.1,       # Strength of FDA perturbation
        batch_size=64,
        lr=0.01,
        device=device
    )


[Round 1] Global Test Accuracy: 9.95%
[Round 2] Global Test Accuracy: 14.30%
[Round 3] Global Test Accuracy: 14.93%
[Round 4] Global Test Accuracy: 18.73%
[Round 5] Global Test Accuracy: 25.91%
[Round 6] Global Test Accuracy: 32.79%
[Round 7] Global Test Accuracy: 35.90%
[Round 8] Global Test Accuracy: 40.51%
[Round 9] Global Test Accuracy: 45.53%
[Round 10] Global Test Accuracy: 48.60%
Federated training with SFDA on MNIST is complete.
