<a href="https://colab.research.google.com/github/MahzabinC/Federated-Learning/blob/main/WO_Federated.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hyperparameters
BATCH_SIZE = 128
LR = 0.01
EPOCHS = 20

In [13]:
# Load Datasets
def get_centralized_dataloader(dataset_name):
    if dataset_name == "mnist":
        train_dataset = datasets.MNIST(
            root="./data",
            train=True,
            download=True,
            transform=transforms.ToTensor()
        )
        test_dataset = datasets.MNIST(
            root="./data",
            train=False,
            download=True,
            transform=transforms.ToTensor()
        )
    elif dataset_name == "cifar10":
        train_dataset = datasets.CIFAR10(
            root="./data",
            train=True,
            download=True,
            transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
            ])
        )
        test_dataset = datasets.CIFAR10(
            root="./data",
            train=False,
            download=True,
            transform=transforms.Compose([
                transforms.ToTensor(),
                # Added std to Normalize for the test_dataset
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
            ])
        )
    else:
        raise ValueError("Dataset not supported")

    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

    return train_loader, test_loader

In [3]:
# Model Architectures
class MNIST2NN(nn.Module):
    def __init__(self):
        super(MNIST2NN, self).__init__()
        self.fc1 = nn.Linear(28*28, 200)
        self.fc2 = nn.Linear(200, 200)
        self.fc3 = nn.Linear(200, 10)

    def forward(self, x):
        x = x.view(-1, 28*28)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

class CNN(nn.Module):
    def __init__(self, input_channels=1):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(input_channels, 32, kernel_size=5, stride=1, padding=2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2)
        self.fc1 = nn.Linear(64*8*8, 512) # Changed from 64*7*7 to 64*8*8
        self.fc2 = nn.Linear(512, 10)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.max_pool2d(x, 2)
        x = torch.relu(self.conv2(x))
        x = torch.max_pool2d(x, 2)
        x = x.view(-1, 64*8*8)  # Changed from 64*7*7 to 64*8*8
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [10]:
# Model Training Function
def train_centralized(model, train_loader, criterion, optimizer, epochs):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for data, target in train_loader:
            data, target = data.to(device), target.to(device)

            # Forward pass
            outputs = model(data)
            loss = criterion(outputs, target)

            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
        print(f"Epoch [{epoch + 1}/{epochs}], Loss: {total_loss / len(train_loader):.4f}")

# Model Evaluation Function
def evaluate_centralized(model, test_loader):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            outputs = model(data)
            _, predicted = outputs.max(1)
            correct += (predicted == target).sum().item()
            total += target.size(0)
    print(f"Test Accuracy: {correct / total:.4f}")



In [11]:
# Run MNIST Experiment
print("Training on MNIST with Centralized Learning...")
mnist_train_loader, mnist_test_loader = get_centralized_dataloader("mnist")
mnist_model = MNIST2NN().to(device)  # Use the MNIST2NN defined earlier
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(mnist_model.parameters(), lr=LR)

train_centralized(mnist_model, mnist_train_loader, criterion, optimizer, EPOCHS)
evaluate_centralized(mnist_model, mnist_test_loader)


Training on MNIST with Centralized Learning...
Epoch [1/20], Loss: 2.1347
Epoch [2/20], Loss: 1.1524
Epoch [3/20], Loss: 0.6007
Epoch [4/20], Loss: 0.4648
Epoch [5/20], Loss: 0.4052
Epoch [6/20], Loss: 0.3708
Epoch [7/20], Loss: 0.3473
Epoch [8/20], Loss: 0.3293
Epoch [9/20], Loss: 0.3150
Epoch [10/20], Loss: 0.3025
Epoch [11/20], Loss: 0.2916
Epoch [12/20], Loss: 0.2814
Epoch [13/20], Loss: 0.2719
Epoch [14/20], Loss: 0.2634
Epoch [15/20], Loss: 0.2552
Epoch [16/20], Loss: 0.2471
Epoch [17/20], Loss: 0.2395
Epoch [18/20], Loss: 0.2321
Epoch [19/20], Loss: 0.2249
Epoch [20/20], Loss: 0.2182
Test Accuracy: 0.9389


In [14]:
# Run CIFAR-10 Experiment
print("Training on CIFAR-10 with Centralized Learning...")
cifar10_train_loader, cifar10_test_loader = get_centralized_dataloader("cifar10")
cifar10_model = CNN(input_channels=3).to(device)  # Use the CNN defined earlier
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(cifar10_model.parameters(), lr=LR)

train_centralized(cifar10_model, cifar10_train_loader, criterion, optimizer, EPOCHS)
evaluate_centralized(cifar10_model, cifar10_test_loader)

Training on CIFAR-10 with Centralized Learning...
Files already downloaded and verified
Files already downloaded and verified
Epoch [1/20], Loss: 2.1117
Epoch [2/20], Loss: 1.8036
Epoch [3/20], Loss: 1.6159
Epoch [4/20], Loss: 1.4862
Epoch [5/20], Loss: 1.4067
Epoch [6/20], Loss: 1.3513
Epoch [7/20], Loss: 1.2987
Epoch [8/20], Loss: 1.2530
Epoch [9/20], Loss: 1.2100
Epoch [10/20], Loss: 1.1731
Epoch [11/20], Loss: 1.1357
Epoch [12/20], Loss: 1.0965
Epoch [13/20], Loss: 1.0586
Epoch [14/20], Loss: 1.0234
Epoch [15/20], Loss: 0.9908
Epoch [16/20], Loss: 0.9594
Epoch [17/20], Loss: 0.9286
Epoch [18/20], Loss: 0.8970
Epoch [19/20], Loss: 0.8719
Epoch [20/20], Loss: 0.8399
Test Accuracy: 0.6585
