<a href="https://colab.research.google.com/github/Zurehma/AML_Project/blob/nicolai/LocalSGD/local_sgd.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torch.nn.functional as F
from torchvision.datasets import CIFAR10, CIFAR100
from torch.utils.data import DataLoader, Subset

from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import numpy as np
import random

print("PyTorch version:", torch.__version__)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Seeding for reproducibility
seed = 2025
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)


# data loading
def load_data(dataset_name="cifar100", batch_size=128, num_workers=0):
    """CIFAR-100"""

    train_transform = transforms.Compose(
        [
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(
                (0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2762)
            ),  # CIFAR-100 mean and std
        ]
    )

    test_transform = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize((0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2762)),
        ]
    )

    if dataset_name.lower() == "cifar100":
        train_dataset = CIFAR100(
            root="./data", train=True, download=True, transform=train_transform
        )
        test_dataset = CIFAR100(
            root="./data", train=False, download=True, transform=test_transform
        )
        num_classes = 100
    else:
        raise ValueError(f"Dataset {dataset_name} not supported")

    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers
    )
    test_loader = DataLoader(
        test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers
    )

    return train_loader, test_loader, num_classes, train_dataset, test_dataset


# Models
class LeNet5(nn.Module):
    def __init__(self, num_classes=10):
        super(LeNet5, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, kernel_size=5)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, num_classes)

    def forward(self, x):
        x = self.pool1(F.relu(self.conv1(x)))
        x = self.pool2(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


class ImprovedCNN(nn.Module):
    def __init__(self, num_classes=100):
        super(ImprovedCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=5, padding=2)
        self.bn1 = nn.BatchNorm2d(64)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv2 = nn.Conv2d(64, 64, kernel_size=5, padding=2)
        self.bn2 = nn.BatchNorm2d(64)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.fc1 = nn.Linear(64 * 8 * 8, 384)
        self.dropout1 = nn.Dropout(0.5)
        self.fc2 = nn.Linear(384, 192)
        self.dropout2 = nn.Dropout(0.5)
        self.fc3 = nn.Linear(192, num_classes)

    def forward(self, x):
        x = self.pool1(F.relu(self.bn1(self.conv1(x))))
        x = self.pool2(F.relu(self.bn2(self.conv2(x))))
        x = x.view(-1, 64 * 8 * 8)
        x = F.relu(self.fc1(x))
        x = self.dropout1(x)
        x = F.relu(self.fc2(x))
        x = self.dropout2(x)
        x = self.fc3(x)
        return x


def create_model(model_name, num_classes=100):
    if model_name.lower() == "improved_cnn":
        return ImprovedCNN(num_classes=num_classes)
    elif model_name.lower() == "lenet5":
        return LeNet5(num_classes=num_classes)
    else:
        raise ValueError(f"Unknown model: {model_name}")


def train_epoch(model, train_loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    pbar = tqdm(train_loader, desc="Training", leave=False)

    for inputs, labels in pbar:
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

        pbar.set_postfix({"loss": running_loss / total, "acc": 100.0 * correct / total})

    train_loss = running_loss / total
    train_acc = 100.0 * correct / total

    return train_loss, train_acc


def evaluate(model, test_loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)
            loss = criterion(outputs, labels)

            running_loss += loss.item() * inputs.size(0)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

    test_loss = running_loss / total
    test_acc = 100.0 * correct / total

    return test_loss, test_acc


# Data distribution functions
def create_iid_shards(train_dataset, num_clients):
    """Create IID data shards for clients"""
    indices = list(range(len(train_dataset)))
    random.shuffle(indices)

    shard_size = len(indices) // num_clients
    shards = []

    for i in range(num_clients):
        start_idx = i * shard_size
        if i == num_clients - 1:
            end_idx = len(indices)
        else:
            end_idx = start_idx + shard_size

        shard_indices = indices[start_idx:end_idx]
        shards.append(Subset(train_dataset, shard_indices))

    print(f"Created {num_clients} IID shards with ~{shard_size} samples each")
    return shards


def create_non_iid_shards(train_dataset, num_clients, alpha=0.5):
    """Create non-IID data shards using Dirichlet distribution"""
    if hasattr(train_dataset, "targets"):
        targets = np.array(train_dataset.targets)
    else:
        targets = np.array([train_dataset[i][1] for i in range(len(train_dataset))])

    num_classes = len(np.unique(targets))

    # group indices by class
    class_indices = {k: [] for k in range(num_classes)}
    for idx, label in enumerate(targets):
        class_indices[label].append(idx)

    # client indices
    client_indices = [[] for _ in range(num_clients)]

    # Distribute each class according to Dirichlet distribution
    for class_id in range(num_classes):
        class_samples = class_indices[class_id]
        random.shuffle(class_samples)

        # Sample proportions from Dirichlet distribution
        proportions = np.random.dirichlet([alpha] * num_clients)

        # Distribute samples
        start_idx = 0
        for client_id in range(num_clients):
            num_samples = int(len(class_samples) * proportions[client_id])
            if client_id == num_clients - 1:
                end_idx = len(class_samples)
            else:
                end_idx = start_idx + num_samples

            client_indices[client_id].extend(class_samples[start_idx:end_idx])
            start_idx = end_idx

    # Shuffle and create subsets
    for indices in client_indices:
        random.shuffle(indices)

    shards = [Subset(train_dataset, indices) for indices in client_indices]

    sizes = [len(shard) for shard in shards]
    print(f"Created {num_clients} non-IID shards (α={alpha}) with sizes: {sizes}")

    return shards


# Centralized training
def train_centralized_model(
    model_name="improved_cnn",
    dataset_name="cifar100",
    batch_size=64,
    epochs=80,
    lr=0.01,
    momentum=0.9,
    weight_decay=4e-4,
    lr_scheduler="step",
):
    """Centralized training baseline"""
    print("=" * 50)
    print("CENTRALIZED TRAINING")
    print("=" * 50)

    # Load data
    train_loader, test_loader, num_classes, _, _ = load_data(
        dataset_name=dataset_name, batch_size=batch_size
    )

    # Create model
    model = create_model(model_name, num_classes=num_classes)
    model = model.to(device)

    total_params = sum(p.numel() for p in model.parameters())
    print(f"Model: {model_name}, Total parameters: {total_params:,}")

    # loss function and optimizer
    criterion = nn.CrossEntropyLoss()  # CrossEntropyLoss for multi-class classification
    optimizer = torch.optim.SGD(
        model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay
    )  # SGDM optimizer

    # Learning rate scheduler
    if lr_scheduler == "step":
        scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=[40, 60], gamma=0.1
        )
    elif lr_scheduler == "cosine":
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    else:
        scheduler = None
        print("No learning rate scheduler used")

    # Training history
    history = {
        "train_loss": [],
        "train_acc": [],
        "test_loss": [],
        "test_acc": [],
    }

    best_acc = 0.0

    # Training loop
    epoch_pbar = tqdm(range(epochs), desc="Centralized Training")

    for epoch in epoch_pbar:
        # Training
        train_loss, train_acc = train_epoch(
            model, train_loader, criterion, optimizer, device
        )

        # Evaluation
        test_loss, test_acc = evaluate(model, test_loader, criterion, device)

        if scheduler:
            scheduler.step()

        # Update best accuracy
        if test_acc > best_acc:
            best_acc = test_acc

        # Update history
        history["train_loss"].append(train_loss)
        history["train_acc"].append(train_acc)
        history["test_loss"].append(test_loss)
        history["test_acc"].append(test_acc)

        # Update progress bar
        epoch_pbar.set_postfix(
            {
                "train_acc": f"{train_acc:.2f}%",
                "test_acc": f"{test_acc:.2f}%",
                "best": f"{best_acc:.2f}%",
            }
        )

    print(f"\nCentralized Results: Final={test_acc:.2f}%, Best={best_acc:.2f}%")

    return model, history, best_acc


# LocalSGD training
def train_localsgd(
    model_name="improved_cnn",
    dataset_name="cifar100",
    num_clients=12,
    local_steps=30,
    communication_rounds=200,
    client_batch_size=64,
    lr=0.01,
    momentum=0.9,
    weight_decay=4e-4,
    iid=True,
    alpha=0.5,
):
    """LocalSGD distributed training"""
    print("=" * 50)
    print("LOCALSGD DISTRIBUTED TRAINING")
    print("=" * 50)
    print(
        f"Clients: {num_clients}, Local steps: {local_steps}, Rounds: {communication_rounds}"
    )
    print(f"Data distribution: {'IID' if iid else f'non-IID (α={alpha})'}")

    # Load data
    _, test_loader, num_classes, train_dataset, _ = load_data(
        dataset_name=dataset_name, batch_size=128
    )

    # create data shards
    if iid:
        shards = create_iid_shards(train_dataset, num_clients)
    else:
        shards = create_non_iid_shards(train_dataset, num_clients, alpha)

    # Create client data loaders
    client_loaders = [
        DataLoader(shard, batch_size=client_batch_size, shuffle=True, num_workers=0)
        for shard in shards
    ]

    # initialize client models
    client_models = [
        create_model(model_name, num_classes=num_classes).to(device)
        for _ in range(num_clients)
    ]

    # synchronize initial weights
    with torch.no_grad():
        for client_model in client_models[1:]:
            for target_param, source_param in zip(
                client_model.parameters(), client_models[0].parameters()
            ):
                target_param.data.copy_(source_param.data)

    # optimizers
    client_optimizers = [
        torch.optim.SGD(
            model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay
        )
        for model in client_models
    ]

    criterion = nn.CrossEntropyLoss()

    # metrics
    test_accuracies = []
    test_losses = []

    # LocalSGD main loop
    round_pbar = tqdm(range(communication_rounds), desc="LocalSGD Progress")

    for round_num in round_pbar:
        # Local training phase
        for client_id in range(num_clients):
            model = client_models[client_id]
            optimizer = client_optimizers[client_id]
            data_loader = client_loaders[client_id]

            model.train()
            data_iter = iter(data_loader)

            for step in range(local_steps):
                try:
                    inputs, targets = next(data_iter)
                except StopIteration:
                    data_iter = iter(data_loader)
                    inputs, targets = next(data_iter)

                inputs, targets = inputs.to(device), targets.to(device)

                optimizer.zero_grad()
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                loss.backward()
                optimizer.step()

        # Communication phase - model averaging
        with torch.no_grad():
            # Initialize averaged parameters
            averaged_params = []
            for param in client_models[0].parameters():
                averaged_params.append(torch.zeros_like(param))

            # Sum all client parameters
            for client_model in client_models:
                for avg_param, client_param in zip(
                    averaged_params, client_model.parameters()
                ):
                    avg_param += client_param

            # Average the parameters
            for avg_param in averaged_params:
                avg_param /= num_clients

            # Update all client models
            for client_model in client_models:
                for client_param, avg_param in zip(
                    client_model.parameters(), averaged_params
                ):
                    client_param.data.copy_(avg_param)

        # Evaluation
        if (round_num + 1) % 5 == 0:  # Evaluate every 5 rounds
            global_model = client_models[0]
            test_loss, test_acc = evaluate(global_model, test_loader, criterion, device)
            test_accuracies.append(test_acc)
            test_losses.append(test_loss)

            round_pbar.set_postfix(
                {"test_acc": f"{test_acc:.2f}%", "test_loss": f"{test_loss:.4f}"}
            )

    final_acc = test_accuracies[-1] if test_accuracies else 0
    best_acc = max(test_accuracies) if test_accuracies else 0

    print(f"\nLocalSGD Results: Final={final_acc:.2f}%, Best={best_acc:.2f}%")

    history = {"test_accuracies": test_accuracies, "test_losses": test_losses}

    return client_models[0], history, best_acc


def main(mode="centralized", **kwargs):
    """Main function to run different training modes.

    Modes:
    - "centralized": Run only centralized training
    - "localsgd": Run only LocalSGD training
    """

    if mode == "centralized":
        model, history, best_acc = train_centralized_model(**kwargs)
        return model, history

    elif mode == "localsgd":
        model, history, best_acc = train_localsgd(**kwargs)
        return model, history

    else:
        raise ValueError(f"Unknown mode: {mode}. Use 'centralized'")


if __name__ == "__main__":
    main(
        mode="localsgd",
        model_name="improved_cnn",
        dataset_name="cifar100",
        num_clients=4,
        local_steps=10,
        communication_rounds=200,
        client_batch_size=64,
        lr=0.01,
        momentum=0.9,
        weight_decay=4e-4,
    )


PyTorch version: 2.6.0+cu124
Using device: cuda:0
LOCALSGD DISTRIBUTED TRAINING
Clients: 4, Local steps: 10, Rounds: 200
Data distribution: IID
Created 4 IID shards with ~12500 samples each


LocalSGD Progress:   0%|          | 0/200 [00:00<?, ?it/s]


LocalSGD Results: Final=29.34%, Best=29.40%
