<a href="https://colab.research.google.com/github/Rahad31/Different-VAE-for-KL-FedDis/blob/main/BvaeOrd.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [30]:
%%time
import numpy as np
import torch
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision import transforms
from torch import nn
from typing import Dict

CPU times: user 118 µs, sys: 0 ns, total: 118 µs
Wall time: 126 µs


In [31]:
# Define VAE model
class VAE(nn.Module):
    def __init__(self, x_dim, h_dim, z_dim):
        super(VAE, self).__init__()
        self.x_dim = x_dim
        self.h_dim = h_dim
        self.z_dim = z_dim

        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, 4, stride=2, padding=1),  # 16x16
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2, padding=1),  # 8x8
            nn.ReLU(),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),  # 4x4
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(2048, 256),
            nn.ReLU(),
            nn.Linear(256, 2*z_dim)
        )

        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(z_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 2048),
            nn.ReLU(),
            nn.Unflatten(1, (128, 4, 4)),
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),  # 8x8
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1),  # 16x16
            nn.ReLU(),
            nn.ConvTranspose2d(32, 3, 4, stride=2, padding=1),  # 32x32
            nn.Sigmoid()
        )

    def encode(self, x):
        h = self.encoder(x)
        mu, logvar = h[:, :self.z_dim], h[:, self.z_dim:]
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        return self.decoder(z)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar


In [32]:
def vae_train(vae: VAE, trainloader: DataLoader, epochs: int, beta: float = 4.0) -> None:
    optimizer = torch.optim.Adam(vae.parameters(), lr=1e-3)
    for epoch in range(epochs):
        for i, data in enumerate(trainloader, 0):
            inputs, _ = data
            optimizer.zero_grad()
            recon_x, mu, logvar = vae(inputs)
            loss = vae_loss(recon_x, inputs, mu, logvar, beta=beta)
            loss.backward()
            optimizer.step()



In [33]:

# Define classification model
class Net(nn.Module):
    def __init__(self) -> None:
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.pool(nn.functional.relu(self.conv1(x)))
        x = self.pool(nn.functional.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = nn.functional.relu(self.fc1(x))
        x = nn.functional.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [34]:
# Load CIFAR10 dataset
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)

full_dataset = CIFAR10(root="./data", train=True, download=True, transform=transform)
test_set = CIFAR10(root="./data", train=False, download=True, transform=transform)

In [35]:
# Split dataset into training and validation sets
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_set, val_set = torch.utils.data.random_split(full_dataset, [train_size, val_size])


In [36]:
# Create training and validation loaders
trainloader = DataLoader(train_set, batch_size=128, shuffle=True, num_workers=2)
valloader = DataLoader(val_set, batch_size=128, shuffle=False, num_workers=2)
testloader = DataLoader(test_set, batch_size=128, shuffle=False, num_workers=2)


In [37]:
def vae_loss(recon_x, x, mu, logvar, beta=4):
    BCE = nn.functional.mse_loss(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + beta * KLD


In [38]:
# Define training procedure for classification model
def train(net: nn.Module, trainloader: DataLoader, valloader: DataLoader, epochs: int) -> None:
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

    best_acc = 0.0

    for epoch in range(epochs):
        # Training loop
        net.train()
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        # Validation loop
        net.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for data in valloader:
                images, labels = data
                outputs = net(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        val_acc = 100 * correct / total
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(net.state_dict(), 'best_model.pth')

        print(f"Epoch [{epoch+1}/{epochs}], Training Loss: {running_loss / (i+1):.3f}, Validation Accuracy: {val_acc:.2f}%")

In [39]:
# Define evaluation procedure
def evaluate(net: nn.Module, testloader: DataLoader) -> float:
    net.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    return accuracy

In [40]:
def initialize_clients(trainset, transform, batch_size, num_clients):
    clients = {}
    for i in range(num_clients):
        client_trainset = torch.utils.data.Subset(trainset, range(i * len(trainset) // num_clients, (i + 1) * len(trainset) // num_clients))
        client_trainloader = torch.utils.data.DataLoader(client_trainset, batch_size=batch_size, shuffle=True)
        clients[f"client_{i}"] = client_trainloader
    return clients

In [41]:
def get_distribution_info(vae: VAE) -> Dict:
    # Implement the logic to extract distribution information from the VAE
    # This can involve computing statistics, parameters, or any other relevant information
    # that can be used to generate augmented data

    # Example implementation:
    distribution_info = {
        "mean": vae.encoder[-1].bias.data.cpu().numpy(),
        "std": torch.exp(0.5 * vae.encoder[-1].weight.data).cpu().numpy()
    }

    return distribution_info

In [42]:
def send_distribution_info(distribution_info: Dict) -> None:
    # Implement the logic to send the distribution information to the global server
    # This can involve using a network protocol, a message queue, or any other communication mechanism
    # to send the information

    # Example implementation:
    # Send the distribution information to the global server using a network protocol
    # For example, you can use the `socket` module to send the information over a network
    # or use a message queue like `RabbitMQ` to send the information
    pass


In [43]:
# Define logic to generate augmented data using Ordinary Normal distribution
def generate_augmented_data(vae: VAE, distribution_info: Dict) -> torch.Tensor:
    # Generate augmented data using Ordinary Normal distribution
    mean = distribution_info["mean"]
    std = distribution_info["std"]
    augmented_data = torch.randn(64, vae.z_dim) * std + mean
    return augmented_data

In [44]:
def federated_train(net: nn.Module, vae: VAE, trainloaders: Dict[str, DataLoader], trainloader: DataLoader, valloader: DataLoader, epochs: int) -> None:
    for epoch in range(epochs):
        for client_id, client_trainloader in trainloaders.items():
            # Train VAE on client data
            vae_train(vae, client_trainloader, epochs=10, beta=4)

            # Share distribution information with global server
            distribution_info = get_distribution_info(vae)
            send_distribution_info(distribution_info)

            # Receive distribution information from other clients
            other_distribution_info = receive_distribution_info()

            # Generate augmented data using received distribution information
            augmented_data = generate_augmented_data(vae, other_distribution_info)

            # Train classification model using local, augmented, and validation data
            train(net, client_trainloader, valloader, epochs=10)

            # Send model updates to global server
            send_model_update(client_id, net.state_dict())

In [45]:
# Define logic to receive distribution information from global server
def receive_distribution_info() -> Dict:
    # Receive distribution information logic
    distribution_info = {
        "mean": np.zeros(20),  # Adjust the size based on your latent space dimension
        "std": np.ones(20)
    }
    return distribution_info

In [46]:
def send_model_update(client_id: str, model_update: Dict) -> None:
    # Implement the logic to send the model update to the global server
    # This can involve using a network protocol, a message queue, or any other communication mechanism
    # to send the model update

    # Example implementation:
    # Send the model update to the global server using a network protocol
    # For example, you can use the `socket` module to send the information over a network
    # or use a message queue like `RabbitMQ` to send the information
    pass

***-------For 250 Epochs B=4.0-------***

---



In [28]:
%%time

# Define global server procedure
def global_server() -> None:
    net = Net()
    x_dim = 3 * 32 * 32  # CIFAR-10 input size
    h_dim = 400
    z_dim = 20
    vae = VAE(x_dim, h_dim, z_dim)  # Initialize VAE object with required arguments

    # Initialize clients
    num_clients = 5  # Define the number of clients
    clients = initialize_clients(train_set, transform, batch_size=128, num_clients=num_clients)

    # Train model using FedDIS
    federated_train(net, vae, clients, trainloader, valloader, epochs=5)

    # Evaluate final model
    test_accuracy = evaluate(net, testloader)
    print(f"Test Accuracy: {test_accuracy:.2f}%")

if __name__ == "__main__":
    global_server()


  augmented_data = torch.randn(64, vae.z_dim) * std + mean


Epoch [1/10], Training Loss: 2.305, Validation Accuracy: 10.51%
Epoch [2/10], Training Loss: 2.304, Validation Accuracy: 11.08%
Epoch [3/10], Training Loss: 2.303, Validation Accuracy: 11.00%
Epoch [4/10], Training Loss: 2.302, Validation Accuracy: 10.34%
Epoch [5/10], Training Loss: 2.302, Validation Accuracy: 10.15%
Epoch [6/10], Training Loss: 2.301, Validation Accuracy: 10.08%
Epoch [7/10], Training Loss: 2.300, Validation Accuracy: 10.08%
Epoch [8/10], Training Loss: 2.299, Validation Accuracy: 10.08%
Epoch [9/10], Training Loss: 2.299, Validation Accuracy: 10.08%
Epoch [10/10], Training Loss: 2.297, Validation Accuracy: 10.09%
Epoch [1/10], Training Loss: 2.297, Validation Accuracy: 10.14%
Epoch [2/10], Training Loss: 2.296, Validation Accuracy: 10.60%
Epoch [3/10], Training Loss: 2.294, Validation Accuracy: 12.86%
Epoch [4/10], Training Loss: 2.291, Validation Accuracy: 15.67%
Epoch [5/10], Training Loss: 2.288, Validation Accuracy: 16.95%
Epoch [6/10], Training Loss: 2.283, Val

***-------For 500 Epochs B=4.0-------***

---

In [None]:
# Define global server procedure
def global_server() -> None:
    net = Net()
    x_dim = 3 * 32 * 32  # CIFAR-10 input size
    h_dim = 400
    z_dim = 20
    vae = VAE(x_dim, h_dim, z_dim)  # Initialize VAE object with required arguments

    # Initialize clients
    num_clients = 5  # Define the number of clients
    clients = initialize_clients(train_set, transform, batch_size=128, num_clients=num_clients)

    # Train model using FedDIS
    federated_train(net, vae, clients, trainloader, valloader, epochs=10)

    # Evaluate final model
    test_accuracy, tp, fp, tn, fn, precision, recall, f1 = evaluate(net, testloader)
    print(f"Test Accuracy: {test_accuracy:.2f}%")

    print("True Positives (TP):", tp)
    print("False Positives (FP):", fp)
    print("True Negatives (TN):", tn)
    print("False Negatives (FN):", fn)
    print(":", fn+tn+tp+fp)
    print("Precision:", precision)
    print("Recall:", recall)
    print("F1 Score:", f1)

if __name__ == "__main__":
    global_server()

  augmented_data = torch.randn(64, vae.z_dim) * std + mean


Epoch [1/10], Training Loss: 2.306, Validation Accuracy: 10.74%
Epoch [2/10], Training Loss: 2.305, Validation Accuracy: 10.77%
Epoch [3/10], Training Loss: 2.304, Validation Accuracy: 10.82%
Epoch [4/10], Training Loss: 2.303, Validation Accuracy: 10.87%
Epoch [5/10], Training Loss: 2.303, Validation Accuracy: 11.09%
Epoch [6/10], Training Loss: 2.302, Validation Accuracy: 11.48%
Epoch [7/10], Training Loss: 2.301, Validation Accuracy: 12.17%
Epoch [8/10], Training Loss: 2.300, Validation Accuracy: 12.77%
Epoch [9/10], Training Loss: 2.299, Validation Accuracy: 13.61%
Epoch [10/10], Training Loss: 2.298, Validation Accuracy: 15.41%
Epoch [1/10], Training Loss: 2.297, Validation Accuracy: 16.92%
Epoch [2/10], Training Loss: 2.295, Validation Accuracy: 17.60%
Epoch [3/10], Training Loss: 2.292, Validation Accuracy: 17.82%
Epoch [4/10], Training Loss: 2.288, Validation Accuracy: 17.94%
Epoch [5/10], Training Loss: 2.283, Validation Accuracy: 18.24%
Epoch [6/10], Training Loss: 2.275, Val

In [3]:
%%time
import numpy as np
import torch
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision import transforms
from torch import nn
from typing import Dict

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define VAE model
class VAE(nn.Module):
    def __init__(self, x_dim, h_dim, z_dim):
        super(VAE, self).__init__()
        self.x_dim = x_dim
        self.h_dim = h_dim
        self.z_dim = z_dim

        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(2048, 256),
            nn.ReLU(),
            nn.Linear(256, 2*z_dim)
        )

        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(z_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 2048),
            nn.ReLU(),
            nn.Unflatten(1, (128, 4, 4)),
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 3, 4, stride=2, padding=1),
            nn.Sigmoid()
        )

    def encode(self, x):
        h = self.encoder(x)
        mu, logvar = h[:, :self.z_dim], h[:, self.z_dim:]
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        return self.decoder(z)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar


# Define β-VAE loss function
def vae_loss(recon_x, x, mu, logvar, beta=4.0):
    BCE = nn.functional.mse_loss(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + beta * KLD


# Define VAE training procedure
def vae_train(vae: VAE, trainloader: DataLoader, epochs: int, beta=4.0, device='cpu') -> None:
    vae.to(device)
    optimizer = torch.optim.Adam(vae.parameters(), lr=1e-3)
    vae.train()
    for epoch in range(epochs):
        for i, data in enumerate(trainloader, 0):
            inputs, _ = data
            inputs = inputs.to(device)
            optimizer.zero_grad()
            recon_x, mu, logvar = vae(inputs)
            loss = vae_loss(recon_x, inputs, mu, logvar, beta)
            loss.backward()
            optimizer.step()


# Define classification model
class Net(nn.Module):
    def __init__(self) -> None:
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.pool(nn.functional.relu(self.conv1(x)))
        x = self.pool(nn.functional.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = nn.functional.relu(self.fc1(x))
        x = nn.functional.relu(self.fc2(x))
        x = self.fc3(x)
        return x


# Load CIFAR10 dataset
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)

full_dataset = CIFAR10(root="./data", train=True, download=True, transform=transform)
test_set = CIFAR10(root="./data", train=False, download=True, transform=transform)

train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_set, val_set = torch.utils.data.random_split(full_dataset, [train_size, val_size])

trainloader = DataLoader(train_set, batch_size=128, shuffle=True, num_workers=2)
valloader = DataLoader(val_set, batch_size=128, shuffle=False, num_workers=2)
testloader = DataLoader(test_set, batch_size=128, shuffle=False, num_workers=2)


# Training procedure for classifier
def train(net: nn.Module, trainloader: DataLoader, valloader: DataLoader, epochs: int, device='cpu') -> None:
    net.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

    best_acc = 0.0
    for epoch in range(epochs):
        net.train()
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        net.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for data in valloader:
                images, labels = data
                images, labels = images.to(device), labels.to(device)
                outputs = net(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        val_acc = 100 * correct / total
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(net.state_dict(), 'best_model.pth')

        print(f"Epoch [{epoch+1}/{epochs}], Loss: {running_loss / (i+1):.3f}, Val Acc: {val_acc:.2f}%")


# Evaluation procedure
def evaluate(net: nn.Module, testloader: DataLoader) -> float:
    net.to(device)
    net.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    return 100 * correct / total


# Initialize simulated clients
def initialize_clients(trainset, transform, batch_size, num_clients):
    clients = {}
    for i in range(num_clients):
        start = i * len(trainset) // num_clients
        end = (i + 1) * len(trainset) // num_clients
        client_data = torch.utils.data.Subset(trainset, list(range(start, end)))
        clients[f"client_{i}"] = DataLoader(client_data, batch_size=batch_size, shuffle=True)
    return clients


# Distribution info (using client data)
def get_distribution_info(vae: VAE, dataloader: DataLoader, device='cpu') -> Dict:
    vae.eval()
    mu_list, logvar_list = [], []
    with torch.no_grad():
        for data in dataloader:
            inputs, _ = data
            inputs = inputs.to(device)
            mu, logvar = vae.encode(inputs)
            mu_list.append(mu.cpu())
            logvar_list.append(logvar.cpu())
    mu_all = torch.cat(mu_list)
    logvar_all = torch.cat(logvar_list)
    mean = mu_all.mean(dim=0).numpy()
    std = torch.exp(0.5 * logvar_all).mean(dim=0).numpy()
    return {"mean": mean, "std": std}


# Augmentation from other clients
def generate_augmented_data(vae: VAE, distribution_info: Dict) -> torch.Tensor:
    mean = torch.tensor(distribution_info["mean"], dtype=vae.decoder[0].weight.dtype)
    std = torch.tensor(distribution_info["std"], dtype=vae.decoder[0].weight.dtype)
    z = torch.randn(64, vae.z_dim, dtype=vae.decoder[0].weight.dtype) * std + mean
    generated = vae.decode(z.to(device))
    return generated


def receive_distribution_info() -> Dict:
    return {"mean": np.zeros(20), "std": np.ones(20)}


def send_distribution_info(info: Dict): pass
def send_model_update(client_id: str, model_update: Dict): pass


def federated_train(net: nn.Module, vae: VAE, trainloaders: Dict[str, DataLoader], trainloader: DataLoader, valloader: DataLoader, epochs: int) -> None:
    for epoch in range(epochs):
        for client_id, client_loader in trainloaders.items():
            vae_train(vae, client_loader, epochs=5, beta=4.0, device=device)
            dist_info = get_distribution_info(vae, client_loader, device)
            send_distribution_info(dist_info)
            other_info = receive_distribution_info()
            aug_data = generate_augmented_data(vae, other_info)
            train(net, client_loader, valloader, epochs=1, device=device)
            send_model_update(client_id, net.state_dict())


# Global server
def global_server() -> None:
    net = Net()
    vae = VAE(x_dim=3*32*32, h_dim=400, z_dim=20)
    clients = initialize_clients(train_set, transform, batch_size=128, num_clients=5)
    federated_train(net, vae, clients, trainloader, valloader, epochs=5)
    test_acc = evaluate(net, testloader)
    print(f"Test Accuracy: {test_acc:.2f}%")


if __name__ == "__main__":
    global_server()

Epoch [1/1], Loss: 2.304, Val Acc: 9.87%
Epoch [1/1], Loss: 2.304, Val Acc: 9.74%
Epoch [1/1], Loss: 2.302, Val Acc: 9.74%
Epoch [1/1], Loss: 2.302, Val Acc: 9.76%
Epoch [1/1], Loss: 2.301, Val Acc: 9.83%
Epoch [1/1], Loss: 2.301, Val Acc: 10.04%
Epoch [1/1], Loss: 2.300, Val Acc: 9.78%
Epoch [1/1], Loss: 2.299, Val Acc: 9.81%
Epoch [1/1], Loss: 2.298, Val Acc: 10.78%
Epoch [1/1], Loss: 2.297, Val Acc: 13.31%
Epoch [1/1], Loss: 2.296, Val Acc: 15.81%
Epoch [1/1], Loss: 2.294, Val Acc: 16.77%
Epoch [1/1], Loss: 2.292, Val Acc: 17.76%
Epoch [1/1], Loss: 2.290, Val Acc: 19.16%
Epoch [1/1], Loss: 2.286, Val Acc: 19.82%
Epoch [1/1], Loss: 2.282, Val Acc: 20.03%
Epoch [1/1], Loss: 2.274, Val Acc: 20.48%
Epoch [1/1], Loss: 2.263, Val Acc: 20.99%
Epoch [1/1], Loss: 2.244, Val Acc: 21.35%
Epoch [1/1], Loss: 2.217, Val Acc: 21.51%
Epoch [1/1], Loss: 2.183, Val Acc: 22.25%
Epoch [1/1], Loss: 2.144, Val Acc: 22.99%
Epoch [1/1], Loss: 2.118, Val Acc: 24.10%
Epoch [1/1], Loss: 2.086, Val Acc: 24.16%

In [4]:
%%time
import numpy as np
import torch
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision import transforms
from torch import nn
from typing import Dict

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define VAE model
class VAE(nn.Module):
    def __init__(self, x_dim, h_dim, z_dim):
        super(VAE, self).__init__()
        self.x_dim = x_dim
        self.h_dim = h_dim
        self.z_dim = z_dim

        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(2048, 256),
            nn.ReLU(),
            nn.Linear(256, 2*z_dim)
        )

        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(z_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 2048),
            nn.ReLU(),
            nn.Unflatten(1, (128, 4, 4)),
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 3, 4, stride=2, padding=1),
            nn.Sigmoid()
        )

    def encode(self, x):
        h = self.encoder(x)
        mu, logvar = h[:, :self.z_dim], h[:, self.z_dim:]
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        return self.decoder(z)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar


# Define β-VAE loss function
def vae_loss(recon_x, x, mu, logvar, beta=1.0):
    BCE = nn.functional.mse_loss(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + beta * KLD


# Define VAE training procedure
def vae_train(vae: VAE, trainloader: DataLoader, epochs: int, beta=4.0, device='cpu') -> None:
    vae.to(device)
    optimizer = torch.optim.Adam(vae.parameters(), lr=1e-3)
    vae.train()
    for epoch in range(epochs):
        for i, data in enumerate(trainloader, 0):
            inputs, _ = data
            inputs = inputs.to(device)
            optimizer.zero_grad()
            recon_x, mu, logvar = vae(inputs)
            loss = vae_loss(recon_x, inputs, mu, logvar, beta)
            loss.backward()
            optimizer.step()


# Define classification model
class Net(nn.Module):
    def __init__(self) -> None:
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.pool(nn.functional.relu(self.conv1(x)))
        x = self.pool(nn.functional.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = nn.functional.relu(self.fc1(x))
        x = nn.functional.relu(self.fc2(x))
        x = self.fc3(x)
        return x


# Load CIFAR10 dataset
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)

full_dataset = CIFAR10(root="./data", train=True, download=True, transform=transform)
test_set = CIFAR10(root="./data", train=False, download=True, transform=transform)

train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_set, val_set = torch.utils.data.random_split(full_dataset, [train_size, val_size])

trainloader = DataLoader(train_set, batch_size=128, shuffle=True, num_workers=2)
valloader = DataLoader(val_set, batch_size=128, shuffle=False, num_workers=2)
testloader = DataLoader(test_set, batch_size=128, shuffle=False, num_workers=2)


# Training procedure for classifier
def train(net: nn.Module, trainloader: DataLoader, valloader: DataLoader, epochs: int, device='cpu') -> None:
    net.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

    best_acc = 0.0
    for epoch in range(epochs):
        net.train()
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        net.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for data in valloader:
                images, labels = data
                images, labels = images.to(device), labels.to(device)
                outputs = net(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        val_acc = 100 * correct / total
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(net.state_dict(), 'best_model.pth')

        print(f"Epoch [{epoch+1}/{epochs}], Loss: {running_loss / (i+1):.3f}, Val Acc: {val_acc:.2f}%")


# Evaluation procedure
def evaluate(net: nn.Module, testloader: DataLoader) -> float:
    net.to(device)
    net.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    return 100 * correct / total


# Initialize simulated clients
def initialize_clients(trainset, transform, batch_size, num_clients):
    clients = {}
    for i in range(num_clients):
        start = i * len(trainset) // num_clients
        end = (i + 1) * len(trainset) // num_clients
        client_data = torch.utils.data.Subset(trainset, list(range(start, end)))
        clients[f"client_{i}"] = DataLoader(client_data, batch_size=batch_size, shuffle=True)
    return clients


# Distribution info (using client data)
def get_distribution_info(vae: VAE, dataloader: DataLoader, device='cpu') -> Dict:
    vae.eval()
    mu_list, logvar_list = [], []
    with torch.no_grad():
        for data in dataloader:
            inputs, _ = data
            inputs = inputs.to(device)
            mu, logvar = vae.encode(inputs)
            mu_list.append(mu.cpu())
            logvar_list.append(logvar.cpu())
    mu_all = torch.cat(mu_list)
    logvar_all = torch.cat(logvar_list)
    mean = mu_all.mean(dim=0).numpy()
    std = torch.exp(0.5 * logvar_all).mean(dim=0).numpy()
    return {"mean": mean, "std": std}


# Augmentation from other clients
def generate_augmented_data(vae: VAE, distribution_info: Dict) -> torch.Tensor:
    mean = torch.tensor(distribution_info["mean"], dtype=vae.decoder[0].weight.dtype)
    std = torch.tensor(distribution_info["std"], dtype=vae.decoder[0].weight.dtype)
    z = torch.randn(64, vae.z_dim, dtype=vae.decoder[0].weight.dtype) * std + mean
    generated = vae.decode(z.to(device))
    return generated


def receive_distribution_info() -> Dict:
    return {"mean": np.zeros(20), "std": np.ones(20)}


def send_distribution_info(info: Dict): pass
def send_model_update(client_id: str, model_update: Dict): pass


def federated_train(net: nn.Module, vae: VAE, trainloaders: Dict[str, DataLoader], trainloader: DataLoader, valloader: DataLoader, epochs: int) -> None:
    for epoch in range(epochs):
        for client_id, client_loader in trainloaders.items():
            vae_train(vae, client_loader, epochs=5, beta=4.0, device=device)
            dist_info = get_distribution_info(vae, client_loader, device)
            send_distribution_info(dist_info)
            other_info = receive_distribution_info()
            aug_data = generate_augmented_data(vae, other_info)
            train(net, client_loader, valloader, epochs=1, device=device)
            send_model_update(client_id, net.state_dict())


# Global server
def global_server() -> None:
    net = Net()
    vae = VAE(x_dim=3*32*32, h_dim=400, z_dim=20)
    clients = initialize_clients(train_set, transform, batch_size=128, num_clients=5)
    federated_train(net, vae, clients, trainloader, valloader, epochs=5)
    test_acc = evaluate(net, testloader)
    print(f"Test Accuracy: {test_acc:.2f}%")


if __name__ == "__main__":
    global_server()

Epoch [1/1], Loss: 2.302, Val Acc: 9.48%
Epoch [1/1], Loss: 2.301, Val Acc: 9.52%
Epoch [1/1], Loss: 2.300, Val Acc: 9.25%
Epoch [1/1], Loss: 2.298, Val Acc: 9.65%
Epoch [1/1], Loss: 2.297, Val Acc: 11.89%
Epoch [1/1], Loss: 2.295, Val Acc: 12.20%
Epoch [1/1], Loss: 2.293, Val Acc: 13.87%
Epoch [1/1], Loss: 2.290, Val Acc: 16.91%
Epoch [1/1], Loss: 2.286, Val Acc: 17.25%
Epoch [1/1], Loss: 2.281, Val Acc: 17.45%
Epoch [1/1], Loss: 2.273, Val Acc: 17.81%
Epoch [1/1], Loss: 2.264, Val Acc: 18.73%
Epoch [1/1], Loss: 2.252, Val Acc: 20.64%
Epoch [1/1], Loss: 2.234, Val Acc: 20.92%
Epoch [1/1], Loss: 2.205, Val Acc: 21.57%
Epoch [1/1], Loss: 2.176, Val Acc: 25.18%
Epoch [1/1], Loss: 2.141, Val Acc: 24.91%
Epoch [1/1], Loss: 2.107, Val Acc: 25.89%
Epoch [1/1], Loss: 2.087, Val Acc: 25.44%
Epoch [1/1], Loss: 2.049, Val Acc: 26.19%
Epoch [1/1], Loss: 2.042, Val Acc: 27.60%
Epoch [1/1], Loss: 2.030, Val Acc: 28.19%
Epoch [1/1], Loss: 2.010, Val Acc: 28.29%
Epoch [1/1], Loss: 1.998, Val Acc: 28.

In [29]:
import re

# Your provided text
log = """
Epoch [1/10], Training Loss: 2.305, Validation Accuracy: 10.51%
Epoch [2/10], Training Loss: 2.304, Validation Accuracy: 11.08%
Epoch [3/10], Training Loss: 2.303, Validation Accuracy: 11.00%
Epoch [4/10], Training Loss: 2.302, Validation Accuracy: 10.34%
Epoch [5/10], Training Loss: 2.302, Validation Accuracy: 10.15%
Epoch [6/10], Training Loss: 2.301, Validation Accuracy: 10.08%
Epoch [7/10], Training Loss: 2.300, Validation Accuracy: 10.08%
Epoch [8/10], Training Loss: 2.299, Validation Accuracy: 10.08%
Epoch [9/10], Training Loss: 2.299, Validation Accuracy: 10.08%
Epoch [10/10], Training Loss: 2.297, Validation Accuracy: 10.09%
Epoch [1/10], Training Loss: 2.297, Validation Accuracy: 10.14%
Epoch [2/10], Training Loss: 2.296, Validation Accuracy: 10.60%
Epoch [3/10], Training Loss: 2.294, Validation Accuracy: 12.86%
Epoch [4/10], Training Loss: 2.291, Validation Accuracy: 15.67%
Epoch [5/10], Training Loss: 2.288, Validation Accuracy: 16.95%
Epoch [6/10], Training Loss: 2.283, Validation Accuracy: 18.19%
Epoch [7/10], Training Loss: 2.276, Validation Accuracy: 18.12%
Epoch [8/10], Training Loss: 2.265, Validation Accuracy: 18.67%
Epoch [9/10], Training Loss: 2.248, Validation Accuracy: 19.60%
Epoch [10/10], Training Loss: 2.221, Validation Accuracy: 20.25%
Epoch [1/10], Training Loss: 2.185, Validation Accuracy: 19.53%
Epoch [2/10], Training Loss: 2.151, Validation Accuracy: 22.15%
Epoch [3/10], Training Loss: 2.129, Validation Accuracy: 23.43%
Epoch [4/10], Training Loss: 2.115, Validation Accuracy: 24.52%
Epoch [5/10], Training Loss: 2.103, Validation Accuracy: 24.94%
Epoch [6/10], Training Loss: 2.090, Validation Accuracy: 25.31%
Epoch [7/10], Training Loss: 2.082, Validation Accuracy: 26.61%
Epoch [8/10], Training Loss: 2.070, Validation Accuracy: 27.15%
Epoch [9/10], Training Loss: 2.060, Validation Accuracy: 26.64%
Epoch [10/10], Training Loss: 2.048, Validation Accuracy: 28.15%
Epoch [1/10], Training Loss: 2.053, Validation Accuracy: 28.44%
Epoch [2/10], Training Loss: 2.038, Validation Accuracy: 28.81%
Epoch [3/10], Training Loss: 2.021, Validation Accuracy: 29.30%
Epoch [4/10], Training Loss: 2.001, Validation Accuracy: 29.67%
Epoch [5/10], Training Loss: 1.977, Validation Accuracy: 30.49%
Epoch [6/10], Training Loss: 1.948, Validation Accuracy: 31.65%
Epoch [7/10], Training Loss: 1.920, Validation Accuracy: 30.70%
Epoch [8/10], Training Loss: 1.897, Validation Accuracy: 31.97%
Epoch [9/10], Training Loss: 1.878, Validation Accuracy: 33.31%
Epoch [10/10], Training Loss: 1.859, Validation Accuracy: 33.29%
Epoch [1/10], Training Loss: 1.844, Validation Accuracy: 34.10%
Epoch [2/10], Training Loss: 1.819, Validation Accuracy: 35.30%
Epoch [3/10], Training Loss: 1.801, Validation Accuracy: 35.39%
Epoch [4/10], Training Loss: 1.784, Validation Accuracy: 36.47%
Epoch [5/10], Training Loss: 1.761, Validation Accuracy: 36.91%
Epoch [6/10], Training Loss: 1.753, Validation Accuracy: 35.36%
Epoch [7/10], Training Loss: 1.740, Validation Accuracy: 37.15%
Epoch [8/10], Training Loss: 1.715, Validation Accuracy: 37.59%
Epoch [9/10], Training Loss: 1.704, Validation Accuracy: 38.47%
Epoch [10/10], Training Loss: 1.690, Validation Accuracy: 38.18%
Epoch [1/10], Training Loss: 1.700, Validation Accuracy: 38.11%
Epoch [2/10], Training Loss: 1.686, Validation Accuracy: 39.10%
Epoch [3/10], Training Loss: 1.672, Validation Accuracy: 38.89%
Epoch [4/10], Training Loss: 1.657, Validation Accuracy: 38.70%
Epoch [5/10], Training Loss: 1.646, Validation Accuracy: 39.68%
Epoch [6/10], Training Loss: 1.635, Validation Accuracy: 39.43%
Epoch [7/10], Training Loss: 1.623, Validation Accuracy: 40.51%
Epoch [8/10], Training Loss: 1.610, Validation Accuracy: 40.64%
Epoch [9/10], Training Loss: 1.602, Validation Accuracy: 39.32%
Epoch [10/10], Training Loss: 1.595, Validation Accuracy: 40.34%
Epoch [1/10], Training Loss: 1.613, Validation Accuracy: 41.74%
Epoch [2/10], Training Loss: 1.595, Validation Accuracy: 41.46%
Epoch [3/10], Training Loss: 1.585, Validation Accuracy: 40.92%
Epoch [4/10], Training Loss: 1.572, Validation Accuracy: 41.78%
Epoch [5/10], Training Loss: 1.562, Validation Accuracy: 41.79%
Epoch [6/10], Training Loss: 1.555, Validation Accuracy: 42.63%
Epoch [7/10], Training Loss: 1.542, Validation Accuracy: 42.85%
Epoch [8/10], Training Loss: 1.531, Validation Accuracy: 42.43%
Epoch [9/10], Training Loss: 1.519, Validation Accuracy: 43.45%
Epoch [10/10], Training Loss: 1.510, Validation Accuracy: 43.54%
Epoch [1/10], Training Loss: 1.559, Validation Accuracy: 43.72%
Epoch [2/10], Training Loss: 1.538, Validation Accuracy: 43.79%
Epoch [3/10], Training Loss: 1.526, Validation Accuracy: 43.84%
Epoch [4/10], Training Loss: 1.520, Validation Accuracy: 43.18%
Epoch [5/10], Training Loss: 1.510, Validation Accuracy: 44.20%
Epoch [6/10], Training Loss: 1.502, Validation Accuracy: 45.07%
Epoch [7/10], Training Loss: 1.489, Validation Accuracy: 44.66%
Epoch [8/10], Training Loss: 1.473, Validation Accuracy: 45.04%
Epoch [9/10], Training Loss: 1.469, Validation Accuracy: 45.57%
Epoch [10/10], Training Loss: 1.460, Validation Accuracy: 45.04%
Epoch [1/10], Training Loss: 1.502, Validation Accuracy: 45.72%
Epoch [2/10], Training Loss: 1.482, Validation Accuracy: 45.89%
Epoch [3/10], Training Loss: 1.472, Validation Accuracy: 46.66%
Epoch [4/10], Training Loss: 1.465, Validation Accuracy: 47.08%
Epoch [5/10], Training Loss: 1.443, Validation Accuracy: 46.88%
Epoch [6/10], Training Loss: 1.442, Validation Accuracy: 47.02%
Epoch [7/10], Training Loss: 1.433, Validation Accuracy: 46.47%
Epoch [8/10], Training Loss: 1.425, Validation Accuracy: 46.65%
Epoch [9/10], Training Loss: 1.418, Validation Accuracy: 47.34%
Epoch [10/10], Training Loss: 1.408, Validation Accuracy: 47.98%
Epoch [1/10], Training Loss: 1.435, Validation Accuracy: 47.75%
Epoch [2/10], Training Loss: 1.415, Validation Accuracy: 47.89%
Epoch [3/10], Training Loss: 1.413, Validation Accuracy: 47.70%
Epoch [4/10], Training Loss: 1.392, Validation Accuracy: 47.86%
Epoch [5/10], Training Loss: 1.385, Validation Accuracy: 48.91%
Epoch [6/10], Training Loss: 1.381, Validation Accuracy: 48.63%
Epoch [7/10], Training Loss: 1.374, Validation Accuracy: 48.38%
Epoch [8/10], Training Loss: 1.355, Validation Accuracy: 48.57%
Epoch [9/10], Training Loss: 1.348, Validation Accuracy: 49.14%
Epoch [10/10], Training Loss: 1.342, Validation Accuracy: 48.70%
Epoch [1/10], Training Loss: 1.399, Validation Accuracy: 48.01%
Epoch [2/10], Training Loss: 1.382, Validation Accuracy: 49.04%
Epoch [3/10], Training Loss: 1.362, Validation Accuracy: 48.24%
Epoch [4/10], Training Loss: 1.358, Validation Accuracy: 48.76%
Epoch [5/10], Training Loss: 1.347, Validation Accuracy: 48.86%
Epoch [6/10], Training Loss: 1.340, Validation Accuracy: 49.66%
Epoch [7/10], Training Loss: 1.328, Validation Accuracy: 49.66%
Epoch [8/10], Training Loss: 1.328, Validation Accuracy: 49.24%
Epoch [9/10], Training Loss: 1.314, Validation Accuracy: 48.93%
Epoch [10/10], Training Loss: 1.309, Validation Accuracy: 49.91%
Epoch [1/10], Training Loss: 1.357, Validation Accuracy: 50.16%
Epoch [2/10], Training Loss: 1.337, Validation Accuracy: 50.10%
Epoch [3/10], Training Loss: 1.320, Validation Accuracy: 51.06%
Epoch [4/10], Training Loss: 1.314, Validation Accuracy: 50.81%
Epoch [5/10], Training Loss: 1.298, Validation Accuracy: 50.99%
Epoch [6/10], Training Loss: 1.288, Validation Accuracy: 50.81%
Epoch [7/10], Training Loss: 1.287, Validation Accuracy: 50.66%
Epoch [8/10], Training Loss: 1.277, Validation Accuracy: 51.31%
Epoch [9/10], Training Loss: 1.265, Validation Accuracy: 51.44%
Epoch [10/10], Training Loss: 1.254, Validation Accuracy: 51.13%
Epoch [1/10], Training Loss: 1.341, Validation Accuracy: 50.71%
Epoch [2/10], Training Loss: 1.316, Validation Accuracy: 51.48%
Epoch [3/10], Training Loss: 1.297, Validation Accuracy: 51.77%
Epoch [4/10], Training Loss: 1.302, Validation Accuracy: 50.56%
Epoch [5/10], Training Loss: 1.283, Validation Accuracy: 51.73%
Epoch [6/10], Training Loss: 1.270, Validation Accuracy: 51.27%
Epoch [7/10], Training Loss: 1.258, Validation Accuracy: 52.15%
Epoch [8/10], Training Loss: 1.250, Validation Accuracy: 51.86%
Epoch [9/10], Training Loss: 1.240, Validation Accuracy: 51.50%
Epoch [10/10], Training Loss: 1.239, Validation Accuracy: 51.90%
Epoch [1/10], Training Loss: 1.314, Validation Accuracy: 52.56%
Epoch [2/10], Training Loss: 1.287, Validation Accuracy: 52.62%
Epoch [3/10], Training Loss: 1.278, Validation Accuracy: 53.11%
Epoch [4/10], Training Loss: 1.259, Validation Accuracy: 53.02%
Epoch [5/10], Training Loss: 1.256, Validation Accuracy: 51.67%
Epoch [6/10], Training Loss: 1.240, Validation Accuracy: 53.08%
Epoch [7/10], Training Loss: 1.238, Validation Accuracy: 52.69%
Epoch [8/10], Training Loss: 1.215, Validation Accuracy: 53.46%
Epoch [9/10], Training Loss: 1.208, Validation Accuracy: 53.14%
Epoch [10/10], Training Loss: 1.205, Validation Accuracy: 53.20%
Epoch [1/10], Training Loss: 1.282, Validation Accuracy: 53.82%
Epoch [2/10], Training Loss: 1.245, Validation Accuracy: 54.02%
Epoch [3/10], Training Loss: 1.238, Validation Accuracy: 53.95%
Epoch [4/10], Training Loss: 1.222, Validation Accuracy: 53.95%
Epoch [5/10], Training Loss: 1.220, Validation Accuracy: 53.04%
Epoch [6/10], Training Loss: 1.202, Validation Accuracy: 53.41%
Epoch [7/10], Training Loss: 1.190, Validation Accuracy: 53.74%
Epoch [8/10], Training Loss: 1.183, Validation Accuracy: 53.47%
Epoch [9/10], Training Loss: 1.173, Validation Accuracy: 54.33%
Epoch [10/10], Training Loss: 1.165, Validation Accuracy: 54.20%
Epoch [1/10], Training Loss: 1.244, Validation Accuracy: 54.29%
Epoch [2/10], Training Loss: 1.222, Validation Accuracy: 54.34%
Epoch [3/10], Training Loss: 1.205, Validation Accuracy: 54.96%
Epoch [4/10], Training Loss: 1.197, Validation Accuracy: 54.60%
Epoch [5/10], Training Loss: 1.182, Validation Accuracy: 54.92%
Epoch [6/10], Training Loss: 1.162, Validation Accuracy: 54.63%
Epoch [7/10], Training Loss: 1.159, Validation Accuracy: 54.16%
Epoch [8/10], Training Loss: 1.146, Validation Accuracy: 53.44%
Epoch [9/10], Training Loss: 1.150, Validation Accuracy: 55.14%
Epoch [10/10], Training Loss: 1.131, Validation Accuracy: 55.33%
Epoch [1/10], Training Loss: 1.225, Validation Accuracy: 55.48%
Epoch [2/10], Training Loss: 1.203, Validation Accuracy: 55.48%
Epoch [3/10], Training Loss: 1.178, Validation Accuracy: 55.60%
Epoch [4/10], Training Loss: 1.174, Validation Accuracy: 55.18%
Epoch [5/10], Training Loss: 1.161, Validation Accuracy: 55.65%
Epoch [6/10], Training Loss: 1.144, Validation Accuracy: 55.68%
Epoch [7/10], Training Loss: 1.128, Validation Accuracy: 55.60%
Epoch [8/10], Training Loss: 1.125, Validation Accuracy: 55.51%
Epoch [9/10], Training Loss: 1.109, Validation Accuracy: 54.97%
Epoch [10/10], Training Loss: 1.109, Validation Accuracy: 55.22%
Epoch [1/10], Training Loss: 1.193, Validation Accuracy: 55.00%
Epoch [2/10], Training Loss: 1.163, Validation Accuracy: 55.43%
Epoch [3/10], Training Loss: 1.151, Validation Accuracy: 55.20%
Epoch [4/10], Training Loss: 1.134, Validation Accuracy: 55.63%
Epoch [5/10], Training Loss: 1.118, Validation Accuracy: 55.97%
Epoch [6/10], Training Loss: 1.114, Validation Accuracy: 55.88%
Epoch [7/10], Training Loss: 1.104, Validation Accuracy: 55.28%
Epoch [8/10], Training Loss: 1.101, Validation Accuracy: 55.11%
Epoch [9/10], Training Loss: 1.093, Validation Accuracy: 55.58%
Epoch [10/10], Training Loss: 1.069, Validation Accuracy: 55.67%
Epoch [1/10], Training Loss: 1.195, Validation Accuracy: 56.33%
Epoch [2/10], Training Loss: 1.162, Validation Accuracy: 56.58%
Epoch [3/10], Training Loss: 1.146, Validation Accuracy: 55.84%
Epoch [4/10], Training Loss: 1.132, Validation Accuracy: 56.45%
Epoch [5/10], Training Loss: 1.115, Validation Accuracy: 57.07%
Epoch [6/10], Training Loss: 1.099, Validation Accuracy: 56.76%
Epoch [7/10], Training Loss: 1.089, Validation Accuracy: 56.87%
Epoch [8/10], Training Loss: 1.088, Validation Accuracy: 57.09%
Epoch [9/10], Training Loss: 1.075, Validation Accuracy: 56.84%
Epoch [10/10], Training Loss: 1.058, Validation Accuracy: 56.24%
Epoch [1/10], Training Loss: 1.168, Validation Accuracy: 57.38%
Epoch [2/10], Training Loss: 1.133, Validation Accuracy: 57.64%
Epoch [3/10], Training Loss: 1.118, Validation Accuracy: 57.67%
Epoch [4/10], Training Loss: 1.102, Validation Accuracy: 57.53%
Epoch [5/10], Training Loss: 1.088, Validation Accuracy: 57.47%
Epoch [6/10], Training Loss: 1.076, Validation Accuracy: 57.68%
Epoch [7/10], Training Loss: 1.063, Validation Accuracy: 57.87%
Epoch [8/10], Training Loss: 1.066, Validation Accuracy: 56.99%
Epoch [9/10], Training Loss: 1.039, Validation Accuracy: 56.60%
Epoch [10/10], Training Loss: 1.040, Validation Accuracy: 57.40%
Epoch [1/10], Training Loss: 1.154, Validation Accuracy: 57.56%
Epoch [2/10], Training Loss: 1.115, Validation Accuracy: 58.17%
Epoch [3/10], Training Loss: 1.085, Validation Accuracy: 57.47%
Epoch [4/10], Training Loss: 1.083, Validation Accuracy: 56.48%
Epoch [5/10], Training Loss: 1.062, Validation Accuracy: 57.82%
Epoch [6/10], Training Loss: 1.053, Validation Accuracy: 56.99%
Epoch [7/10], Training Loss: 1.033, Validation Accuracy: 56.86%
Epoch [8/10], Training Loss: 1.024, Validation Accuracy: 57.88%
Epoch [9/10], Training Loss: 1.014, Validation Accuracy: 57.49%
Epoch [10/10], Training Loss: 1.002, Validation Accuracy: 58.08%
Epoch [1/10], Training Loss: 1.120, Validation Accuracy: 57.40%
Epoch [2/10], Training Loss: 1.095, Validation Accuracy: 58.81%
Epoch [3/10], Training Loss: 1.059, Validation Accuracy: 58.17%
Epoch [4/10], Training Loss: 1.060, Validation Accuracy: 58.43%
Epoch [5/10], Training Loss: 1.047, Validation Accuracy: 57.26%
Epoch [6/10], Training Loss: 1.020, Validation Accuracy: 58.31%
Epoch [7/10], Training Loss: 1.015, Validation Accuracy: 57.97%
Epoch [8/10], Training Loss: 0.997, Validation Accuracy: 58.30%
Epoch [9/10], Training Loss: 0.992, Validation Accuracy: 58.56%
Epoch [10/10], Training Loss: 0.977, Validation Accuracy: 58.08%
Epoch [1/10], Training Loss: 1.108, Validation Accuracy: 58.74%
Epoch [2/10], Training Loss: 1.067, Validation Accuracy: 57.33%
Epoch [3/10], Training Loss: 1.045, Validation Accuracy: 58.40%
Epoch [4/10], Training Loss: 1.019, Validation Accuracy: 58.72%
Epoch [5/10], Training Loss: 1.012, Validation Accuracy: 58.38%
Epoch [6/10], Training Loss: 0.996, Validation Accuracy: 58.31%
Epoch [7/10], Training Loss: 0.980, Validation Accuracy: 58.14%
Epoch [8/10], Training Loss: 0.975, Validation Accuracy: 57.83%
Epoch [9/10], Training Loss: 0.965, Validation Accuracy: 57.29%
Epoch [10/10], Training Loss: 0.958, Validation Accuracy: 57.91%
Epoch [1/10], Training Loss: 1.103, Validation Accuracy: 58.24%
Epoch [2/10], Training Loss: 1.068, Validation Accuracy: 58.75%
Epoch [3/10], Training Loss: 1.040, Validation Accuracy: 59.03%
Epoch [4/10], Training Loss: 1.028, Validation Accuracy: 59.22%
Epoch [5/10], Training Loss: 1.008, Validation Accuracy: 59.18%
Epoch [6/10], Training Loss: 0.989, Validation Accuracy: 59.49%
Epoch [7/10], Training Loss: 0.971, Validation Accuracy: 59.37%
Epoch [8/10], Training Loss: 0.967, Validation Accuracy: 59.20%
Epoch [9/10], Training Loss: 0.950, Validation Accuracy: 59.34%
Epoch [10/10], Training Loss: 0.942, Validation Accuracy: 59.03%
Epoch [1/10], Training Loss: 1.079, Validation Accuracy: 59.55%
Epoch [2/10], Training Loss: 1.050, Validation Accuracy: 59.49%
Epoch [3/10], Training Loss: 1.013, Validation Accuracy: 59.68%
Epoch [4/10], Training Loss: 0.997, Validation Accuracy: 59.26%
Epoch [5/10], Training Loss: 0.979, Validation Accuracy: 59.39%
Epoch [6/10], Training Loss: 0.966, Validation Accuracy: 58.82%
Epoch [7/10], Training Loss: 0.950, Validation Accuracy: 59.11%
Epoch [8/10], Training Loss: 0.941, Validation Accuracy: 59.27%
Epoch [9/10], Training Loss: 0.932, Validation Accuracy: 59.08%
Epoch [10/10], Training Loss: 0.912, Validation Accuracy: 59.34%
"""

# Regular expression to find validation accuracies
accuracies = re.findall(r'Validation Accuracy: (\d+\.\d+)%', log)

# Convert accuracies from string to float
accuracies = [float(acc) for acc in accuracies]

# Print accuracies
print("Accuracies:", accuracies)

# Print size of the array
print("Size of array:", len(accuracies))


Accuracies: [10.51, 11.08, 11.0, 10.34, 10.15, 10.08, 10.08, 10.08, 10.08, 10.09, 10.14, 10.6, 12.86, 15.67, 16.95, 18.19, 18.12, 18.67, 19.6, 20.25, 19.53, 22.15, 23.43, 24.52, 24.94, 25.31, 26.61, 27.15, 26.64, 28.15, 28.44, 28.81, 29.3, 29.67, 30.49, 31.65, 30.7, 31.97, 33.31, 33.29, 34.1, 35.3, 35.39, 36.47, 36.91, 35.36, 37.15, 37.59, 38.47, 38.18, 38.11, 39.1, 38.89, 38.7, 39.68, 39.43, 40.51, 40.64, 39.32, 40.34, 41.74, 41.46, 40.92, 41.78, 41.79, 42.63, 42.85, 42.43, 43.45, 43.54, 43.72, 43.79, 43.84, 43.18, 44.2, 45.07, 44.66, 45.04, 45.57, 45.04, 45.72, 45.89, 46.66, 47.08, 46.88, 47.02, 46.47, 46.65, 47.34, 47.98, 47.75, 47.89, 47.7, 47.86, 48.91, 48.63, 48.38, 48.57, 49.14, 48.7, 48.01, 49.04, 48.24, 48.76, 48.86, 49.66, 49.66, 49.24, 48.93, 49.91, 50.16, 50.1, 51.06, 50.81, 50.99, 50.81, 50.66, 51.31, 51.44, 51.13, 50.71, 51.48, 51.77, 50.56, 51.73, 51.27, 52.15, 51.86, 51.5, 51.9, 52.56, 52.62, 53.11, 53.02, 51.67, 53.08, 52.69, 53.46, 53.14, 53.2, 53.82, 54.02, 53.95, 53

In [None]:
import numpy as np
from scipy.stats import truncnorm
import torch
import torchvision
import torchvision.transforms as transforms
from torch import nn
from torch.utils.data import Subset, DataLoader
from torchvision.datasets import CIFAR10
from typing import Dict, Tuple
from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score
import random

# Define VAE model
class VAE(nn.Module):
    def __init__(self, x_dim, h_dim, z_dim):
        super(VAE, self).__init__()
        self.x_dim = x_dim
        self.h_dim = h_dim
        self.z_dim = z_dim

        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, 4, stride=2, padding=1),  # 16x16
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2, padding=1),  # 8x8
            nn.ReLU(),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),  # 4x4
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(2048, 256),
            nn.ReLU(),
            nn.Linear(256, 2*z_dim)
        )

        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(z_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 2048),
            nn.ReLU(),
            nn.Unflatten(1, (128, 4, 4)),
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),  # 8x8
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1),  # 16x16
            nn.ReLU(),
            nn.ConvTranspose2d(32, 3, 4, stride=2, padding=1),  # 32x32
            nn.Sigmoid()
        )

    def encode(self, x):
        h = self.encoder(x)
        mu, logvar = h[:, :self.z_dim], h[:, self.z_dim:]
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        return self.decoder(z)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar


# Define VAE training procedure
def vae_train(vae: VAE, trainloader: DataLoader, epochs: int) -> None:
    optimizer = torch.optim.Adam(vae.parameters(), lr=1e-3)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)

    for epoch in range(epochs):
        vae.train()
        for i, data in enumerate(trainloader, 0):
            inputs, _ = data
            optimizer.zero_grad()
            recon_x, mu, logvar = vae(inputs)
            loss = vae_loss(recon_x, inputs, mu, logvar)
            loss.backward()
            optimizer.step()

        scheduler.step()


# Define classification model
class Net(nn.Module):
    def __init__(self) -> None:
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.pool(nn.functional.relu(self.conv1(x)))
        x = self.pool(nn.functional.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = nn.functional.relu(self.fc1(x))
        x = nn.functional.relu(self.fc2(x))
        x = self.fc3(x)
        return x


# Load CIFAR10 dataset
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)

full_dataset = CIFAR10(root="./data", train=True, download=True, transform=transform)
test_set = CIFAR10(root="./data", train=False, download=True, transform=transform)

# Separate dataset by class
class_indices = {i: [] for i in range(10)}  # CIFAR-10 has 10 classes
for idx, (_, label) in enumerate(full_dataset):
    class_indices[label].append(idx)

# Define target count per class, summing to 60,000 with random distribution
class_counts = np.random.multinomial(60000, [0.1] * 10)  # Adjust probabilities if you want specific class biases
print("Random Images per Class:", class_counts)

# Sample indices based on the specified class counts
indices = []
for class_id, count in enumerate(class_counts):
    # Ensure count does not exceed available images
    count = min(count, len(class_indices[class_id]))
    selected_indices = random.sample(class_indices[class_id], count)
    indices.extend(selected_indices)

# Create a custom CIFAR-10 dataset with the sampled indices
custom_dataset = Subset(full_dataset, indices)
# Split dataset into training and validation sets
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_set, val_set = torch.utils.data.random_split(full_dataset, [train_size, val_size])

# Create training and validation loaders
trainloader = DataLoader(train_set, batch_size=128, shuffle=True, num_workers=2)
valloader = DataLoader(val_set, batch_size=128, shuffle=False, num_workers=2)
testloader = DataLoader(test_set, batch_size=128, shuffle=False, num_workers=2)



# Define VAE loss function
def vae_loss(recon_x, x, mu, logvar):
    BCE = nn.functional.mse_loss(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

# Define training procedure for classification model
def train(net: nn.Module, trainloader: DataLoader, valloader: DataLoader, epochs: int) -> None:
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)

    best_acc = 0.0

    for epoch in range(epochs):
        # Training loop
        net.train()
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        scheduler.step()

        # Validation loop
        net.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for data in valloader:
                images, labels = data
                outputs = net(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        val_acc = 100 * correct / total
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(net.state_dict(), 'best_model.pth')

        print(f"Epoch [{epoch+1}/{epochs}], Training Loss: {running_loss / (i+1):.3f}, Validation Accuracy: {val_acc:.2f}%")

# Define evaluation procedure
def evaluate(net: nn.Module, testloader: DataLoader) -> Tuple[float, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    net.eval()
    correct = 0
    total = 0
    all_labels = []
    all_predictions = []
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            all_labels.extend(labels.cpu().numpy())
            all_predictions.extend(predicted.cpu().numpy())

    accuracy = 100 * correct / total

    # Calculate confusion matrix
    cm = confusion_matrix(all_labels, all_predictions)
    print(f"Confusion Matrix:\n{cm}")

    # Calculate TP, FP, TN, FN for each class
    tp = np.diag(cm)
    fp = cm.sum(axis=0) - tp
    fn = cm.sum(axis=1) - tp
    tn = cm.sum() - (fp + fn + tp)

    # Calculate precision, recall, and F1 score for each class
    precision = precision_score(all_labels, all_predictions, average=None)
    recall = recall_score(all_labels, all_predictions, average=None)
    f1 = f1_score(all_labels, all_predictions, average=None)

    return accuracy, tp, fp, tn, fn, precision, recall, f1

# Initialize clients
def initialize_clients(trainset, transform, batch_size, num_clients):
    clients = {}
    for i in range(num_clients):
        client_trainset = torch.utils.data.Subset(trainset, range(i * len(trainset) // num_clients, (i + 1) * len(trainset) // num_clients))
        client_trainloader = torch.utils.data.DataLoader(client_trainset, batch_size=batch_size, shuffle=True)
        clients[f"client_{i}"] = client_trainloader
    return clients

def get_distribution_info(vae: VAE) -> Dict:
    # Implement the logic to extract distribution information from the VAE
    # This can involve computing statistics, parameters, or any other relevant information
    # that can be used to generate augmented data

    # Example implementation:
    distribution_info = {

        "normal": {
            "mean": vae.encoder[-1].bias.data.cpu().numpy(),
            "std": torch.exp(0.5 * vae.encoder[-1].weight.data).cpu().numpy()
        }
    }

    return distribution_info

def send_distribution_info(distribution_info: Dict) -> None:
    # Implement the logic to send the distribution information to the global server
    # This can involve using a network protocol, a message queue, or any other communication mechanism
    # to send the information

    # Example implementation:
    # Send the distribution information to the global server using a network protocol
    # For example, you can use the `socket` module to send the information over a network
    # or use a message queue like `RabbitMQ` to send the information
    pass

def generate_augmented_data(vae: VAE, distribution_info_normal: Dict) -> torch.Tensor:
    # Generate augmented data using both uniform and truncated uniform distributions



    mean_normal = distribution_info_normal["mean"]
    std_normal = distribution_info_normal["std"]

    augmented_data_normal = torch.randn(64, vae.z_dim) * std_normal + mean_normal





    # Calculate the average of augmented data from both distributions
    augmented_data_average =  augmented_data_normal

    return augmented_data_average

def federated_train(net: nn.Module, vae: VAE, trainloaders: Dict[str, DataLoader], trainloader: DataLoader, valloader: DataLoader, epochs: int) -> None:
    for epoch in range(epochs):
        for client_id, client_trainloader in trainloaders.items():
            # Train VAE on client data
            vae_train(vae, client_trainloader, epochs=10)

            # Share distribution information with global server
            distribution_info = get_distribution_info(vae)
            send_distribution_info(distribution_info)

            # Receive distribution information from other clients
            other_distribution_info = receive_distribution_info()

            # Generate augmented data using received distribution information
            augmented_data = generate_augmented_data(vae, other_distribution_info["normal"])

            # Train classification model using local, augmented, and validation data
            train(net, client_trainloader, valloader, epochs=10)

            # Send model updates to global server
            send_model_update(client_id, net.state_dict())

# Define logic to receive distribution information from global server
def receive_distribution_info() -> Dict:
    # Receive distribution information logic
    distribution_info = {

        "normal": {
            "mean": np.zeros(20),  # Adjust the size based on your latent space dimension
            "std": np.ones(20)
        }
    }
    return distribution_info

def send_model_update(client_id: str, model_update: Dict) -> None:
    # Implement the logic to send the model update to the global server
    # This can involve using a network protocol, a message queue, or any other communication mechanism
    # to send the model update

    # Example implementation:
    # Send the model update to the global server using a network protocol
    # For example, you can use the `socket` module to send the information over a network
    # or use a message queue like `RabbitMQ` to send the information
    pass

# Define global server procedure
def global_server() -> None:
    net = Net()
    x_dim = 3 * 32 * 32  # CIFAR-10 input size
    h_dim = 400
    z_dim = 20
    vae = VAE(x_dim, h_dim, z_dim)  # Initialize VAE object with required arguments

    # Initialize clients
    num_clients = 5  # Define the number of clients
    clients = initialize_clients(train_set, transform, batch_size=128, num_clients=num_clients)

    # Train model using FedDIS
    federated_train(net, vae, clients, trainloader, valloader, epochs=10)

    # Evaluate final model
    test_accuracy, tp, fp, tn, fn, precision, recall, f1 = evaluate(net, testloader)
    print(f"Test Accuracy: {test_accuracy:.2f}%")

    print("True Positives (TP):", tp)
    print("False Positives (FP):", fp)
    print("True Negatives (TN):", tn)
    print("False Negatives (FN):", fn)
    print(":", fn+tn+tp+fp)
    print("Precision:", precision)
    print("Recall:", recall)
    print("F1 Score:", f1)

if __name__ == "__main__":
    global_server()

Files already downloaded and verified
Files already downloaded and verified
Random Images per Class: [5958 6008 5989 5987 5944 5912 6049 6086 5993 6074]
Epoch [1/10], Training Loss: 2.306, Validation Accuracy: 9.51%
Epoch [2/10], Training Loss: 2.305, Validation Accuracy: 9.51%
Epoch [3/10], Training Loss: 2.304, Validation Accuracy: 9.51%
Epoch [4/10], Training Loss: 2.303, Validation Accuracy: 9.59%
Epoch [5/10], Training Loss: 2.302, Validation Accuracy: 10.49%
Epoch [6/10], Training Loss: 2.301, Validation Accuracy: 12.03%
Epoch [7/10], Training Loss: 2.300, Validation Accuracy: 12.99%
Epoch [8/10], Training Loss: 2.299, Validation Accuracy: 13.97%
Epoch [9/10], Training Loss: 2.297, Validation Accuracy: 14.28%
Epoch [10/10], Training Loss: 2.296, Validation Accuracy: 14.70%
Epoch [1/10], Training Loss: 2.295, Validation Accuracy: 15.30%
Epoch [2/10], Training Loss: 2.293, Validation Accuracy: 16.12%
Epoch [3/10], Training Loss: 2.289, Validation Accuracy: 16.74%
Epoch [4/10], Trai

In [None]:
# Define global server procedure
def global_server() -> None:
    net = Net()
    x_dim = 3 * 32 * 32  # CIFAR-10 input size
    h_dim = 400
    z_dim = 20
    vae = VAE(x_dim, h_dim, z_dim)  # Initialize VAE object with required arguments

    # Initialize clients
    num_clients = 5  # Define the number of clients
    clients = initialize_clients(train_set, transform, batch_size=128, num_clients=num_clients)

    # Train model using FedDIS
    federated_train(net, vae, clients, trainloader, valloader, epochs=10)

    # Evaluate final model
    test_accuracy, tp, fp, tn, fn, precision, recall, f1 = evaluate(net, testloader)
    print(f"Test Accuracy: {test_accuracy:.2f}%")

    print("True Positives (TP):", tp)
    print("False Positives (FP):", fp)
    print("True Negatives (TN):", tn)
    print("False Negatives (FN):", fn)
    print(":", fn+tn+tp+fp)
    print("Precision:", precision)
    print("Recall:", recall)
    print("F1 Score:", f1)

if __name__ == "__main__":
    global_server()