<a href="https://colab.research.google.com/github/Rahad31/Different-VAE-for-KL-FedDis/blob/main/Beta_Uni_Trunc.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Beta=.1

In [None]:
import numpy as np
from scipy.stats import truncnorm
import torch
import torchvision
import torchvision.transforms as transforms
from torch import nn
from torch.utils.data import Subset, DataLoader
from torchvision.datasets import CIFAR10
from typing import Dict, Tuple
from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score
import random

# Define VAE model
class VAE(nn.Module):
    def __init__(self, x_dim, h_dim, z_dim):
        super(VAE, self).__init__()
        self.x_dim = x_dim
        self.h_dim = h_dim
        self.z_dim = z_dim

        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, 4, stride=2, padding=1),  # 16x16
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2, padding=1),  # 8x8
            nn.ReLU(),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),  # 4x4
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(2048, 256),
            nn.ReLU(),
            nn.Linear(256, 2*z_dim)
        )

        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(z_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 2048),
            nn.ReLU(),
            nn.Unflatten(1, (128, 4, 4)),
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),  # 8x8
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1),  # 16x16
            nn.ReLU(),
            nn.ConvTranspose2d(32, 3, 4, stride=2, padding=1),  # 32x32
            nn.Sigmoid()
        )

    def encode(self, x):
        h = self.encoder(x)
        mu, logvar = h[:, :self.z_dim], h[:, self.z_dim:]
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        return self.decoder(z)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar


# Define VAE training procedure
def vae_train(vae: VAE, trainloader: DataLoader, epochs: int, beta: float =0.1):
    optimizer = torch.optim.Adam(vae.parameters(), lr=1e-3)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)

    for epoch in range(epochs):
        vae.train()
        for i, data in enumerate(trainloader, 0):
            inputs, _ = data
            optimizer.zero_grad()
            recon_x, mu, logvar = vae(inputs)
            loss = vae_loss(recon_x, inputs, mu, logvar , beta=beta)
            loss.backward()
            optimizer.step()

        scheduler.step()



# Define classification model
class Net(nn.Module):
    def __init__(self) -> None:
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.pool(nn.functional.relu(self.conv1(x)))
        x = self.pool(nn.functional.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = nn.functional.relu(self.fc1(x))
        x = nn.functional.relu(self.fc2(x))
        x = self.fc3(x)
        return x


# Load CIFAR10 dataset
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)

full_dataset = CIFAR10(root="./data", train=True, download=True, transform=transform)
test_set = CIFAR10(root="./data", train=False, download=True, transform=transform)

# Separate dataset by class
class_indices = {i: [] for i in range(10)}  # CIFAR-10 has 10 classes
for idx, (_, label) in enumerate(full_dataset):
    class_indices[label].append(idx)

# Define target count per class, summing to 60,000 with random distribution
class_counts = np.random.multinomial(60000, [0.1] * 10)  # Adjust probabilities if you want specific class biases
print("Random Images per Class:", class_counts)

# Sample indices based on the specified class counts
indices = []
for class_id, count in enumerate(class_counts):
    # Ensure count does not exceed available images
    count = min(count, len(class_indices[class_id]))
    selected_indices = random.sample(class_indices[class_id], count)
    indices.extend(selected_indices)

# Create a custom CIFAR-10 dataset with the sampled indices
custom_dataset = Subset(full_dataset, indices)
# Split dataset into training and validation sets
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_set, val_set = torch.utils.data.random_split(full_dataset, [train_size, val_size])

# Create training and validation loaders
trainloader = DataLoader(train_set, batch_size=128, shuffle=True, num_workers=2)
valloader = DataLoader(val_set, batch_size=128, shuffle=False, num_workers=2)
testloader = DataLoader(test_set, batch_size=128, shuffle=False, num_workers=2)




# Define VAE loss function
def vae_loss(recon_x, x, mu, logvar, beta=0.1):
    BCE = nn.functional.mse_loss(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + beta * KLD


# Define training procedure for classification model
def train(net: nn.Module, trainloader: DataLoader, valloader: DataLoader, epochs: int) -> None:
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)

    best_acc = 0.0

    for epoch in range(epochs):
        # Training loop
        net.train()
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        scheduler.step()

        # Validation loop
        net.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for data in valloader:
                images, labels = data
                outputs = net(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        val_acc = 100 * correct / total
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(net.state_dict(), 'best_model.pth')

        print(f"Epoch [{epoch+1}/{epochs}], Training Loss: {running_loss / (i+1):.3f}, Validation Accuracy: {val_acc:.2f}%")

# Define evaluation procedure
def evaluate(net: nn.Module, testloader: DataLoader) -> Tuple[float, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    net.eval()
    correct = 0
    total = 0
    all_labels = []
    all_predictions = []
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            all_labels.extend(labels.cpu().numpy())
            all_predictions.extend(predicted.cpu().numpy())

    accuracy = 100 * correct / total

    # Calculate confusion matrix
    cm = confusion_matrix(all_labels, all_predictions)
    print(f"Confusion Matrix:\n{cm}")

    # Calculate TP, FP, TN, FN for each class
    tp = np.diag(cm)
    fp = cm.sum(axis=0) - tp
    fn = cm.sum(axis=1) - tp
    tn = cm.sum() - (fp + fn + tp)

    # Calculate precision, recall, and F1 score for each class
    precision = precision_score(all_labels, all_predictions, average=None)
    recall = recall_score(all_labels, all_predictions, average=None)
    f1 = f1_score(all_labels, all_predictions, average=None)

    return accuracy, tp, fp, tn, fn, precision, recall, f1

# Initialize clients
def initialize_clients(trainset, transform, batch_size, num_clients):
    clients = {}
    for i in range(num_clients):
        client_trainset = torch.utils.data.Subset(trainset, range(i * len(trainset) // num_clients, (i + 1) * len(trainset) // num_clients))
        client_trainloader = torch.utils.data.DataLoader(client_trainset, batch_size=batch_size, shuffle=True)
        clients[f"client_{i}"] = client_trainloader
    return clients

def get_distribution_info(vae: VAE) -> Dict:
    # Implement the logic to extract distribution information from the VAE
    # This can involve computing statistics, parameters, or any other relevant information
    # that can be used to generate augmented data

    # Example implementation:
    distribution_info = {
        "uniform": {
            "mean": vae.encoder[-1].bias.data.cpu().numpy(),
            "std": torch.exp(0.5 * vae.encoder[-1].weight.data).cpu().numpy()
        },
        "normal": {
            "mean": vae.encoder[-1].bias.data.cpu().numpy(),
            "std": torch.exp(0.5 * vae.encoder[-1].weight.data).cpu().numpy()
        }
    }

    return distribution_info

def send_distribution_info(distribution_info: Dict) -> None:
    # Implement the logic to send the distribution information to the global server
    # This can involve using a network protocol, a message queue, or any other communication mechanism
    # to send the information

    # Example implementation:
    # Send the distribution information to the global server using a network protocol
    # For example, you can use the `socket` module to send the information over a network
    # or use a message queue like `RabbitMQ` to send the information
    pass

def generate_augmented_data(vae: VAE, distribution_info_uniform: Dict, distribution_info_normal: Dict) -> torch.Tensor:
    # Generate augmented data using both uniform and truncated uniform distributions

    mean = distribution_info_uniform["mean"].mean().item()  # Convert numpy array to float
    std = distribution_info_uniform["std"].mean().item()  # Convert numpy array to float


    mean_normal = distribution_info_normal["mean"]
    std_normal = distribution_info_normal["std"]

    augmented_data_normal = torch.randn(64, vae.z_dim) * std_normal + mean_normal

     # Generate augmented data using Uniform distribution
    augmented_data_uniform = torch.FloatTensor(64, vae.z_dim).uniform_(mean - std, mean + std)


    # Calculate the average of augmented data from both distributions
    augmented_data_average = (augmented_data_uniform + augmented_data_normal ) / 2

    return augmented_data_average

def federated_train(net: nn.Module, vae: VAE, trainloaders: Dict[str, DataLoader], trainloader: DataLoader, valloader: DataLoader, epochs: int) -> None:
    for epoch in range(epochs):
        for client_id, client_trainloader in trainloaders.items():
            # Train VAE on client data
            vae_train(vae, client_trainloader, epochs=10, beta=0.1)

            # Share distribution information with global server
            distribution_info = get_distribution_info(vae)
            send_distribution_info(distribution_info)

            # Receive distribution information from other clients
            other_distribution_info = receive_distribution_info()

            # Generate augmented data using received distribution information
            augmented_data = generate_augmented_data(vae, other_distribution_info["uniform"], other_distribution_info["normal"])

            # Train classification model using local, augmented, and validation data
            train(net, client_trainloader, valloader, epochs=10)

            # Send model updates to global server
            send_model_update(client_id, net.state_dict())

# Define logic to receive distribution information from global server
def receive_distribution_info() -> Dict:
    # Receive distribution information logic
    distribution_info = {
        "uniform": {
            "mean": np.zeros(20),  # Adjust the size based on your latent space dimension
            "std": np.ones(20)
        },
        "normal": {
            "mean": np.zeros(20),  # Adjust the size based on your latent space dimension
            "std": np.ones(20)
        }
    }
    return distribution_info

def send_model_update(client_id: str, model_update: Dict) -> None:
    # Implement the logic to send the model update to the global server
    # This can involve using a network protocol, a message queue, or any other communication mechanism
    # to send the model update

    # Example implementation:
    # Send the model update to the global server using a network protocol
    # For example, you can use the `socket` module to send the information over a network
    # or use a message queue like `RabbitMQ` to send the information
    pass

# Define global server procedure
def global_server() -> None:
    net = Net()
    x_dim = 3 * 32 * 32  # CIFAR-10 input size
    h_dim = 400
    z_dim = 20
    vae = VAE(x_dim, h_dim, z_dim)  # Initialize VAE object with required arguments

    # Initialize clients
    num_clients = 5  # Define the number of clients
    clients = initialize_clients(train_set, transform, batch_size=128, num_clients=num_clients)

    # Train model using FedDIS
    federated_train(net, vae, clients, trainloader, valloader, epochs=10)

    # Evaluate final model
    test_accuracy, tp, fp, tn, fn, precision, recall, f1 = evaluate(net, testloader)
    print(f"Test Accuracy: {test_accuracy:.2f}%")

    print("True Positives (TP):", tp)
    print("False Positives (FP):", fp)
    print("True Negatives (TN):", tn)
    print("False Negatives (FN):", fn)
    print(":", fn+tn+tp+fp)
    print("Precision:", precision)
    print("Recall:", recall)
    print("F1 Score:", f1)

if __name__ == "__main__":
    global_server()

100%|██████████| 170M/170M [00:02<00:00, 69.1MB/s]


Random Images per Class: [5893 6152 6063 5953 6013 6066 5914 6053 5926 5967]


  augmented_data_normal = torch.randn(64, vae.z_dim) * std_normal + mean_normal


Epoch [1/10], Training Loss: 2.304, Validation Accuracy: 10.54%
Epoch [2/10], Training Loss: 2.302, Validation Accuracy: 10.68%
Epoch [3/10], Training Loss: 2.300, Validation Accuracy: 12.05%
Epoch [4/10], Training Loss: 2.299, Validation Accuracy: 13.80%
Epoch [5/10], Training Loss: 2.296, Validation Accuracy: 14.89%
Epoch [6/10], Training Loss: 2.293, Validation Accuracy: 15.71%
Epoch [7/10], Training Loss: 2.289, Validation Accuracy: 16.57%
Epoch [8/10], Training Loss: 2.284, Validation Accuracy: 18.38%
Epoch [9/10], Training Loss: 2.276, Validation Accuracy: 18.15%
Epoch [10/10], Training Loss: 2.264, Validation Accuracy: 17.43%
Epoch [1/10], Training Loss: 2.249, Validation Accuracy: 17.42%
Epoch [2/10], Training Loss: 2.225, Validation Accuracy: 17.12%
Epoch [3/10], Training Loss: 2.195, Validation Accuracy: 18.14%
Epoch [4/10], Training Loss: 2.163, Validation Accuracy: 21.13%
Epoch [5/10], Training Loss: 2.136, Validation Accuracy: 23.99%
Epoch [6/10], Training Loss: 2.115, Val

Beta=0.5

In [None]:
import numpy as np
from scipy.stats import truncnorm
import torch
import torchvision
import torchvision.transforms as transforms
from torch import nn
from torch.utils.data import Subset, DataLoader
from torchvision.datasets import CIFAR10
from typing import Dict, Tuple
from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score
import random

# Define VAE model
class VAE(nn.Module):
    def __init__(self, x_dim, h_dim, z_dim):
        super(VAE, self).__init__()
        self.x_dim = x_dim
        self.h_dim = h_dim
        self.z_dim = z_dim

        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, 4, stride=2, padding=1),  # 16x16
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2, padding=1),  # 8x8
            nn.ReLU(),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),  # 4x4
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(2048, 256),
            nn.ReLU(),
            nn.Linear(256, 2*z_dim)
        )

        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(z_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 2048),
            nn.ReLU(),
            nn.Unflatten(1, (128, 4, 4)),
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),  # 8x8
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1),  # 16x16
            nn.ReLU(),
            nn.ConvTranspose2d(32, 3, 4, stride=2, padding=1),  # 32x32
            nn.Sigmoid()
        )

    def encode(self, x):
        h = self.encoder(x)
        mu, logvar = h[:, :self.z_dim], h[:, self.z_dim:]
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        return self.decoder(z)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar


# Define VAE training procedure
def vae_train(vae: VAE, trainloader: DataLoader, epochs: int, beta: float =0.5):
    optimizer = torch.optim.Adam(vae.parameters(), lr=1e-3)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)

    for epoch in range(epochs):
        vae.train()
        for i, data in enumerate(trainloader, 0):
            inputs, _ = data
            optimizer.zero_grad()
            recon_x, mu, logvar = vae(inputs)
            loss = vae_loss(recon_x, inputs, mu, logvar , beta=beta)
            loss.backward()
            optimizer.step()

        scheduler.step()



# Define classification model
class Net(nn.Module):
    def __init__(self) -> None:
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.pool(nn.functional.relu(self.conv1(x)))
        x = self.pool(nn.functional.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = nn.functional.relu(self.fc1(x))
        x = nn.functional.relu(self.fc2(x))
        x = self.fc3(x)
        return x


# Load CIFAR10 dataset
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)

full_dataset = CIFAR10(root="./data", train=True, download=True, transform=transform)
test_set = CIFAR10(root="./data", train=False, download=True, transform=transform)

# Separate dataset by class
class_indices = {i: [] for i in range(10)}  # CIFAR-10 has 10 classes
for idx, (_, label) in enumerate(full_dataset):
    class_indices[label].append(idx)

# Define target count per class, summing to 60,000 with random distribution
class_counts = np.random.multinomial(60000, [0.1] * 10)  # Adjust probabilities if you want specific class biases
print("Random Images per Class:", class_counts)

# Sample indices based on the specified class counts
indices = []
for class_id, count in enumerate(class_counts):
    # Ensure count does not exceed available images
    count = min(count, len(class_indices[class_id]))
    selected_indices = random.sample(class_indices[class_id], count)
    indices.extend(selected_indices)

# Create a custom CIFAR-10 dataset with the sampled indices
custom_dataset = Subset(full_dataset, indices)
# Split dataset into training and validation sets
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_set, val_set = torch.utils.data.random_split(full_dataset, [train_size, val_size])

# Create training and validation loaders
trainloader = DataLoader(train_set, batch_size=128, shuffle=True, num_workers=2)
valloader = DataLoader(val_set, batch_size=128, shuffle=False, num_workers=2)
testloader = DataLoader(test_set, batch_size=128, shuffle=False, num_workers=2)




# Define VAE loss function
def vae_loss(recon_x, x, mu, logvar, beta=0.5):
    BCE = nn.functional.mse_loss(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + beta * KLD


# Define training procedure for classification model
def train(net: nn.Module, trainloader: DataLoader, valloader: DataLoader, epochs: int) -> None:
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)

    best_acc = 0.0

    for epoch in range(epochs):
        # Training loop
        net.train()
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        scheduler.step()

        # Validation loop
        net.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for data in valloader:
                images, labels = data
                outputs = net(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        val_acc = 100 * correct / total
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(net.state_dict(), 'best_model.pth')

        print(f"Epoch [{epoch+1}/{epochs}], Training Loss: {running_loss / (i+1):.3f}, Validation Accuracy: {val_acc:.2f}%")

# Define evaluation procedure
def evaluate(net: nn.Module, testloader: DataLoader) -> Tuple[float, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    net.eval()
    correct = 0
    total = 0
    all_labels = []
    all_predictions = []
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            all_labels.extend(labels.cpu().numpy())
            all_predictions.extend(predicted.cpu().numpy())

    accuracy = 100 * correct / total

    # Calculate confusion matrix
    cm = confusion_matrix(all_labels, all_predictions)
    print(f"Confusion Matrix:\n{cm}")

    # Calculate TP, FP, TN, FN for each class
    tp = np.diag(cm)
    fp = cm.sum(axis=0) - tp
    fn = cm.sum(axis=1) - tp
    tn = cm.sum() - (fp + fn + tp)

    # Calculate precision, recall, and F1 score for each class
    precision = precision_score(all_labels, all_predictions, average=None)
    recall = recall_score(all_labels, all_predictions, average=None)
    f1 = f1_score(all_labels, all_predictions, average=None)

    return accuracy, tp, fp, tn, fn, precision, recall, f1

# Initialize clients
def initialize_clients(trainset, transform, batch_size, num_clients):
    clients = {}
    for i in range(num_clients):
        client_trainset = torch.utils.data.Subset(trainset, range(i * len(trainset) // num_clients, (i + 1) * len(trainset) // num_clients))
        client_trainloader = torch.utils.data.DataLoader(client_trainset, batch_size=batch_size, shuffle=True)
        clients[f"client_{i}"] = client_trainloader
    return clients

def get_distribution_info(vae: VAE) -> Dict:
    # Implement the logic to extract distribution information from the VAE
    # This can involve computing statistics, parameters, or any other relevant information
    # that can be used to generate augmented data

    # Example implementation:
    distribution_info = {
        "uniform": {
            "mean": vae.encoder[-1].bias.data.cpu().numpy(),
            "std": torch.exp(0.5 * vae.encoder[-1].weight.data).cpu().numpy()
        },
        "normal": {
            "mean": vae.encoder[-1].bias.data.cpu().numpy(),
            "std": torch.exp(0.5 * vae.encoder[-1].weight.data).cpu().numpy()
        }
    }

    return distribution_info

def send_distribution_info(distribution_info: Dict) -> None:
    # Implement the logic to send the distribution information to the global server
    # This can involve using a network protocol, a message queue, or any other communication mechanism
    # to send the information

    # Example implementation:
    # Send the distribution information to the global server using a network protocol
    # For example, you can use the `socket` module to send the information over a network
    # or use a message queue like `RabbitMQ` to send the information
    pass

def generate_augmented_data(vae: VAE, distribution_info_uniform: Dict, distribution_info_normal: Dict) -> torch.Tensor:
    # Generate augmented data using both uniform and truncated uniform distributions

    mean = distribution_info_uniform["mean"].mean().item()  # Convert numpy array to float
    std = distribution_info_uniform["std"].mean().item()  # Convert numpy array to float


    mean_normal = distribution_info_normal["mean"]
    std_normal = distribution_info_normal["std"]

    augmented_data_normal = torch.randn(64, vae.z_dim) * std_normal + mean_normal

     # Generate augmented data using Uniform distribution
    augmented_data_uniform = torch.FloatTensor(64, vae.z_dim).uniform_(mean - std, mean + std)


    # Calculate the average of augmented data from both distributions
    augmented_data_average = (augmented_data_uniform + augmented_data_normal ) / 2

    return augmented_data_average

def federated_train(net: nn.Module, vae: VAE, trainloaders: Dict[str, DataLoader], trainloader: DataLoader, valloader: DataLoader, epochs: int) -> None:
    for epoch in range(epochs):
        for client_id, client_trainloader in trainloaders.items():
            # Train VAE on client data
            vae_train(vae, client_trainloader, epochs=10, beta=0.5)

            # Share distribution information with global server
            distribution_info = get_distribution_info(vae)
            send_distribution_info(distribution_info)

            # Receive distribution information from other clients
            other_distribution_info = receive_distribution_info()

            # Generate augmented data using received distribution information
            augmented_data = generate_augmented_data(vae, other_distribution_info["uniform"], other_distribution_info["normal"])

            # Train classification model using local, augmented, and validation data
            train(net, client_trainloader, valloader, epochs=10)

            # Send model updates to global server
            send_model_update(client_id, net.state_dict())

# Define logic to receive distribution information from global server
def receive_distribution_info() -> Dict:
    # Receive distribution information logic
    distribution_info = {
        "uniform": {
            "mean": np.zeros(20),  # Adjust the size based on your latent space dimension
            "std": np.ones(20)
        },
        "normal": {
            "mean": np.zeros(20),  # Adjust the size based on your latent space dimension
            "std": np.ones(20)
        }
    }
    return distribution_info

def send_model_update(client_id: str, model_update: Dict) -> None:
    # Implement the logic to send the model update to the global server
    # This can involve using a network protocol, a message queue, or any other communication mechanism
    # to send the model update

    # Example implementation:
    # Send the model update to the global server using a network protocol
    # For example, you can use the `socket` module to send the information over a network
    # or use a message queue like `RabbitMQ` to send the information
    pass

# Define global server procedure
def global_server() -> None:
    net = Net()
    x_dim = 3 * 32 * 32  # CIFAR-10 input size
    h_dim = 400
    z_dim = 20
    vae = VAE(x_dim, h_dim, z_dim)  # Initialize VAE object with required arguments

    # Initialize clients
    num_clients = 5  # Define the number of clients
    clients = initialize_clients(train_set, transform, batch_size=128, num_clients=num_clients)

    # Train model using FedDIS
    federated_train(net, vae, clients, trainloader, valloader, epochs=10)

    # Evaluate final model
    test_accuracy, tp, fp, tn, fn, precision, recall, f1 = evaluate(net, testloader)
    print(f"Test Accuracy: {test_accuracy:.2f}%")

    print("True Positives (TP):", tp)
    print("False Positives (FP):", fp)
    print("True Negatives (TN):", tn)
    print("False Negatives (FN):", fn)
    print(":", fn+tn+tp+fp)
    print("Precision:", precision)
    print("Recall:", recall)
    print("F1 Score:", f1)

if __name__ == "__main__":
    global_server()

100%|██████████| 170M/170M [00:01<00:00, 98.8MB/s]


Random Images per Class: [6024 5975 6030 5992 5894 6008 5978 5917 6166 6016]


  augmented_data_normal = torch.randn(64, vae.z_dim) * std_normal + mean_normal


Epoch [1/10], Training Loss: 2.305, Validation Accuracy: 12.49%
Epoch [2/10], Training Loss: 2.304, Validation Accuracy: 10.42%
Epoch [3/10], Training Loss: 2.303, Validation Accuracy: 10.12%
Epoch [4/10], Training Loss: 2.303, Validation Accuracy: 10.11%
Epoch [5/10], Training Loss: 2.302, Validation Accuracy: 10.11%
Epoch [6/10], Training Loss: 2.302, Validation Accuracy: 10.11%
Epoch [7/10], Training Loss: 2.301, Validation Accuracy: 10.12%
Epoch [8/10], Training Loss: 2.300, Validation Accuracy: 10.18%
Epoch [9/10], Training Loss: 2.299, Validation Accuracy: 10.32%
Epoch [10/10], Training Loss: 2.299, Validation Accuracy: 11.37%
Epoch [1/10], Training Loss: 2.299, Validation Accuracy: 11.90%
Epoch [2/10], Training Loss: 2.298, Validation Accuracy: 13.02%
Epoch [3/10], Training Loss: 2.296, Validation Accuracy: 14.02%
Epoch [4/10], Training Loss: 2.295, Validation Accuracy: 15.70%
Epoch [5/10], Training Loss: 2.293, Validation Accuracy: 16.82%
Epoch [6/10], Training Loss: 2.290, Val