<a href="https://colab.research.google.com/github/Rahad31/Kl-FedDis-Research-/blob/main/normal_uniform.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
import torch
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision import transforms
from torch import nn
from typing import Dict
from scipy.stats import truncnorm

In [None]:

# Define VAE model
class VAE(nn.Module):
    def __init__(self, x_dim, h_dim, z_dim):
        super(VAE, self).__init__()
        self.x_dim = x_dim
        self.h_dim = h_dim
        self.z_dim = z_dim

        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, 4, stride=2, padding=1),  # 16x16
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2, padding=1),  # 8x8
            nn.ReLU(),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),  # 4x4
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(2048, 256),
            nn.ReLU(),
            nn.Linear(256, 2*z_dim)
        )

        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(z_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 2048),
            nn.ReLU(),
            nn.Unflatten(1, (128, 4, 4)),
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),  # 8x8
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1),  # 16x16
            nn.ReLU(),
            nn.ConvTranspose2d(32, 3, 4, stride=2, padding=1),  # 32x32
            nn.Sigmoid()
        )

    def encode(self, x):
        h = self.encoder(x)
        mu, logvar = h[:, :self.z_dim], h[:, self.z_dim:]
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        return self.decoder(z)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar



In [None]:
# Define VAE training procedure
def vae_train(vae: VAE, trainloader: DataLoader, epochs: int) -> None:
    optimizer = torch.optim.Adam(vae.parameters(), lr=1e-3)
    for epoch in range(epochs):
        for i, data in enumerate(trainloader, 0):
            inputs, _ = data
            optimizer.zero_grad()
            recon_x, mu, logvar = vae(inputs)
            loss = vae_loss(recon_x, inputs, mu, logvar)
            loss.backward()
            optimizer.step()

In [None]:
# Define classification model
class Net(nn.Module):
    def __init__(self) -> None:
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.pool(nn.functional.relu(self.conv1(x)))
        x = self.pool(nn.functional.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = nn.functional.relu(self.fc1(x))
        x = nn.functional.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [None]:
# Load CIFAR10 dataset
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)

full_dataset = CIFAR10(root="./data", train=True, download=True, transform=transform)
test_set = CIFAR10(root="./data", train=False, download=True, transform=transform)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:02<00:00, 63066760.12it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [None]:
# Split dataset into training and validation sets
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_set, val_set = torch.utils.data.random_split(full_dataset, [train_size, val_size])


In [None]:
# Create training and validation loaders
trainloader = DataLoader(train_set, batch_size=128, shuffle=True, num_workers=2)
valloader = DataLoader(val_set, batch_size=128, shuffle=False, num_workers=2)
testloader = DataLoader(test_set, batch_size=128, shuffle=False, num_workers=2)


In [None]:
# Define VAE loss function
def vae_loss(recon_x, x, mu, logvar):
    BCE = nn.functional.mse_loss(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

In [None]:
# Define training procedure for classification model
def train(net: nn.Module, trainloader: DataLoader, valloader: DataLoader, epochs: int) -> None:
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

    best_acc = 0.0

    for epoch in range(epochs):
        # Training loop
        net.train()
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        # Validation loop
        net.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for data in valloader:
                images, labels = data
                outputs = net(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        val_acc = 100 * correct / total
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(net.state_dict(), 'best_model.pth')

        print(f"Epoch [{epoch+1}/{epochs}], Training Loss: {running_loss / (i+1):.3f}, Validation Accuracy: {val_acc:.2f}%")


In [None]:
# Define evaluation procedure
def evaluate(net: nn.Module, testloader: DataLoader) -> float:
    net.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    return accuracy

In [None]:
def initialize_clients(trainset, transform, batch_size, num_clients):
    clients = {}
    for i in range(num_clients):
        client_trainset = torch.utils.data.Subset(trainset, range(i * len(trainset) // num_clients, (i + 1) * len(trainset) // num_clients))
        client_trainloader = torch.utils.data.DataLoader(client_trainset, batch_size=batch_size, shuffle=True)
        clients[f"client_{i}"] = client_trainloader
    return clients

In [None]:
def get_distribution_info(vae: VAE) -> Dict:
    # Implement the logic to extract distribution information from the VAE
    # This can involve computing statistics, parameters, or any other relevant information
    # that can be used to generate augmented data

    # Example implementation:
    distribution_info = {
        "normal": {
            "mean": vae.encoder[-1].bias.data.cpu().numpy(),
            "std": torch.exp(0.5 * vae.encoder[-1].weight.data).cpu().numpy()
        },
        "uniform": {
            "mean": vae.encoder[-1].bias.data.cpu().numpy(),
            "std": torch.exp(0.5 * vae.encoder[-1].weight.data).cpu().numpy()
        }
    }

    return distribution_info

In [None]:
def send_distribution_info(distribution_info: Dict) -> None:
    # Implement the logic to send the distribution information to the global server
    # This can involve using a network protocol, a message queue, or any other communication mechanism
    # to send the information

    # Example implementation:
    # Send the distribution information to the global server using a network protocol
    # For example, you can use the `socket` module to send the information over a network
    # or use a message queue like `RabbitMQ` to send the information
    pass


In [None]:
def generate_augmented_data(vae: VAE, distribution_info_normal: Dict, distribution_info_uniform: Dict) -> torch.Tensor:
    # Generate augmented data using both normal and uniform distributions
    mean_normal = distribution_info_normal["mean"]
    std_normal = distribution_info_normal["std"]

    # Generate augmented data from normal distribution
    augmented_data_normal = torch.randn(64, vae.z_dim) * std_normal + mean_normal

    # Generate augmented data using Uniform distribution
    mean = distribution_info_uniform["mean"].mean().item()  # Convert numpy array to float
    std = distribution_info_uniform["std"].mean().item()  # Convert numpy array to float
    augmented_data_uniform = torch.FloatTensor(64, vae.z_dim).uniform_(mean - std, mean + std)

    # Calculate the average of augmented data from both distributions
    augmented_data_average = (augmented_data_normal + augmented_data_uniform) / 2

    return augmented_data_average

In [None]:
# Define the federated training logic
def federated_train(net: nn.Module, vae: VAE, trainloaders: Dict[str, DataLoader], trainloader: DataLoader, valloader: DataLoader, epochs: int) -> None:
    for epoch in range(epochs):
        for client_id, client_trainloader in trainloaders.items():
            # Train VAE on client data
            vae_train(vae, client_trainloader, epochs=10)

            # Share distribution information with global server
            distribution_info = get_distribution_info(vae)
            send_distribution_info(distribution_info)

            # Receive distribution information from other clients
            other_distribution_info = receive_distribution_info()

            # Generate augmented data using received distribution information
            augmented_data = generate_augmented_data(vae, other_distribution_info["normal"], other_distribution_info["uniform"])

            # Train classification model using local, augmented, and validation data
            train(net, client_trainloader, valloader, epochs=10)

            # Send model updates to global server
            send_model_update(client_id, net.state_dict())

In [None]:
# Define logic to receive distribution information from global server
def receive_distribution_info() -> Dict:
    # Receive distribution information logic
    distribution_info = {
        "normal": {
            "mean": np.zeros(20),  # Adjust the size based on your latent space dimension
            "std": np.ones(20)
        },
        "uniform": {
            "mean": np.zeros(20),
            "std": np.ones(20)
        }
    }
    return distribution_info

In [None]:
def send_model_update(client_id: str, model_update: Dict) -> None:
    # Implement the logic to send the model update to the global server
    # This can involve using a network protocol, a message queue, or any other communication mechanism
    # to send the model update

    # Example implementation:
    # Send the model update to the global server using a network protocol
    # For example, you can use the `socket` module to send the information over a network
    # or use a message queue like `RabbitMQ` to send the information
    pass

In [None]:
# Define global server procedure
def global_server() -> None:
    net = Net()
    x_dim = 3 * 32 * 32  # CIFAR-10 input size
    h_dim = 400
    z_dim = 20
    vae = VAE(x_dim, h_dim, z_dim)  # Initialize VAE object with required arguments

    # Initialize clients
    num_clients = 5  # Define the number of clients
    clients = initialize_clients(train_set, transform, batch_size=128, num_clients=num_clients)

    # Train model using FedDIS
    federated_train(net, vae, clients, trainloader, valloader, epochs=5)

    # Evaluate final model
    test_accuracy = evaluate(net, testloader)
    print(f"Test Accuracy: {test_accuracy:.2f}%")

if __name__ == "__main__":
    global_server()

Epoch [1/10], Training Loss: 2.303, Validation Accuracy: 10.18%
Epoch [2/10], Training Loss: 2.303, Validation Accuracy: 10.18%
Epoch [3/10], Training Loss: 2.302, Validation Accuracy: 10.18%
Epoch [4/10], Training Loss: 2.301, Validation Accuracy: 10.18%
Epoch [5/10], Training Loss: 2.300, Validation Accuracy: 10.18%
Epoch [6/10], Training Loss: 2.299, Validation Accuracy: 10.18%
Epoch [7/10], Training Loss: 2.298, Validation Accuracy: 10.18%
Epoch [8/10], Training Loss: 2.296, Validation Accuracy: 10.18%
Epoch [9/10], Training Loss: 2.294, Validation Accuracy: 10.24%
Epoch [10/10], Training Loss: 2.292, Validation Accuracy: 10.93%
Epoch [1/10], Training Loss: 2.289, Validation Accuracy: 12.14%
Epoch [2/10], Training Loss: 2.284, Validation Accuracy: 13.77%
Epoch [3/10], Training Loss: 2.277, Validation Accuracy: 15.59%
Epoch [4/10], Training Loss: 2.265, Validation Accuracy: 16.43%
Epoch [5/10], Training Loss: 2.249, Validation Accuracy: 17.58%
Epoch [6/10], Training Loss: 2.224, Val

In [None]:
import numpy as np
from scipy.stats import truncnorm
import torch
import torchvision
import torchvision.transforms as transforms
from torch import nn
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from typing import Dict, Tuple
from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score

# Define VAE model
class VAE(nn.Module):
    def __init__(self, x_dim, h_dim, z_dim):
        super(VAE, self).__init__()
        self.x_dim = x_dim
        self.h_dim = h_dim
        self.z_dim = z_dim

        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, 4, stride=2, padding=1),  # 16x16
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2, padding=1),  # 8x8
            nn.ReLU(),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),  # 4x4
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(2048, 256),
            nn.ReLU(),
            nn.Linear(256, 2*z_dim)
        )

        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(z_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 2048),
            nn.ReLU(),
            nn.Unflatten(1, (128, 4, 4)),
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),  # 8x8
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1),  # 16x16
            nn.ReLU(),
            nn.ConvTranspose2d(32, 3, 4, stride=2, padding=1),  # 32x32
            nn.Sigmoid()
        )

    def encode(self, x):
        h = self.encoder(x)
        mu, logvar = h[:, :self.z_dim], h[:, self.z_dim:]
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        return self.decoder(z)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar


# Define VAE training procedure
def vae_train(vae: VAE, trainloader: DataLoader, epochs: int) -> None:
    optimizer = torch.optim.Adam(vae.parameters(), lr=1e-3)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)

    for epoch in range(epochs):
        vae.train()
        for i, data in enumerate(trainloader, 0):
            inputs, _ = data
            optimizer.zero_grad()
            recon_x, mu, logvar = vae(inputs)
            loss = vae_loss(recon_x, inputs, mu, logvar)
            loss.backward()
            optimizer.step()

        scheduler.step()


# Define classification model
class Net(nn.Module):
    def __init__(self) -> None:
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.pool(nn.functional.relu(self.conv1(x)))
        x = self.pool(nn.functional.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = nn.functional.relu(self.fc1(x))
        x = nn.functional.relu(self.fc2(x))
        x = self.fc3(x)
        return x


# Load CIFAR10 dataset
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)

full_dataset = CIFAR10(root="./data", train=True, download=True, transform=transform)
test_set = CIFAR10(root="./data", train=False, download=True, transform=transform)

# Split dataset into training and validation sets
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_set, val_set = torch.utils.data.random_split(full_dataset, [train_size, val_size])

# Create training and validation loaders
trainloader = DataLoader(train_set, batch_size=128, shuffle=True, num_workers=2)
valloader = DataLoader(val_set, batch_size=128, shuffle=False, num_workers=2)
testloader = DataLoader(test_set, batch_size=128, shuffle=False, num_workers=2)



# Define VAE loss function
def vae_loss(recon_x, x, mu, logvar):
    BCE = nn.functional.mse_loss(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

# Define training procedure for classification model
def train(net: nn.Module, trainloader: DataLoader, valloader: DataLoader, epochs: int) -> None:
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)

    best_acc = 0.0

    for epoch in range(epochs):
        # Training loop
        net.train()
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        scheduler.step()

        # Validation loop
        net.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for data in valloader:
                images, labels = data
                outputs = net(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        val_acc = 100 * correct / total
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(net.state_dict(), 'best_model.pth')

        print(f"Epoch [{epoch+1}/{epochs}], Training Loss: {running_loss / (i+1):.3f}, Validation Accuracy: {val_acc:.2f}%")

# Define evaluation procedure
def evaluate(net: nn.Module, testloader: DataLoader) -> Tuple[float, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    net.eval()
    correct = 0
    total = 0
    all_labels = []
    all_predictions = []
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            all_labels.extend(labels.cpu().numpy())
            all_predictions.extend(predicted.cpu().numpy())

    accuracy = 100 * correct / total

    # Calculate confusion matrix
    cm = confusion_matrix(all_labels, all_predictions)
    print(f"Confusion Matrix:\n{cm}")

    # Calculate TP, FP, TN, FN for each class
    tp = np.diag(cm)
    fp = cm.sum(axis=0) - tp
    fn = cm.sum(axis=1) - tp
    tn = cm.sum() - (fp + fn + tp)

    # Calculate precision, recall, and F1 score for each class
    precision = precision_score(all_labels, all_predictions, average=None)
    recall = recall_score(all_labels, all_predictions, average=None)
    f1 = f1_score(all_labels, all_predictions, average=None)

    return accuracy, tp, fp, tn, fn, precision, recall, f1

# Initialize clients
def initialize_clients(trainset, transform, batch_size, num_clients):
    clients = {}
    for i in range(num_clients):
        client_trainset = torch.utils.data.Subset(trainset, range(i * len(trainset) // num_clients, (i + 1) * len(trainset) // num_clients))
        client_trainloader = torch.utils.data.DataLoader(client_trainset, batch_size=batch_size, shuffle=True)
        clients[f"client_{i}"] = client_trainloader
    return clients

def get_distribution_info(vae: VAE) -> Dict:
    # Implement the logic to extract distribution information from the VAE
    # This can involve computing statistics, parameters, or any other relevant information
    # that can be used to generate augmented data

    # Example implementation:
    distribution_info = {
        "normal": {
            "mean": vae.encoder[-1].bias.data.cpu().numpy(),
            "std": torch.exp(0.5 * vae.encoder[-1].weight.data).cpu().numpy()
        },
        "uniform": {
            "mean": vae.encoder[-1].bias.data.cpu().numpy(),
            "std": torch.exp(0.5 * vae.encoder[-1].weight.data).cpu().numpy()
        }
    }

    return distribution_info

def send_distribution_info(distribution_info: Dict) -> None:
    # Implement the logic to send the distribution information to the global server
    # This can involve using a network protocol, a message queue, or any other communication mechanism
    # to send the information

    # Example implementation:
    # Send the distribution information to the global server using a network protocol
    # For example, you can use the `socket` module to send the information over a network
    # or use a message queue like `RabbitMQ` to send the information
    pass

def generate_augmented_data(vae: VAE, distribution_info_normal: Dict, distribution_info_uniform: Dict) -> torch.Tensor:
    # Generate augmented data using both normal and uniform distributions
    mean_normal = distribution_info_normal["mean"]
    std_normal = distribution_info_normal["std"]

    # Generate augmented data from normal distribution
    augmented_data_normal = torch.randn(64, vae.z_dim) * std_normal + mean_normal

    # Generate augmented data using Uniform distribution
    mean = distribution_info_uniform["mean"].mean().item()  # Convert numpy array to float
    std = distribution_info_uniform["std"].mean().item()  # Convert numpy array to float
    augmented_data_uniform = torch.FloatTensor(64, vae.z_dim).uniform_(mean - std, mean + std)

    # Calculate the average of augmented data from both distributions
    augmented_data_average = (augmented_data_normal + augmented_data_uniform) / 2

    return augmented_data_average

def federated_train(net: nn.Module, vae: VAE, trainloaders: Dict[str, DataLoader], trainloader: DataLoader, valloader: DataLoader, epochs: int) -> None:
    for epoch in range(epochs):
        for client_id, client_trainloader in trainloaders.items():
            # Train VAE on client data
            vae_train(vae, client_trainloader, epochs=10)

            # Share distribution information with global server
            distribution_info = get_distribution_info(vae)
            send_distribution_info(distribution_info)

            # Receive distribution information from other clients
            other_distribution_info = receive_distribution_info()

            # Generate augmented data using received distribution information
            augmented_data = generate_augmented_data(vae, other_distribution_info["normal"], other_distribution_info["uniform"])

            # Train classification model using local, augmented, and validation data
            train(net, client_trainloader, valloader, epochs=10)

            # Send model updates to global server
            send_model_update(client_id, net.state_dict())

# Define logic to receive distribution information from global server
def receive_distribution_info() -> Dict:
    # Receive distribution information logic
    distribution_info = {
        "normal": {
            "mean": np.zeros(20),  # Adjust the size based on your latent space dimension
            "std": np.ones(20)
        },
        "uniform": {
            "mean": np.zeros(20),
            "std": np.ones(20)
        }
    }
    return distribution_info

def send_model_update(client_id: str, model_update: Dict) -> None:
    # Implement the logic to send the model update to the global server
    # This can involve using a network protocol, a message queue, or any other communication mechanism
    # to send the model update

    # Example implementation:
    # Send the model update to the global server using a network protocol
    # For example, you can use the `socket` module to send the information over a network
    # or use a message queue like `RabbitMQ` to send the information
    pass

# Define global server procedure
def global_server() -> None:
    net = Net()
    x_dim = 3 * 32 * 32  # CIFAR-10 input size
    h_dim = 400
    z_dim = 20
    vae = VAE(x_dim, h_dim, z_dim)  # Initialize VAE object with required arguments

    # Initialize clients
    num_clients = 5  # Define the number of clients
    clients = initialize_clients(train_set, transform, batch_size=128, num_clients=num_clients)

    # Train model using FedDIS
    federated_train(net, vae, clients, trainloader, valloader, epochs=10)

    # Evaluate final model
    test_accuracy, tp, fp, tn, fn, precision, recall, f1 = evaluate(net, testloader)
    print(f"Test Accuracy: {test_accuracy:.2f}%")

    print("True Positives (TP):", tp)
    print("False Positives (FP):", fp)
    print("True Negatives (TN):", tn)
    print("False Negatives (FN):", fn)
    print(":", fn+tn+tp+fp)
    print("Precision:", precision)
    print("Recall:", recall)
    print("F1 Score:", f1)

if __name__ == "__main__":
    global_server()


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:03<00:00, 44975392.00it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified
Epoch [1/10], Training Loss: 2.305, Validation Accuracy: 9.89%
Epoch [2/10], Training Loss: 2.304, Validation Accuracy: 9.89%
Epoch [3/10], Training Loss: 2.303, Validation Accuracy: 9.89%
Epoch [4/10], Training Loss: 2.302, Validation Accuracy: 9.89%
Epoch [5/10], Training Loss: 2.302, Validation Accuracy: 9.89%
Epoch [6/10], Training Loss: 2.301, Validation Accuracy: 9.89%
Epoch [7/10], Training Loss: 2.300, Validation Accuracy: 9.89%
Epoch [8/10], Training Loss: 2.299, Validation Accuracy: 9.89%
Epoch [9/10], Training Loss: 2.298, Validation Accuracy: 9.89%
Epoch [10/10], Training Loss: 2.297, Validation Accuracy: 9.89%
Epoch [1/10], Training Loss: 2.297, Validation Accuracy: 9.89%
Epoch [2/10], Training Loss: 2.296, Validation Accuracy: 9.90%
Epoch [3/10], Training Loss: 2.294, Validation Accuracy: 10.05%
Epoch [4/10], Training Loss: 2.291, Validation Accuracy: 11.80%
Epoch [5/10], Training Los

In [None]:
import re

# Your provided text
log = """
Epoch [1/10], Training Loss: 2.305, Validation Accuracy: 9.89%
Epoch [2/10], Training Loss: 2.304, Validation Accuracy: 9.89%
Epoch [3/10], Training Loss: 2.303, Validation Accuracy: 9.89%
Epoch [4/10], Training Loss: 2.302, Validation Accuracy: 9.89%
Epoch [5/10], Training Loss: 2.302, Validation Accuracy: 9.89%
Epoch [6/10], Training Loss: 2.301, Validation Accuracy: 9.89%
Epoch [7/10], Training Loss: 2.300, Validation Accuracy: 9.89%
Epoch [8/10], Training Loss: 2.299, Validation Accuracy: 9.89%
Epoch [9/10], Training Loss: 2.298, Validation Accuracy: 9.89%
Epoch [10/10], Training Loss: 2.297, Validation Accuracy: 9.89%
Epoch [1/10], Training Loss: 2.297, Validation Accuracy: 9.89%
Epoch [2/10], Training Loss: 2.296, Validation Accuracy: 9.90%
Epoch [3/10], Training Loss: 2.294, Validation Accuracy: 10.05%
Epoch [4/10], Training Loss: 2.291, Validation Accuracy: 11.80%
Epoch [5/10], Training Loss: 2.288, Validation Accuracy: 14.01%
Epoch [6/10], Training Loss: 2.283, Validation Accuracy: 14.36%
Epoch [7/10], Training Loss: 2.276, Validation Accuracy: 14.27%
Epoch [8/10], Training Loss: 2.266, Validation Accuracy: 15.41%
Epoch [9/10], Training Loss: 2.252, Validation Accuracy: 16.01%
Epoch [10/10], Training Loss: 2.231, Validation Accuracy: 17.05%
Epoch [1/10], Training Loss: 2.213, Validation Accuracy: 19.88%
Epoch [2/10], Training Loss: 2.179, Validation Accuracy: 21.49%
Epoch [3/10], Training Loss: 2.142, Validation Accuracy: 23.10%
Epoch [4/10], Training Loss: 2.115, Validation Accuracy: 24.23%
Epoch [5/10], Training Loss: 2.090, Validation Accuracy: 25.19%
Epoch [6/10], Training Loss: 2.063, Validation Accuracy: 26.34%
Epoch [7/10], Training Loss: 2.039, Validation Accuracy: 26.71%
Epoch [8/10], Training Loss: 2.013, Validation Accuracy: 27.22%
Epoch [9/10], Training Loss: 1.991, Validation Accuracy: 27.68%
Epoch [10/10], Training Loss: 1.972, Validation Accuracy: 27.52%
Epoch [1/10], Training Loss: 1.961, Validation Accuracy: 28.60%
Epoch [2/10], Training Loss: 1.944, Validation Accuracy: 29.16%
Epoch [3/10], Training Loss: 1.930, Validation Accuracy: 29.75%
Epoch [4/10], Training Loss: 1.917, Validation Accuracy: 30.42%
Epoch [5/10], Training Loss: 1.902, Validation Accuracy: 31.17%
Epoch [6/10], Training Loss: 1.886, Validation Accuracy: 31.34%
Epoch [7/10], Training Loss: 1.871, Validation Accuracy: 31.20%
Epoch [8/10], Training Loss: 1.857, Validation Accuracy: 31.76%
Epoch [9/10], Training Loss: 1.837, Validation Accuracy: 32.39%
Epoch [10/10], Training Loss: 1.820, Validation Accuracy: 33.22%
Epoch [1/10], Training Loss: 1.830, Validation Accuracy: 33.81%
Epoch [2/10], Training Loss: 1.816, Validation Accuracy: 34.08%
Epoch [3/10], Training Loss: 1.802, Validation Accuracy: 34.79%
Epoch [4/10], Training Loss: 1.787, Validation Accuracy: 34.94%
Epoch [5/10], Training Loss: 1.775, Validation Accuracy: 35.08%
Epoch [6/10], Training Loss: 1.759, Validation Accuracy: 35.60%
Epoch [7/10], Training Loss: 1.748, Validation Accuracy: 35.07%
Epoch [8/10], Training Loss: 1.738, Validation Accuracy: 36.25%
Epoch [9/10], Training Loss: 1.722, Validation Accuracy: 36.16%
Epoch [10/10], Training Loss: 1.708, Validation Accuracy: 37.01%
Epoch [1/10], Training Loss: 1.713, Validation Accuracy: 37.73%
Epoch [2/10], Training Loss: 1.697, Validation Accuracy: 37.92%
Epoch [3/10], Training Loss: 1.679, Validation Accuracy: 38.34%
Epoch [4/10], Training Loss: 1.670, Validation Accuracy: 39.26%
Epoch [5/10], Training Loss: 1.648, Validation Accuracy: 39.74%
Epoch [6/10], Training Loss: 1.633, Validation Accuracy: 40.38%
Epoch [7/10], Training Loss: 1.621, Validation Accuracy: 40.66%
Epoch [8/10], Training Loss: 1.613, Validation Accuracy: 40.50%
Epoch [9/10], Training Loss: 1.606, Validation Accuracy: 41.16%
Epoch [10/10], Training Loss: 1.589, Validation Accuracy: 41.16%
Epoch [1/10], Training Loss: 1.594, Validation Accuracy: 41.27%
Epoch [2/10], Training Loss: 1.582, Validation Accuracy: 41.93%
Epoch [3/10], Training Loss: 1.569, Validation Accuracy: 42.55%
Epoch [4/10], Training Loss: 1.566, Validation Accuracy: 41.94%
Epoch [5/10], Training Loss: 1.550, Validation Accuracy: 42.29%
Epoch [6/10], Training Loss: 1.542, Validation Accuracy: 43.00%
Epoch [7/10], Training Loss: 1.536, Validation Accuracy: 42.71%
Epoch [8/10], Training Loss: 1.527, Validation Accuracy: 43.48%
Epoch [9/10], Training Loss: 1.525, Validation Accuracy: 43.33%
Epoch [10/10], Training Loss: 1.519, Validation Accuracy: 43.78%
Epoch [1/10], Training Loss: 1.537, Validation Accuracy: 43.60%
Epoch [2/10], Training Loss: 1.523, Validation Accuracy: 44.25%
Epoch [3/10], Training Loss: 1.513, Validation Accuracy: 44.79%
Epoch [4/10], Training Loss: 1.506, Validation Accuracy: 44.12%
Epoch [5/10], Training Loss: 1.502, Validation Accuracy: 44.43%
Epoch [6/10], Training Loss: 1.498, Validation Accuracy: 44.76%
Epoch [7/10], Training Loss: 1.488, Validation Accuracy: 45.19%
Epoch [8/10], Training Loss: 1.483, Validation Accuracy: 45.29%
Epoch [9/10], Training Loss: 1.474, Validation Accuracy: 45.16%
Epoch [10/10], Training Loss: 1.464, Validation Accuracy: 45.61%
Epoch [1/10], Training Loss: 1.470, Validation Accuracy: 45.77%
Epoch [2/10], Training Loss: 1.459, Validation Accuracy: 45.85%
Epoch [3/10], Training Loss: 1.454, Validation Accuracy: 46.24%
Epoch [4/10], Training Loss: 1.441, Validation Accuracy: 46.14%
Epoch [5/10], Training Loss: 1.429, Validation Accuracy: 46.65%
Epoch [6/10], Training Loss: 1.438, Validation Accuracy: 46.16%
Epoch [7/10], Training Loss: 1.423, Validation Accuracy: 46.26%
Epoch [8/10], Training Loss: 1.416, Validation Accuracy: 46.58%
Epoch [9/10], Training Loss: 1.409, Validation Accuracy: 47.02%
Epoch [10/10], Training Loss: 1.402, Validation Accuracy: 47.42%
Epoch [1/10], Training Loss: 1.458, Validation Accuracy: 47.18%
Epoch [2/10], Training Loss: 1.449, Validation Accuracy: 47.64%
Epoch [3/10], Training Loss: 1.441, Validation Accuracy: 47.32%
Epoch [4/10], Training Loss: 1.434, Validation Accuracy: 47.77%
Epoch [5/10], Training Loss: 1.421, Validation Accuracy: 47.66%
Epoch [6/10], Training Loss: 1.415, Validation Accuracy: 47.22%
Epoch [7/10], Training Loss: 1.413, Validation Accuracy: 48.34%
Epoch [8/10], Training Loss: 1.401, Validation Accuracy: 48.53%
Epoch [9/10], Training Loss: 1.389, Validation Accuracy: 48.21%
Epoch [10/10], Training Loss: 1.385, Validation Accuracy: 46.46%
Epoch [1/10], Training Loss: 1.431, Validation Accuracy: 48.53%
Epoch [2/10], Training Loss: 1.413, Validation Accuracy: 48.52%
Epoch [3/10], Training Loss: 1.401, Validation Accuracy: 48.53%
Epoch [4/10], Training Loss: 1.395, Validation Accuracy: 49.06%
Epoch [5/10], Training Loss: 1.383, Validation Accuracy: 48.32%
Epoch [6/10], Training Loss: 1.374, Validation Accuracy: 49.13%
Epoch [7/10], Training Loss: 1.371, Validation Accuracy: 49.09%
Epoch [8/10], Training Loss: 1.360, Validation Accuracy: 48.39%
Epoch [9/10], Training Loss: 1.356, Validation Accuracy: 49.79%
Epoch [10/10], Training Loss: 1.344, Validation Accuracy: 49.42%
Epoch [1/10], Training Loss: 1.383, Validation Accuracy: 49.62%
Epoch [2/10], Training Loss: 1.365, Validation Accuracy: 50.18%
Epoch [3/10], Training Loss: 1.350, Validation Accuracy: 49.71%
Epoch [4/10], Training Loss: 1.346, Validation Accuracy: 50.38%
Epoch [5/10], Training Loss: 1.339, Validation Accuracy: 50.64%
Epoch [6/10], Training Loss: 1.332, Validation Accuracy: 50.49%
Epoch [7/10], Training Loss: 1.320, Validation Accuracy: 50.81%
Epoch [8/10], Training Loss: 1.319, Validation Accuracy: 50.69%
Epoch [9/10], Training Loss: 1.306, Validation Accuracy: 50.79%
Epoch [10/10], Training Loss: 1.305, Validation Accuracy: 50.40%
Epoch [1/10], Training Loss: 1.347, Validation Accuracy: 51.25%
Epoch [2/10], Training Loss: 1.339, Validation Accuracy: 51.56%
Epoch [3/10], Training Loss: 1.319, Validation Accuracy: 51.83%
Epoch [4/10], Training Loss: 1.307, Validation Accuracy: 51.59%
Epoch [5/10], Training Loss: 1.309, Validation Accuracy: 51.28%
Epoch [6/10], Training Loss: 1.293, Validation Accuracy: 51.29%
Epoch [7/10], Training Loss: 1.291, Validation Accuracy: 51.41%
Epoch [8/10], Training Loss: 1.277, Validation Accuracy: 52.19%
Epoch [9/10], Training Loss: 1.266, Validation Accuracy: 51.60%
Epoch [10/10], Training Loss: 1.257, Validation Accuracy: 51.63%
Epoch [1/10], Training Loss: 1.304, Validation Accuracy: 52.33%
Epoch [2/10], Training Loss: 1.282, Validation Accuracy: 52.37%
Epoch [3/10], Training Loss: 1.263, Validation Accuracy: 52.24%
Epoch [4/10], Training Loss: 1.257, Validation Accuracy: 52.40%
Epoch [5/10], Training Loss: 1.257, Validation Accuracy: 52.10%
Epoch [6/10], Training Loss: 1.239, Validation Accuracy: 53.48%
Epoch [7/10], Training Loss: 1.235, Validation Accuracy: 53.22%
Epoch [8/10], Training Loss: 1.220, Validation Accuracy: 52.68%
Epoch [9/10], Training Loss: 1.212, Validation Accuracy: 52.91%
Epoch [10/10], Training Loss: 1.215, Validation Accuracy: 51.90%
Epoch [1/10], Training Loss: 1.296, Validation Accuracy: 52.50%
Epoch [2/10], Training Loss: 1.270, Validation Accuracy: 52.99%
Epoch [3/10], Training Loss: 1.252, Validation Accuracy: 53.83%
Epoch [4/10], Training Loss: 1.240, Validation Accuracy: 53.30%
Epoch [5/10], Training Loss: 1.233, Validation Accuracy: 53.85%
Epoch [6/10], Training Loss: 1.225, Validation Accuracy: 53.61%
Epoch [7/10], Training Loss: 1.209, Validation Accuracy: 53.76%
Epoch [8/10], Training Loss: 1.207, Validation Accuracy: 52.71%
Epoch [9/10], Training Loss: 1.200, Validation Accuracy: 53.21%
Epoch [10/10], Training Loss: 1.188, Validation Accuracy: 54.21%
Epoch [1/10], Training Loss: 1.267, Validation Accuracy: 54.21%
Epoch [2/10], Training Loss: 1.249, Validation Accuracy: 53.70%
Epoch [3/10], Training Loss: 1.229, Validation Accuracy: 54.41%
Epoch [4/10], Training Loss: 1.213, Validation Accuracy: 54.99%
Epoch [5/10], Training Loss: 1.199, Validation Accuracy: 54.79%
Epoch [6/10], Training Loss: 1.199, Validation Accuracy: 54.27%
Epoch [7/10], Training Loss: 1.184, Validation Accuracy: 54.76%
Epoch [8/10], Training Loss: 1.174, Validation Accuracy: 54.35%
Epoch [9/10], Training Loss: 1.166, Validation Accuracy: 54.77%
Epoch [10/10], Training Loss: 1.157, Validation Accuracy: 54.92%
Epoch [1/10], Training Loss: 1.236, Validation Accuracy: 55.68%
Epoch [2/10], Training Loss: 1.213, Validation Accuracy: 55.70%
Epoch [3/10], Training Loss: 1.202, Validation Accuracy: 54.97%
Epoch [4/10], Training Loss: 1.191, Validation Accuracy: 55.97%
Epoch [5/10], Training Loss: 1.176, Validation Accuracy: 55.82%
Epoch [6/10], Training Loss: 1.165, Validation Accuracy: 55.66%
Epoch [7/10], Training Loss: 1.160, Validation Accuracy: 55.47%
Epoch [8/10], Training Loss: 1.149, Validation Accuracy: 56.19%
Epoch [9/10], Training Loss: 1.138, Validation Accuracy: 56.17%
Epoch [10/10], Training Loss: 1.130, Validation Accuracy: 55.90%
Epoch [1/10], Training Loss: 1.218, Validation Accuracy: 55.22%
Epoch [2/10], Training Loss: 1.190, Validation Accuracy: 56.28%
Epoch [3/10], Training Loss: 1.182, Validation Accuracy: 55.57%
Epoch [4/10], Training Loss: 1.163, Validation Accuracy: 56.08%
Epoch [5/10], Training Loss: 1.154, Validation Accuracy: 55.89%
Epoch [6/10], Training Loss: 1.136, Validation Accuracy: 55.74%
Epoch [7/10], Training Loss: 1.128, Validation Accuracy: 55.81%
Epoch [8/10], Training Loss: 1.121, Validation Accuracy: 56.81%
Epoch [9/10], Training Loss: 1.106, Validation Accuracy: 56.53%
Epoch [10/10], Training Loss: 1.101, Validation Accuracy: 55.60%
Epoch [1/10], Training Loss: 1.169, Validation Accuracy: 54.27%
Epoch [2/10], Training Loss: 1.147, Validation Accuracy: 57.01%
Epoch [3/10], Training Loss: 1.126, Validation Accuracy: 56.94%
Epoch [4/10], Training Loss: 1.110, Validation Accuracy: 56.63%
Epoch [5/10], Training Loss: 1.097, Validation Accuracy: 56.18%
Epoch [6/10], Training Loss: 1.093, Validation Accuracy: 56.34%
Epoch [7/10], Training Loss: 1.087, Validation Accuracy: 56.97%
Epoch [8/10], Training Loss: 1.059, Validation Accuracy: 56.58%
Epoch [9/10], Training Loss: 1.057, Validation Accuracy: 57.59%
Epoch [10/10], Training Loss: 1.037, Validation Accuracy: 57.36%
Epoch [1/10], Training Loss: 1.177, Validation Accuracy: 57.54%
Epoch [2/10], Training Loss: 1.143, Validation Accuracy: 57.49%
Epoch [3/10], Training Loss: 1.124, Validation Accuracy: 57.50%
Epoch [4/10], Training Loss: 1.103, Validation Accuracy: 57.57%
Epoch [5/10], Training Loss: 1.093, Validation Accuracy: 57.98%
Epoch [6/10], Training Loss: 1.084, Validation Accuracy: 57.73%
Epoch [7/10], Training Loss: 1.067, Validation Accuracy: 57.86%
Epoch [8/10], Training Loss: 1.065, Validation Accuracy: 56.37%
Epoch [9/10], Training Loss: 1.047, Validation Accuracy: 57.81%
Epoch [10/10], Training Loss: 1.045, Validation Accuracy: 57.96%
Epoch [1/10], Training Loss: 1.155, Validation Accuracy: 57.84%
Epoch [2/10], Training Loss: 1.123, Validation Accuracy: 57.90%
Epoch [3/10], Training Loss: 1.101, Validation Accuracy: 58.31%
Epoch [4/10], Training Loss: 1.085, Validation Accuracy: 58.11%
Epoch [5/10], Training Loss: 1.071, Validation Accuracy: 57.74%
Epoch [6/10], Training Loss: 1.061, Validation Accuracy: 57.87%
Epoch [7/10], Training Loss: 1.052, Validation Accuracy: 57.58%
Epoch [8/10], Training Loss: 1.035, Validation Accuracy: 57.87%
Epoch [9/10], Training Loss: 1.021, Validation Accuracy: 57.69%
Epoch [10/10], Training Loss: 1.018, Validation Accuracy: 58.39%
Epoch [1/10], Training Loss: 1.143, Validation Accuracy: 57.43%
Epoch [2/10], Training Loss: 1.112, Validation Accuracy: 57.52%
Epoch [3/10], Training Loss: 1.092, Validation Accuracy: 58.54%
Epoch [4/10], Training Loss: 1.073, Validation Accuracy: 58.99%
Epoch [5/10], Training Loss: 1.053, Validation Accuracy: 58.65%
Epoch [6/10], Training Loss: 1.035, Validation Accuracy: 58.87%
Epoch [7/10], Training Loss: 1.033, Validation Accuracy: 58.70%
Epoch [8/10], Training Loss: 1.023, Validation Accuracy: 57.97%
Epoch [9/10], Training Loss: 1.005, Validation Accuracy: 58.61%
Epoch [10/10], Training Loss: 0.991, Validation Accuracy: 58.78%
Epoch [1/10], Training Loss: 1.124, Validation Accuracy: 58.41%
Epoch [2/10], Training Loss: 1.092, Validation Accuracy: 58.99%
Epoch [3/10], Training Loss: 1.069, Validation Accuracy: 58.86%
Epoch [4/10], Training Loss: 1.054, Validation Accuracy: 58.51%
Epoch [5/10], Training Loss: 1.035, Validation Accuracy: 59.07%
Epoch [6/10], Training Loss: 1.027, Validation Accuracy: 58.84%
Epoch [7/10], Training Loss: 1.014, Validation Accuracy: 58.63%
Epoch [8/10], Training Loss: 0.995, Validation Accuracy: 58.58%
Epoch [9/10], Training Loss: 0.987, Validation Accuracy: 58.82%
Epoch [10/10], Training Loss: 0.976, Validation Accuracy: 59.16%
Epoch [1/10], Training Loss: 1.078, Validation Accuracy: 58.80%
Epoch [2/10], Training Loss: 1.041, Validation Accuracy: 59.16%
Epoch [3/10], Training Loss: 1.014, Validation Accuracy: 59.32%
Epoch [4/10], Training Loss: 0.995, Validation Accuracy: 58.74%
Epoch [5/10], Training Loss: 0.984, Validation Accuracy: 58.73%
Epoch [6/10], Training Loss: 0.969, Validation Accuracy: 58.45%
Epoch [7/10], Training Loss: 0.958, Validation Accuracy: 58.41%
Epoch [8/10], Training Loss: 0.950, Validation Accuracy: 58.57%
Epoch [9/10], Training Loss: 0.936, Validation Accuracy: 58.25%
Epoch [10/10], Training Loss: 0.923, Validation Accuracy: 59.41%
Epoch [1/10], Training Loss: 1.084, Validation Accuracy: 59.75%
Epoch [2/10], Training Loss: 1.048, Validation Accuracy: 59.60%
Epoch [3/10], Training Loss: 1.024, Validation Accuracy: 58.71%
Epoch [4/10], Training Loss: 0.999, Validation Accuracy: 58.89%
Epoch [5/10], Training Loss: 0.986, Validation Accuracy: 59.49%
Epoch [6/10], Training Loss: 0.972, Validation Accuracy: 59.39%
Epoch [7/10], Training Loss: 0.956, Validation Accuracy: 59.59%
Epoch [8/10], Training Loss: 0.939, Validation Accuracy: 58.78%
Epoch [9/10], Training Loss: 0.930, Validation Accuracy: 60.04%
Epoch [10/10], Training Loss: 0.925, Validation Accuracy: 59.48%
Epoch [1/10], Training Loss: 1.071, Validation Accuracy: 60.47%
Epoch [2/10], Training Loss: 1.024, Validation Accuracy: 60.00%
Epoch [3/10], Training Loss: 1.005, Validation Accuracy: 60.19%
Epoch [4/10], Training Loss: 0.978, Validation Accuracy: 59.80%
Epoch [5/10], Training Loss: 0.962, Validation Accuracy: 59.99%
Epoch [6/10], Training Loss: 0.952, Validation Accuracy: 59.92%
Epoch [7/10], Training Loss: 0.932, Validation Accuracy: 59.81%
Epoch [8/10], Training Loss: 0.920, Validation Accuracy: 60.10%
Epoch [9/10], Training Loss: 0.910, Validation Accuracy: 60.04%
Epoch [10/10], Training Loss: 0.896, Validation Accuracy: 59.38%
Epoch [1/10], Training Loss: 1.059, Validation Accuracy: 60.48%
Epoch [2/10], Training Loss: 1.016, Validation Accuracy: 59.64%
Epoch [3/10], Training Loss: 0.993, Validation Accuracy: 59.80%
Epoch [4/10], Training Loss: 0.974, Validation Accuracy: 59.81%
Epoch [5/10], Training Loss: 0.961, Validation Accuracy: 60.12%
Epoch [6/10], Training Loss: 0.933, Validation Accuracy: 60.03%
Epoch [7/10], Training Loss: 0.926, Validation Accuracy: 58.60%
Epoch [8/10], Training Loss: 0.912, Validation Accuracy: 60.66%
Epoch [9/10], Training Loss: 0.894, Validation Accuracy: 60.23%
Epoch [10/10], Training Loss: 0.882, Validation Accuracy: 60.01%
Epoch [1/10], Training Loss: 1.054, Validation Accuracy: 60.57%
Epoch [2/10], Training Loss: 1.009, Validation Accuracy: 60.59%
Epoch [3/10], Training Loss: 0.980, Validation Accuracy: 60.49%
Epoch [4/10], Training Loss: 0.959, Validation Accuracy: 60.45%
Epoch [5/10], Training Loss: 0.941, Validation Accuracy: 59.54%
Epoch [6/10], Training Loss: 0.929, Validation Accuracy: 60.58%
Epoch [7/10], Training Loss: 0.904, Validation Accuracy: 60.20%
Epoch [8/10], Training Loss: 0.907, Validation Accuracy: 60.44%
Epoch [9/10], Training Loss: 0.887, Validation Accuracy: 59.70%
Epoch [10/10], Training Loss: 0.881, Validation Accuracy: 59.99%
Epoch [1/10], Training Loss: 1.015, Validation Accuracy: 60.68%
Epoch [2/10], Training Loss: 0.954, Validation Accuracy: 61.33%
Epoch [3/10], Training Loss: 0.931, Validation Accuracy: 61.07%
Epoch [4/10], Training Loss: 0.908, Validation Accuracy: 60.68%
Epoch [5/10], Training Loss: 0.901, Validation Accuracy: 60.42%
Epoch [6/10], Training Loss: 0.876, Validation Accuracy: 60.09%
Epoch [7/10], Training Loss: 0.862, Validation Accuracy: 60.44%
Epoch [8/10], Training Loss: 0.846, Validation Accuracy: 60.24%
Epoch [9/10], Training Loss: 0.836, Validation Accuracy: 59.63%
Epoch [10/10], Training Loss: 0.814, Validation Accuracy: 60.46%
Epoch [1/10], Training Loss: 1.022, Validation Accuracy: 60.50%
Epoch [2/10], Training Loss: 0.971, Validation Accuracy: 60.71%
Epoch [3/10], Training Loss: 0.931, Validation Accuracy: 60.49%
Epoch [4/10], Training Loss: 0.915, Validation Accuracy: 59.92%
Epoch [5/10], Training Loss: 0.895, Validation Accuracy: 60.71%
Epoch [6/10], Training Loss: 0.883, Validation Accuracy: 60.16%
Epoch [7/10], Training Loss: 0.864, Validation Accuracy: 60.86%
Epoch [8/10], Training Loss: 0.852, Validation Accuracy: 60.39%
Epoch [9/10], Training Loss: 0.828, Validation Accuracy: 60.67%
Epoch [10/10], Training Loss: 0.825, Validation Accuracy: 60.05%
Epoch [1/10], Training Loss: 0.999, Validation Accuracy: 60.73%
Epoch [2/10], Training Loss: 0.947, Validation Accuracy: 60.87%
Epoch [3/10], Training Loss: 0.911, Validation Accuracy: 61.27%
Epoch [4/10], Training Loss: 0.894, Validation Accuracy: 60.63%
Epoch [5/10], Training Loss: 0.866, Validation Accuracy: 61.32%
Epoch [6/10], Training Loss: 0.849, Validation Accuracy: 61.15%
Epoch [7/10], Training Loss: 0.832, Validation Accuracy: 60.39%
Epoch [8/10], Training Loss: 0.822, Validation Accuracy: 61.19%
Epoch [9/10], Training Loss: 0.808, Validation Accuracy: 60.94%
Epoch [10/10], Training Loss: 0.788, Validation Accuracy: 61.11%
Epoch [1/10], Training Loss: 1.002, Validation Accuracy: 60.92%
Epoch [2/10], Training Loss: 0.945, Validation Accuracy: 61.37%
Epoch [3/10], Training Loss: 0.908, Validation Accuracy: 61.11%
Epoch [4/10], Training Loss: 0.884, Validation Accuracy: 61.27%
Epoch [5/10], Training Loss: 0.862, Validation Accuracy: 61.47%
Epoch [6/10], Training Loss: 0.844, Validation Accuracy: 61.10%
Epoch [7/10], Training Loss: 0.825, Validation Accuracy: 61.19%
Epoch [8/10], Training Loss: 0.803, Validation Accuracy: 61.06%
Epoch [9/10], Training Loss: 0.798, Validation Accuracy: 60.75%
Epoch [10/10], Training Loss: 0.775, Validation Accuracy: 60.55%
Epoch [1/10], Training Loss: 0.993, Validation Accuracy: 61.34%
Epoch [2/10], Training Loss: 0.941, Validation Accuracy: 60.97%
Epoch [3/10], Training Loss: 0.904, Validation Accuracy: 61.32%
Epoch [4/10], Training Loss: 0.877, Validation Accuracy: 60.36%
Epoch [5/10], Training Loss: 0.854, Validation Accuracy: 61.23%
Epoch [6/10], Training Loss: 0.830, Validation Accuracy: 60.32%
Epoch [7/10], Training Loss: 0.805, Validation Accuracy: 61.30%
Epoch [8/10], Training Loss: 0.789, Validation Accuracy: 61.06%
Epoch [9/10], Training Loss: 0.781, Validation Accuracy: 60.72%
Epoch [10/10], Training Loss: 0.775, Validation Accuracy: 60.84%
Epoch [1/10], Training Loss: 0.953, Validation Accuracy: 61.68%
Epoch [2/10], Training Loss: 0.887, Validation Accuracy: 61.24%
Epoch [3/10], Training Loss: 0.848, Validation Accuracy: 60.67%
Epoch [4/10], Training Loss: 0.829, Validation Accuracy: 61.34%
Epoch [5/10], Training Loss: 0.795, Validation Accuracy: 61.20%
Epoch [6/10], Training Loss: 0.784, Validation Accuracy: 60.65%
Epoch [7/10], Training Loss: 0.767, Validation Accuracy: 60.73%
Epoch [8/10], Training Loss: 0.745, Validation Accuracy: 60.59%
Epoch [9/10], Training Loss: 0.731, Validation Accuracy: 60.77%
Epoch [10/10], Training Loss: 0.720, Validation Accuracy: 60.91%
Epoch [1/10], Training Loss: 0.963, Validation Accuracy: 61.14%
Epoch [2/10], Training Loss: 0.903, Validation Accuracy: 60.96%
Epoch [3/10], Training Loss: 0.857, Validation Accuracy: 59.82%
Epoch [4/10], Training Loss: 0.828, Validation Accuracy: 60.69%
Epoch [5/10], Training Loss: 0.798, Validation Accuracy: 61.06%
Epoch [6/10], Training Loss: 0.782, Validation Accuracy: 60.88%
Epoch [7/10], Training Loss: 0.765, Validation Accuracy: 60.95%
Epoch [8/10], Training Loss: 0.744, Validation Accuracy: 60.98%
Epoch [9/10], Training Loss: 0.720, Validation Accuracy: 61.06%
Epoch [10/10], Training Loss: 0.713, Validation Accuracy: 60.95%
Epoch [1/10], Training Loss: 0.955, Validation Accuracy: 61.04%
Epoch [2/10], Training Loss: 0.875, Validation Accuracy: 61.61%
Epoch [3/10], Training Loss: 0.837, Validation Accuracy: 62.18%
Epoch [4/10], Training Loss: 0.806, Validation Accuracy: 61.62%
Epoch [5/10], Training Loss: 0.786, Validation Accuracy: 60.82%
Epoch [6/10], Training Loss: 0.761, Validation Accuracy: 61.63%
Epoch [7/10], Training Loss: 0.740, Validation Accuracy: 61.27%
Epoch [8/10], Training Loss: 0.726, Validation Accuracy: 60.90%
Epoch [9/10], Training Loss: 0.710, Validation Accuracy: 60.89%
Epoch [10/10], Training Loss: 0.686, Validation Accuracy: 61.21%
Epoch [1/10], Training Loss: 0.945, Validation Accuracy: 60.91%
Epoch [2/10], Training Loss: 0.867, Validation Accuracy: 61.66%
Epoch [3/10], Training Loss: 0.828, Validation Accuracy: 62.10%
Epoch [4/10], Training Loss: 0.795, Validation Accuracy: 60.03%
Epoch [5/10], Training Loss: 0.794, Validation Accuracy: 61.86%
Epoch [6/10], Training Loss: 0.760, Validation Accuracy: 61.55%
Epoch [7/10], Training Loss: 0.726, Validation Accuracy: 61.93%
Epoch [8/10], Training Loss: 0.712, Validation Accuracy: 60.96%
Epoch [9/10], Training Loss: 0.696, Validation Accuracy: 61.50%
Epoch [10/10], Training Loss: 0.681, Validation Accuracy: 61.02%
Epoch [1/10], Training Loss: 0.949, Validation Accuracy: 60.93%
Epoch [2/10], Training Loss: 0.869, Validation Accuracy: 61.45%
Epoch [3/10], Training Loss: 0.821, Validation Accuracy: 61.74%
Epoch [4/10], Training Loss: 0.788, Validation Accuracy: 61.18%
Epoch [5/10], Training Loss: 0.762, Validation Accuracy: 61.11%
Epoch [6/10], Training Loss: 0.736, Validation Accuracy: 61.70%
Epoch [7/10], Training Loss: 0.721, Validation Accuracy: 61.10%
Epoch [8/10], Training Loss: 0.706, Validation Accuracy: 60.79%
Epoch [9/10], Training Loss: 0.678, Validation Accuracy: 61.04%
Epoch [10/10], Training Loss: 0.662, Validation Accuracy: 61.28%
Epoch [1/10], Training Loss: 0.917, Validation Accuracy: 61.15%
Epoch [2/10], Training Loss: 0.835, Validation Accuracy: 61.20%
Epoch [3/10], Training Loss: 0.787, Validation Accuracy: 60.89%
Epoch [4/10], Training Loss: 0.756, Validation Accuracy: 61.70%
Epoch [5/10], Training Loss: 0.726, Validation Accuracy: 61.28%
Epoch [6/10], Training Loss: 0.692, Validation Accuracy: 61.60%
Epoch [7/10], Training Loss: 0.674, Validation Accuracy: 61.25%
Epoch [8/10], Training Loss: 0.658, Validation Accuracy: 61.01%
Epoch [9/10], Training Loss: 0.649, Validation Accuracy: 61.49%
Epoch [10/10], Training Loss: 0.625, Validation Accuracy: 61.07%
Epoch [1/10], Training Loss: 0.903, Validation Accuracy: 61.71%
Epoch [2/10], Training Loss: 0.828, Validation Accuracy: 61.67%
Epoch [3/10], Training Loss: 0.772, Validation Accuracy: 60.88%
Epoch [4/10], Training Loss: 0.749, Validation Accuracy: 61.34%
Epoch [5/10], Training Loss: 0.719, Validation Accuracy: 61.42%
Epoch [6/10], Training Loss: 0.689, Validation Accuracy: 60.87%
Epoch [7/10], Training Loss: 0.669, Validation Accuracy: 61.56%
Epoch [8/10], Training Loss: 0.645, Validation Accuracy: 61.32%
Epoch [9/10], Training Loss: 0.626, Validation Accuracy: 60.62%
Epoch [10/10], Training Loss: 0.610, Validation Accuracy: 61.23%
Epoch [1/10], Training Loss: 0.897, Validation Accuracy: 61.46%
Epoch [2/10], Training Loss: 0.801, Validation Accuracy: 61.48%
Epoch [3/10], Training Loss: 0.760, Validation Accuracy: 61.71%
Epoch [4/10], Training Loss: 0.728, Validation Accuracy: 61.61%
Epoch [5/10], Training Loss: 0.696, Validation Accuracy: 61.89%
Epoch [6/10], Training Loss: 0.664, Validation Accuracy: 61.72%
Epoch [7/10], Training Loss: 0.642, Validation Accuracy: 61.24%
Epoch [8/10], Training Loss: 0.618, Validation Accuracy: 61.40%
Epoch [9/10], Training Loss: 0.602, Validation Accuracy: 61.59%
Epoch [10/10], Training Loss: 0.592, Validation Accuracy: 60.90%
Epoch [1/10], Training Loss: 0.905, Validation Accuracy: 62.03%
Epoch [2/10], Training Loss: 0.813, Validation Accuracy: 61.40%
Epoch [3/10], Training Loss: 0.759, Validation Accuracy: 61.87%
Epoch [4/10], Training Loss: 0.724, Validation Accuracy: 61.33%
Epoch [5/10], Training Loss: 0.689, Validation Accuracy: 61.73%
Epoch [6/10], Training Loss: 0.662, Validation Accuracy: 61.53%
Epoch [7/10], Training Loss: 0.639, Validation Accuracy: 61.85%
Epoch [8/10], Training Loss: 0.622, Validation Accuracy: 61.31%
Epoch [9/10], Training Loss: 0.598, Validation Accuracy: 61.02%
Epoch [10/10], Training Loss: 0.579, Validation Accuracy: 60.55%
Epoch [1/10], Training Loss: 0.908, Validation Accuracy: 61.86%
Epoch [2/10], Training Loss: 0.800, Validation Accuracy: 61.14%
Epoch [3/10], Training Loss: 0.749, Validation Accuracy: 61.24%
Epoch [4/10], Training Loss: 0.712, Validation Accuracy: 61.37%
Epoch [5/10], Training Loss: 0.680, Validation Accuracy: 61.77%
Epoch [6/10], Training Loss: 0.662, Validation Accuracy: 61.05%
Epoch [7/10], Training Loss: 0.629, Validation Accuracy: 61.60%
Epoch [8/10], Training Loss: 0.610, Validation Accuracy: 60.62%
Epoch [9/10], Training Loss: 0.583, Validation Accuracy: 60.78%
Epoch [10/10], Training Loss: 0.570, Validation Accuracy: 61.44%
Epoch [1/10], Training Loss: 0.866, Validation Accuracy: 60.89%
Epoch [2/10], Training Loss: 0.776, Validation Accuracy: 61.59%
Epoch [3/10], Training Loss: 0.708, Validation Accuracy: 61.15%
Epoch [4/10], Training Loss: 0.671, Validation Accuracy: 61.54%
Epoch [5/10], Training Loss: 0.647, Validation Accuracy: 61.38%
Epoch [6/10], Training Loss: 0.618, Validation Accuracy: 61.50%
Epoch [7/10], Training Loss: 0.592, Validation Accuracy: 61.12%
Epoch [8/10], Training Loss: 0.570, Validation Accuracy: 61.18%
Epoch [9/10], Training Loss: 0.550, Validation Accuracy: 60.64%
Epoch [10/10], Training Loss: 0.534, Validation Accuracy: 61.33%
Epoch [1/10], Training Loss: 0.864, Validation Accuracy: 61.15%
Epoch [2/10], Training Loss: 0.752, Validation Accuracy: 60.97%
Epoch [3/10], Training Loss: 0.701, Validation Accuracy: 60.15%
Epoch [4/10], Training Loss: 0.685, Validation Accuracy: 61.62%
Epoch [5/10], Training Loss: 0.637, Validation Accuracy: 61.13%
Epoch [6/10], Training Loss: 0.604, Validation Accuracy: 61.75%
Epoch [7/10], Training Loss: 0.569, Validation Accuracy: 61.07%
Epoch [8/10], Training Loss: 0.554, Validation Accuracy: 61.46%
Epoch [9/10], Training Loss: 0.527, Validation Accuracy: 61.45%
Epoch [10/10], Training Loss: 0.509, Validation Accuracy: 61.20%
Epoch [1/10], Training Loss: 0.855, Validation Accuracy: 60.95%
Epoch [2/10], Training Loss: 0.748, Validation Accuracy: 60.90%
Epoch [3/10], Training Loss: 0.692, Validation Accuracy: 61.92%
Epoch [4/10], Training Loss: 0.641, Validation Accuracy: 61.84%
Epoch [5/10], Training Loss: 0.607, Validation Accuracy: 61.36%
Epoch [6/10], Training Loss: 0.580, Validation Accuracy: 61.63%
Epoch [7/10], Training Loss: 0.552, Validation Accuracy: 61.90%
Epoch [8/10], Training Loss: 0.532, Validation Accuracy: 61.44%
Epoch [9/10], Training Loss: 0.505, Validation Accuracy: 61.18%
Epoch [10/10], Training Loss: 0.489, Validation Accuracy: 61.78%
Epoch [1/10], Training Loss: 0.858, Validation Accuracy: 60.89%
Epoch [2/10], Training Loss: 0.743, Validation Accuracy: 61.65%
Epoch [3/10], Training Loss: 0.681, Validation Accuracy: 61.73%
Epoch [4/10], Training Loss: 0.646, Validation Accuracy: 61.27%
Epoch [5/10], Training Loss: 0.613, Validation Accuracy: 61.22%
Epoch [6/10], Training Loss: 0.578, Validation Accuracy: 61.52%
Epoch [7/10], Training Loss: 0.548, Validation Accuracy: 60.91%
Epoch [8/10], Training Loss: 0.526, Validation Accuracy: 61.24%
Epoch [9/10], Training Loss: 0.503, Validation Accuracy: 61.15%
Epoch [10/10], Training Loss: 0.491, Validation Accuracy: 61.04%
Epoch [1/10], Training Loss: 0.869, Validation Accuracy: 60.81%
Epoch [2/10], Training Loss: 0.748, Validation Accuracy: 61.49%
Epoch [3/10], Training Loss: 0.675, Validation Accuracy: 61.07%
Epoch [4/10], Training Loss: 0.625, Validation Accuracy: 61.57%
Epoch [5/10], Training Loss: 0.603, Validation Accuracy: 60.67%
Epoch [6/10], Training Loss: 0.568, Validation Accuracy: 61.15%
Epoch [7/10], Training Loss: 0.541, Validation Accuracy: 61.32%
Epoch [8/10], Training Loss: 0.514, Validation Accuracy: 60.98%
Epoch [9/10], Training Loss: 0.497, Validation Accuracy: 60.43%
Epoch [10/10], Training Loss: 0.477, Validation Accuracy: 60.64%
Epoch [1/10], Training Loss: 0.841, Validation Accuracy: 60.88%
Epoch [2/10], Training Loss: 0.713, Validation Accuracy: 61.34%
Epoch [3/10], Training Loss: 0.647, Validation Accuracy: 60.69%
Epoch [4/10], Training Loss: 0.609, Validation Accuracy: 61.52%
Epoch [5/10], Training Loss: 0.561, Validation Accuracy: 60.97%
Epoch [6/10], Training Loss: 0.540, Validation Accuracy: 61.43%
Epoch [7/10], Training Loss: 0.509, Validation Accuracy: 61.33%
Epoch [8/10], Training Loss: 0.486, Validation Accuracy: 61.12%
Epoch [9/10], Training Loss: 0.464, Validation Accuracy: 60.26%
Epoch [10/10], Training Loss: 0.439, Validation Accuracy: 60.43%
Epoch [1/10], Training Loss: 0.819, Validation Accuracy: 60.46%
Epoch [2/10], Training Loss: 0.692, Validation Accuracy: 60.77%
Epoch [3/10], Training Loss: 0.619, Validation Accuracy: 61.06%
Epoch [4/10], Training Loss: 0.574, Validation Accuracy: 61.28%
Epoch [5/10], Training Loss: 0.542, Validation Accuracy: 61.05%
Epoch [6/10], Training Loss: 0.509, Validation Accuracy: 60.91%
Epoch [7/10], Training Loss: 0.482, Validation Accuracy: 60.70%
Epoch [8/10], Training Loss: 0.474, Validation Accuracy: 60.91%
Epoch [9/10], Training Loss: 0.450, Validation Accuracy: 61.02%
Epoch [10/10], Training Loss: 0.418, Validation Accuracy: 61.30%
"""

# Regular expression to find validation accuracies
accuracies = re.findall(r'Validation Accuracy: (\d+\.\d+)%', log)

# Convert accuracies from string to float
accuracies = [float(acc) for acc in accuracies]

# Print accuracies
print("Accuracies:", accuracies)

# Print size of the array
print("Size of array:", len(accuracies))


Accuracies: [9.89, 9.89, 9.89, 9.89, 9.89, 9.89, 9.89, 9.89, 9.89, 9.89, 9.89, 9.9, 10.05, 11.8, 14.01, 14.36, 14.27, 15.41, 16.01, 17.05, 19.88, 21.49, 23.1, 24.23, 25.19, 26.34, 26.71, 27.22, 27.68, 27.52, 28.6, 29.16, 29.75, 30.42, 31.17, 31.34, 31.2, 31.76, 32.39, 33.22, 33.81, 34.08, 34.79, 34.94, 35.08, 35.6, 35.07, 36.25, 36.16, 37.01, 37.73, 37.92, 38.34, 39.26, 39.74, 40.38, 40.66, 40.5, 41.16, 41.16, 41.27, 41.93, 42.55, 41.94, 42.29, 43.0, 42.71, 43.48, 43.33, 43.78, 43.6, 44.25, 44.79, 44.12, 44.43, 44.76, 45.19, 45.29, 45.16, 45.61, 45.77, 45.85, 46.24, 46.14, 46.65, 46.16, 46.26, 46.58, 47.02, 47.42, 47.18, 47.64, 47.32, 47.77, 47.66, 47.22, 48.34, 48.53, 48.21, 46.46, 48.53, 48.52, 48.53, 49.06, 48.32, 49.13, 49.09, 48.39, 49.79, 49.42, 49.62, 50.18, 49.71, 50.38, 50.64, 50.49, 50.81, 50.69, 50.79, 50.4, 51.25, 51.56, 51.83, 51.59, 51.28, 51.29, 51.41, 52.19, 51.6, 51.63, 52.33, 52.37, 52.24, 52.4, 52.1, 53.48, 53.22, 52.68, 52.91, 51.9, 52.5, 52.99, 53.83, 53.3, 53.85, 

In [None]:
import numpy as np
from scipy.stats import truncnorm
import torch
import torchvision
import torchvision.transforms as transforms
from torch import nn
from torch.utils.data import Subset, DataLoader
from torchvision.datasets import CIFAR10
from typing import Dict, Tuple
from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score
import random

# Define VAE model
class VAE(nn.Module):
    def __init__(self, x_dim, h_dim, z_dim):
        super(VAE, self).__init__()
        self.x_dim = x_dim
        self.h_dim = h_dim
        self.z_dim = z_dim

        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, 4, stride=2, padding=1),  # 16x16
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2, padding=1),  # 8x8
            nn.ReLU(),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),  # 4x4
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(2048, 256),
            nn.ReLU(),
            nn.Linear(256, 2*z_dim)
        )

        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(z_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 2048),
            nn.ReLU(),
            nn.Unflatten(1, (128, 4, 4)),
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),  # 8x8
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1),  # 16x16
            nn.ReLU(),
            nn.ConvTranspose2d(32, 3, 4, stride=2, padding=1),  # 32x32
            nn.Sigmoid()
        )

    def encode(self, x):
        h = self.encoder(x)
        mu, logvar = h[:, :self.z_dim], h[:, self.z_dim:]
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        return self.decoder(z)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar


# Define VAE training procedure
def vae_train(vae: VAE, trainloader: DataLoader, epochs: int) -> None:
    optimizer = torch.optim.Adam(vae.parameters(), lr=1e-3)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)

    for epoch in range(epochs):
        vae.train()
        for i, data in enumerate(trainloader, 0):
            inputs, _ = data
            optimizer.zero_grad()
            recon_x, mu, logvar = vae(inputs)
            loss = vae_loss(recon_x, inputs, mu, logvar)
            loss.backward()
            optimizer.step()

        scheduler.step()


# Define classification model
class Net(nn.Module):
    def __init__(self) -> None:
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.pool(nn.functional.relu(self.conv1(x)))
        x = self.pool(nn.functional.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = nn.functional.relu(self.fc1(x))
        x = nn.functional.relu(self.fc2(x))
        x = self.fc3(x)
        return x


# Load CIFAR10 dataset
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)

full_dataset = CIFAR10(root="./data", train=True, download=True, transform=transform)
test_set = CIFAR10(root="./data", train=False, download=True, transform=transform)

# Separate dataset by class
class_indices = {i: [] for i in range(10)}  # CIFAR-10 has 10 classes
for idx, (_, label) in enumerate(full_dataset):
    class_indices[label].append(idx)

# Define target count per class, summing to 60,000 with random distribution
class_counts = np.random.multinomial(60000, [0.1] * 10)  # Adjust probabilities if you want specific class biases
print("Random Images per Class:", class_counts)

# Sample indices based on the specified class counts
indices = []
for class_id, count in enumerate(class_counts):
    # Ensure count does not exceed available images
    count = min(count, len(class_indices[class_id]))
    selected_indices = random.sample(class_indices[class_id], count)
    indices.extend(selected_indices)

# Create a custom CIFAR-10 dataset with the sampled indices
custom_dataset = Subset(full_dataset, indices)
# Split dataset into training and validation sets
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_set, val_set = torch.utils.data.random_split(full_dataset, [train_size, val_size])

# Create training and validation loaders
trainloader = DataLoader(train_set, batch_size=128, shuffle=True, num_workers=2)
valloader = DataLoader(val_set, batch_size=128, shuffle=False, num_workers=2)
testloader = DataLoader(test_set, batch_size=128, shuffle=False, num_workers=2)



# Define VAE loss function
def vae_loss(recon_x, x, mu, logvar):
    BCE = nn.functional.mse_loss(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

# Define training procedure for classification model
def train(net: nn.Module, trainloader: DataLoader, valloader: DataLoader, epochs: int) -> None:
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)

    best_acc = 0.0

    for epoch in range(epochs):
        # Training loop
        net.train()
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        scheduler.step()

        # Validation loop
        net.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for data in valloader:
                images, labels = data
                outputs = net(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        val_acc = 100 * correct / total
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(net.state_dict(), 'best_model.pth')

        print(f"Epoch [{epoch+1}/{epochs}], Training Loss: {running_loss / (i+1):.3f}, Validation Accuracy: {val_acc:.2f}%")

# Define evaluation procedure
def evaluate(net: nn.Module, testloader: DataLoader) -> Tuple[float, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    net.eval()
    correct = 0
    total = 0
    all_labels = []
    all_predictions = []
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            all_labels.extend(labels.cpu().numpy())
            all_predictions.extend(predicted.cpu().numpy())

    accuracy = 100 * correct / total

    # Calculate confusion matrix
    cm = confusion_matrix(all_labels, all_predictions)
    print(f"Confusion Matrix:\n{cm}")

    # Calculate TP, FP, TN, FN for each class
    tp = np.diag(cm)
    fp = cm.sum(axis=0) - tp
    fn = cm.sum(axis=1) - tp
    tn = cm.sum() - (fp + fn + tp)

    # Calculate precision, recall, and F1 score for each class
    precision = precision_score(all_labels, all_predictions, average=None)
    recall = recall_score(all_labels, all_predictions, average=None)
    f1 = f1_score(all_labels, all_predictions, average=None)

    return accuracy, tp, fp, tn, fn, precision, recall, f1

# Initialize clients
def initialize_clients(trainset, transform, batch_size, num_clients):
    clients = {}
    for i in range(num_clients):
        client_trainset = torch.utils.data.Subset(trainset, range(i * len(trainset) // num_clients, (i + 1) * len(trainset) // num_clients))
        client_trainloader = torch.utils.data.DataLoader(client_trainset, batch_size=batch_size, shuffle=True)
        clients[f"client_{i}"] = client_trainloader
    return clients

def get_distribution_info(vae: VAE) -> Dict:
    # Implement the logic to extract distribution information from the VAE
    # This can involve computing statistics, parameters, or any other relevant information
    # that can be used to generate augmented data

    # Example implementation:
    distribution_info = {
        "uniform": {
            "mean": vae.encoder[-1].bias.data.cpu().numpy(),
            "std": torch.exp(0.5 * vae.encoder[-1].weight.data).cpu().numpy()
        },
        "normal": {
            "mean": vae.encoder[-1].bias.data.cpu().numpy(),
            "std": torch.exp(0.5 * vae.encoder[-1].weight.data).cpu().numpy()
        }
    }

    return distribution_info

def send_distribution_info(distribution_info: Dict) -> None:
    # Implement the logic to send the distribution information to the global server
    # This can involve using a network protocol, a message queue, or any other communication mechanism
    # to send the information

    # Example implementation:
    # Send the distribution information to the global server using a network protocol
    # For example, you can use the `socket` module to send the information over a network
    # or use a message queue like `RabbitMQ` to send the information
    pass

def generate_augmented_data(vae: VAE, distribution_info_uniform: Dict, distribution_info_normal: Dict) -> torch.Tensor:
    # Generate augmented data using both uniform and truncated uniform distributions

    mean = distribution_info_uniform["mean"].mean().item()  # Convert numpy array to float
    std = distribution_info_uniform["std"].mean().item()  # Convert numpy array to float


    mean_normal = distribution_info_normal["mean"]
    std_normal = distribution_info_normal["std"]

    augmented_data_normal = torch.randn(64, vae.z_dim) * std_normal + mean_normal

     # Generate augmented data using Uniform distribution
    augmented_data_uniform = torch.FloatTensor(64, vae.z_dim).uniform_(mean - std, mean + std)


    # Calculate the average of augmented data from both distributions
    augmented_data_average = (augmented_data_uniform + augmented_data_normal ) / 2

    return augmented_data_average

def federated_train(net: nn.Module, vae: VAE, trainloaders: Dict[str, DataLoader], trainloader: DataLoader, valloader: DataLoader, epochs: int) -> None:
    for epoch in range(epochs):
        for client_id, client_trainloader in trainloaders.items():
            # Train VAE on client data
            vae_train(vae, client_trainloader, epochs=10)

            # Share distribution information with global server
            distribution_info = get_distribution_info(vae)
            send_distribution_info(distribution_info)

            # Receive distribution information from other clients
            other_distribution_info = receive_distribution_info()

            # Generate augmented data using received distribution information
            augmented_data = generate_augmented_data(vae, other_distribution_info["uniform"], other_distribution_info["normal"])

            # Train classification model using local, augmented, and validation data
            train(net, client_trainloader, valloader, epochs=10)

            # Send model updates to global server
            send_model_update(client_id, net.state_dict())

# Define logic to receive distribution information from global server
def receive_distribution_info() -> Dict:
    # Receive distribution information logic
    distribution_info = {
        "uniform": {
            "mean": np.zeros(20),  # Adjust the size based on your latent space dimension
            "std": np.ones(20)
        },
        "normal": {
            "mean": np.zeros(20),  # Adjust the size based on your latent space dimension
            "std": np.ones(20)
        }
    }
    return distribution_info

def send_model_update(client_id: str, model_update: Dict) -> None:
    # Implement the logic to send the model update to the global server
    # This can involve using a network protocol, a message queue, or any other communication mechanism
    # to send the model update

    # Example implementation:
    # Send the model update to the global server using a network protocol
    # For example, you can use the `socket` module to send the information over a network
    # or use a message queue like `RabbitMQ` to send the information
    pass

# Define global server procedure
def global_server() -> None:
    net = Net()
    x_dim = 3 * 32 * 32  # CIFAR-10 input size
    h_dim = 400
    z_dim = 20
    vae = VAE(x_dim, h_dim, z_dim)  # Initialize VAE object with required arguments

    # Initialize clients
    num_clients = 5  # Define the number of clients
    clients = initialize_clients(train_set, transform, batch_size=128, num_clients=num_clients)

    # Train model using FedDIS
    federated_train(net, vae, clients, trainloader, valloader, epochs=10)

    # Evaluate final model
    test_accuracy, tp, fp, tn, fn, precision, recall, f1 = evaluate(net, testloader)
    print(f"Test Accuracy: {test_accuracy:.2f}%")

    print("True Positives (TP):", tp)
    print("False Positives (FP):", fp)
    print("True Negatives (TN):", tn)
    print("False Negatives (FN):", fn)
    print(":", fn+tn+tp+fp)
    print("Precision:", precision)
    print("Recall:", recall)
    print("F1 Score:", f1)

if __name__ == "__main__":
    global_server()

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:01<00:00, 93.7MB/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified
Random Images per Class: [6038 6060 6069 5999 5989 6051 5982 6005 5876 5931]
Epoch [1/10], Training Loss: 2.305, Validation Accuracy: 10.34%
Epoch [2/10], Training Loss: 2.303, Validation Accuracy: 10.30%
Epoch [3/10], Training Loss: 2.301, Validation Accuracy: 10.49%
Epoch [4/10], Training Loss: 2.299, Validation Accuracy: 11.05%
Epoch [5/10], Training Loss: 2.296, Validation Accuracy: 11.75%
Epoch [6/10], Training Loss: 2.293, Validation Accuracy: 12.24%
Epoch [7/10], Training Loss: 2.289, Validation Accuracy: 12.45%
Epoch [8/10], Training Loss: 2.283, Validation Accuracy: 15.12%
Epoch [9/10], Training Loss: 2.276, Validation Accuracy: 17.40%
Epoch [10/10], Training Loss: 2.266, Validation Accuracy: 18.49%
Epoch [1/10], Training Loss: 2.256, Validation Accuracy: 18.26%
Epoch [2/10], Training Loss: 2.242, Validation Accuracy: 18.88%
Epoch [3/10], Training Loss: 2.223, Validation Accuracy: 20.73%
E

In [None]:
import re

# Your provided text
log = """
Epoch [1/10], Training Loss: 2.305, Validation Accuracy: 10.34%
Epoch [2/10], Training Loss: 2.303, Validation Accuracy: 10.30%
Epoch [3/10], Training Loss: 2.301, Validation Accuracy: 10.49%
Epoch [4/10], Training Loss: 2.299, Validation Accuracy: 11.05%
Epoch [5/10], Training Loss: 2.296, Validation Accuracy: 11.75%
Epoch [6/10], Training Loss: 2.293, Validation Accuracy: 12.24%
Epoch [7/10], Training Loss: 2.289, Validation Accuracy: 12.45%
Epoch [8/10], Training Loss: 2.283, Validation Accuracy: 15.12%
Epoch [9/10], Training Loss: 2.276, Validation Accuracy: 17.40%
Epoch [10/10], Training Loss: 2.266, Validation Accuracy: 18.49%
Epoch [1/10], Training Loss: 2.256, Validation Accuracy: 18.26%
Epoch [2/10], Training Loss: 2.242, Validation Accuracy: 18.88%
Epoch [3/10], Training Loss: 2.223, Validation Accuracy: 20.73%
Epoch [4/10], Training Loss: 2.197, Validation Accuracy: 22.43%
Epoch [5/10], Training Loss: 2.168, Validation Accuracy: 23.29%
Epoch [6/10], Training Loss: 2.138, Validation Accuracy: 26.29%
Epoch [7/10], Training Loss: 2.118, Validation Accuracy: 27.01%
Epoch [8/10], Training Loss: 2.102, Validation Accuracy: 27.20%
Epoch [9/10], Training Loss: 2.087, Validation Accuracy: 27.19%
Epoch [10/10], Training Loss: 2.071, Validation Accuracy: 27.50%
Epoch [1/10], Training Loss: 2.038, Validation Accuracy: 26.39%
Epoch [2/10], Training Loss: 2.020, Validation Accuracy: 28.42%
Epoch [3/10], Training Loss: 2.003, Validation Accuracy: 28.60%
Epoch [4/10], Training Loss: 1.990, Validation Accuracy: 28.81%
Epoch [5/10], Training Loss: 1.970, Validation Accuracy: 29.77%
Epoch [6/10], Training Loss: 1.955, Validation Accuracy: 29.89%
Epoch [7/10], Training Loss: 1.942, Validation Accuracy: 30.61%
Epoch [8/10], Training Loss: 1.920, Validation Accuracy: 31.12%
Epoch [9/10], Training Loss: 1.900, Validation Accuracy: 31.63%
Epoch [10/10], Training Loss: 1.884, Validation Accuracy: 31.96%
Epoch [1/10], Training Loss: 1.909, Validation Accuracy: 33.55%
Epoch [2/10], Training Loss: 1.889, Validation Accuracy: 32.80%
Epoch [3/10], Training Loss: 1.872, Validation Accuracy: 33.40%
Epoch [4/10], Training Loss: 1.857, Validation Accuracy: 34.79%
Epoch [5/10], Training Loss: 1.842, Validation Accuracy: 34.94%
Epoch [6/10], Training Loss: 1.823, Validation Accuracy: 35.74%
Epoch [7/10], Training Loss: 1.808, Validation Accuracy: 35.06%
Epoch [8/10], Training Loss: 1.796, Validation Accuracy: 36.08%
Epoch [9/10], Training Loss: 1.780, Validation Accuracy: 36.70%
Epoch [10/10], Training Loss: 1.767, Validation Accuracy: 35.39%
Epoch [1/10], Training Loss: 1.789, Validation Accuracy: 37.45%
Epoch [2/10], Training Loss: 1.765, Validation Accuracy: 37.89%
Epoch [3/10], Training Loss: 1.747, Validation Accuracy: 38.94%
Epoch [4/10], Training Loss: 1.733, Validation Accuracy: 39.32%
Epoch [5/10], Training Loss: 1.718, Validation Accuracy: 39.46%
Epoch [6/10], Training Loss: 1.698, Validation Accuracy: 39.68%
Epoch [7/10], Training Loss: 1.684, Validation Accuracy: 39.67%
Epoch [8/10], Training Loss: 1.673, Validation Accuracy: 40.03%
Epoch [9/10], Training Loss: 1.655, Validation Accuracy: 40.11%
Epoch [10/10], Training Loss: 1.644, Validation Accuracy: 41.04%
Epoch [1/10], Training Loss: 1.640, Validation Accuracy: 41.27%
Epoch [2/10], Training Loss: 1.631, Validation Accuracy: 41.38%
Epoch [3/10], Training Loss: 1.617, Validation Accuracy: 42.16%
Epoch [4/10], Training Loss: 1.594, Validation Accuracy: 42.49%
Epoch [5/10], Training Loss: 1.590, Validation Accuracy: 42.07%
Epoch [6/10], Training Loss: 1.577, Validation Accuracy: 43.36%
Epoch [7/10], Training Loss: 1.562, Validation Accuracy: 43.20%
Epoch [8/10], Training Loss: 1.549, Validation Accuracy: 43.27%
Epoch [9/10], Training Loss: 1.543, Validation Accuracy: 44.33%
Epoch [10/10], Training Loss: 1.538, Validation Accuracy: 43.43%
Epoch [1/10], Training Loss: 1.561, Validation Accuracy: 44.09%
Epoch [2/10], Training Loss: 1.548, Validation Accuracy: 44.37%
Epoch [3/10], Training Loss: 1.533, Validation Accuracy: 44.40%
Epoch [4/10], Training Loss: 1.525, Validation Accuracy: 43.98%
Epoch [5/10], Training Loss: 1.521, Validation Accuracy: 45.08%
Epoch [6/10], Training Loss: 1.505, Validation Accuracy: 45.37%
Epoch [7/10], Training Loss: 1.502, Validation Accuracy: 45.18%
Epoch [8/10], Training Loss: 1.485, Validation Accuracy: 45.15%
Epoch [9/10], Training Loss: 1.478, Validation Accuracy: 45.85%
Epoch [10/10], Training Loss: 1.473, Validation Accuracy: 46.03%
Epoch [1/10], Training Loss: 1.469, Validation Accuracy: 46.43%
Epoch [2/10], Training Loss: 1.451, Validation Accuracy: 45.51%
Epoch [3/10], Training Loss: 1.452, Validation Accuracy: 46.53%
Epoch [4/10], Training Loss: 1.426, Validation Accuracy: 46.31%
Epoch [5/10], Training Loss: 1.427, Validation Accuracy: 46.87%
Epoch [6/10], Training Loss: 1.409, Validation Accuracy: 46.87%
Epoch [7/10], Training Loss: 1.416, Validation Accuracy: 47.37%
Epoch [8/10], Training Loss: 1.395, Validation Accuracy: 46.94%
Epoch [9/10], Training Loss: 1.385, Validation Accuracy: 47.77%
Epoch [10/10], Training Loss: 1.383, Validation Accuracy: 47.28%
Epoch [1/10], Training Loss: 1.445, Validation Accuracy: 48.27%
Epoch [2/10], Training Loss: 1.438, Validation Accuracy: 48.57%
Epoch [3/10], Training Loss: 1.414, Validation Accuracy: 48.05%
Epoch [4/10], Training Loss: 1.401, Validation Accuracy: 48.28%
Epoch [5/10], Training Loss: 1.391, Validation Accuracy: 48.74%
Epoch [6/10], Training Loss: 1.382, Validation Accuracy: 48.92%
Epoch [7/10], Training Loss: 1.369, Validation Accuracy: 47.21%
Epoch [8/10], Training Loss: 1.359, Validation Accuracy: 49.35%
Epoch [9/10], Training Loss: 1.351, Validation Accuracy: 48.97%
Epoch [10/10], Training Loss: 1.350, Validation Accuracy: 49.21%
Epoch [1/10], Training Loss: 1.409, Validation Accuracy: 49.23%
Epoch [2/10], Training Loss: 1.390, Validation Accuracy: 48.86%
Epoch [3/10], Training Loss: 1.382, Validation Accuracy: 49.86%
Epoch [4/10], Training Loss: 1.370, Validation Accuracy: 49.93%
Epoch [5/10], Training Loss: 1.357, Validation Accuracy: 49.84%
Epoch [6/10], Training Loss: 1.347, Validation Accuracy: 49.44%
Epoch [7/10], Training Loss: 1.349, Validation Accuracy: 50.97%
Epoch [8/10], Training Loss: 1.342, Validation Accuracy: 50.44%
Epoch [9/10], Training Loss: 1.322, Validation Accuracy: 51.30%
Epoch [10/10], Training Loss: 1.314, Validation Accuracy: 50.78%
Epoch [1/10], Training Loss: 1.366, Validation Accuracy: 51.51%
Epoch [2/10], Training Loss: 1.354, Validation Accuracy: 51.64%
Epoch [3/10], Training Loss: 1.331, Validation Accuracy: 51.77%
Epoch [4/10], Training Loss: 1.327, Validation Accuracy: 50.46%
Epoch [5/10], Training Loss: 1.311, Validation Accuracy: 51.98%
Epoch [6/10], Training Loss: 1.309, Validation Accuracy: 51.86%
Epoch [7/10], Training Loss: 1.291, Validation Accuracy: 51.88%
Epoch [8/10], Training Loss: 1.281, Validation Accuracy: 51.45%
Epoch [9/10], Training Loss: 1.274, Validation Accuracy: 52.23%
Epoch [10/10], Training Loss: 1.267, Validation Accuracy: 52.37%
Epoch [1/10], Training Loss: 1.338, Validation Accuracy: 52.06%
Epoch [2/10], Training Loss: 1.312, Validation Accuracy: 51.97%
Epoch [3/10], Training Loss: 1.312, Validation Accuracy: 53.12%
Epoch [4/10], Training Loss: 1.289, Validation Accuracy: 52.70%
Epoch [5/10], Training Loss: 1.282, Validation Accuracy: 52.39%
Epoch [6/10], Training Loss: 1.263, Validation Accuracy: 52.49%
Epoch [7/10], Training Loss: 1.261, Validation Accuracy: 52.72%
Epoch [8/10], Training Loss: 1.254, Validation Accuracy: 52.09%
Epoch [9/10], Training Loss: 1.240, Validation Accuracy: 52.49%
Epoch [10/10], Training Loss: 1.238, Validation Accuracy: 52.20%
Epoch [1/10], Training Loss: 1.275, Validation Accuracy: 53.35%
Epoch [2/10], Training Loss: 1.264, Validation Accuracy: 53.00%
Epoch [3/10], Training Loss: 1.241, Validation Accuracy: 52.74%
Epoch [4/10], Training Loss: 1.228, Validation Accuracy: 53.15%
Epoch [5/10], Training Loss: 1.217, Validation Accuracy: 53.42%
Epoch [6/10], Training Loss: 1.212, Validation Accuracy: 53.79%
Epoch [7/10], Training Loss: 1.187, Validation Accuracy: 53.46%
Epoch [8/10], Training Loss: 1.182, Validation Accuracy: 53.54%
Epoch [9/10], Training Loss: 1.167, Validation Accuracy: 53.58%
Epoch [10/10], Training Loss: 1.170, Validation Accuracy: 52.89%
Epoch [1/10], Training Loss: 1.272, Validation Accuracy: 54.33%
Epoch [2/10], Training Loss: 1.251, Validation Accuracy: 53.09%
Epoch [3/10], Training Loss: 1.230, Validation Accuracy: 54.03%
Epoch [4/10], Training Loss: 1.214, Validation Accuracy: 54.15%
Epoch [5/10], Training Loss: 1.204, Validation Accuracy: 53.98%
Epoch [6/10], Training Loss: 1.203, Validation Accuracy: 54.18%
Epoch [7/10], Training Loss: 1.188, Validation Accuracy: 54.87%
Epoch [8/10], Training Loss: 1.179, Validation Accuracy: 54.66%
Epoch [9/10], Training Loss: 1.171, Validation Accuracy: 54.34%
Epoch [10/10], Training Loss: 1.159, Validation Accuracy: 54.49%
Epoch [1/10], Training Loss: 1.259, Validation Accuracy: 55.55%
Epoch [2/10], Training Loss: 1.232, Validation Accuracy: 55.03%
Epoch [3/10], Training Loss: 1.217, Validation Accuracy: 55.06%
Epoch [4/10], Training Loss: 1.200, Validation Accuracy: 55.69%
Epoch [5/10], Training Loss: 1.194, Validation Accuracy: 54.86%
Epoch [6/10], Training Loss: 1.183, Validation Accuracy: 55.93%
Epoch [7/10], Training Loss: 1.172, Validation Accuracy: 55.37%
Epoch [8/10], Training Loss: 1.154, Validation Accuracy: 55.47%
Epoch [9/10], Training Loss: 1.150, Validation Accuracy: 55.77%
Epoch [10/10], Training Loss: 1.141, Validation Accuracy: 55.49%
Epoch [1/10], Training Loss: 1.222, Validation Accuracy: 55.70%
Epoch [2/10], Training Loss: 1.192, Validation Accuracy: 55.88%
Epoch [3/10], Training Loss: 1.171, Validation Accuracy: 56.18%
Epoch [4/10], Training Loss: 1.171, Validation Accuracy: 55.40%
Epoch [5/10], Training Loss: 1.156, Validation Accuracy: 56.57%
Epoch [6/10], Training Loss: 1.139, Validation Accuracy: 55.85%
Epoch [7/10], Training Loss: 1.127, Validation Accuracy: 54.97%
Epoch [8/10], Training Loss: 1.124, Validation Accuracy: 56.36%
Epoch [9/10], Training Loss: 1.102, Validation Accuracy: 56.35%
Epoch [10/10], Training Loss: 1.092, Validation Accuracy: 55.38%
Epoch [1/10], Training Loss: 1.206, Validation Accuracy: 56.42%
Epoch [2/10], Training Loss: 1.172, Validation Accuracy: 56.23%
Epoch [3/10], Training Loss: 1.164, Validation Accuracy: 56.12%
Epoch [4/10], Training Loss: 1.148, Validation Accuracy: 56.25%
Epoch [5/10], Training Loss: 1.125, Validation Accuracy: 56.35%
Epoch [6/10], Training Loss: 1.114, Validation Accuracy: 56.48%
Epoch [7/10], Training Loss: 1.104, Validation Accuracy: 56.86%
Epoch [8/10], Training Loss: 1.091, Validation Accuracy: 57.16%
Epoch [9/10], Training Loss: 1.086, Validation Accuracy: 56.74%
Epoch [10/10], Training Loss: 1.071, Validation Accuracy: 56.52%
Epoch [1/10], Training Loss: 1.151, Validation Accuracy: 57.10%
Epoch [2/10], Training Loss: 1.122, Validation Accuracy: 57.27%
Epoch [3/10], Training Loss: 1.096, Validation Accuracy: 57.50%
Epoch [4/10], Training Loss: 1.087, Validation Accuracy: 56.01%
Epoch [5/10], Training Loss: 1.071, Validation Accuracy: 57.93%
Epoch [6/10], Training Loss: 1.052, Validation Accuracy: 57.72%
Epoch [7/10], Training Loss: 1.033, Validation Accuracy: 57.47%
Epoch [8/10], Training Loss: 1.028, Validation Accuracy: 56.71%
Epoch [9/10], Training Loss: 1.018, Validation Accuracy: 57.45%
Epoch [10/10], Training Loss: 1.018, Validation Accuracy: 57.50%
Epoch [1/10], Training Loss: 1.165, Validation Accuracy: 57.08%
Epoch [2/10], Training Loss: 1.128, Validation Accuracy: 57.65%
Epoch [3/10], Training Loss: 1.106, Validation Accuracy: 57.93%
Epoch [4/10], Training Loss: 1.085, Validation Accuracy: 57.91%
Epoch [5/10], Training Loss: 1.066, Validation Accuracy: 57.61%
Epoch [6/10], Training Loss: 1.059, Validation Accuracy: 57.55%
Epoch [7/10], Training Loss: 1.038, Validation Accuracy: 57.34%
Epoch [8/10], Training Loss: 1.045, Validation Accuracy: 57.53%
Epoch [9/10], Training Loss: 1.014, Validation Accuracy: 58.23%
Epoch [10/10], Training Loss: 1.009, Validation Accuracy: 57.96%
Epoch [1/10], Training Loss: 1.151, Validation Accuracy: 58.06%
Epoch [2/10], Training Loss: 1.117, Validation Accuracy: 58.65%
Epoch [3/10], Training Loss: 1.100, Validation Accuracy: 58.38%
Epoch [4/10], Training Loss: 1.069, Validation Accuracy: 58.56%
Epoch [5/10], Training Loss: 1.060, Validation Accuracy: 57.84%
Epoch [6/10], Training Loss: 1.043, Validation Accuracy: 57.87%
Epoch [7/10], Training Loss: 1.036, Validation Accuracy: 57.78%
Epoch [8/10], Training Loss: 1.023, Validation Accuracy: 57.54%
Epoch [9/10], Training Loss: 1.013, Validation Accuracy: 58.31%
Epoch [10/10], Training Loss: 0.992, Validation Accuracy: 58.22%
Epoch [1/10], Training Loss: 1.129, Validation Accuracy: 57.28%
Epoch [2/10], Training Loss: 1.092, Validation Accuracy: 59.14%
Epoch [3/10], Training Loss: 1.068, Validation Accuracy: 58.27%
Epoch [4/10], Training Loss: 1.046, Validation Accuracy: 58.78%
Epoch [5/10], Training Loss: 1.032, Validation Accuracy: 58.61%
Epoch [6/10], Training Loss: 1.012, Validation Accuracy: 58.23%
Epoch [7/10], Training Loss: 1.006, Validation Accuracy: 58.64%
Epoch [8/10], Training Loss: 0.989, Validation Accuracy: 58.34%
Epoch [9/10], Training Loss: 0.984, Validation Accuracy: 59.02%
Epoch [10/10], Training Loss: 0.973, Validation Accuracy: 58.37%
Epoch [1/10], Training Loss: 1.103, Validation Accuracy: 58.27%
Epoch [2/10], Training Loss: 1.080, Validation Accuracy: 58.32%
Epoch [3/10], Training Loss: 1.052, Validation Accuracy: 58.60%
Epoch [4/10], Training Loss: 1.030, Validation Accuracy: 58.55%
Epoch [5/10], Training Loss: 1.010, Validation Accuracy: 58.77%
Epoch [6/10], Training Loss: 0.997, Validation Accuracy: 58.65%
Epoch [7/10], Training Loss: 0.986, Validation Accuracy: 58.77%
Epoch [8/10], Training Loss: 0.972, Validation Accuracy: 58.93%
Epoch [9/10], Training Loss: 0.952, Validation Accuracy: 58.32%
Epoch [10/10], Training Loss: 0.944, Validation Accuracy: 58.54%
Epoch [1/10], Training Loss: 1.070, Validation Accuracy: 59.56%
Epoch [2/10], Training Loss: 1.026, Validation Accuracy: 59.29%
Epoch [3/10], Training Loss: 1.013, Validation Accuracy: 58.67%
Epoch [4/10], Training Loss: 0.983, Validation Accuracy: 59.30%
Epoch [5/10], Training Loss: 0.967, Validation Accuracy: 59.23%
Epoch [6/10], Training Loss: 0.938, Validation Accuracy: 59.09%
Epoch [7/10], Training Loss: 0.933, Validation Accuracy: 59.37%
Epoch [8/10], Training Loss: 0.916, Validation Accuracy: 59.20%
Epoch [9/10], Training Loss: 0.894, Validation Accuracy: 59.13%
Epoch [10/10], Training Loss: 0.896, Validation Accuracy: 59.19%
Epoch [1/10], Training Loss: 1.085, Validation Accuracy: 59.52%
Epoch [2/10], Training Loss: 1.028, Validation Accuracy: 58.62%
Epoch [3/10], Training Loss: 1.005, Validation Accuracy: 58.92%
Epoch [4/10], Training Loss: 0.987, Validation Accuracy: 59.80%
Epoch [5/10], Training Loss: 0.964, Validation Accuracy: 58.54%
Epoch [6/10], Training Loss: 0.957, Validation Accuracy: 59.97%
Epoch [7/10], Training Loss: 0.937, Validation Accuracy: 59.30%
Epoch [8/10], Training Loss: 0.929, Validation Accuracy: 59.29%
Epoch [9/10], Training Loss: 0.909, Validation Accuracy: 58.65%
Epoch [10/10], Training Loss: 0.893, Validation Accuracy: 58.57%
Epoch [1/10], Training Loss: 1.076, Validation Accuracy: 58.19%
Epoch [2/10], Training Loss: 1.021, Validation Accuracy: 59.74%
Epoch [3/10], Training Loss: 1.001, Validation Accuracy: 59.76%
Epoch [4/10], Training Loss: 0.977, Validation Accuracy: 60.15%
Epoch [5/10], Training Loss: 0.968, Validation Accuracy: 59.78%
Epoch [6/10], Training Loss: 0.944, Validation Accuracy: 58.73%
Epoch [7/10], Training Loss: 0.928, Validation Accuracy: 60.43%
Epoch [8/10], Training Loss: 0.912, Validation Accuracy: 59.84%
Epoch [9/10], Training Loss: 0.898, Validation Accuracy: 59.46%
Epoch [10/10], Training Loss: 0.889, Validation Accuracy: 59.61%
Epoch [1/10], Training Loss: 1.048, Validation Accuracy: 59.96%
Epoch [2/10], Training Loss: 1.002, Validation Accuracy: 60.33%
Epoch [3/10], Training Loss: 0.972, Validation Accuracy: 59.91%
Epoch [4/10], Training Loss: 0.952, Validation Accuracy: 60.15%
Epoch [5/10], Training Loss: 0.938, Validation Accuracy: 59.99%
Epoch [6/10], Training Loss: 0.913, Validation Accuracy: 59.48%
Epoch [7/10], Training Loss: 0.898, Validation Accuracy: 59.83%
Epoch [8/10], Training Loss: 0.888, Validation Accuracy: 60.02%
Epoch [9/10], Training Loss: 0.869, Validation Accuracy: 59.95%
Epoch [10/10], Training Loss: 0.866, Validation Accuracy: 59.76%
Epoch [1/10], Training Loss: 1.032, Validation Accuracy: 59.61%
Epoch [2/10], Training Loss: 0.979, Validation Accuracy: 59.72%
Epoch [3/10], Training Loss: 0.964, Validation Accuracy: 59.41%
Epoch [4/10], Training Loss: 0.930, Validation Accuracy: 60.16%
Epoch [5/10], Training Loss: 0.922, Validation Accuracy: 60.08%
Epoch [6/10], Training Loss: 0.893, Validation Accuracy: 60.08%
Epoch [7/10], Training Loss: 0.878, Validation Accuracy: 59.35%
Epoch [8/10], Training Loss: 0.868, Validation Accuracy: 59.67%
Epoch [9/10], Training Loss: 0.853, Validation Accuracy: 59.74%
Epoch [10/10], Training Loss: 0.839, Validation Accuracy: 60.45%
Epoch [1/10], Training Loss: 1.004, Validation Accuracy: 60.10%
Epoch [2/10], Training Loss: 0.945, Validation Accuracy: 60.73%
Epoch [3/10], Training Loss: 0.914, Validation Accuracy: 60.89%
Epoch [4/10], Training Loss: 0.900, Validation Accuracy: 60.71%
Epoch [5/10], Training Loss: 0.861, Validation Accuracy: 60.62%
Epoch [6/10], Training Loss: 0.853, Validation Accuracy: 61.08%
Epoch [7/10], Training Loss: 0.831, Validation Accuracy: 60.20%
Epoch [8/10], Training Loss: 0.814, Validation Accuracy: 60.47%
Epoch [9/10], Training Loss: 0.797, Validation Accuracy: 60.93%
Epoch [10/10], Training Loss: 0.791, Validation Accuracy: 60.74%
Epoch [1/10], Training Loss: 1.014, Validation Accuracy: 60.54%
Epoch [2/10], Training Loss: 0.951, Validation Accuracy: 60.95%
Epoch [3/10], Training Loss: 0.928, Validation Accuracy: 60.64%
Epoch [4/10], Training Loss: 0.885, Validation Accuracy: 60.95%
Epoch [5/10], Training Loss: 0.868, Validation Accuracy: 60.51%
Epoch [6/10], Training Loss: 0.850, Validation Accuracy: 60.26%
Epoch [7/10], Training Loss: 0.845, Validation Accuracy: 60.55%
Epoch [8/10], Training Loss: 0.818, Validation Accuracy: 60.30%
Epoch [9/10], Training Loss: 0.806, Validation Accuracy: 59.54%
Epoch [10/10], Training Loss: 0.794, Validation Accuracy: 60.17%
Epoch [1/10], Training Loss: 1.008, Validation Accuracy: 59.57%
Epoch [2/10], Training Loss: 0.949, Validation Accuracy: 60.94%
Epoch [3/10], Training Loss: 0.908, Validation Accuracy: 60.97%
Epoch [4/10], Training Loss: 0.884, Validation Accuracy: 61.13%
Epoch [5/10], Training Loss: 0.871, Validation Accuracy: 61.35%
Epoch [6/10], Training Loss: 0.849, Validation Accuracy: 61.07%
Epoch [7/10], Training Loss: 0.828, Validation Accuracy: 61.42%
Epoch [8/10], Training Loss: 0.811, Validation Accuracy: 61.73%
Epoch [9/10], Training Loss: 0.799, Validation Accuracy: 61.10%
Epoch [10/10], Training Loss: 0.777, Validation Accuracy: 60.39%
Epoch [1/10], Training Loss: 0.999, Validation Accuracy: 60.46%
Epoch [2/10], Training Loss: 0.927, Validation Accuracy: 60.90%
Epoch [3/10], Training Loss: 0.893, Validation Accuracy: 60.70%
Epoch [4/10], Training Loss: 0.870, Validation Accuracy: 60.59%
Epoch [5/10], Training Loss: 0.847, Validation Accuracy: 60.78%
Epoch [6/10], Training Loss: 0.828, Validation Accuracy: 60.03%
Epoch [7/10], Training Loss: 0.806, Validation Accuracy: 60.35%
Epoch [8/10], Training Loss: 0.793, Validation Accuracy: 60.34%
Epoch [9/10], Training Loss: 0.769, Validation Accuracy: 60.53%
Epoch [10/10], Training Loss: 0.756, Validation Accuracy: 60.13%
Epoch [1/10], Training Loss: 0.979, Validation Accuracy: 60.64%
Epoch [2/10], Training Loss: 0.917, Validation Accuracy: 60.00%
Epoch [3/10], Training Loss: 0.876, Validation Accuracy: 61.40%
Epoch [4/10], Training Loss: 0.849, Validation Accuracy: 60.23%
Epoch [5/10], Training Loss: 0.836, Validation Accuracy: 61.03%
Epoch [6/10], Training Loss: 0.803, Validation Accuracy: 60.53%
Epoch [7/10], Training Loss: 0.782, Validation Accuracy: 61.30%
Epoch [8/10], Training Loss: 0.770, Validation Accuracy: 60.76%
Epoch [9/10], Training Loss: 0.749, Validation Accuracy: 60.95%
Epoch [10/10], Training Loss: 0.727, Validation Accuracy: 61.11%
Epoch [1/10], Training Loss: 0.949, Validation Accuracy: 61.20%
Epoch [2/10], Training Loss: 0.874, Validation Accuracy: 60.76%
Epoch [3/10], Training Loss: 0.839, Validation Accuracy: 61.18%
Epoch [4/10], Training Loss: 0.801, Validation Accuracy: 61.49%
Epoch [5/10], Training Loss: 0.777, Validation Accuracy: 61.84%
Epoch [6/10], Training Loss: 0.758, Validation Accuracy: 61.31%
Epoch [7/10], Training Loss: 0.730, Validation Accuracy: 61.41%
Epoch [8/10], Training Loss: 0.730, Validation Accuracy: 61.11%
Epoch [9/10], Training Loss: 0.704, Validation Accuracy: 61.39%
Epoch [10/10], Training Loss: 0.690, Validation Accuracy: 61.23%
Epoch [1/10], Training Loss: 0.968, Validation Accuracy: 61.31%
Epoch [2/10], Training Loss: 0.894, Validation Accuracy: 61.39%
Epoch [3/10], Training Loss: 0.852, Validation Accuracy: 60.85%
Epoch [4/10], Training Loss: 0.820, Validation Accuracy: 61.04%
Epoch [5/10], Training Loss: 0.796, Validation Accuracy: 61.41%
Epoch [6/10], Training Loss: 0.766, Validation Accuracy: 61.15%
Epoch [7/10], Training Loss: 0.749, Validation Accuracy: 61.09%
Epoch [8/10], Training Loss: 0.732, Validation Accuracy: 60.16%
Epoch [9/10], Training Loss: 0.710, Validation Accuracy: 61.23%
Epoch [10/10], Training Loss: 0.693, Validation Accuracy: 60.86%
Epoch [1/10], Training Loss: 0.963, Validation Accuracy: 59.94%
Epoch [2/10], Training Loss: 0.889, Validation Accuracy: 61.78%
Epoch [3/10], Training Loss: 0.843, Validation Accuracy: 60.82%
Epoch [4/10], Training Loss: 0.826, Validation Accuracy: 61.40%
Epoch [5/10], Training Loss: 0.795, Validation Accuracy: 62.43%
Epoch [6/10], Training Loss: 0.760, Validation Accuracy: 62.30%
Epoch [7/10], Training Loss: 0.742, Validation Accuracy: 61.08%
Epoch [8/10], Training Loss: 0.725, Validation Accuracy: 61.46%
Epoch [9/10], Training Loss: 0.708, Validation Accuracy: 60.98%
Epoch [10/10], Training Loss: 0.689, Validation Accuracy: 61.03%
Epoch [1/10], Training Loss: 0.932, Validation Accuracy: 61.41%
Epoch [2/10], Training Loss: 0.856, Validation Accuracy: 60.87%
Epoch [3/10], Training Loss: 0.816, Validation Accuracy: 61.03%
Epoch [4/10], Training Loss: 0.784, Validation Accuracy: 61.44%
Epoch [5/10], Training Loss: 0.766, Validation Accuracy: 61.36%
Epoch [6/10], Training Loss: 0.732, Validation Accuracy: 61.28%
Epoch [7/10], Training Loss: 0.708, Validation Accuracy: 60.37%
Epoch [8/10], Training Loss: 0.687, Validation Accuracy: 60.74%
Epoch [9/10], Training Loss: 0.674, Validation Accuracy: 60.86%
Epoch [10/10], Training Loss: 0.652, Validation Accuracy: 61.28%
Epoch [1/10], Training Loss: 0.936, Validation Accuracy: 60.56%
Epoch [2/10], Training Loss: 0.850, Validation Accuracy: 61.32%
Epoch [3/10], Training Loss: 0.796, Validation Accuracy: 61.47%
Epoch [4/10], Training Loss: 0.772, Validation Accuracy: 61.45%
Epoch [5/10], Training Loss: 0.739, Validation Accuracy: 61.52%
Epoch [6/10], Training Loss: 0.710, Validation Accuracy: 61.68%
Epoch [7/10], Training Loss: 0.686, Validation Accuracy: 59.85%
Epoch [8/10], Training Loss: 0.683, Validation Accuracy: 61.43%
Epoch [9/10], Training Loss: 0.659, Validation Accuracy: 61.04%
Epoch [10/10], Training Loss: 0.637, Validation Accuracy: 61.42%
Epoch [1/10], Training Loss: 0.893, Validation Accuracy: 60.67%
Epoch [2/10], Training Loss: 0.816, Validation Accuracy: 61.44%
Epoch [3/10], Training Loss: 0.767, Validation Accuracy: 61.69%
Epoch [4/10], Training Loss: 0.721, Validation Accuracy: 61.99%
Epoch [5/10], Training Loss: 0.693, Validation Accuracy: 61.38%
Epoch [6/10], Training Loss: 0.666, Validation Accuracy: 61.40%
Epoch [7/10], Training Loss: 0.644, Validation Accuracy: 61.46%
Epoch [8/10], Training Loss: 0.630, Validation Accuracy: 61.20%
Epoch [9/10], Training Loss: 0.610, Validation Accuracy: 61.29%
Epoch [10/10], Training Loss: 0.588, Validation Accuracy: 61.55%
Epoch [1/10], Training Loss: 0.918, Validation Accuracy: 60.70%
Epoch [2/10], Training Loss: 0.827, Validation Accuracy: 60.99%
Epoch [3/10], Training Loss: 0.778, Validation Accuracy: 61.21%
Epoch [4/10], Training Loss: 0.733, Validation Accuracy: 61.79%
Epoch [5/10], Training Loss: 0.697, Validation Accuracy: 61.07%
Epoch [6/10], Training Loss: 0.685, Validation Accuracy: 61.23%
Epoch [7/10], Training Loss: 0.655, Validation Accuracy: 60.96%
Epoch [8/10], Training Loss: 0.636, Validation Accuracy: 60.93%
Epoch [9/10], Training Loss: 0.618, Validation Accuracy: 60.89%
Epoch [10/10], Training Loss: 0.593, Validation Accuracy: 60.59%
Epoch [1/10], Training Loss: 0.921, Validation Accuracy: 61.38%
Epoch [2/10], Training Loss: 0.833, Validation Accuracy: 61.53%
Epoch [3/10], Training Loss: 0.769, Validation Accuracy: 62.04%
Epoch [4/10], Training Loss: 0.738, Validation Accuracy: 61.20%
Epoch [5/10], Training Loss: 0.707, Validation Accuracy: 60.26%
Epoch [6/10], Training Loss: 0.689, Validation Accuracy: 61.20%
Epoch [7/10], Training Loss: 0.656, Validation Accuracy: 61.72%
Epoch [8/10], Training Loss: 0.629, Validation Accuracy: 60.75%
Epoch [9/10], Training Loss: 0.611, Validation Accuracy: 61.19%
Epoch [10/10], Training Loss: 0.597, Validation Accuracy: 61.34%
Epoch [1/10], Training Loss: 0.895, Validation Accuracy: 61.63%
Epoch [2/10], Training Loss: 0.794, Validation Accuracy: 61.48%
Epoch [3/10], Training Loss: 0.738, Validation Accuracy: 60.74%
Epoch [4/10], Training Loss: 0.703, Validation Accuracy: 61.27%
Epoch [5/10], Training Loss: 0.668, Validation Accuracy: 61.29%
Epoch [6/10], Training Loss: 0.645, Validation Accuracy: 60.45%
Epoch [7/10], Training Loss: 0.622, Validation Accuracy: 60.83%
Epoch [8/10], Training Loss: 0.598, Validation Accuracy: 61.05%
Epoch [9/10], Training Loss: 0.584, Validation Accuracy: 60.90%
Epoch [10/10], Training Loss: 0.568, Validation Accuracy: 60.99%
Epoch [1/10], Training Loss: 0.888, Validation Accuracy: 61.18%
Epoch [2/10], Training Loss: 0.784, Validation Accuracy: 61.46%
Epoch [3/10], Training Loss: 0.726, Validation Accuracy: 61.18%
Epoch [4/10], Training Loss: 0.685, Validation Accuracy: 61.54%
Epoch [5/10], Training Loss: 0.648, Validation Accuracy: 61.52%
Epoch [6/10], Training Loss: 0.632, Validation Accuracy: 61.13%
Epoch [7/10], Training Loss: 0.599, Validation Accuracy: 61.35%
Epoch [8/10], Training Loss: 0.583, Validation Accuracy: 61.31%
Epoch [9/10], Training Loss: 0.559, Validation Accuracy: 61.37%
Epoch [10/10], Training Loss: 0.543, Validation Accuracy: 60.80%
Epoch [1/10], Training Loss: 0.857, Validation Accuracy: 60.62%
Epoch [2/10], Training Loss: 0.748, Validation Accuracy: 61.31%
Epoch [3/10], Training Loss: 0.686, Validation Accuracy: 61.25%
Epoch [4/10], Training Loss: 0.644, Validation Accuracy: 61.11%
Epoch [5/10], Training Loss: 0.616, Validation Accuracy: 61.44%
Epoch [6/10], Training Loss: 0.587, Validation Accuracy: 61.97%
Epoch [7/10], Training Loss: 0.554, Validation Accuracy: 61.81%
Epoch [8/10], Training Loss: 0.540, Validation Accuracy: 61.25%
Epoch [9/10], Training Loss: 0.520, Validation Accuracy: 61.53%
Epoch [10/10], Training Loss: 0.512, Validation Accuracy: 61.52%
Epoch [1/10], Training Loss: 0.883, Validation Accuracy: 60.91%
Epoch [2/10], Training Loss: 0.765, Validation Accuracy: 60.37%
Epoch [3/10], Training Loss: 0.721, Validation Accuracy: 60.91%
Epoch [4/10], Training Loss: 0.667, Validation Accuracy: 61.35%
Epoch [5/10], Training Loss: 0.626, Validation Accuracy: 60.49%
Epoch [6/10], Training Loss: 0.607, Validation Accuracy: 61.29%
Epoch [7/10], Training Loss: 0.576, Validation Accuracy: 61.34%
Epoch [8/10], Training Loss: 0.552, Validation Accuracy: 60.05%
Epoch [9/10], Training Loss: 0.539, Validation Accuracy: 60.84%
Epoch [10/10], Training Loss: 0.522, Validation Accuracy: 60.98%
Epoch [1/10], Training Loss: 0.903, Validation Accuracy: 59.76%
Epoch [2/10], Training Loss: 0.767, Validation Accuracy: 60.85%
Epoch [3/10], Training Loss: 0.701, Validation Accuracy: 61.15%
Epoch [4/10], Training Loss: 0.656, Validation Accuracy: 61.13%
Epoch [5/10], Training Loss: 0.625, Validation Accuracy: 61.00%
Epoch [6/10], Training Loss: 0.599, Validation Accuracy: 61.28%
Epoch [7/10], Training Loss: 0.580, Validation Accuracy: 61.16%
Epoch [8/10], Training Loss: 0.560, Validation Accuracy: 61.03%
Epoch [9/10], Training Loss: 0.525, Validation Accuracy: 60.70%
Epoch [10/10], Training Loss: 0.503, Validation Accuracy: 61.17%
Epoch [1/10], Training Loss: 0.853, Validation Accuracy: 60.83%
Epoch [2/10], Training Loss: 0.737, Validation Accuracy: 61.56%
Epoch [3/10], Training Loss: 0.672, Validation Accuracy: 61.25%
Epoch [4/10], Training Loss: 0.630, Validation Accuracy: 60.95%
Epoch [5/10], Training Loss: 0.588, Validation Accuracy: 60.57%
Epoch [6/10], Training Loss: 0.565, Validation Accuracy: 61.26%
Epoch [7/10], Training Loss: 0.540, Validation Accuracy: 60.21%
Epoch [8/10], Training Loss: 0.530, Validation Accuracy: 60.14%
Epoch [9/10], Training Loss: 0.493, Validation Accuracy: 59.85%
Epoch [10/10], Training Loss: 0.466, Validation Accuracy: 59.98%
Epoch [1/10], Training Loss: 0.853, Validation Accuracy: 60.70%
Epoch [2/10], Training Loss: 0.733, Validation Accuracy: 60.36%
Epoch [3/10], Training Loss: 0.660, Validation Accuracy: 61.07%
Epoch [4/10], Training Loss: 0.608, Validation Accuracy: 61.37%
Epoch [5/10], Training Loss: 0.567, Validation Accuracy: 61.24%
Epoch [6/10], Training Loss: 0.537, Validation Accuracy: 61.21%
Epoch [7/10], Training Loss: 0.519, Validation Accuracy: 61.08%
Epoch [8/10], Training Loss: 0.488, Validation Accuracy: 61.47%
Epoch [9/10], Training Loss: 0.469, Validation Accuracy: 60.38%
Epoch [10/10], Training Loss: 0.460, Validation Accuracy: 60.49%
Epoch [1/10], Training Loss: 0.814, Validation Accuracy: 60.29%
Epoch [2/10], Training Loss: 0.687, Validation Accuracy: 61.36%
Epoch [3/10], Training Loss: 0.610, Validation Accuracy: 60.48%
Epoch [4/10], Training Loss: 0.590, Validation Accuracy: 62.01%
Epoch [5/10], Training Loss: 0.541, Validation Accuracy: 61.52%
Epoch [6/10], Training Loss: 0.506, Validation Accuracy: 61.05%
Epoch [7/10], Training Loss: 0.472, Validation Accuracy: 61.31%
Epoch [8/10], Training Loss: 0.462, Validation Accuracy: 61.05%
Epoch [9/10], Training Loss: 0.433, Validation Accuracy: 60.84%
Epoch [10/10], Training Loss: 0.412, Validation Accuracy: 60.82%
Epoch [1/10], Training Loss: 0.844, Validation Accuracy: 59.78%
Epoch [2/10], Training Loss: 0.712, Validation Accuracy: 60.97%
Epoch [3/10], Training Loss: 0.642, Validation Accuracy: 61.14%
Epoch [4/10], Training Loss: 0.588, Validation Accuracy: 60.75%
Epoch [5/10], Training Loss: 0.552, Validation Accuracy: 61.14%
Epoch [6/10], Training Loss: 0.530, Validation Accuracy: 61.50%
Epoch [7/10], Training Loss: 0.497, Validation Accuracy: 60.76%
Epoch [8/10], Training Loss: 0.484, Validation Accuracy: 60.88%
Epoch [9/10], Training Loss: 0.445, Validation Accuracy: 60.46%
Epoch [10/10], Training Loss: 0.434, Validation Accuracy: 60.17%
Epoch [1/10], Training Loss: 0.855, Validation Accuracy: 60.90%
Epoch [2/10], Training Loss: 0.701, Validation Accuracy: 61.00%
Epoch [3/10], Training Loss: 0.635, Validation Accuracy: 61.05%
Epoch [4/10], Training Loss: 0.578, Validation Accuracy: 60.20%
Epoch [5/10], Training Loss: 0.537, Validation Accuracy: 61.16%
Epoch [6/10], Training Loss: 0.505, Validation Accuracy: 61.05%
Epoch [7/10], Training Loss: 0.488, Validation Accuracy: 61.03%
Epoch [8/10], Training Loss: 0.460, Validation Accuracy: 60.58%
Epoch [9/10], Training Loss: 0.447, Validation Accuracy: 60.49%
Epoch [10/10], Training Loss: 0.416, Validation Accuracy: 60.63%
"""

# Regular expression to find validation accuracies
accuracies = re.findall(r'Validation Accuracy: (\d+\.\d+)%', log)

# Convert accuracies from string to float
accuracies = [float(acc) for acc in accuracies]

# Print accuracies
print("Accuracies:", accuracies)

# Print size of the array
print("Size of array:", len(accuracies))


Accuracies: [10.34, 10.3, 10.49, 11.05, 11.75, 12.24, 12.45, 15.12, 17.4, 18.49, 18.26, 18.88, 20.73, 22.43, 23.29, 26.29, 27.01, 27.2, 27.19, 27.5, 26.39, 28.42, 28.6, 28.81, 29.77, 29.89, 30.61, 31.12, 31.63, 31.96, 33.55, 32.8, 33.4, 34.79, 34.94, 35.74, 35.06, 36.08, 36.7, 35.39, 37.45, 37.89, 38.94, 39.32, 39.46, 39.68, 39.67, 40.03, 40.11, 41.04, 41.27, 41.38, 42.16, 42.49, 42.07, 43.36, 43.2, 43.27, 44.33, 43.43, 44.09, 44.37, 44.4, 43.98, 45.08, 45.37, 45.18, 45.15, 45.85, 46.03, 46.43, 45.51, 46.53, 46.31, 46.87, 46.87, 47.37, 46.94, 47.77, 47.28, 48.27, 48.57, 48.05, 48.28, 48.74, 48.92, 47.21, 49.35, 48.97, 49.21, 49.23, 48.86, 49.86, 49.93, 49.84, 49.44, 50.97, 50.44, 51.3, 50.78, 51.51, 51.64, 51.77, 50.46, 51.98, 51.86, 51.88, 51.45, 52.23, 52.37, 52.06, 51.97, 53.12, 52.7, 52.39, 52.49, 52.72, 52.09, 52.49, 52.2, 53.35, 53.0, 52.74, 53.15, 53.42, 53.79, 53.46, 53.54, 53.58, 52.89, 54.33, 53.09, 54.03, 54.15, 53.98, 54.18, 54.87, 54.66, 54.34, 54.49, 55.55, 55.03, 55.06, 

In [None]:
import numpy as np
from scipy.stats import truncnorm
import torch
import torchvision
import torchvision.transforms as transforms
from torch import nn
from torch.utils.data import Subset, DataLoader
from torchvision.datasets import CIFAR10
from typing import Dict, Tuple
from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score
import random

# Define VAE model
class VAE(nn.Module):
    def __init__(self, x_dim, h_dim, z_dim):
        super(VAE, self).__init__()
        self.x_dim = x_dim
        self.h_dim = h_dim
        self.z_dim = z_dim

        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, 4, stride=2, padding=1),  # 16x16
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2, padding=1),  # 8x8
            nn.ReLU(),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),  # 4x4
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(2048, 256),
            nn.ReLU(),
            nn.Linear(256, 2*z_dim)
        )

        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(z_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 2048),
            nn.ReLU(),
            nn.Unflatten(1, (128, 4, 4)),
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),  # 8x8
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1),  # 16x16
            nn.ReLU(),
            nn.ConvTranspose2d(32, 3, 4, stride=2, padding=1),  # 32x32
            nn.Sigmoid()
        )

    def encode(self, x):
        h = self.encoder(x)
        mu, logvar = h[:, :self.z_dim], h[:, self.z_dim:]
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        return self.decoder(z)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar


# Define VAE training procedure
def vae_train(vae: VAE, trainloader: DataLoader, epochs: int) -> None:
    optimizer = torch.optim.Adam(vae.parameters(), lr=1e-3)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)

    for epoch in range(epochs):
        vae.train()
        for i, data in enumerate(trainloader, 0):
            inputs, _ = data
            optimizer.zero_grad()
            recon_x, mu, logvar = vae(inputs)
            loss = vae_loss(recon_x, inputs, mu, logvar)
            loss.backward()
            optimizer.step()

        scheduler.step()


# Define classification model
class Net(nn.Module):
    def __init__(self) -> None:
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.pool(nn.functional.relu(self.conv1(x)))
        x = self.pool(nn.functional.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = nn.functional.relu(self.fc1(x))
        x = nn.functional.relu(self.fc2(x))
        x = self.fc3(x)
        return x


# Load CIFAR10 dataset
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)

full_dataset = CIFAR10(root="./data", train=True, download=True, transform=transform)
test_set = CIFAR10(root="./data", train=False, download=True, transform=transform)

# Separate dataset by class
class_indices = {i: [] for i in range(10)}  # CIFAR-10 has 10 classes
for idx, (_, label) in enumerate(full_dataset):
    class_indices[label].append(idx)

# Define target count per class, summing to 60,000 with random distribution
class_counts = np.random.multinomial(60000, [0.1] * 10)  # Adjust probabilities if you want specific class biases
print("Random Images per Class:", class_counts)

# Sample indices based on the specified class counts
indices = []
for class_id, count in enumerate(class_counts):
    # Ensure count does not exceed available images
    count = min(count, len(class_indices[class_id]))
    selected_indices = random.sample(class_indices[class_id], count)
    indices.extend(selected_indices)

# Create a custom CIFAR-10 dataset with the sampled indices
custom_dataset = Subset(full_dataset, indices)
# Split dataset into training and validation sets
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_set, val_set = torch.utils.data.random_split(full_dataset, [train_size, val_size])

# Create training and validation loaders
trainloader = DataLoader(train_set, batch_size=128, shuffle=True, num_workers=2)
valloader = DataLoader(val_set, batch_size=128, shuffle=False, num_workers=2)
testloader = DataLoader(test_set, batch_size=128, shuffle=False, num_workers=2)



# Define VAE loss function
def vae_loss(recon_x, x, mu, logvar):
    BCE = nn.functional.mse_loss(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

# Define training procedure for classification model
def train(net: nn.Module, trainloader: DataLoader, valloader: DataLoader, epochs: int) -> None:
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)

    best_acc = 0.0

    for epoch in range(epochs):
        # Training loop
        net.train()
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        scheduler.step()

        # Validation loop
        net.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for data in valloader:
                images, labels = data
                outputs = net(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        val_acc = 100 * correct / total
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(net.state_dict(), 'best_model.pth')

        print(f"Epoch [{epoch+1}/{epochs}], Training Loss: {running_loss / (i+1):.3f}, Validation Accuracy: {val_acc:.2f}%")

# Define evaluation procedure
def evaluate(net: nn.Module, testloader: DataLoader) -> Tuple[float, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    net.eval()
    correct = 0
    total = 0
    all_labels = []
    all_predictions = []
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            all_labels.extend(labels.cpu().numpy())
            all_predictions.extend(predicted.cpu().numpy())

    accuracy = 100 * correct / total

    # Calculate confusion matrix
    cm = confusion_matrix(all_labels, all_predictions)
    print(f"Confusion Matrix:\n{cm}")

    # Calculate TP, FP, TN, FN for each class
    tp = np.diag(cm)
    fp = cm.sum(axis=0) - tp
    fn = cm.sum(axis=1) - tp
    tn = cm.sum() - (fp + fn + tp)

    # Calculate precision, recall, and F1 score for each class
    precision = precision_score(all_labels, all_predictions, average=None)
    recall = recall_score(all_labels, all_predictions, average=None)
    f1 = f1_score(all_labels, all_predictions, average=None)

    return accuracy, tp, fp, tn, fn, precision, recall, f1

# Initialize clients
def initialize_clients(trainset, transform, batch_size, num_clients):
    clients = {}
    for i in range(num_clients):
        client_trainset = torch.utils.data.Subset(trainset, range(i * len(trainset) // num_clients, (i + 1) * len(trainset) // num_clients))
        client_trainloader = torch.utils.data.DataLoader(client_trainset, batch_size=batch_size, shuffle=True)
        clients[f"client_{i}"] = client_trainloader
    return clients

def get_distribution_info(vae: VAE) -> Dict:
    # Implement the logic to extract distribution information from the VAE
    # This can involve computing statistics, parameters, or any other relevant information
    # that can be used to generate augmented data

    # Example implementation:
    distribution_info = {
        "uniform": {
            "mean": vae.encoder[-1].bias.data.cpu().numpy(),
            "std": torch.exp(0.5 * vae.encoder[-1].weight.data).cpu().numpy()
        },
        "normal": {
            "mean": vae.encoder[-1].bias.data.cpu().numpy(),
            "std": torch.exp(0.5 * vae.encoder[-1].weight.data).cpu().numpy()
        }
    }

    return distribution_info

def send_distribution_info(distribution_info: Dict) -> None:
    # Implement the logic to send the distribution information to the global server
    # This can involve using a network protocol, a message queue, or any other communication mechanism
    # to send the information

    # Example implementation:
    # Send the distribution information to the global server using a network protocol
    # For example, you can use the `socket` module to send the information over a network
    # or use a message queue like `RabbitMQ` to send the information
    pass

def generate_augmented_data(vae: VAE, distribution_info_uniform: Dict, distribution_info_normal: Dict) -> torch.Tensor:
    # Generate augmented data using both uniform and truncated uniform distributions

    mean = distribution_info_uniform["mean"].mean().item()  # Convert numpy array to float
    std = distribution_info_uniform["std"].mean().item()  # Convert numpy array to float


    mean_normal = distribution_info_normal["mean"]
    std_normal = distribution_info_normal["std"]

    augmented_data_normal = torch.randn(64, vae.z_dim) * std_normal + mean_normal

     # Generate augmented data using Uniform distribution
    augmented_data_uniform = torch.FloatTensor(64, vae.z_dim).uniform_(mean - std, mean + std)


    # Calculate the average of augmented data from both distributions
    augmented_data_average = (augmented_data_uniform + augmented_data_normal ) / 2

    return augmented_data_average

def federated_train(net: nn.Module, vae: VAE, trainloaders: Dict[str, DataLoader], trainloader: DataLoader, valloader: DataLoader, epochs: int) -> None:
    for epoch in range(epochs):
        for client_id, client_trainloader in trainloaders.items():
            # Train VAE on client data
            vae_train(vae, client_trainloader, epochs=10)

            # Share distribution information with global server
            distribution_info = get_distribution_info(vae)
            send_distribution_info(distribution_info)

            # Receive distribution information from other clients
            other_distribution_info = receive_distribution_info()

            # Generate augmented data using received distribution information
            augmented_data = generate_augmented_data(vae, other_distribution_info["uniform"], other_distribution_info["normal"])

            # Train classification model using local, augmented, and validation data
            train(net, client_trainloader, valloader, epochs=10)

            # Send model updates to global server
            send_model_update(client_id, net.state_dict())

# Define logic to receive distribution information from global server
def receive_distribution_info() -> Dict:
    # Receive distribution information logic
    distribution_info = {
        "uniform": {
            "mean": np.zeros(20),  # Adjust the size based on your latent space dimension
            "std": np.ones(20)
        },
        "normal": {
            "mean": np.zeros(20),  # Adjust the size based on your latent space dimension
            "std": np.ones(20)
        }
    }
    return distribution_info

def send_model_update(client_id: str, model_update: Dict) -> None:
    # Implement the logic to send the model update to the global server
    # This can involve using a network protocol, a message queue, or any other communication mechanism
    # to send the model update

    # Example implementation:
    # Send the model update to the global server using a network protocol
    # For example, you can use the `socket` module to send the information over a network
    # or use a message queue like `RabbitMQ` to send the information
    pass

# Define global server procedure
def global_server() -> None:
    net = Net()
    x_dim = 3 * 32 * 32  # CIFAR-10 input size
    h_dim = 400
    z_dim = 20
    vae = VAE(x_dim, h_dim, z_dim)  # Initialize VAE object with required arguments

    # Initialize clients
    num_clients = 5  # Define the number of clients
    clients = initialize_clients(train_set, transform, batch_size=128, num_clients=num_clients)

    # Train model using FedDIS
    federated_train(net, vae, clients, trainloader, valloader, epochs=10)

    # Evaluate final model
    test_accuracy, tp, fp, tn, fn, precision, recall, f1 = evaluate(net, testloader)
    print(f"Test Accuracy: {test_accuracy:.2f}%")

    print("True Positives (TP):", tp)
    print("False Positives (FP):", fp)
    print("True Negatives (TN):", tn)
    print("False Negatives (FN):", fn)
    print(":", fn+tn+tp+fp)
    print("Precision:", precision)
    print("Recall:", recall)
    print("F1 Score:", f1)

if __name__ == "__main__":
    global_server()

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:07<00:00, 21.7MB/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified
Random Images per Class: [5915 5870 6025 6095 6076 6021 5977 5943 6058 6020]
Epoch [1/10], Training Loss: 2.304, Validation Accuracy: 9.99%
Epoch [2/10], Training Loss: 2.302, Validation Accuracy: 9.95%
Epoch [3/10], Training Loss: 2.301, Validation Accuracy: 10.47%
Epoch [4/10], Training Loss: 2.299, Validation Accuracy: 11.61%
Epoch [5/10], Training Loss: 2.297, Validation Accuracy: 13.41%
Epoch [6/10], Training Loss: 2.294, Validation Accuracy: 14.80%
Epoch [7/10], Training Loss: 2.291, Validation Accuracy: 15.18%
Epoch [8/10], Training Loss: 2.286, Validation Accuracy: 15.46%
Epoch [9/10], Training Loss: 2.280, Validation Accuracy: 15.18%
Epoch [10/10], Training Loss: 2.269, Validation Accuracy: 16.26%
Epoch [1/10], Training Loss: 2.259, Validation Accuracy: 16.59%
Epoch [2/10], Training Loss: 2.238, Validation Accuracy: 19.19%
Epoch [3/10], Training Loss: 2.208, Validation Accuracy: 19.83%
Epo