In [1]:
!ls

requirements.txt  sample_data


In [2]:
!touch requirements.txt

In [3]:
!pip install -r requirements.txt



In [7]:
import torch
import torch.nn as nn
from torch.utils.data import Subset, DataLoader, random_split
import torch.optim as optim
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import seaborn as sns

In [8]:
# Data Transformation
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

In [9]:
# Model definition
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc = nn.Linear(784,128)
        self.relu = nn.ReLU()
        self.out = nn.Linear(128,10)

    def forward(self,x):
        x = torch.flatten(x,1)
        x = self.fc(x)
        x = self.relu(x)
        x = self.out(x)
        return x




In [10]:
# Training Function
def train_model(model,train_set):
    batch_size = 64
    num_epochs = 10

    train_loader = DataLoader(train_set, batch_size = batch_size, shuffle=True)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

    model.train() # set the model to training mode

    for epoch in range(len(num_epochs)):
        running_loss = 0.0
        for inputs, labels in train_loader:
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs,labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        print(f"Epoch {epoch + 1}: Loss = {running_loss / len(train_loader)}")

    print("Training complete")




In [11]:
# Evaluation Function

def evaluate_model(model,test_set):
    model.eval() # set the model to evaluation mode
    correct = 0
    total = 0
    total_loss = 0

    test_loader = DataLoader(test_set,batch_size=64,shuffle=False)
    criterion = nn.CrossEntropyLoss()

    with torch.no_grad():
        for inputs,labels in test_loader:
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data,1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            loss = criterion(outputs,labels)
            total_loss += loss.item()

    accuracy = correct / total
    average_loss = total_loss / len(test_loader)

    return average_loss, accuracy


In [12]:
# Data Manipulation Functions
def include_digits(dataset, included_digits):
    including_indices = [
        idx for idx in range(len(dataset)) if dataset[idx][1] in included_digits
    ]
    return torch.utils.data.Subset(dataset, including_indices)

def exclude_digits(dataset, excluded_digits):
    including_indices = [
        idx for idx in range(len(dataset)) if dataset[idx][1] not in excluded_digits
    ]
    return torch.utils.data.Subset(dataset, including_indices)

In [13]:
# Visualising functions
def plot_distribution(dataset, title):
    labels = [data[1] for data in dataset]
    unique_labels, label_counts = torch.unique(torch.tensor(labels), return_counts=True)

    plt.figure(figsize=(4, 2))

    counts_dict = {
        label.item(): count.item() for label, count in zip(unique_labels, label_counts)
    }

    all_labels = np.arange(10)
    all_label_counts = [counts_dict.get(label, 0) for label in all_labels]

    plt.bar(all_labels, all_label_counts)
    plt.title(title)
    plt.xlabel("Digit")
    plt.ylabel("Count")
    plt.xticks(all_labels)
    plt.show()

def compute_confusion_matrix(model, testset):
    true_labels = []
    predicted_labels = []

    for image, label in testset:
        output = model(image.unsqueeze(0))
        _, predicted = torch.max(output, 1)

        true_labels.append(label)
        predicted_labels.append(predicted.item())

    true_labels = np.array(true_labels)
    predicted_labels = np.array(predicted_labels)

    cm = confusion_matrix(true_labels, predicted_labels)

    return cm

def plot_confusion_matrix(cm, title):
    plt.figure(figsize=(6, 4))
    sns.heatmap(cm, annot=True, cmap="Blues", fmt="d", linewidths=0.5)
    plt.title(title)
    plt.xlabel("Predicted Label")
    plt.ylabel("True Label")
    plt.show()

In [14]:
# ACTUAL FEDERATED LEARNING EXAMPLE