# Federated Learning Demo with MNIST Dataset

This notebook demonstrates a simple federated learning scenario using the MNIST dataset.

The essential idea is that, we want to prove that by using fedrated learning, (meaning that spliting the dataset into two to train a model), it performs similarily when compared with training as one. 


In [None]:
# Setting up the environment
import torch
from torch import nn, optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import numpy as np

In [None]:
# Load and Prepare the Data (Code Cell)
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5,), (0.5,))])

# Download MNIST dataset
dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

# Split dataset into two parts for federated learning simulation
train_size = int(0.5 * len(dataset))
subset_sizes = [train_size, len(dataset) - train_size]
datasets = torch.utils.data.random_split(dataset, subset_sizes)


In [None]:
# Define the Neural Network Model (Code Cell)

class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits
