<a href="https://colab.research.google.com/github/Rahad31/Kl-FedDis-Research-/blob/main/Truncated.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from scipy.stats import truncnorm

In [None]:
import numpy as np
import torch
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision import transforms
from torch import nn
from typing import Dict

In [None]:
# Define VAE model
class VAE(nn.Module):
    def __init__(self, x_dim, h_dim, z_dim):
        super(VAE, self).__init__()
        self.x_dim = x_dim
        self.h_dim = h_dim
        self.z_dim = z_dim

        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, 4, stride=2, padding=1),  # 16x16
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2, padding=1),  # 8x8
            nn.ReLU(),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),  # 4x4
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(2048, 256),
            nn.ReLU(),
            nn.Linear(256, 2*z_dim)
        )

        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(z_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 2048),
            nn.ReLU(),
            nn.Unflatten(1, (128, 4, 4)),
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),  # 8x8
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1),  # 16x16
            nn.ReLU(),
            nn.ConvTranspose2d(32, 3, 4, stride=2, padding=1),  # 32x32
            nn.Sigmoid()
        )

    def encode(self, x):
        h = self.encoder(x)
        mu, logvar = h[:, :self.z_dim], h[:, self.z_dim:]
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        return self.decoder(z)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar


In [None]:
# Define VAE training procedure
def vae_train(vae: VAE, trainloader: DataLoader, epochs: int) -> None:
    optimizer = torch.optim.Adam(vae.parameters(), lr=1e-3)
    for epoch in range(epochs):
        for i, data in enumerate(trainloader, 0):
            inputs, _ = data
            optimizer.zero_grad()
            recon_x, mu, logvar = vae(inputs)
            loss = vae_loss(recon_x, inputs, mu, logvar)
            loss.backward()
            optimizer.step()


In [None]:

# Define classification model
class Net(nn.Module):
    def __init__(self) -> None:
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.pool(nn.functional.relu(self.conv1(x)))
        x = self.pool(nn.functional.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = nn.functional.relu(self.fc1(x))
        x = nn.functional.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [None]:
# Load CIFAR10 dataset
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)

full_dataset = CIFAR10(root="./data", train=True, download=True, transform=transform)
test_set = CIFAR10(root="./data", train=False, download=True, transform=transform)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:03<00:00, 48844265.98it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [None]:
# Split dataset into training and validation sets
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_set, val_set = torch.utils.data.random_split(full_dataset, [train_size, val_size])


In [None]:
# Create training and validation loaders
trainloader = DataLoader(train_set, batch_size=128, shuffle=True, num_workers=2)
valloader = DataLoader(val_set, batch_size=128, shuffle=False, num_workers=2)
testloader = DataLoader(test_set, batch_size=128, shuffle=False, num_workers=2)


In [None]:
# Define VAE loss function
def vae_loss(recon_x, x, mu, logvar):
    BCE = nn.functional.mse_loss(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

In [None]:
# Define training procedure for classification model
def train(net: nn.Module, trainloader: DataLoader, valloader: DataLoader, epochs: int) -> None:
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

    best_acc = 0.0

    for epoch in range(epochs):
        # Training loop
        net.train()
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        # Validation loop
        net.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for data in valloader:
                images, labels = data
                outputs = net(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        val_acc = 100 * correct / total
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(net.state_dict(), 'best_model.pth')

        print(f"Epoch [{epoch+1}/{epochs}], Training Loss: {running_loss / (i+1):.3f}, Validation Accuracy: {val_acc:.2f}%")

In [None]:
# Define evaluation procedure
def evaluate(net: nn.Module, testloader: DataLoader) -> float:
    net.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    return accuracy

In [None]:
def initialize_clients(trainset, transform, batch_size, num_clients):
    clients = {}
    for i in range(num_clients):
        client_trainset = torch.utils.data.Subset(trainset, range(i * len(trainset) // num_clients, (i + 1) * len(trainset) // num_clients))
        client_trainloader = torch.utils.data.DataLoader(client_trainset, batch_size=batch_size, shuffle=True)
        clients[f"client_{i}"] = client_trainloader
    return clients

In [None]:
def get_distribution_info(vae: VAE) -> Dict:
    # Implement the logic to extract distribution information from the VAE
    # This can involve computing statistics, parameters, or any other relevant information
    # that can be used to generate augmented data

    # Example implementation:
    distribution_info = {
        "mean": vae.encoder[-1].bias.data.cpu().numpy(),
        "std": torch.exp(0.5 * vae.encoder[-1].weight.data).cpu().numpy()
    }

    return distribution_info

In [None]:
def send_distribution_info(distribution_info: Dict) -> None:
    # Implement the logic to send the distribution information to the global server
    # This can involve using a network protocol, a message queue, or any other communication mechanism
    # to send the information

    # Example implementation:
    # Send the distribution information to the global server using a network protocol
    # For example, you can use the `socket` module to send the information over a network
    # or use a message queue like `RabbitMQ` to send the information
    pass


In [None]:
# Define logic to generate augmented data using Truncated Normal distribution
def generate_augmented_data(vae: VAE, distribution_info: Dict) -> torch.Tensor:
    # Generate augmented data using Truncated Normal distribution
    mean = distribution_info["mean"]
    std = distribution_info["std"]
    a = (0 - mean) / std
    b = np.inf
    augmented_data = torch.from_numpy(truncnorm.rvs(a, b, loc=mean, scale=std, size=(64, vae.z_dim))).float()
    return augmented_data

In [None]:
def federated_train(net: nn.Module, vae: VAE, trainloaders: Dict[str, DataLoader], trainloader: DataLoader, valloader: DataLoader, epochs: int) -> None:
    for epoch in range(epochs):
        for client_id, client_trainloader in trainloaders.items():
            # Train VAE on client data
            vae_train(vae, client_trainloader, epochs=10)

            # Share distribution information with global server
            distribution_info = get_distribution_info(vae)
            send_distribution_info(distribution_info)

            # Receive distribution information from other clients
            other_distribution_info = receive_distribution_info()

            # Generate augmented data using received distribution information
            augmented_data = generate_augmented_data(vae, other_distribution_info)

            # Train classification model using local, augmented, and validation data
            train(net, client_trainloader, valloader, epochs=10)

            # Send model updates to global server
            send_model_update(client_id, net.state_dict())

In [None]:
# Define logic to receive distribution information from global server
def receive_distribution_info() -> Dict:
    # Receive distribution information logic
    distribution_info = {
        "mean": np.zeros(20),  # Adjust the size based on your latent space dimension
        "std": np.ones(20)
    }
    return distribution_info

In [None]:
def send_model_update(client_id: str, model_update: Dict) -> None:
    # Implement the logic to send the model update to the global server
    # This can involve using a network protocol, a message queue, or any other communication mechanism
    # to send the model update

    # Example implementation:
    # Send the model update to the global server using a network protocol
    # For example, you can use the `socket` module to send the information over a network
    # or use a message queue like `RabbitMQ` to send the information
    pass

In [None]:
# Define global server procedure
def global_server() -> None:
    net = Net()
    x_dim = 3 * 32 * 32  # CIFAR-10 input size
    h_dim = 400
    z_dim = 20
    vae = VAE(x_dim, h_dim, z_dim)  # Initialize VAE object with required arguments

    # Initialize clients
    num_clients = 5  # Define the number of clients
    clients = initialize_clients(train_set, transform, batch_size=128, num_clients=num_clients)

    # Train model using FedDIS
    federated_train(net, vae, clients, trainloader, valloader, epochs=10)

    # Evaluate final model
    test_accuracy = evaluate(net, testloader)
    print(f"Test Accuracy: {test_accuracy:.2f}%")

if __name__ == "__main__":
    global_server()


Epoch [1/10], Training Loss: 2.305, Validation Accuracy: 9.84%
Epoch [2/10], Training Loss: 2.304, Validation Accuracy: 10.01%
Epoch [3/10], Training Loss: 2.303, Validation Accuracy: 10.99%
Epoch [4/10], Training Loss: 2.301, Validation Accuracy: 11.45%
Epoch [5/10], Training Loss: 2.300, Validation Accuracy: 11.84%
Epoch [6/10], Training Loss: 2.299, Validation Accuracy: 12.89%
Epoch [7/10], Training Loss: 2.297, Validation Accuracy: 13.90%
Epoch [8/10], Training Loss: 2.295, Validation Accuracy: 15.55%
Epoch [9/10], Training Loss: 2.292, Validation Accuracy: 16.91%
Epoch [10/10], Training Loss: 2.288, Validation Accuracy: 17.83%
Epoch [1/10], Training Loss: 2.285, Validation Accuracy: 18.46%
Epoch [2/10], Training Loss: 2.277, Validation Accuracy: 18.37%
Epoch [3/10], Training Loss: 2.265, Validation Accuracy: 20.00%
Epoch [4/10], Training Loss: 2.245, Validation Accuracy: 23.10%
Epoch [5/10], Training Loss: 2.211, Validation Accuracy: 23.20%
Epoch [6/10], Training Loss: 2.162, Vali

In [None]:
from scipy.stats import truncnorm
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
from torch import nn
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from typing import Dict, Tuple

# Define VAE model
class VAE(nn.Module):
    def __init__(self, x_dim, h_dim, z_dim):
        super(VAE, self).__init__()
        self.x_dim = x_dim
        self.h_dim = h_dim
        self.z_dim = z_dim

        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, 4, stride=2, padding=1),  # 16x16
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2, padding=1),  # 8x8
            nn.ReLU(),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),  # 4x4
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(2048, 256),
            nn.ReLU(),
            nn.Linear(256, 2*z_dim)
        )

        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(z_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 2048),
            nn.ReLU(),
            nn.Unflatten(1, (128, 4, 4)),
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),  # 8x8
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1),  # 16x16
            nn.ReLU(),
            nn.ConvTranspose2d(32, 3, 4, stride=2, padding=1),  # 32x32
            nn.Sigmoid()
        )

    def encode(self, x):
        h = self.encoder(x)
        mu, logvar = h[:, :self.z_dim], h[:, self.z_dim:]
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        return self.decoder(z)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar


# Define VAE training procedure
def vae_train(vae: VAE, trainloader: DataLoader, epochs: int) -> None:
    optimizer = torch.optim.Adam(vae.parameters(), lr=1e-3)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)

    for epoch in range(epochs):
        vae.train()
        for i, data in enumerate(trainloader, 0):
            inputs, _ = data
            optimizer.zero_grad()
            recon_x, mu, logvar = vae(inputs)
            loss = vae_loss(recon_x, inputs, mu, logvar)
            loss.backward()
            optimizer.step()

        scheduler.step()


# Define classification model
class Net(nn.Module):
    def __init__(self) -> None:
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.pool(nn.functional.relu(self.conv1(x)))
        x = self.pool(nn.functional.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = nn.functional.relu(self.fc1(x))
        x = nn.functional.relu(self.fc2(x))
        x = self.fc3(x)
        return x


# Load CIFAR10 dataset
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)

full_dataset = CIFAR10(root="./data", train=True, download=True, transform=transform)
test_set = CIFAR10(root="./data", train=False, download=True, transform=transform)

# Split dataset into training and validation sets
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_set, val_set = torch.utils.data.random_split(full_dataset, [train_size, val_size])

# Create training and validation loaders
trainloader = DataLoader(train_set, batch_size=128, shuffle=True, num_workers=2)
valloader = DataLoader(val_set, batch_size=128, shuffle=False, num_workers=2)
testloader = DataLoader(test_set, batch_size=128, shuffle=False, num_workers=2)



# Define VAE loss function
def vae_loss(recon_x, x, mu, logvar):
    BCE = nn.functional.mse_loss(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

# Define training procedure for classification model
def train(net: nn.Module, trainloader: DataLoader, valloader: DataLoader, epochs: int) -> None:
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)

    best_acc = 0.0

    for epoch in range(epochs):
        # Training loop
        net.train()
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        scheduler.step()

        # Validation loop
        net.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for data in valloader:
                images, labels = data
                outputs = net(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        val_acc = 100 * correct / total
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(net.state_dict(), 'best_model.pth')

        print(f"Epoch [{epoch+1}/{epochs}], Training Loss: {running_loss / (i+1):.3f}, Validation Accuracy: {val_acc:.2f}%")




# Define evaluation procedure
def evaluate(net: nn.Module, testloader: DataLoader) -> float:
    net.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    return accuracy




# Initialize clients
def initialize_clients(trainset, transform, batch_size, num_clients):
    clients = {}
    for i in range(num_clients):
        client_trainset = torch.utils.data.Subset(trainset, range(i * len(trainset) // num_clients, (i + 1) * len(trainset) // num_clients))
        client_trainloader = torch.utils.data.DataLoader(client_trainset, batch_size=batch_size, shuffle=True)
        clients[f"client_{i}"] = client_trainloader
    return clients

def get_distribution_info(vae: VAE) -> Dict:
    # Implement the logic to extract distribution information from the VAE
    # This can involve computing statistics, parameters, or any other relevant information
    # that can be used to generate augmented data

    # Example implementation:
    distribution_info = {
        "mean": vae.encoder[-1].bias.data.cpu().numpy(),
        "std": torch.exp(0.5 * vae.encoder[-1].weight.data).cpu().numpy()
    }

    return distribution_info

def send_distribution_info(distribution_info: Dict) -> None:
    # Implement the logic to send the distribution information to the global server
    # This can involve using a network protocol, a message queue, or any other communication mechanism
    # to send the information

    # Example implementation:
    # Send the distribution information to the global server using a network protocol
    # For example, you can use the `socket` module to send the information over a network
    # or use a message queue like `RabbitMQ` to send the information
    pass


# Define logic to generate augmented data using Truncated Normal distribution
def generate_augmented_data(vae: VAE, distribution_info: Dict) -> torch.Tensor:
    # Generate augmented data using Truncated Normal distribution
    mean = distribution_info["mean"]
    std = distribution_info["std"]
    a = (0 - mean) / std
    b = np.inf
    augmented_data = torch.from_numpy(truncnorm.rvs(a, b, loc=mean, scale=std, size=(64, vae.z_dim))).float()
    return augmented_data





def federated_train(net: nn.Module, vae: VAE, trainloaders: Dict[str, DataLoader], trainloader: DataLoader, valloader: DataLoader, epochs: int) -> None:
    for epoch in range(epochs):
        for client_id, client_trainloader in trainloaders.items():
            # Train VAE on client data
            vae_train(vae, client_trainloader, epochs=10)

            # Share distribution information with global server
            distribution_info = get_distribution_info(vae)
            send_distribution_info(distribution_info)

            # Receive distribution information from other clients
            other_distribution_info = receive_distribution_info()

            # Generate augmented data using received distribution information
            augmented_data = generate_augmented_data(vae, other_distribution_info)

            # Train classification model using local, augmented, and validation data
            train(net, client_trainloader, valloader, epochs=10)

            # Send model updates to global server
            send_model_update(client_id, net.state_dict())


def receive_distribution_info() -> Dict:
    # Implement the logic to receive the distribution information from the global server
    # This can involve using a network protocol, a message queue, or any other communication mechanism
    # to receive the information

    # Example implementation:
    # Receive the distribution information from the global server using a network protocol
    # For example, you can use the `socket` module to receive the information over a network
    # or use a message queue like `RabbitMQ` to receive the information
    distribution_info = {
        "mean": np.zeros(20),
        "std": np.ones(20)
    }
    return distribution_info

def send_model_update(client_id: str, model_update: Dict) -> None:
    # Implement the logic to send the model update to the global server
    # This can involve using a network protocol, a message queue, or any other communication mechanism
    # to send the model update

    # Example implementation:
    # Send the model update to the global server using a network protocol
    # For example, you can use the `socket` module to send the information over a network
    # or use a message queue like `RabbitMQ` to send the information
    pass

# Define global server procedure
# Define global server procedure
def global_server() -> None:
    net = Net()
    x_dim = 3 * 32 * 32  # CIFAR-10 input size
    h_dim = 400
    z_dim = 20
    vae = VAE(x_dim, h_dim, z_dim)  # Initialize VAE object with required arguments

    # Initialize clients
    num_clients = 5  # Define the number of clients
    clients = initialize_clients(train_set, transform, batch_size=128, num_clients=num_clients)

    # Train model using FedDIS
    federated_train(net, vae, clients, trainloader, valloader, epochs=10)

    # Evaluate final model
    test_accuracy = evaluate(net, testloader)
    print(f"Test Accuracy: {test_accuracy:.2f}%")

if __name__ == "__main__":
    global_server()


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:02<00:00, 74705682.27it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified
Epoch [1/10], Training Loss: 2.305, Validation Accuracy: 10.19%
Epoch [2/10], Training Loss: 2.305, Validation Accuracy: 10.19%
Epoch [3/10], Training Loss: 2.304, Validation Accuracy: 10.19%
Epoch [4/10], Training Loss: 2.303, Validation Accuracy: 10.19%
Epoch [5/10], Training Loss: 2.302, Validation Accuracy: 10.19%
Epoch [6/10], Training Loss: 2.302, Validation Accuracy: 10.19%
Epoch [7/10], Training Loss: 2.301, Validation Accuracy: 10.19%
Epoch [8/10], Training Loss: 2.300, Validation Accuracy: 10.19%
Epoch [9/10], Training Loss: 2.299, Validation Accuracy: 10.19%
Epoch [10/10], Training Loss: 2.298, Validation Accuracy: 10.19%
Epoch [1/10], Training Loss: 2.298, Validation Accuracy: 10.23%
Epoch [2/10], Training Loss: 2.296, Validation Accuracy: 11.73%
Epoch [3/10], Training Loss: 2.293, Validation Accuracy: 13.83%
Epoch [4/10], Training Loss: 2.290, Validation Accuracy: 15.32%
Epoch [5/10], 

In [None]:
from scipy.stats import truncnorm
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
from torch import nn
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from typing import Dict, Tuple
from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score

# Define VAE model
class VAE(nn.Module):
    def __init__(self, x_dim, h_dim, z_dim):
        super(VAE, self).__init__()
        self.x_dim = x_dim
        self.h_dim = h_dim
        self.z_dim = z_dim

        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, 4, stride=2, padding=1),  # 16x16
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2, padding=1),  # 8x8
            nn.ReLU(),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),  # 4x4
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(2048, 256),
            nn.ReLU(),
            nn.Linear(256, 2*z_dim)
        )

        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(z_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 2048),
            nn.ReLU(),
            nn.Unflatten(1, (128, 4, 4)),
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),  # 8x8
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1),  # 16x16
            nn.ReLU(),
            nn.ConvTranspose2d(32, 3, 4, stride=2, padding=1),  # 32x32
            nn.Sigmoid()
        )

    def encode(self, x):
        h = self.encoder(x)
        mu, logvar = h[:, :self.z_dim], h[:, self.z_dim:]
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        return self.decoder(z)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar


# Define VAE training procedure
def vae_train(vae: VAE, trainloader: DataLoader, epochs: int) -> None:
    optimizer = torch.optim.Adam(vae.parameters(), lr=1e-3)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)

    for epoch in range(epochs):
        vae.train()
        for i, data in enumerate(trainloader, 0):
            inputs, _ = data
            optimizer.zero_grad()
            recon_x, mu, logvar = vae(inputs)
            loss = vae_loss(recon_x, inputs, mu, logvar)
            loss.backward()
            optimizer.step()

        scheduler.step()


# Define classification model
class Net(nn.Module):
    def __init__(self) -> None:
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.pool(nn.functional.relu(self.conv1(x)))
        x = self.pool(nn.functional.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = nn.functional.relu(self.fc1(x))
        x = nn.functional.relu(self.fc2(x))
        x = self.fc3(x)
        return x


# Load CIFAR10 dataset
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)

full_dataset = CIFAR10(root="./data", train=True, download=True, transform=transform)
test_set = CIFAR10(root="./data", train=False, download=True, transform=transform)

# Split dataset into training and validation sets
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_set, val_set = torch.utils.data.random_split(full_dataset, [train_size, val_size])

# Create training and validation loaders
trainloader = DataLoader(train_set, batch_size=128, shuffle=True, num_workers=2)
valloader = DataLoader(val_set, batch_size=128, shuffle=False, num_workers=2)
testloader = DataLoader(test_set, batch_size=128, shuffle=False, num_workers=2)



# Define VAE loss function
def vae_loss(recon_x, x, mu, logvar):
    BCE = nn.functional.mse_loss(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

# Define training procedure for classification model
def train(net: nn.Module, trainloader: DataLoader, valloader: DataLoader, epochs: int) -> None:
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)

    best_acc = 0.0

    for epoch in range(epochs):
        # Training loop
        net.train()
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        scheduler.step()

        # Validation loop
        net.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for data in valloader:
                images, labels = data
                outputs = net(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        val_acc = 100 * correct / total
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(net.state_dict(), 'best_model.pth')

        print(f"Epoch [{epoch+1}/{epochs}], Training Loss: {running_loss / (i+1):.3f}, Validation Accuracy: {val_acc:.2f}%")

# Define evaluation procedure
def evaluate(net: nn.Module, testloader: DataLoader) -> Tuple[float, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    net.eval()
    correct = 0
    total = 0
    all_labels = []
    all_predictions = []
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            all_labels.extend(labels.cpu().numpy())
            all_predictions.extend(predicted.cpu().numpy())

    accuracy = 100 * correct / total

    # Calculate confusion matrix
    cm = confusion_matrix(all_labels, all_predictions)
    print(f"Confusion Matrix:\n{cm}")

    # Calculate TP, FP, TN, FN for each class
    tp = np.diag(cm)
    fp = cm.sum(axis=0) - tp
    fn = cm.sum(axis=1) - tp
    tn = cm.sum() - (fp + fn + tp)

    # Calculate precision, recall, and F1 score for each class
    precision = precision_score(all_labels, all_predictions, average=None)
    recall = recall_score(all_labels, all_predictions, average=None)
    f1 = f1_score(all_labels, all_predictions, average=None)

    return accuracy, tp, fp, tn, fn, precision, recall, f1

# Initialize clients
def initialize_clients(trainset, transform, batch_size, num_clients):
    clients = {}
    for i in range(num_clients):
        client_trainset = torch.utils.data.Subset(trainset, range(i * len(trainset) // num_clients, (i + 1) * len(trainset) // num_clients))
        client_trainloader = torch.utils.data.DataLoader(client_trainset, batch_size=batch_size, shuffle=True)
        clients[f"client_{i}"] = client_trainloader
    return clients

def get_distribution_info(vae: VAE) -> Dict:
    # Implement the logic to get the distribution information from the VAE model
    # This can involve extracting the mean and standard deviation of the latent space
    # and sending this information to the global server for use in generating augmented data

    # Example implementation:
    distribution_info = {
        "mean": vae.encoder[-1].bias.data.cpu().numpy(),
        "std": torch.exp(0.5 * vae.encoder[-1].weight.data).cpu().numpy()
    }

    return distribution_info

def send_distribution_info(distribution_info: Dict) -> None:
    # Implement the logic to send the distribution information to the global server
    # This can involve using a network protocol, a message queue, or any other communication mechanism
    # to send the information

    # Example implementation:
    # Send the distribution information to the global server using a network protocol
    # For example, you can use the `socket` module to send the information over a network
    # or use a message queue like `RabbitMQ` to send the information
    pass

# Define logic to generate augmented data using Truncated Normal distribution
def generate_augmented_data(vae: VAE, distribution_info: Dict) -> torch.Tensor:
    # Generate augmented data using Truncated Normal distribution
    mean = distribution_info["mean"]
    std = distribution_info["std"]
    a = (0 - mean) / std
    b = np.inf
    augmented_data = torch.from_numpy(truncnorm.rvs(a, b, loc=mean, scale=std, size=(64, vae.z_dim))).float()
    return augmented_data

def federated_train(net: nn.Module, vae: VAE, trainloaders: Dict[str, DataLoader], trainloader: DataLoader, valloader: DataLoader, epochs: int) -> None:
    for epoch in range(epochs):
        for client_id, client_trainloader in trainloaders.items():
            # Train VAE on client data
            vae_train(vae, client_trainloader, epochs=10)

            # Share distribution information with global server
            distribution_info = get_distribution_info(vae)
            send_distribution_info(distribution_info)

            # Receive distribution information from other clients
            other_distribution_info = receive_distribution_info()

            # Generate augmented data using received distribution information
            augmented_data = generate_augmented_data(vae, other_distribution_info)

            # Train classification model using local, augmented, and validation data
            train(net, client_trainloader, valloader, epochs=10)

            # Send model updates to global server
            send_model_update(client_id, net.state_dict())

def receive_distribution_info() -> Dict:
    # Implement the logic to receive the distribution information from the global server
    # This can involve using a network protocol, a message queue, or any other communication mechanism
    # to receive the information

    # Example implementation:
    # Receive the distribution information from the global server using a network protocol
    # For example, you can use the `socket` module to receive the information over a network
    # or use a message queue like `RabbitMQ` to receive the information
    distribution_info = {
        "mean": np.zeros(20),
        "std": np.ones(20)
    }
    return distribution_info

def send_model_update(client_id: str, model_update: Dict) -> None:
    # Implement the logic to send the model update to the global server
    # This can involve using a network protocol, a message queue, or any other communication mechanism
    # to send the model update

    # Example implementation:
    # Send the model update to the global server using a network protocol
    # For example, you can use the `socket` module to send the information over a network
    # or use a message queue like `RabbitMQ` to send the information
    pass

# Define global server procedure
def global_server() -> None:
    net = Net()
    x_dim = 3 * 32 * 32  # CIFAR-10 input size
    h_dim = 400
    z_dim = 20
    vae = VAE(x_dim, h_dim, z_dim)  # Initialize VAE object with required arguments

    # Initialize clients
    num_clients = 5  # Define the number of clients
    clients = initialize_clients(train_set, transform, batch_size=128, num_clients=num_clients)

    # Train model using FedDIS
    federated_train(net, vae, clients, trainloader, valloader, epochs=10)

    # Evaluate final model
    test_accuracy, tp, fp, tn, fn, precision, recall, f1 = evaluate(net, testloader)
    print(f"Test Accuracy: {test_accuracy:.2f}%")

    print("True Positives (TP):", tp)
    print("False Positives (FP):", fp)
    print("True Negatives (TN):", tn)
    print("False Negatives (FN):", fn)
    print(":", fn+tn+tp+fp)
    print("Precision:", precision)
    print("Recall:", recall)
    print("F1 Score:", f1)

if __name__ == "__main__":
    global_server()


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:02<00:00, 63859793.39it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified
Epoch [1/10], Training Loss: 2.303, Validation Accuracy: 10.75%
Epoch [2/10], Training Loss: 2.301, Validation Accuracy: 12.22%
Epoch [3/10], Training Loss: 2.299, Validation Accuracy: 14.96%
Epoch [4/10], Training Loss: 2.297, Validation Accuracy: 14.55%
Epoch [5/10], Training Loss: 2.294, Validation Accuracy: 13.37%
Epoch [6/10], Training Loss: 2.290, Validation Accuracy: 14.07%
Epoch [7/10], Training Loss: 2.285, Validation Accuracy: 16.00%
Epoch [8/10], Training Loss: 2.278, Validation Accuracy: 16.99%
Epoch [9/10], Training Loss: 2.269, Validation Accuracy: 16.84%
Epoch [10/10], Training Loss: 2.254, Validation Accuracy: 17.89%
Epoch [1/10], Training Loss: 2.239, Validation Accuracy: 19.78%
Epoch [2/10], Training Loss: 2.207, Validation Accuracy: 21.91%
Epoch [3/10], Training Loss: 2.165, Validation Accuracy: 23.47%
Epoch [4/10], Training Loss: 2.123, Validation Accuracy: 24.92%
Epoch [5/10], 

In [None]:
import re

# Your provided text
log = """
Epoch [1/10], Training Loss: 2.303, Validation Accuracy: 10.75%
Epoch [2/10], Training Loss: 2.301, Validation Accuracy: 12.22%
Epoch [3/10], Training Loss: 2.299, Validation Accuracy: 14.96%
Epoch [4/10], Training Loss: 2.297, Validation Accuracy: 14.55%
Epoch [5/10], Training Loss: 2.294, Validation Accuracy: 13.37%
Epoch [6/10], Training Loss: 2.290, Validation Accuracy: 14.07%
Epoch [7/10], Training Loss: 2.285, Validation Accuracy: 16.00%
Epoch [8/10], Training Loss: 2.278, Validation Accuracy: 16.99%
Epoch [9/10], Training Loss: 2.269, Validation Accuracy: 16.84%
Epoch [10/10], Training Loss: 2.254, Validation Accuracy: 17.89%
Epoch [1/10], Training Loss: 2.239, Validation Accuracy: 19.78%
Epoch [2/10], Training Loss: 2.207, Validation Accuracy: 21.91%
Epoch [3/10], Training Loss: 2.165, Validation Accuracy: 23.47%
Epoch [4/10], Training Loss: 2.123, Validation Accuracy: 24.92%
Epoch [5/10], Training Loss: 2.088, Validation Accuracy: 26.15%
Epoch [6/10], Training Loss: 2.058, Validation Accuracy: 26.80%
Epoch [7/10], Training Loss: 2.028, Validation Accuracy: 28.03%
Epoch [8/10], Training Loss: 1.998, Validation Accuracy: 28.58%
Epoch [9/10], Training Loss: 1.972, Validation Accuracy: 29.07%
Epoch [10/10], Training Loss: 1.953, Validation Accuracy: 29.60%
Epoch [1/10], Training Loss: 1.938, Validation Accuracy: 29.41%
Epoch [2/10], Training Loss: 1.925, Validation Accuracy: 30.06%
Epoch [3/10], Training Loss: 1.908, Validation Accuracy: 30.35%
Epoch [4/10], Training Loss: 1.897, Validation Accuracy: 30.56%
Epoch [5/10], Training Loss: 1.885, Validation Accuracy: 31.46%
Epoch [6/10], Training Loss: 1.870, Validation Accuracy: 32.00%
Epoch [7/10], Training Loss: 1.857, Validation Accuracy: 32.21%
Epoch [8/10], Training Loss: 1.843, Validation Accuracy: 33.06%
Epoch [9/10], Training Loss: 1.829, Validation Accuracy: 33.21%
Epoch [10/10], Training Loss: 1.809, Validation Accuracy: 33.59%
Epoch [1/10], Training Loss: 1.817, Validation Accuracy: 33.50%
Epoch [2/10], Training Loss: 1.799, Validation Accuracy: 34.90%
Epoch [3/10], Training Loss: 1.779, Validation Accuracy: 34.94%
Epoch [4/10], Training Loss: 1.764, Validation Accuracy: 36.12%
Epoch [5/10], Training Loss: 1.742, Validation Accuracy: 36.42%
Epoch [6/10], Training Loss: 1.722, Validation Accuracy: 36.51%
Epoch [7/10], Training Loss: 1.700, Validation Accuracy: 37.89%
Epoch [8/10], Training Loss: 1.681, Validation Accuracy: 38.89%
Epoch [9/10], Training Loss: 1.664, Validation Accuracy: 39.38%
Epoch [10/10], Training Loss: 1.645, Validation Accuracy: 39.29%
Epoch [1/10], Training Loss: 1.678, Validation Accuracy: 40.31%
Epoch [2/10], Training Loss: 1.656, Validation Accuracy: 40.58%
Epoch [3/10], Training Loss: 1.643, Validation Accuracy: 41.17%
Epoch [4/10], Training Loss: 1.631, Validation Accuracy: 41.30%
Epoch [5/10], Training Loss: 1.617, Validation Accuracy: 41.20%
Epoch [6/10], Training Loss: 1.601, Validation Accuracy: 43.14%
Epoch [7/10], Training Loss: 1.592, Validation Accuracy: 42.31%
Epoch [8/10], Training Loss: 1.574, Validation Accuracy: 42.53%
Epoch [9/10], Training Loss: 1.570, Validation Accuracy: 43.27%
Epoch [10/10], Training Loss: 1.556, Validation Accuracy: 43.35%
Epoch [1/10], Training Loss: 1.550, Validation Accuracy: 44.38%
Epoch [2/10], Training Loss: 1.544, Validation Accuracy: 43.99%
Epoch [3/10], Training Loss: 1.525, Validation Accuracy: 44.24%
Epoch [4/10], Training Loss: 1.510, Validation Accuracy: 45.09%
Epoch [5/10], Training Loss: 1.502, Validation Accuracy: 44.40%
Epoch [6/10], Training Loss: 1.495, Validation Accuracy: 44.95%
Epoch [7/10], Training Loss: 1.484, Validation Accuracy: 44.86%
Epoch [8/10], Training Loss: 1.476, Validation Accuracy: 45.87%
Epoch [9/10], Training Loss: 1.466, Validation Accuracy: 45.88%
Epoch [10/10], Training Loss: 1.462, Validation Accuracy: 45.97%
Epoch [1/10], Training Loss: 1.495, Validation Accuracy: 46.04%
Epoch [2/10], Training Loss: 1.484, Validation Accuracy: 44.78%
Epoch [3/10], Training Loss: 1.467, Validation Accuracy: 46.37%
Epoch [4/10], Training Loss: 1.456, Validation Accuracy: 46.48%
Epoch [5/10], Training Loss: 1.444, Validation Accuracy: 45.00%
Epoch [6/10], Training Loss: 1.438, Validation Accuracy: 46.11%
Epoch [7/10], Training Loss: 1.430, Validation Accuracy: 47.73%
Epoch [8/10], Training Loss: 1.417, Validation Accuracy: 46.55%
Epoch [9/10], Training Loss: 1.412, Validation Accuracy: 47.39%
Epoch [10/10], Training Loss: 1.401, Validation Accuracy: 47.55%
Epoch [1/10], Training Loss: 1.447, Validation Accuracy: 46.85%
Epoch [2/10], Training Loss: 1.444, Validation Accuracy: 47.81%
Epoch [3/10], Training Loss: 1.418, Validation Accuracy: 48.53%
Epoch [4/10], Training Loss: 1.403, Validation Accuracy: 47.72%
Epoch [5/10], Training Loss: 1.389, Validation Accuracy: 48.32%
Epoch [6/10], Training Loss: 1.387, Validation Accuracy: 48.97%
Epoch [7/10], Training Loss: 1.386, Validation Accuracy: 48.08%
Epoch [8/10], Training Loss: 1.378, Validation Accuracy: 49.20%
Epoch [9/10], Training Loss: 1.361, Validation Accuracy: 48.95%
Epoch [10/10], Training Loss: 1.350, Validation Accuracy: 48.90%
Epoch [1/10], Training Loss: 1.399, Validation Accuracy: 48.67%
Epoch [2/10], Training Loss: 1.404, Validation Accuracy: 49.39%
Epoch [3/10], Training Loss: 1.384, Validation Accuracy: 49.50%
Epoch [4/10], Training Loss: 1.366, Validation Accuracy: 50.09%
Epoch [5/10], Training Loss: 1.355, Validation Accuracy: 49.58%
Epoch [6/10], Training Loss: 1.342, Validation Accuracy: 49.85%
Epoch [7/10], Training Loss: 1.333, Validation Accuracy: 49.38%
Epoch [8/10], Training Loss: 1.332, Validation Accuracy: 49.95%
Epoch [9/10], Training Loss: 1.325, Validation Accuracy: 49.52%
Epoch [10/10], Training Loss: 1.317, Validation Accuracy: 49.48%
Epoch [1/10], Training Loss: 1.402, Validation Accuracy: 50.61%
Epoch [2/10], Training Loss: 1.385, Validation Accuracy: 49.86%
Epoch [3/10], Training Loss: 1.363, Validation Accuracy: 51.39%
Epoch [4/10], Training Loss: 1.350, Validation Accuracy: 50.31%
Epoch [5/10], Training Loss: 1.338, Validation Accuracy: 51.06%
Epoch [6/10], Training Loss: 1.331, Validation Accuracy: 51.26%
Epoch [7/10], Training Loss: 1.320, Validation Accuracy: 51.25%
Epoch [8/10], Training Loss: 1.308, Validation Accuracy: 51.34%
Epoch [9/10], Training Loss: 1.303, Validation Accuracy: 51.73%
Epoch [10/10], Training Loss: 1.295, Validation Accuracy: 51.03%
Epoch [1/10], Training Loss: 1.344, Validation Accuracy: 51.80%
Epoch [2/10], Training Loss: 1.315, Validation Accuracy: 51.69%
Epoch [3/10], Training Loss: 1.306, Validation Accuracy: 51.73%
Epoch [4/10], Training Loss: 1.291, Validation Accuracy: 52.09%
Epoch [5/10], Training Loss: 1.286, Validation Accuracy: 51.65%
Epoch [6/10], Training Loss: 1.275, Validation Accuracy: 52.54%
Epoch [7/10], Training Loss: 1.261, Validation Accuracy: 52.99%
Epoch [8/10], Training Loss: 1.250, Validation Accuracy: 51.97%
Epoch [9/10], Training Loss: 1.252, Validation Accuracy: 52.91%
Epoch [10/10], Training Loss: 1.230, Validation Accuracy: 52.17%
Epoch [1/10], Training Loss: 1.321, Validation Accuracy: 52.32%
Epoch [2/10], Training Loss: 1.297, Validation Accuracy: 52.50%
Epoch [3/10], Training Loss: 1.279, Validation Accuracy: 51.98%
Epoch [4/10], Training Loss: 1.265, Validation Accuracy: 53.10%
Epoch [5/10], Training Loss: 1.257, Validation Accuracy: 53.31%
Epoch [6/10], Training Loss: 1.246, Validation Accuracy: 53.55%
Epoch [7/10], Training Loss: 1.223, Validation Accuracy: 52.59%
Epoch [8/10], Training Loss: 1.225, Validation Accuracy: 53.26%
Epoch [9/10], Training Loss: 1.213, Validation Accuracy: 53.29%
Epoch [10/10], Training Loss: 1.210, Validation Accuracy: 53.37%
Epoch [1/10], Training Loss: 1.289, Validation Accuracy: 53.89%
Epoch [2/10], Training Loss: 1.261, Validation Accuracy: 53.18%
Epoch [3/10], Training Loss: 1.244, Validation Accuracy: 53.82%
Epoch [4/10], Training Loss: 1.229, Validation Accuracy: 54.18%
Epoch [5/10], Training Loss: 1.213, Validation Accuracy: 53.99%
Epoch [6/10], Training Loss: 1.206, Validation Accuracy: 54.12%
Epoch [7/10], Training Loss: 1.204, Validation Accuracy: 54.25%
Epoch [8/10], Training Loss: 1.185, Validation Accuracy: 54.70%
Epoch [9/10], Training Loss: 1.173, Validation Accuracy: 54.14%
Epoch [10/10], Training Loss: 1.168, Validation Accuracy: 54.66%
Epoch [1/10], Training Loss: 1.260, Validation Accuracy: 53.79%
Epoch [2/10], Training Loss: 1.237, Validation Accuracy: 54.48%
Epoch [3/10], Training Loss: 1.227, Validation Accuracy: 55.40%
Epoch [4/10], Training Loss: 1.206, Validation Accuracy: 55.12%
Epoch [5/10], Training Loss: 1.190, Validation Accuracy: 54.17%
Epoch [6/10], Training Loss: 1.186, Validation Accuracy: 55.28%
Epoch [7/10], Training Loss: 1.174, Validation Accuracy: 55.13%
Epoch [8/10], Training Loss: 1.166, Validation Accuracy: 54.98%
Epoch [9/10], Training Loss: 1.155, Validation Accuracy: 54.58%
Epoch [10/10], Training Loss: 1.141, Validation Accuracy: 54.91%
Epoch [1/10], Training Loss: 1.257, Validation Accuracy: 55.20%
Epoch [2/10], Training Loss: 1.225, Validation Accuracy: 55.63%
Epoch [3/10], Training Loss: 1.210, Validation Accuracy: 55.18%
Epoch [4/10], Training Loss: 1.188, Validation Accuracy: 54.97%
Epoch [5/10], Training Loss: 1.185, Validation Accuracy: 54.84%
Epoch [6/10], Training Loss: 1.169, Validation Accuracy: 54.73%
Epoch [7/10], Training Loss: 1.154, Validation Accuracy: 55.09%
Epoch [8/10], Training Loss: 1.148, Validation Accuracy: 53.81%
Epoch [9/10], Training Loss: 1.136, Validation Accuracy: 54.99%
Epoch [10/10], Training Loss: 1.128, Validation Accuracy: 55.13%
Epoch [1/10], Training Loss: 1.223, Validation Accuracy: 54.77%
Epoch [2/10], Training Loss: 1.190, Validation Accuracy: 55.85%
Epoch [3/10], Training Loss: 1.167, Validation Accuracy: 55.71%
Epoch [4/10], Training Loss: 1.147, Validation Accuracy: 55.74%
Epoch [5/10], Training Loss: 1.137, Validation Accuracy: 55.47%
Epoch [6/10], Training Loss: 1.125, Validation Accuracy: 55.80%
Epoch [7/10], Training Loss: 1.111, Validation Accuracy: 56.09%
Epoch [8/10], Training Loss: 1.109, Validation Accuracy: 55.88%
Epoch [9/10], Training Loss: 1.094, Validation Accuracy: 56.12%
Epoch [10/10], Training Loss: 1.077, Validation Accuracy: 55.99%
Epoch [1/10], Training Loss: 1.206, Validation Accuracy: 55.06%
Epoch [2/10], Training Loss: 1.176, Validation Accuracy: 56.26%
Epoch [3/10], Training Loss: 1.155, Validation Accuracy: 54.83%
Epoch [4/10], Training Loss: 1.135, Validation Accuracy: 56.65%
Epoch [5/10], Training Loss: 1.115, Validation Accuracy: 55.80%
Epoch [6/10], Training Loss: 1.108, Validation Accuracy: 56.58%
Epoch [7/10], Training Loss: 1.096, Validation Accuracy: 57.01%
Epoch [8/10], Training Loss: 1.089, Validation Accuracy: 56.65%
Epoch [9/10], Training Loss: 1.066, Validation Accuracy: 55.41%
Epoch [10/10], Training Loss: 1.062, Validation Accuracy: 56.28%
Epoch [1/10], Training Loss: 1.183, Validation Accuracy: 56.32%
Epoch [2/10], Training Loss: 1.147, Validation Accuracy: 55.79%
Epoch [3/10], Training Loss: 1.135, Validation Accuracy: 57.08%
Epoch [4/10], Training Loss: 1.106, Validation Accuracy: 57.83%
Epoch [5/10], Training Loss: 1.085, Validation Accuracy: 56.70%
Epoch [6/10], Training Loss: 1.088, Validation Accuracy: 56.37%
Epoch [7/10], Training Loss: 1.065, Validation Accuracy: 57.17%
Epoch [8/10], Training Loss: 1.047, Validation Accuracy: 57.15%
Epoch [9/10], Training Loss: 1.041, Validation Accuracy: 57.08%
Epoch [10/10], Training Loss: 1.030, Validation Accuracy: 56.99%
Epoch [1/10], Training Loss: 1.177, Validation Accuracy: 57.65%
Epoch [2/10], Training Loss: 1.133, Validation Accuracy: 56.81%
Epoch [3/10], Training Loss: 1.104, Validation Accuracy: 56.95%
Epoch [4/10], Training Loss: 1.090, Validation Accuracy: 57.00%
Epoch [5/10], Training Loss: 1.073, Validation Accuracy: 57.57%
Epoch [6/10], Training Loss: 1.068, Validation Accuracy: 57.33%
Epoch [7/10], Training Loss: 1.035, Validation Accuracy: 57.66%
Epoch [8/10], Training Loss: 1.024, Validation Accuracy: 56.14%
Epoch [9/10], Training Loss: 1.016, Validation Accuracy: 57.50%
Epoch [10/10], Training Loss: 1.008, Validation Accuracy: 57.02%
Epoch [1/10], Training Loss: 1.157, Validation Accuracy: 57.44%
Epoch [2/10], Training Loss: 1.120, Validation Accuracy: 58.49%
Epoch [3/10], Training Loss: 1.095, Validation Accuracy: 57.27%
Epoch [4/10], Training Loss: 1.078, Validation Accuracy: 58.17%
Epoch [5/10], Training Loss: 1.055, Validation Accuracy: 57.95%
Epoch [6/10], Training Loss: 1.040, Validation Accuracy: 57.58%
Epoch [7/10], Training Loss: 1.028, Validation Accuracy: 58.01%
Epoch [8/10], Training Loss: 1.013, Validation Accuracy: 57.34%
Epoch [9/10], Training Loss: 0.998, Validation Accuracy: 57.41%
Epoch [10/10], Training Loss: 0.992, Validation Accuracy: 55.79%
Epoch [1/10], Training Loss: 1.136, Validation Accuracy: 57.63%
Epoch [2/10], Training Loss: 1.087, Validation Accuracy: 57.93%
Epoch [3/10], Training Loss: 1.067, Validation Accuracy: 58.69%
Epoch [4/10], Training Loss: 1.039, Validation Accuracy: 57.28%
Epoch [5/10], Training Loss: 1.021, Validation Accuracy: 57.58%
Epoch [6/10], Training Loss: 1.007, Validation Accuracy: 57.83%
Epoch [7/10], Training Loss: 0.989, Validation Accuracy: 58.23%
Epoch [8/10], Training Loss: 0.979, Validation Accuracy: 58.16%
Epoch [9/10], Training Loss: 0.958, Validation Accuracy: 58.70%
Epoch [10/10], Training Loss: 0.954, Validation Accuracy: 58.29%
Epoch [1/10], Training Loss: 1.113, Validation Accuracy: 58.54%
Epoch [2/10], Training Loss: 1.065, Validation Accuracy: 58.43%
Epoch [3/10], Training Loss: 1.045, Validation Accuracy: 58.80%
Epoch [4/10], Training Loss: 1.015, Validation Accuracy: 58.56%
Epoch [5/10], Training Loss: 1.001, Validation Accuracy: 58.64%
Epoch [6/10], Training Loss: 0.980, Validation Accuracy: 58.27%
Epoch [7/10], Training Loss: 0.970, Validation Accuracy: 57.46%
Epoch [8/10], Training Loss: 0.953, Validation Accuracy: 58.32%
Epoch [9/10], Training Loss: 0.943, Validation Accuracy: 58.43%
Epoch [10/10], Training Loss: 0.927, Validation Accuracy: 58.90%
Epoch [1/10], Training Loss: 1.099, Validation Accuracy: 57.97%
Epoch [2/10], Training Loss: 1.059, Validation Accuracy: 58.53%
Epoch [3/10], Training Loss: 1.021, Validation Accuracy: 58.74%
Epoch [4/10], Training Loss: 0.996, Validation Accuracy: 58.18%
Epoch [5/10], Training Loss: 0.974, Validation Accuracy: 58.88%
Epoch [6/10], Training Loss: 0.959, Validation Accuracy: 59.06%
Epoch [7/10], Training Loss: 0.942, Validation Accuracy: 59.20%
Epoch [8/10], Training Loss: 0.925, Validation Accuracy: 58.98%
Epoch [9/10], Training Loss: 0.915, Validation Accuracy: 59.17%
Epoch [10/10], Training Loss: 0.898, Validation Accuracy: 59.16%
Epoch [1/10], Training Loss: 1.084, Validation Accuracy: 59.21%
Epoch [2/10], Training Loss: 1.036, Validation Accuracy: 59.11%
Epoch [3/10], Training Loss: 1.004, Validation Accuracy: 59.65%
Epoch [4/10], Training Loss: 0.978, Validation Accuracy: 59.48%
Epoch [5/10], Training Loss: 0.959, Validation Accuracy: 59.29%
Epoch [6/10], Training Loss: 0.945, Validation Accuracy: 59.28%
Epoch [7/10], Training Loss: 0.924, Validation Accuracy: 59.05%
Epoch [8/10], Training Loss: 0.908, Validation Accuracy: 58.99%
Epoch [9/10], Training Loss: 0.906, Validation Accuracy: 59.20%
Epoch [10/10], Training Loss: 0.882, Validation Accuracy: 58.93%
Epoch [1/10], Training Loss: 1.074, Validation Accuracy: 59.00%
Epoch [2/10], Training Loss: 1.032, Validation Accuracy: 59.18%
Epoch [3/10], Training Loss: 0.992, Validation Accuracy: 59.39%
Epoch [4/10], Training Loss: 0.971, Validation Accuracy: 59.28%
Epoch [5/10], Training Loss: 0.943, Validation Accuracy: 59.39%
Epoch [6/10], Training Loss: 0.925, Validation Accuracy: 59.53%
Epoch [7/10], Training Loss: 0.911, Validation Accuracy: 58.83%
Epoch [8/10], Training Loss: 0.899, Validation Accuracy: 58.81%
Epoch [9/10], Training Loss: 0.882, Validation Accuracy: 58.41%
Epoch [10/10], Training Loss: 0.859, Validation Accuracy: 59.05%
Epoch [1/10], Training Loss: 1.046, Validation Accuracy: 59.80%
Epoch [2/10], Training Loss: 0.992, Validation Accuracy: 59.57%
Epoch [3/10], Training Loss: 0.961, Validation Accuracy: 59.66%
Epoch [4/10], Training Loss: 0.933, Validation Accuracy: 60.23%
Epoch [5/10], Training Loss: 0.919, Validation Accuracy: 59.11%
Epoch [6/10], Training Loss: 0.888, Validation Accuracy: 59.25%
Epoch [7/10], Training Loss: 0.876, Validation Accuracy: 59.74%
Epoch [8/10], Training Loss: 0.871, Validation Accuracy: 59.29%
Epoch [9/10], Training Loss: 0.844, Validation Accuracy: 59.95%
Epoch [10/10], Training Loss: 0.831, Validation Accuracy: 59.50%
Epoch [1/10], Training Loss: 1.049, Validation Accuracy: 59.77%
Epoch [2/10], Training Loss: 0.980, Validation Accuracy: 58.98%
Epoch [3/10], Training Loss: 0.954, Validation Accuracy: 59.68%
Epoch [4/10], Training Loss: 0.918, Validation Accuracy: 58.98%
Epoch [5/10], Training Loss: 0.894, Validation Accuracy: 59.79%
Epoch [6/10], Training Loss: 0.870, Validation Accuracy: 59.83%
Epoch [7/10], Training Loss: 0.861, Validation Accuracy: 58.60%
Epoch [8/10], Training Loss: 0.842, Validation Accuracy: 59.04%
Epoch [9/10], Training Loss: 0.826, Validation Accuracy: 59.50%
Epoch [10/10], Training Loss: 0.807, Validation Accuracy: 59.10%
Epoch [1/10], Training Loss: 1.038, Validation Accuracy: 59.71%
Epoch [2/10], Training Loss: 0.968, Validation Accuracy: 58.39%
Epoch [3/10], Training Loss: 0.943, Validation Accuracy: 59.55%
Epoch [4/10], Training Loss: 0.901, Validation Accuracy: 59.53%
Epoch [5/10], Training Loss: 0.876, Validation Accuracy: 59.20%
Epoch [6/10], Training Loss: 0.858, Validation Accuracy: 59.24%
Epoch [7/10], Training Loss: 0.842, Validation Accuracy: 58.87%
Epoch [8/10], Training Loss: 0.816, Validation Accuracy: 59.82%
Epoch [9/10], Training Loss: 0.796, Validation Accuracy: 59.86%
Epoch [10/10], Training Loss: 0.785, Validation Accuracy: 59.79%
Epoch [1/10], Training Loss: 1.025, Validation Accuracy: 58.61%
Epoch [2/10], Training Loss: 0.964, Validation Accuracy: 59.65%
Epoch [3/10], Training Loss: 0.915, Validation Accuracy: 59.85%
Epoch [4/10], Training Loss: 0.895, Validation Accuracy: 60.28%
Epoch [5/10], Training Loss: 0.862, Validation Accuracy: 59.44%
Epoch [6/10], Training Loss: 0.839, Validation Accuracy: 60.17%
Epoch [7/10], Training Loss: 0.831, Validation Accuracy: 60.55%
Epoch [8/10], Training Loss: 0.805, Validation Accuracy: 59.58%
Epoch [9/10], Training Loss: 0.782, Validation Accuracy: 60.13%
Epoch [10/10], Training Loss: 0.769, Validation Accuracy: 58.70%
Epoch [1/10], Training Loss: 1.015, Validation Accuracy: 60.41%
Epoch [2/10], Training Loss: 0.946, Validation Accuracy: 60.13%
Epoch [3/10], Training Loss: 0.910, Validation Accuracy: 59.99%
Epoch [4/10], Training Loss: 0.881, Validation Accuracy: 60.18%
Epoch [5/10], Training Loss: 0.845, Validation Accuracy: 60.26%
Epoch [6/10], Training Loss: 0.814, Validation Accuracy: 60.02%
Epoch [7/10], Training Loss: 0.803, Validation Accuracy: 58.56%
Epoch [8/10], Training Loss: 0.792, Validation Accuracy: 59.95%
Epoch [9/10], Training Loss: 0.773, Validation Accuracy: 59.65%
Epoch [10/10], Training Loss: 0.751, Validation Accuracy: 59.15%
Epoch [1/10], Training Loss: 1.006, Validation Accuracy: 60.29%
Epoch [2/10], Training Loss: 0.924, Validation Accuracy: 58.62%
Epoch [3/10], Training Loss: 0.884, Validation Accuracy: 60.32%
Epoch [4/10], Training Loss: 0.845, Validation Accuracy: 60.46%
Epoch [5/10], Training Loss: 0.823, Validation Accuracy: 60.80%
Epoch [6/10], Training Loss: 0.796, Validation Accuracy: 60.79%
Epoch [7/10], Training Loss: 0.769, Validation Accuracy: 60.37%
Epoch [8/10], Training Loss: 0.757, Validation Accuracy: 59.99%
Epoch [9/10], Training Loss: 0.738, Validation Accuracy: 60.29%
Epoch [10/10], Training Loss: 0.722, Validation Accuracy: 59.66%
Epoch [1/10], Training Loss: 0.987, Validation Accuracy: 60.28%
Epoch [2/10], Training Loss: 0.902, Validation Accuracy: 59.93%
Epoch [3/10], Training Loss: 0.863, Validation Accuracy: 60.77%
Epoch [4/10], Training Loss: 0.827, Validation Accuracy: 59.92%
Epoch [5/10], Training Loss: 0.805, Validation Accuracy: 60.38%
Epoch [6/10], Training Loss: 0.776, Validation Accuracy: 60.88%
Epoch [7/10], Training Loss: 0.763, Validation Accuracy: 60.41%
Epoch [8/10], Training Loss: 0.737, Validation Accuracy: 60.28%
Epoch [9/10], Training Loss: 0.722, Validation Accuracy: 59.76%
Epoch [10/10], Training Loss: 0.710, Validation Accuracy: 60.40%
Epoch [1/10], Training Loss: 0.985, Validation Accuracy: 59.15%
Epoch [2/10], Training Loss: 0.891, Validation Accuracy: 59.80%
Epoch [3/10], Training Loss: 0.841, Validation Accuracy: 60.98%
Epoch [4/10], Training Loss: 0.804, Validation Accuracy: 60.57%
Epoch [5/10], Training Loss: 0.778, Validation Accuracy: 60.90%
Epoch [6/10], Training Loss: 0.762, Validation Accuracy: 60.43%
Epoch [7/10], Training Loss: 0.731, Validation Accuracy: 60.28%
Epoch [8/10], Training Loss: 0.713, Validation Accuracy: 61.00%
Epoch [9/10], Training Loss: 0.693, Validation Accuracy: 60.68%
Epoch [10/10], Training Loss: 0.674, Validation Accuracy: 60.03%
Epoch [1/10], Training Loss: 0.977, Validation Accuracy: 60.70%
Epoch [2/10], Training Loss: 0.893, Validation Accuracy: 60.53%
Epoch [3/10], Training Loss: 0.835, Validation Accuracy: 61.15%
Epoch [4/10], Training Loss: 0.795, Validation Accuracy: 60.73%
Epoch [5/10], Training Loss: 0.769, Validation Accuracy: 59.67%
Epoch [6/10], Training Loss: 0.748, Validation Accuracy: 60.27%
Epoch [7/10], Training Loss: 0.727, Validation Accuracy: 60.41%
Epoch [8/10], Training Loss: 0.710, Validation Accuracy: 60.49%
Epoch [9/10], Training Loss: 0.685, Validation Accuracy: 60.18%
Epoch [10/10], Training Loss: 0.660, Validation Accuracy: 60.43%
Epoch [1/10], Training Loss: 0.958, Validation Accuracy: 60.67%
Epoch [2/10], Training Loss: 0.870, Validation Accuracy: 60.57%
Epoch [3/10], Training Loss: 0.818, Validation Accuracy: 60.83%
Epoch [4/10], Training Loss: 0.787, Validation Accuracy: 60.69%
Epoch [5/10], Training Loss: 0.748, Validation Accuracy: 60.85%
Epoch [6/10], Training Loss: 0.723, Validation Accuracy: 60.49%
Epoch [7/10], Training Loss: 0.708, Validation Accuracy: 59.95%
Epoch [8/10], Training Loss: 0.681, Validation Accuracy: 60.11%
Epoch [9/10], Training Loss: 0.661, Validation Accuracy: 60.29%
Epoch [10/10], Training Loss: 0.638, Validation Accuracy: 60.08%
Epoch [1/10], Training Loss: 0.941, Validation Accuracy: 61.20%
Epoch [2/10], Training Loss: 0.846, Validation Accuracy: 61.34%
Epoch [3/10], Training Loss: 0.796, Validation Accuracy: 61.03%
Epoch [4/10], Training Loss: 0.756, Validation Accuracy: 61.21%
Epoch [5/10], Training Loss: 0.735, Validation Accuracy: 59.42%
Epoch [6/10], Training Loss: 0.702, Validation Accuracy: 60.59%
Epoch [7/10], Training Loss: 0.683, Validation Accuracy: 60.66%
Epoch [8/10], Training Loss: 0.658, Validation Accuracy: 59.84%
Epoch [9/10], Training Loss: 0.638, Validation Accuracy: 61.00%
Epoch [10/10], Training Loss: 0.618, Validation Accuracy: 61.00%
Epoch [1/10], Training Loss: 0.927, Validation Accuracy: 60.70%
Epoch [2/10], Training Loss: 0.836, Validation Accuracy: 61.26%
Epoch [3/10], Training Loss: 0.788, Validation Accuracy: 60.18%
Epoch [4/10], Training Loss: 0.745, Validation Accuracy: 60.51%
Epoch [5/10], Training Loss: 0.710, Validation Accuracy: 60.41%
Epoch [6/10], Training Loss: 0.694, Validation Accuracy: 60.85%
Epoch [7/10], Training Loss: 0.664, Validation Accuracy: 60.59%
Epoch [8/10], Training Loss: 0.644, Validation Accuracy: 61.21%
Epoch [9/10], Training Loss: 0.627, Validation Accuracy: 60.22%
Epoch [10/10], Training Loss: 0.609, Validation Accuracy: 60.47%
Epoch [1/10], Training Loss: 0.930, Validation Accuracy: 60.75%
Epoch [2/10], Training Loss: 0.822, Validation Accuracy: 60.35%
Epoch [3/10], Training Loss: 0.765, Validation Accuracy: 60.58%
Epoch [4/10], Training Loss: 0.721, Validation Accuracy: 60.24%
Epoch [5/10], Training Loss: 0.686, Validation Accuracy: 61.10%
Epoch [6/10], Training Loss: 0.662, Validation Accuracy: 61.18%
Epoch [7/10], Training Loss: 0.623, Validation Accuracy: 59.76%
Epoch [8/10], Training Loss: 0.618, Validation Accuracy: 59.98%
Epoch [9/10], Training Loss: 0.603, Validation Accuracy: 60.49%
Epoch [10/10], Training Loss: 0.576, Validation Accuracy: 59.84%
Epoch [1/10], Training Loss: 0.929, Validation Accuracy: 60.47%
Epoch [2/10], Training Loss: 0.820, Validation Accuracy: 60.51%
Epoch [3/10], Training Loss: 0.751, Validation Accuracy: 61.09%
Epoch [4/10], Training Loss: 0.717, Validation Accuracy: 59.94%
Epoch [5/10], Training Loss: 0.692, Validation Accuracy: 60.14%
Epoch [6/10], Training Loss: 0.652, Validation Accuracy: 60.57%
Epoch [7/10], Training Loss: 0.628, Validation Accuracy: 60.60%
Epoch [8/10], Training Loss: 0.606, Validation Accuracy: 60.14%
Epoch [9/10], Training Loss: 0.585, Validation Accuracy: 60.45%
Epoch [10/10], Training Loss: 0.565, Validation Accuracy: 60.09%
Epoch [1/10], Training Loss: 0.910, Validation Accuracy: 60.25%
Epoch [2/10], Training Loss: 0.800, Validation Accuracy: 61.09%
Epoch [3/10], Training Loss: 0.738, Validation Accuracy: 61.00%
Epoch [4/10], Training Loss: 0.694, Validation Accuracy: 60.14%
Epoch [5/10], Training Loss: 0.664, Validation Accuracy: 61.00%
Epoch [6/10], Training Loss: 0.632, Validation Accuracy: 60.11%
Epoch [7/10], Training Loss: 0.620, Validation Accuracy: 60.27%
Epoch [8/10], Training Loss: 0.582, Validation Accuracy: 59.52%
Epoch [9/10], Training Loss: 0.563, Validation Accuracy: 59.78%
Epoch [10/10], Training Loss: 0.555, Validation Accuracy: 60.49%
Epoch [1/10], Training Loss: 0.894, Validation Accuracy: 61.07%
Epoch [2/10], Training Loss: 0.767, Validation Accuracy: 60.74%
Epoch [3/10], Training Loss: 0.720, Validation Accuracy: 61.31%
Epoch [4/10], Training Loss: 0.668, Validation Accuracy: 61.57%
Epoch [5/10], Training Loss: 0.645, Validation Accuracy: 60.80%
Epoch [6/10], Training Loss: 0.610, Validation Accuracy: 60.83%
Epoch [7/10], Training Loss: 0.587, Validation Accuracy: 60.41%
Epoch [8/10], Training Loss: 0.570, Validation Accuracy: 60.60%
Epoch [9/10], Training Loss: 0.537, Validation Accuracy: 60.70%
Epoch [10/10], Training Loss: 0.516, Validation Accuracy: 60.82%
Epoch [1/10], Training Loss: 0.898, Validation Accuracy: 60.52%
Epoch [2/10], Training Loss: 0.767, Validation Accuracy: 61.40%
Epoch [3/10], Training Loss: 0.713, Validation Accuracy: 60.38%
Epoch [4/10], Training Loss: 0.671, Validation Accuracy: 61.33%
Epoch [5/10], Training Loss: 0.636, Validation Accuracy: 61.09%
Epoch [6/10], Training Loss: 0.598, Validation Accuracy: 60.96%
Epoch [7/10], Training Loss: 0.574, Validation Accuracy: 60.22%
Epoch [8/10], Training Loss: 0.554, Validation Accuracy: 60.86%
Epoch [9/10], Training Loss: 0.532, Validation Accuracy: 60.22%
Epoch [10/10], Training Loss: 0.509, Validation Accuracy: 60.27%
Epoch [1/10], Training Loss: 0.884, Validation Accuracy: 60.15%
Epoch [2/10], Training Loss: 0.762, Validation Accuracy: 60.42%
Epoch [3/10], Training Loss: 0.682, Validation Accuracy: 60.36%
Epoch [4/10], Training Loss: 0.650, Validation Accuracy: 60.48%
Epoch [5/10], Training Loss: 0.603, Validation Accuracy: 60.64%
Epoch [6/10], Training Loss: 0.573, Validation Accuracy: 60.68%
Epoch [7/10], Training Loss: 0.551, Validation Accuracy: 59.55%
Epoch [8/10], Training Loss: 0.520, Validation Accuracy: 59.78%
Epoch [9/10], Training Loss: 0.495, Validation Accuracy: 60.54%
Epoch [10/10], Training Loss: 0.484, Validation Accuracy: 60.68%
Epoch [1/10], Training Loss: 0.884, Validation Accuracy: 59.23%
Epoch [2/10], Training Loss: 0.747, Validation Accuracy: 59.69%
Epoch [3/10], Training Loss: 0.686, Validation Accuracy: 60.60%
Epoch [4/10], Training Loss: 0.643, Validation Accuracy: 60.77%
Epoch [5/10], Training Loss: 0.608, Validation Accuracy: 61.17%
Epoch [6/10], Training Loss: 0.581, Validation Accuracy: 60.36%
Epoch [7/10], Training Loss: 0.541, Validation Accuracy: 60.86%
Epoch [8/10], Training Loss: 0.519, Validation Accuracy: 60.60%
Epoch [9/10], Training Loss: 0.500, Validation Accuracy: 60.67%
Epoch [10/10], Training Loss: 0.475, Validation Accuracy: 60.22%
Epoch [1/10], Training Loss: 0.859, Validation Accuracy: 60.80%
Epoch [2/10], Training Loss: 0.734, Validation Accuracy: 59.52%
Epoch [3/10], Training Loss: 0.667, Validation Accuracy: 60.82%
Epoch [4/10], Training Loss: 0.618, Validation Accuracy: 60.80%
Epoch [5/10], Training Loss: 0.579, Validation Accuracy: 60.60%
Epoch [6/10], Training Loss: 0.548, Validation Accuracy: 60.32%
Epoch [7/10], Training Loss: 0.519, Validation Accuracy: 60.78%
Epoch [8/10], Training Loss: 0.494, Validation Accuracy: 59.67%
Epoch [9/10], Training Loss: 0.478, Validation Accuracy: 60.24%
Epoch [10/10], Training Loss: 0.458, Validation Accuracy: 60.07%
Epoch [1/10], Training Loss: 0.865, Validation Accuracy: 60.55%
Epoch [2/10], Training Loss: 0.711, Validation Accuracy: 61.52%
Epoch [3/10], Training Loss: 0.632, Validation Accuracy: 61.45%
Epoch [4/10], Training Loss: 0.609, Validation Accuracy: 61.00%
Epoch [5/10], Training Loss: 0.565, Validation Accuracy: 60.96%
Epoch [6/10], Training Loss: 0.523, Validation Accuracy: 61.15%
Epoch [7/10], Training Loss: 0.498, Validation Accuracy: 61.22%
Epoch [8/10], Training Loss: 0.474, Validation Accuracy: 60.79%
Epoch [9/10], Training Loss: 0.459, Validation Accuracy: 60.97%
Epoch [10/10], Training Loss: 0.427, Validation Accuracy: 60.79%
Epoch [1/10], Training Loss: 0.885, Validation Accuracy: 60.62%
Epoch [2/10], Training Loss: 0.732, Validation Accuracy: 60.76%
Epoch [3/10], Training Loss: 0.635, Validation Accuracy: 61.17%
Epoch [4/10], Training Loss: 0.592, Validation Accuracy: 61.02%
Epoch [5/10], Training Loss: 0.546, Validation Accuracy: 61.12%
Epoch [6/10], Training Loss: 0.518, Validation Accuracy: 60.72%
Epoch [7/10], Training Loss: 0.489, Validation Accuracy: 60.68%
Epoch [8/10], Training Loss: 0.468, Validation Accuracy: 60.49%
Epoch [9/10], Training Loss: 0.444, Validation Accuracy: 60.34%
Epoch [10/10], Training Loss: 0.417, Validation Accuracy: 60.08%
Epoch [1/10], Training Loss: 0.849, Validation Accuracy: 59.92%
Epoch [2/10], Training Loss: 0.692, Validation Accuracy: 60.47%
Epoch [3/10], Training Loss: 0.612, Validation Accuracy: 60.60%
Epoch [4/10], Training Loss: 0.554, Validation Accuracy: 59.98%
Epoch [5/10], Training Loss: 0.518, Validation Accuracy: 60.13%
Epoch [6/10], Training Loss: 0.487, Validation Accuracy: 60.58%
Epoch [7/10], Training Loss: 0.468, Validation Accuracy: 59.58%
Epoch [8/10], Training Loss: 0.442, Validation Accuracy: 60.56%
Epoch [9/10], Training Loss: 0.411, Validation Accuracy: 60.52%
Epoch [10/10], Training Loss: 0.393, Validation Accuracy: 60.38%
Epoch [1/10], Training Loss: 0.869, Validation Accuracy: 60.28%
Epoch [2/10], Training Loss: 0.678, Validation Accuracy: 60.24%
Epoch [3/10], Training Loss: 0.604, Validation Accuracy: 61.14%
Epoch [4/10], Training Loss: 0.558, Validation Accuracy: 60.30%
Epoch [5/10], Training Loss: 0.528, Validation Accuracy: 59.88%
Epoch [6/10], Training Loss: 0.489, Validation Accuracy: 60.34%
Epoch [7/10], Training Loss: 0.458, Validation Accuracy: 59.62%
Epoch [8/10], Training Loss: 0.436, Validation Accuracy: 60.23%
Epoch [9/10], Training Loss: 0.409, Validation Accuracy: 60.09%
Epoch [10/10], Training Loss: 0.405, Validation Accuracy: 59.79%
Epoch [1/10], Training Loss: 0.834, Validation Accuracy: 60.54%
Epoch [2/10], Training Loss: 0.671, Validation Accuracy: 60.20%
Epoch [3/10], Training Loss: 0.598, Validation Accuracy: 60.30%
Epoch [4/10], Training Loss: 0.535, Validation Accuracy: 59.64%
Epoch [5/10], Training Loss: 0.499, Validation Accuracy: 60.54%
Epoch [6/10], Training Loss: 0.468, Validation Accuracy: 59.64%
Epoch [7/10], Training Loss: 0.432, Validation Accuracy: 59.60%
Epoch [8/10], Training Loss: 0.410, Validation Accuracy: 59.80%
Epoch [9/10], Training Loss: 0.385, Validation Accuracy: 60.16%
Epoch [10/10], Training Loss: 0.365, Validation Accuracy: 59.90%
"""

# Regular expression to find validation accuracies
accuracies = re.findall(r'Validation Accuracy: (\d+\.\d+)%', log)

# Convert accuracies from string to float
accuracies = [float(acc) for acc in accuracies]

# Print accuracies
print("Accuracies:", accuracies)

# Print size of the array
print("Size of array:", len(accuracies))


Accuracies: [10.75, 12.22, 14.96, 14.55, 13.37, 14.07, 16.0, 16.99, 16.84, 17.89, 19.78, 21.91, 23.47, 24.92, 26.15, 26.8, 28.03, 28.58, 29.07, 29.6, 29.41, 30.06, 30.35, 30.56, 31.46, 32.0, 32.21, 33.06, 33.21, 33.59, 33.5, 34.9, 34.94, 36.12, 36.42, 36.51, 37.89, 38.89, 39.38, 39.29, 40.31, 40.58, 41.17, 41.3, 41.2, 43.14, 42.31, 42.53, 43.27, 43.35, 44.38, 43.99, 44.24, 45.09, 44.4, 44.95, 44.86, 45.87, 45.88, 45.97, 46.04, 44.78, 46.37, 46.48, 45.0, 46.11, 47.73, 46.55, 47.39, 47.55, 46.85, 47.81, 48.53, 47.72, 48.32, 48.97, 48.08, 49.2, 48.95, 48.9, 48.67, 49.39, 49.5, 50.09, 49.58, 49.85, 49.38, 49.95, 49.52, 49.48, 50.61, 49.86, 51.39, 50.31, 51.06, 51.26, 51.25, 51.34, 51.73, 51.03, 51.8, 51.69, 51.73, 52.09, 51.65, 52.54, 52.99, 51.97, 52.91, 52.17, 52.32, 52.5, 51.98, 53.1, 53.31, 53.55, 52.59, 53.26, 53.29, 53.37, 53.89, 53.18, 53.82, 54.18, 53.99, 54.12, 54.25, 54.7, 54.14, 54.66, 53.79, 54.48, 55.4, 55.12, 54.17, 55.28, 55.13, 54.98, 54.58, 54.91, 55.2, 55.63, 55.18, 54.97

In [None]:
import numpy as np
from scipy.stats import truncnorm
import torch
import torchvision
import torchvision.transforms as transforms
from torch import nn
from torch.utils.data import Subset, DataLoader
from torchvision.datasets import CIFAR10
from typing import Dict, Tuple
from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score
import random

# Define VAE model
class VAE(nn.Module):
    def __init__(self, x_dim, h_dim, z_dim):
        super(VAE, self).__init__()
        self.x_dim = x_dim
        self.h_dim = h_dim
        self.z_dim = z_dim

        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, 4, stride=2, padding=1),  # 16x16
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2, padding=1),  # 8x8
            nn.ReLU(),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),  # 4x4
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(2048, 256),
            nn.ReLU(),
            nn.Linear(256, 2*z_dim)
        )

        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(z_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 2048),
            nn.ReLU(),
            nn.Unflatten(1, (128, 4, 4)),
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),  # 8x8
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1),  # 16x16
            nn.ReLU(),
            nn.ConvTranspose2d(32, 3, 4, stride=2, padding=1),  # 32x32
            nn.Sigmoid()
        )

    def encode(self, x):
        h = self.encoder(x)
        mu, logvar = h[:, :self.z_dim], h[:, self.z_dim:]
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        return self.decoder(z)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar


# Define VAE training procedure
def vae_train(vae: VAE, trainloader: DataLoader, epochs: int) -> None:
    optimizer = torch.optim.Adam(vae.parameters(), lr=1e-3)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)

    for epoch in range(epochs):
        vae.train()
        for i, data in enumerate(trainloader, 0):
            inputs, _ = data
            optimizer.zero_grad()
            recon_x, mu, logvar = vae(inputs)
            loss = vae_loss(recon_x, inputs, mu, logvar)
            loss.backward()
            optimizer.step()

        scheduler.step()


# Define classification model
class Net(nn.Module):
    def __init__(self) -> None:
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.pool(nn.functional.relu(self.conv1(x)))
        x = self.pool(nn.functional.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = nn.functional.relu(self.fc1(x))
        x = nn.functional.relu(self.fc2(x))
        x = self.fc3(x)
        return x


# Load CIFAR10 dataset
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)

full_dataset = CIFAR10(root="./data", train=True, download=True, transform=transform)
test_set = CIFAR10(root="./data", train=False, download=True, transform=transform)

# Separate dataset by class
class_indices = {i: [] for i in range(10)}  # CIFAR-10 has 10 classes
for idx, (_, label) in enumerate(full_dataset):
    class_indices[label].append(idx)

# Define target count per class, summing to 60,000 with random distribution
class_counts = np.random.multinomial(60000, [0.1] * 10)  # Adjust probabilities if you want specific class biases
print("Random Images per Class:", class_counts)

# Sample indices based on the specified class counts
indices = []
for class_id, count in enumerate(class_counts):
    # Ensure count does not exceed available images
    count = min(count, len(class_indices[class_id]))
    selected_indices = random.sample(class_indices[class_id], count)
    indices.extend(selected_indices)

# Create a custom CIFAR-10 dataset with the sampled indices
custom_dataset = Subset(full_dataset, indices)
# Split dataset into training and validation sets
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_set, val_set = torch.utils.data.random_split(full_dataset, [train_size, val_size])

# Create training and validation loaders
trainloader = DataLoader(train_set, batch_size=128, shuffle=True, num_workers=2)
valloader = DataLoader(val_set, batch_size=128, shuffle=False, num_workers=2)
testloader = DataLoader(test_set, batch_size=128, shuffle=False, num_workers=2)



# Define VAE loss function
def vae_loss(recon_x, x, mu, logvar):
    BCE = nn.functional.mse_loss(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

# Define training procedure for classification model
def train(net: nn.Module, trainloader: DataLoader, valloader: DataLoader, epochs: int) -> None:
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)

    best_acc = 0.0

    for epoch in range(epochs):
        # Training loop
        net.train()
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        scheduler.step()

        # Validation loop
        net.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for data in valloader:
                images, labels = data
                outputs = net(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        val_acc = 100 * correct / total
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(net.state_dict(), 'best_model.pth')

        print(f"Epoch [{epoch+1}/{epochs}], Training Loss: {running_loss / (i+1):.3f}, Validation Accuracy: {val_acc:.2f}%")

# Define evaluation procedure
def evaluate(net: nn.Module, testloader: DataLoader) -> Tuple[float, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    net.eval()
    correct = 0
    total = 0
    all_labels = []
    all_predictions = []
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            all_labels.extend(labels.cpu().numpy())
            all_predictions.extend(predicted.cpu().numpy())

    accuracy = 100 * correct / total

    # Calculate confusion matrix
    cm = confusion_matrix(all_labels, all_predictions)
    print(f"Confusion Matrix:\n{cm}")

    # Calculate TP, FP, TN, FN for each class
    tp = np.diag(cm)
    fp = cm.sum(axis=0) - tp
    fn = cm.sum(axis=1) - tp
    tn = cm.sum() - (fp + fn + tp)

    # Calculate precision, recall, and F1 score for each class
    precision = precision_score(all_labels, all_predictions, average=None)
    recall = recall_score(all_labels, all_predictions, average=None)
    f1 = f1_score(all_labels, all_predictions, average=None)

    return accuracy, tp, fp, tn, fn, precision, recall, f1

# Initialize clients
def initialize_clients(trainset, transform, batch_size, num_clients):
    clients = {}
    for i in range(num_clients):
        client_trainset = torch.utils.data.Subset(trainset, range(i * len(trainset) // num_clients, (i + 1) * len(trainset) // num_clients))
        client_trainloader = torch.utils.data.DataLoader(client_trainset, batch_size=batch_size, shuffle=True)
        clients[f"client_{i}"] = client_trainloader
    return clients

def get_distribution_info(vae: VAE) -> Dict:
    # Implement the logic to extract distribution information from the VAE
    # This can involve computing statistics, parameters, or any other relevant information
    # that can be used to generate augmented data

    # Example implementation:
    distribution_info = {

        "truncated": {
            "mean": vae.encoder[-1].bias.data.cpu().numpy(),
            "std": torch.exp(0.5 * vae.encoder[-1].weight.data).cpu().numpy()
        }
    }

    return distribution_info

def send_distribution_info(distribution_info: Dict) -> None:
    # Implement the logic to send the distribution information to the global server
    # This can involve using a network protocol, a message queue, or any other communication mechanism
    # to send the information

    # Example implementation:
    # Send the distribution information to the global server using a network protocol
    # For example, you can use the `socket` module to send the information over a network
    # or use a message queue like `RabbitMQ` to send the information
    pass

def generate_augmented_data(vae: VAE,  distribution_info_truncated: Dict) -> torch.Tensor:
    # Generate augmented data using both uniform and truncated uniform distributions


    mean_truncated = distribution_info_truncated["mean"]
    std_truncated = distribution_info_truncated["std"]




    # Generate augmented data from truncated normal distribution
    a = (0 - mean_truncated) / std_truncated
    b = np.inf
    augmented_data_truncated = torch.from_numpy(truncnorm.rvs(a, b, loc=mean_truncated, scale=std_truncated, size=(64, vae.z_dim))).float()

    # Calculate the average of augmented data from both distributions
    augmented_data_average =  augmented_data_truncated

    return augmented_data_average

def federated_train(net: nn.Module, vae: VAE, trainloaders: Dict[str, DataLoader], trainloader: DataLoader, valloader: DataLoader, epochs: int) -> None:
    for epoch in range(epochs):
        for client_id, client_trainloader in trainloaders.items():
            # Train VAE on client data
            vae_train(vae, client_trainloader, epochs=10)

            # Share distribution information with global server
            distribution_info = get_distribution_info(vae)
            send_distribution_info(distribution_info)

            # Receive distribution information from other clients
            other_distribution_info = receive_distribution_info()

            # Generate augmented data using received distribution information
            augmented_data = generate_augmented_data(vae, other_distribution_info["truncated"])

            # Train classification model using local, augmented, and validation data
            train(net, client_trainloader, valloader, epochs=10)

            # Send model updates to global server
            send_model_update(client_id, net.state_dict())

# Define logic to receive distribution information from global server
def receive_distribution_info() -> Dict:
    # Receive distribution information logic
    distribution_info = {

        "truncated": {
            "mean": np.zeros(20),
            "std": np.ones(20)
        }
    }
    return distribution_info

def send_model_update(client_id: str, model_update: Dict) -> None:
    # Implement the logic to send the model update to the global server
    # This can involve using a network protocol, a message queue, or any other communication mechanism
    # to send the model update

    # Example implementation:
    # Send the model update to the global server using a network protocol
    # For example, you can use the `socket` module to send the information over a network
    # or use a message queue like `RabbitMQ` to send the information
    pass

# Define global server procedure
def global_server() -> None:
    net = Net()
    x_dim = 3 * 32 * 32  # CIFAR-10 input size
    h_dim = 400
    z_dim = 20
    vae = VAE(x_dim, h_dim, z_dim)  # Initialize VAE object with required arguments

    # Initialize clients
    num_clients = 5  # Define the number of clients
    clients = initialize_clients(train_set, transform, batch_size=128, num_clients=num_clients)

    # Train model using FedDIS
    federated_train(net, vae, clients, trainloader, valloader, epochs=10)

    # Evaluate final model
    test_accuracy, tp, fp, tn, fn, precision, recall, f1 = evaluate(net, testloader)
    print(f"Test Accuracy: {test_accuracy:.2f}%")

    print("True Positives (TP):", tp)
    print("False Positives (FP):", fp)
    print("True Negatives (TN):", tn)
    print("False Negatives (FN):", fn)
    print(":", fn+tn+tp+fp)
    print("Precision:", precision)
    print("Recall:", recall)
    print("F1 Score:", f1)

if __name__ == "__main__":
    global_server()

Files already downloaded and verified
Files already downloaded and verified
Random Images per Class: [5843 5936 6049 6011 6149 5995 6058 6085 5882 5992]
Epoch [1/10], Training Loss: 2.303, Validation Accuracy: 9.09%
Epoch [2/10], Training Loss: 2.302, Validation Accuracy: 8.80%
Epoch [3/10], Training Loss: 2.302, Validation Accuracy: 8.83%
Epoch [4/10], Training Loss: 2.302, Validation Accuracy: 8.95%
Epoch [5/10], Training Loss: 2.301, Validation Accuracy: 9.52%
Epoch [6/10], Training Loss: 2.301, Validation Accuracy: 9.62%
Epoch [7/10], Training Loss: 2.300, Validation Accuracy: 9.65%
Epoch [8/10], Training Loss: 2.300, Validation Accuracy: 9.66%
Epoch [9/10], Training Loss: 2.299, Validation Accuracy: 9.66%
Epoch [10/10], Training Loss: 2.298, Validation Accuracy: 9.66%
Epoch [1/10], Training Loss: 2.299, Validation Accuracy: 9.66%
Epoch [2/10], Training Loss: 2.298, Validation Accuracy: 9.72%
Epoch [3/10], Training Loss: 2.296, Validation Accuracy: 9.78%
Epoch [4/10], Training Loss

In [None]:
import re

# Your provided text
log = """
Epoch [1/10], Training Loss: 2.303, Validation Accuracy: 9.09%
Epoch [2/10], Training Loss: 2.302, Validation Accuracy: 8.80%
Epoch [3/10], Training Loss: 2.302, Validation Accuracy: 8.83%
Epoch [4/10], Training Loss: 2.302, Validation Accuracy: 8.95%
Epoch [5/10], Training Loss: 2.301, Validation Accuracy: 9.52%
Epoch [6/10], Training Loss: 2.301, Validation Accuracy: 9.62%
Epoch [7/10], Training Loss: 2.300, Validation Accuracy: 9.65%
Epoch [8/10], Training Loss: 2.300, Validation Accuracy: 9.66%
Epoch [9/10], Training Loss: 2.299, Validation Accuracy: 9.66%
Epoch [10/10], Training Loss: 2.298, Validation Accuracy: 9.66%
Epoch [1/10], Training Loss: 2.299, Validation Accuracy: 9.66%
Epoch [2/10], Training Loss: 2.298, Validation Accuracy: 9.72%
Epoch [3/10], Training Loss: 2.296, Validation Accuracy: 9.78%
Epoch [4/10], Training Loss: 2.294, Validation Accuracy: 10.35%
Epoch [5/10], Training Loss: 2.292, Validation Accuracy: 11.38%
Epoch [6/10], Training Loss: 2.288, Validation Accuracy: 13.31%
Epoch [7/10], Training Loss: 2.283, Validation Accuracy: 14.51%
Epoch [8/10], Training Loss: 2.274, Validation Accuracy: 15.77%
Epoch [9/10], Training Loss: 2.260, Validation Accuracy: 16.55%
Epoch [10/10], Training Loss: 2.240, Validation Accuracy: 16.95%
Epoch [1/10], Training Loss: 2.213, Validation Accuracy: 17.66%
Epoch [2/10], Training Loss: 2.179, Validation Accuracy: 21.05%
Epoch [3/10], Training Loss: 2.141, Validation Accuracy: 23.12%
Epoch [4/10], Training Loss: 2.099, Validation Accuracy: 24.57%
Epoch [5/10], Training Loss: 2.058, Validation Accuracy: 26.33%
Epoch [6/10], Training Loss: 2.021, Validation Accuracy: 27.11%
Epoch [7/10], Training Loss: 1.989, Validation Accuracy: 28.22%
Epoch [8/10], Training Loss: 1.963, Validation Accuracy: 28.16%
Epoch [9/10], Training Loss: 1.941, Validation Accuracy: 29.41%
Epoch [10/10], Training Loss: 1.919, Validation Accuracy: 30.06%
Epoch [1/10], Training Loss: 1.902, Validation Accuracy: 30.18%
Epoch [2/10], Training Loss: 1.881, Validation Accuracy: 31.11%
Epoch [3/10], Training Loss: 1.864, Validation Accuracy: 31.47%
Epoch [4/10], Training Loss: 1.851, Validation Accuracy: 31.89%
Epoch [5/10], Training Loss: 1.836, Validation Accuracy: 32.67%
Epoch [6/10], Training Loss: 1.820, Validation Accuracy: 32.83%
Epoch [7/10], Training Loss: 1.806, Validation Accuracy: 33.57%
Epoch [8/10], Training Loss: 1.793, Validation Accuracy: 34.17%
Epoch [9/10], Training Loss: 1.779, Validation Accuracy: 34.26%
Epoch [10/10], Training Loss: 1.769, Validation Accuracy: 35.14%
Epoch [1/10], Training Loss: 1.769, Validation Accuracy: 35.42%
Epoch [2/10], Training Loss: 1.752, Validation Accuracy: 36.06%
Epoch [3/10], Training Loss: 1.735, Validation Accuracy: 36.02%
Epoch [4/10], Training Loss: 1.725, Validation Accuracy: 36.89%
Epoch [5/10], Training Loss: 1.708, Validation Accuracy: 37.26%
Epoch [6/10], Training Loss: 1.702, Validation Accuracy: 37.56%
Epoch [7/10], Training Loss: 1.686, Validation Accuracy: 38.08%
Epoch [8/10], Training Loss: 1.671, Validation Accuracy: 38.22%
Epoch [9/10], Training Loss: 1.663, Validation Accuracy: 38.43%
Epoch [10/10], Training Loss: 1.651, Validation Accuracy: 39.11%
Epoch [1/10], Training Loss: 1.651, Validation Accuracy: 39.45%
Epoch [2/10], Training Loss: 1.633, Validation Accuracy: 40.51%
Epoch [3/10], Training Loss: 1.623, Validation Accuracy: 39.55%
Epoch [4/10], Training Loss: 1.607, Validation Accuracy: 40.74%
Epoch [5/10], Training Loss: 1.593, Validation Accuracy: 41.25%
Epoch [6/10], Training Loss: 1.583, Validation Accuracy: 41.52%
Epoch [7/10], Training Loss: 1.575, Validation Accuracy: 40.70%
Epoch [8/10], Training Loss: 1.559, Validation Accuracy: 42.09%
Epoch [9/10], Training Loss: 1.552, Validation Accuracy: 42.08%
Epoch [10/10], Training Loss: 1.540, Validation Accuracy: 42.14%
Epoch [1/10], Training Loss: 1.585, Validation Accuracy: 43.02%
Epoch [2/10], Training Loss: 1.570, Validation Accuracy: 42.68%
Epoch [3/10], Training Loss: 1.554, Validation Accuracy: 43.28%
Epoch [4/10], Training Loss: 1.550, Validation Accuracy: 43.59%
Epoch [5/10], Training Loss: 1.533, Validation Accuracy: 43.75%
Epoch [6/10], Training Loss: 1.524, Validation Accuracy: 44.03%
Epoch [7/10], Training Loss: 1.516, Validation Accuracy: 43.75%
Epoch [8/10], Training Loss: 1.503, Validation Accuracy: 44.79%
Epoch [9/10], Training Loss: 1.494, Validation Accuracy: 44.87%
Epoch [10/10], Training Loss: 1.490, Validation Accuracy: 44.57%
Epoch [1/10], Training Loss: 1.514, Validation Accuracy: 45.16%
Epoch [2/10], Training Loss: 1.495, Validation Accuracy: 45.63%
Epoch [3/10], Training Loss: 1.482, Validation Accuracy: 45.46%
Epoch [4/10], Training Loss: 1.470, Validation Accuracy: 45.66%
Epoch [5/10], Training Loss: 1.462, Validation Accuracy: 46.49%
Epoch [6/10], Training Loss: 1.457, Validation Accuracy: 46.26%
Epoch [7/10], Training Loss: 1.445, Validation Accuracy: 45.98%
Epoch [8/10], Training Loss: 1.437, Validation Accuracy: 46.30%
Epoch [9/10], Training Loss: 1.434, Validation Accuracy: 46.20%
Epoch [10/10], Training Loss: 1.417, Validation Accuracy: 46.86%
Epoch [1/10], Training Loss: 1.460, Validation Accuracy: 46.61%
Epoch [2/10], Training Loss: 1.449, Validation Accuracy: 46.93%
Epoch [3/10], Training Loss: 1.430, Validation Accuracy: 47.17%
Epoch [4/10], Training Loss: 1.420, Validation Accuracy: 47.88%
Epoch [5/10], Training Loss: 1.416, Validation Accuracy: 47.29%
Epoch [6/10], Training Loss: 1.403, Validation Accuracy: 47.72%
Epoch [7/10], Training Loss: 1.395, Validation Accuracy: 48.24%
Epoch [8/10], Training Loss: 1.383, Validation Accuracy: 48.99%
Epoch [9/10], Training Loss: 1.377, Validation Accuracy: 48.79%
Epoch [10/10], Training Loss: 1.367, Validation Accuracy: 48.20%
Epoch [1/10], Training Loss: 1.410, Validation Accuracy: 48.34%
Epoch [2/10], Training Loss: 1.389, Validation Accuracy: 49.19%
Epoch [3/10], Training Loss: 1.375, Validation Accuracy: 49.54%
Epoch [4/10], Training Loss: 1.363, Validation Accuracy: 49.72%
Epoch [5/10], Training Loss: 1.353, Validation Accuracy: 49.56%
Epoch [6/10], Training Loss: 1.342, Validation Accuracy: 49.58%
Epoch [7/10], Training Loss: 1.337, Validation Accuracy: 49.54%
Epoch [8/10], Training Loss: 1.320, Validation Accuracy: 49.97%
Epoch [9/10], Training Loss: 1.315, Validation Accuracy: 50.01%
Epoch [10/10], Training Loss: 1.308, Validation Accuracy: 50.39%
Epoch [1/10], Training Loss: 1.351, Validation Accuracy: 49.71%
Epoch [2/10], Training Loss: 1.335, Validation Accuracy: 50.78%
Epoch [3/10], Training Loss: 1.318, Validation Accuracy: 50.10%
Epoch [4/10], Training Loss: 1.308, Validation Accuracy: 50.62%
Epoch [5/10], Training Loss: 1.296, Validation Accuracy: 49.97%
Epoch [6/10], Training Loss: 1.289, Validation Accuracy: 51.72%
Epoch [7/10], Training Loss: 1.268, Validation Accuracy: 51.33%
Epoch [8/10], Training Loss: 1.259, Validation Accuracy: 51.54%
Epoch [9/10], Training Loss: 1.257, Validation Accuracy: 51.24%
Epoch [10/10], Training Loss: 1.241, Validation Accuracy: 50.79%
Epoch [1/10], Training Loss: 1.341, Validation Accuracy: 51.82%
Epoch [2/10], Training Loss: 1.322, Validation Accuracy: 52.07%
Epoch [3/10], Training Loss: 1.302, Validation Accuracy: 52.53%
Epoch [4/10], Training Loss: 1.289, Validation Accuracy: 52.64%
Epoch [5/10], Training Loss: 1.278, Validation Accuracy: 52.13%
Epoch [6/10], Training Loss: 1.266, Validation Accuracy: 52.96%
Epoch [7/10], Training Loss: 1.268, Validation Accuracy: 52.76%
Epoch [8/10], Training Loss: 1.246, Validation Accuracy: 52.55%
Epoch [9/10], Training Loss: 1.242, Validation Accuracy: 52.14%
Epoch [10/10], Training Loss: 1.230, Validation Accuracy: 52.86%
Epoch [1/10], Training Loss: 1.288, Validation Accuracy: 51.95%
Epoch [2/10], Training Loss: 1.275, Validation Accuracy: 53.17%
Epoch [3/10], Training Loss: 1.251, Validation Accuracy: 53.91%
Epoch [4/10], Training Loss: 1.231, Validation Accuracy: 53.38%
Epoch [5/10], Training Loss: 1.228, Validation Accuracy: 53.74%
Epoch [6/10], Training Loss: 1.207, Validation Accuracy: 53.32%
Epoch [7/10], Training Loss: 1.201, Validation Accuracy: 52.78%
Epoch [8/10], Training Loss: 1.202, Validation Accuracy: 54.22%
Epoch [9/10], Training Loss: 1.191, Validation Accuracy: 53.65%
Epoch [10/10], Training Loss: 1.174, Validation Accuracy: 53.61%
Epoch [1/10], Training Loss: 1.274, Validation Accuracy: 52.64%
Epoch [2/10], Training Loss: 1.239, Validation Accuracy: 53.96%
Epoch [3/10], Training Loss: 1.230, Validation Accuracy: 53.72%
Epoch [4/10], Training Loss: 1.220, Validation Accuracy: 53.51%
Epoch [5/10], Training Loss: 1.207, Validation Accuracy: 54.83%
Epoch [6/10], Training Loss: 1.194, Validation Accuracy: 54.70%
Epoch [7/10], Training Loss: 1.185, Validation Accuracy: 54.61%
Epoch [8/10], Training Loss: 1.181, Validation Accuracy: 54.85%
Epoch [9/10], Training Loss: 1.166, Validation Accuracy: 54.05%
Epoch [10/10], Training Loss: 1.156, Validation Accuracy: 54.72%
Epoch [1/10], Training Loss: 1.222, Validation Accuracy: 54.60%
Epoch [2/10], Training Loss: 1.205, Validation Accuracy: 54.76%
Epoch [3/10], Training Loss: 1.179, Validation Accuracy: 55.29%
Epoch [4/10], Training Loss: 1.159, Validation Accuracy: 55.20%
Epoch [5/10], Training Loss: 1.155, Validation Accuracy: 55.93%
Epoch [6/10], Training Loss: 1.146, Validation Accuracy: 54.46%
Epoch [7/10], Training Loss: 1.132, Validation Accuracy: 54.13%
Epoch [8/10], Training Loss: 1.115, Validation Accuracy: 55.33%
Epoch [9/10], Training Loss: 1.115, Validation Accuracy: 55.26%
Epoch [10/10], Training Loss: 1.110, Validation Accuracy: 55.66%
Epoch [1/10], Training Loss: 1.188, Validation Accuracy: 55.33%
Epoch [2/10], Training Loss: 1.165, Validation Accuracy: 54.90%
Epoch [3/10], Training Loss: 1.145, Validation Accuracy: 55.48%
Epoch [4/10], Training Loss: 1.134, Validation Accuracy: 55.43%
Epoch [5/10], Training Loss: 1.119, Validation Accuracy: 56.41%
Epoch [6/10], Training Loss: 1.107, Validation Accuracy: 55.60%
Epoch [7/10], Training Loss: 1.091, Validation Accuracy: 56.08%
Epoch [8/10], Training Loss: 1.081, Validation Accuracy: 55.76%
Epoch [9/10], Training Loss: 1.072, Validation Accuracy: 56.72%
Epoch [10/10], Training Loss: 1.062, Validation Accuracy: 55.63%
Epoch [1/10], Training Loss: 1.202, Validation Accuracy: 56.35%
Epoch [2/10], Training Loss: 1.176, Validation Accuracy: 56.21%
Epoch [3/10], Training Loss: 1.159, Validation Accuracy: 56.31%
Epoch [4/10], Training Loss: 1.141, Validation Accuracy: 55.99%
Epoch [5/10], Training Loss: 1.133, Validation Accuracy: 56.52%
Epoch [6/10], Training Loss: 1.120, Validation Accuracy: 55.86%
Epoch [7/10], Training Loss: 1.102, Validation Accuracy: 56.58%
Epoch [8/10], Training Loss: 1.092, Validation Accuracy: 56.63%
Epoch [9/10], Training Loss: 1.082, Validation Accuracy: 56.68%
Epoch [10/10], Training Loss: 1.071, Validation Accuracy: 56.09%
Epoch [1/10], Training Loss: 1.168, Validation Accuracy: 57.04%
Epoch [2/10], Training Loss: 1.134, Validation Accuracy: 56.76%
Epoch [3/10], Training Loss: 1.117, Validation Accuracy: 57.11%
Epoch [4/10], Training Loss: 1.098, Validation Accuracy: 57.17%
Epoch [5/10], Training Loss: 1.081, Validation Accuracy: 57.30%
Epoch [6/10], Training Loss: 1.066, Validation Accuracy: 56.35%
Epoch [7/10], Training Loss: 1.057, Validation Accuracy: 57.42%
Epoch [8/10], Training Loss: 1.041, Validation Accuracy: 57.16%
Epoch [9/10], Training Loss: 1.029, Validation Accuracy: 57.09%
Epoch [10/10], Training Loss: 1.021, Validation Accuracy: 57.35%
Epoch [1/10], Training Loss: 1.147, Validation Accuracy: 57.35%
Epoch [2/10], Training Loss: 1.122, Validation Accuracy: 57.02%
Epoch [3/10], Training Loss: 1.103, Validation Accuracy: 57.45%
Epoch [4/10], Training Loss: 1.091, Validation Accuracy: 55.62%
Epoch [5/10], Training Loss: 1.077, Validation Accuracy: 56.74%
Epoch [6/10], Training Loss: 1.053, Validation Accuracy: 58.15%
Epoch [7/10], Training Loss: 1.044, Validation Accuracy: 57.47%
Epoch [8/10], Training Loss: 1.037, Validation Accuracy: 57.08%
Epoch [9/10], Training Loss: 1.026, Validation Accuracy: 56.93%
Epoch [10/10], Training Loss: 1.011, Validation Accuracy: 57.03%
Epoch [1/10], Training Loss: 1.114, Validation Accuracy: 57.07%
Epoch [2/10], Training Loss: 1.082, Validation Accuracy: 57.69%
Epoch [3/10], Training Loss: 1.063, Validation Accuracy: 58.45%
Epoch [4/10], Training Loss: 1.036, Validation Accuracy: 58.10%
Epoch [5/10], Training Loss: 1.020, Validation Accuracy: 57.61%
Epoch [6/10], Training Loss: 1.007, Validation Accuracy: 58.01%
Epoch [7/10], Training Loss: 0.996, Validation Accuracy: 58.58%
Epoch [8/10], Training Loss: 0.981, Validation Accuracy: 57.86%
Epoch [9/10], Training Loss: 0.972, Validation Accuracy: 57.86%
Epoch [10/10], Training Loss: 0.962, Validation Accuracy: 57.49%
Epoch [1/10], Training Loss: 1.090, Validation Accuracy: 57.98%
Epoch [2/10], Training Loss: 1.046, Validation Accuracy: 57.70%
Epoch [3/10], Training Loss: 1.030, Validation Accuracy: 58.55%
Epoch [4/10], Training Loss: 1.011, Validation Accuracy: 58.33%
Epoch [5/10], Training Loss: 0.994, Validation Accuracy: 58.11%
Epoch [6/10], Training Loss: 0.984, Validation Accuracy: 58.35%
Epoch [7/10], Training Loss: 0.966, Validation Accuracy: 58.51%
Epoch [8/10], Training Loss: 0.950, Validation Accuracy: 58.20%
Epoch [9/10], Training Loss: 0.941, Validation Accuracy: 57.10%
Epoch [10/10], Training Loss: 0.924, Validation Accuracy: 58.33%
Epoch [1/10], Training Loss: 1.114, Validation Accuracy: 57.85%
Epoch [2/10], Training Loss: 1.072, Validation Accuracy: 58.97%
Epoch [3/10], Training Loss: 1.055, Validation Accuracy: 58.44%
Epoch [4/10], Training Loss: 1.031, Validation Accuracy: 59.06%
Epoch [5/10], Training Loss: 1.010, Validation Accuracy: 58.88%
Epoch [6/10], Training Loss: 0.996, Validation Accuracy: 59.02%
Epoch [7/10], Training Loss: 0.984, Validation Accuracy: 59.29%
Epoch [8/10], Training Loss: 0.963, Validation Accuracy: 58.66%
Epoch [9/10], Training Loss: 0.959, Validation Accuracy: 58.55%
Epoch [10/10], Training Loss: 0.945, Validation Accuracy: 58.77%
Epoch [1/10], Training Loss: 1.081, Validation Accuracy: 59.16%
Epoch [2/10], Training Loss: 1.037, Validation Accuracy: 58.63%
Epoch [3/10], Training Loss: 1.021, Validation Accuracy: 58.62%
Epoch [4/10], Training Loss: 0.990, Validation Accuracy: 59.30%
Epoch [5/10], Training Loss: 0.976, Validation Accuracy: 58.45%
Epoch [6/10], Training Loss: 0.968, Validation Accuracy: 58.58%
Epoch [7/10], Training Loss: 0.943, Validation Accuracy: 58.68%
Epoch [8/10], Training Loss: 0.936, Validation Accuracy: 58.74%
Epoch [9/10], Training Loss: 0.920, Validation Accuracy: 58.46%
Epoch [10/10], Training Loss: 0.907, Validation Accuracy: 58.81%
Epoch [1/10], Training Loss: 1.076, Validation Accuracy: 57.98%
Epoch [2/10], Training Loss: 1.037, Validation Accuracy: 59.57%
Epoch [3/10], Training Loss: 1.004, Validation Accuracy: 58.96%
Epoch [4/10], Training Loss: 0.980, Validation Accuracy: 58.89%
Epoch [5/10], Training Loss: 0.972, Validation Accuracy: 59.78%
Epoch [6/10], Training Loss: 0.947, Validation Accuracy: 58.99%
Epoch [7/10], Training Loss: 0.940, Validation Accuracy: 59.33%
Epoch [8/10], Training Loss: 0.919, Validation Accuracy: 59.39%
Epoch [9/10], Training Loss: 0.918, Validation Accuracy: 58.27%
Epoch [10/10], Training Loss: 0.896, Validation Accuracy: 58.94%
Epoch [1/10], Training Loss: 1.041, Validation Accuracy: 59.67%
Epoch [2/10], Training Loss: 0.988, Validation Accuracy: 59.57%
Epoch [3/10], Training Loss: 0.957, Validation Accuracy: 60.14%
Epoch [4/10], Training Loss: 0.933, Validation Accuracy: 58.64%
Epoch [5/10], Training Loss: 0.916, Validation Accuracy: 60.54%
Epoch [6/10], Training Loss: 0.893, Validation Accuracy: 59.96%
Epoch [7/10], Training Loss: 0.881, Validation Accuracy: 59.52%
Epoch [8/10], Training Loss: 0.863, Validation Accuracy: 60.14%
Epoch [9/10], Training Loss: 0.853, Validation Accuracy: 59.33%
Epoch [10/10], Training Loss: 0.851, Validation Accuracy: 59.42%
Epoch [1/10], Training Loss: 1.003, Validation Accuracy: 59.61%
Epoch [2/10], Training Loss: 0.949, Validation Accuracy: 59.89%
Epoch [3/10], Training Loss: 0.926, Validation Accuracy: 59.43%
Epoch [4/10], Training Loss: 0.907, Validation Accuracy: 59.39%
Epoch [5/10], Training Loss: 0.900, Validation Accuracy: 59.64%
Epoch [6/10], Training Loss: 0.868, Validation Accuracy: 60.05%
Epoch [7/10], Training Loss: 0.849, Validation Accuracy: 59.52%
Epoch [8/10], Training Loss: 0.841, Validation Accuracy: 59.64%
Epoch [9/10], Training Loss: 0.823, Validation Accuracy: 59.94%
Epoch [10/10], Training Loss: 0.816, Validation Accuracy: 59.76%
Epoch [1/10], Training Loss: 1.045, Validation Accuracy: 58.89%
Epoch [2/10], Training Loss: 0.991, Validation Accuracy: 60.21%
Epoch [3/10], Training Loss: 0.955, Validation Accuracy: 59.86%
Epoch [4/10], Training Loss: 0.936, Validation Accuracy: 60.37%
Epoch [5/10], Training Loss: 0.911, Validation Accuracy: 60.57%
Epoch [6/10], Training Loss: 0.896, Validation Accuracy: 60.36%
Epoch [7/10], Training Loss: 0.882, Validation Accuracy: 59.95%
Epoch [8/10], Training Loss: 0.863, Validation Accuracy: 59.81%
Epoch [9/10], Training Loss: 0.846, Validation Accuracy: 59.88%
Epoch [10/10], Training Loss: 0.831, Validation Accuracy: 59.40%
Epoch [1/10], Training Loss: 1.021, Validation Accuracy: 59.98%
Epoch [2/10], Training Loss: 0.957, Validation Accuracy: 59.12%
Epoch [3/10], Training Loss: 0.927, Validation Accuracy: 60.83%
Epoch [4/10], Training Loss: 0.902, Validation Accuracy: 60.91%
Epoch [5/10], Training Loss: 0.878, Validation Accuracy: 60.02%
Epoch [6/10], Training Loss: 0.861, Validation Accuracy: 60.58%
Epoch [7/10], Training Loss: 0.846, Validation Accuracy: 60.20%
Epoch [8/10], Training Loss: 0.838, Validation Accuracy: 59.35%
Epoch [9/10], Training Loss: 0.820, Validation Accuracy: 60.18%
Epoch [10/10], Training Loss: 0.803, Validation Accuracy: 60.27%
Epoch [1/10], Training Loss: 1.013, Validation Accuracy: 60.82%
Epoch [2/10], Training Loss: 0.952, Validation Accuracy: 60.33%
Epoch [3/10], Training Loss: 0.911, Validation Accuracy: 60.38%
Epoch [4/10], Training Loss: 0.892, Validation Accuracy: 59.98%
Epoch [5/10], Training Loss: 0.873, Validation Accuracy: 60.02%
Epoch [6/10], Training Loss: 0.853, Validation Accuracy: 60.48%
Epoch [7/10], Training Loss: 0.833, Validation Accuracy: 60.22%
Epoch [8/10], Training Loss: 0.812, Validation Accuracy: 60.46%
Epoch [9/10], Training Loss: 0.801, Validation Accuracy: 60.47%
Epoch [10/10], Training Loss: 0.787, Validation Accuracy: 60.16%
Epoch [1/10], Training Loss: 0.958, Validation Accuracy: 60.16%
Epoch [2/10], Training Loss: 0.909, Validation Accuracy: 60.82%
Epoch [3/10], Training Loss: 0.870, Validation Accuracy: 60.16%
Epoch [4/10], Training Loss: 0.845, Validation Accuracy: 60.75%
Epoch [5/10], Training Loss: 0.810, Validation Accuracy: 60.38%
Epoch [6/10], Training Loss: 0.808, Validation Accuracy: 61.06%
Epoch [7/10], Training Loss: 0.772, Validation Accuracy: 60.95%
Epoch [8/10], Training Loss: 0.761, Validation Accuracy: 60.30%
Epoch [9/10], Training Loss: 0.743, Validation Accuracy: 60.92%
Epoch [10/10], Training Loss: 0.726, Validation Accuracy: 60.45%
Epoch [1/10], Training Loss: 0.935, Validation Accuracy: 60.20%
Epoch [2/10], Training Loss: 0.875, Validation Accuracy: 61.38%
Epoch [3/10], Training Loss: 0.839, Validation Accuracy: 61.17%
Epoch [4/10], Training Loss: 0.814, Validation Accuracy: 60.25%
Epoch [5/10], Training Loss: 0.787, Validation Accuracy: 61.42%
Epoch [6/10], Training Loss: 0.769, Validation Accuracy: 60.80%
Epoch [7/10], Training Loss: 0.753, Validation Accuracy: 60.75%
Epoch [8/10], Training Loss: 0.736, Validation Accuracy: 60.38%
Epoch [9/10], Training Loss: 0.727, Validation Accuracy: 60.67%
Epoch [10/10], Training Loss: 0.709, Validation Accuracy: 60.59%
Epoch [1/10], Training Loss: 0.986, Validation Accuracy: 61.07%
Epoch [2/10], Training Loss: 0.911, Validation Accuracy: 61.22%
Epoch [3/10], Training Loss: 0.882, Validation Accuracy: 60.83%
Epoch [4/10], Training Loss: 0.840, Validation Accuracy: 60.81%
Epoch [5/10], Training Loss: 0.828, Validation Accuracy: 61.11%
Epoch [6/10], Training Loss: 0.802, Validation Accuracy: 60.86%
Epoch [7/10], Training Loss: 0.784, Validation Accuracy: 60.54%
Epoch [8/10], Training Loss: 0.764, Validation Accuracy: 60.48%
Epoch [9/10], Training Loss: 0.745, Validation Accuracy: 60.97%
Epoch [10/10], Training Loss: 0.729, Validation Accuracy: 61.24%
Epoch [1/10], Training Loss: 0.964, Validation Accuracy: 60.95%
Epoch [2/10], Training Loss: 0.887, Validation Accuracy: 61.15%
Epoch [3/10], Training Loss: 0.850, Validation Accuracy: 61.34%
Epoch [4/10], Training Loss: 0.826, Validation Accuracy: 60.99%
Epoch [5/10], Training Loss: 0.811, Validation Accuracy: 61.79%
Epoch [6/10], Training Loss: 0.782, Validation Accuracy: 59.86%
Epoch [7/10], Training Loss: 0.751, Validation Accuracy: 61.18%
Epoch [8/10], Training Loss: 0.737, Validation Accuracy: 61.49%
Epoch [9/10], Training Loss: 0.719, Validation Accuracy: 61.40%
Epoch [10/10], Training Loss: 0.702, Validation Accuracy: 60.64%
Epoch [1/10], Training Loss: 0.958, Validation Accuracy: 61.00%
Epoch [2/10], Training Loss: 0.882, Validation Accuracy: 60.78%
Epoch [3/10], Training Loss: 0.841, Validation Accuracy: 61.46%
Epoch [4/10], Training Loss: 0.803, Validation Accuracy: 61.31%
Epoch [5/10], Training Loss: 0.777, Validation Accuracy: 60.74%
Epoch [6/10], Training Loss: 0.763, Validation Accuracy: 60.59%
Epoch [7/10], Training Loss: 0.742, Validation Accuracy: 61.30%
Epoch [8/10], Training Loss: 0.725, Validation Accuracy: 60.83%
Epoch [9/10], Training Loss: 0.695, Validation Accuracy: 60.70%
Epoch [10/10], Training Loss: 0.695, Validation Accuracy: 60.51%
Epoch [1/10], Training Loss: 0.902, Validation Accuracy: 61.67%
Epoch [2/10], Training Loss: 0.833, Validation Accuracy: 60.83%
Epoch [3/10], Training Loss: 0.793, Validation Accuracy: 60.96%
Epoch [4/10], Training Loss: 0.754, Validation Accuracy: 61.89%
Epoch [5/10], Training Loss: 0.730, Validation Accuracy: 61.14%
Epoch [6/10], Training Loss: 0.701, Validation Accuracy: 61.02%
Epoch [7/10], Training Loss: 0.677, Validation Accuracy: 60.77%
Epoch [8/10], Training Loss: 0.682, Validation Accuracy: 61.25%
Epoch [9/10], Training Loss: 0.640, Validation Accuracy: 61.27%
Epoch [10/10], Training Loss: 0.625, Validation Accuracy: 60.71%
Epoch [1/10], Training Loss: 0.880, Validation Accuracy: 61.26%
Epoch [2/10], Training Loss: 0.803, Validation Accuracy: 60.06%
Epoch [3/10], Training Loss: 0.767, Validation Accuracy: 61.66%
Epoch [4/10], Training Loss: 0.734, Validation Accuracy: 61.29%
Epoch [5/10], Training Loss: 0.708, Validation Accuracy: 61.63%
Epoch [6/10], Training Loss: 0.679, Validation Accuracy: 61.06%
Epoch [7/10], Training Loss: 0.662, Validation Accuracy: 60.89%
Epoch [8/10], Training Loss: 0.649, Validation Accuracy: 60.80%
Epoch [9/10], Training Loss: 0.632, Validation Accuracy: 60.95%
Epoch [10/10], Training Loss: 0.610, Validation Accuracy: 60.81%
Epoch [1/10], Training Loss: 0.930, Validation Accuracy: 61.21%
Epoch [2/10], Training Loss: 0.848, Validation Accuracy: 61.15%
Epoch [3/10], Training Loss: 0.788, Validation Accuracy: 61.61%
Epoch [4/10], Training Loss: 0.755, Validation Accuracy: 61.58%
Epoch [5/10], Training Loss: 0.736, Validation Accuracy: 61.45%
Epoch [6/10], Training Loss: 0.709, Validation Accuracy: 61.61%
Epoch [7/10], Training Loss: 0.694, Validation Accuracy: 60.51%
Epoch [8/10], Training Loss: 0.681, Validation Accuracy: 60.56%
Epoch [9/10], Training Loss: 0.658, Validation Accuracy: 61.58%
Epoch [10/10], Training Loss: 0.635, Validation Accuracy: 61.31%
Epoch [1/10], Training Loss: 0.904, Validation Accuracy: 60.82%
Epoch [2/10], Training Loss: 0.821, Validation Accuracy: 61.04%
Epoch [3/10], Training Loss: 0.763, Validation Accuracy: 60.55%
Epoch [4/10], Training Loss: 0.740, Validation Accuracy: 61.68%
Epoch [5/10], Training Loss: 0.702, Validation Accuracy: 61.80%
Epoch [6/10], Training Loss: 0.683, Validation Accuracy: 61.59%
Epoch [7/10], Training Loss: 0.661, Validation Accuracy: 61.35%
Epoch [8/10], Training Loss: 0.646, Validation Accuracy: 60.90%
Epoch [9/10], Training Loss: 0.620, Validation Accuracy: 60.66%
Epoch [10/10], Training Loss: 0.609, Validation Accuracy: 61.01%
Epoch [1/10], Training Loss: 0.898, Validation Accuracy: 60.93%
Epoch [2/10], Training Loss: 0.817, Validation Accuracy: 61.83%
Epoch [3/10], Training Loss: 0.765, Validation Accuracy: 61.43%
Epoch [4/10], Training Loss: 0.731, Validation Accuracy: 61.78%
Epoch [5/10], Training Loss: 0.698, Validation Accuracy: 61.44%
Epoch [6/10], Training Loss: 0.680, Validation Accuracy: 60.58%
Epoch [7/10], Training Loss: 0.657, Validation Accuracy: 60.64%
Epoch [8/10], Training Loss: 0.625, Validation Accuracy: 61.10%
Epoch [9/10], Training Loss: 0.606, Validation Accuracy: 60.48%
Epoch [10/10], Training Loss: 0.587, Validation Accuracy: 60.75%
Epoch [1/10], Training Loss: 0.856, Validation Accuracy: 60.63%
Epoch [2/10], Training Loss: 0.767, Validation Accuracy: 61.22%
Epoch [3/10], Training Loss: 0.710, Validation Accuracy: 61.46%
Epoch [4/10], Training Loss: 0.681, Validation Accuracy: 61.09%
Epoch [5/10], Training Loss: 0.653, Validation Accuracy: 61.30%
Epoch [6/10], Training Loss: 0.632, Validation Accuracy: 61.46%
Epoch [7/10], Training Loss: 0.585, Validation Accuracy: 61.20%
Epoch [8/10], Training Loss: 0.575, Validation Accuracy: 61.28%
Epoch [9/10], Training Loss: 0.548, Validation Accuracy: 60.94%
Epoch [10/10], Training Loss: 0.534, Validation Accuracy: 61.48%
Epoch [1/10], Training Loss: 0.850, Validation Accuracy: 60.52%
Epoch [2/10], Training Loss: 0.750, Validation Accuracy: 61.53%
Epoch [3/10], Training Loss: 0.694, Validation Accuracy: 61.45%
Epoch [4/10], Training Loss: 0.658, Validation Accuracy: 61.22%
Epoch [5/10], Training Loss: 0.638, Validation Accuracy: 60.74%
Epoch [6/10], Training Loss: 0.602, Validation Accuracy: 61.21%
Epoch [7/10], Training Loss: 0.583, Validation Accuracy: 60.89%
Epoch [8/10], Training Loss: 0.554, Validation Accuracy: 60.75%
Epoch [9/10], Training Loss: 0.546, Validation Accuracy: 61.03%
Epoch [10/10], Training Loss: 0.519, Validation Accuracy: 60.52%
Epoch [1/10], Training Loss: 0.889, Validation Accuracy: 60.63%
Epoch [2/10], Training Loss: 0.778, Validation Accuracy: 60.97%
Epoch [3/10], Training Loss: 0.734, Validation Accuracy: 61.07%
Epoch [4/10], Training Loss: 0.684, Validation Accuracy: 61.42%
Epoch [5/10], Training Loss: 0.647, Validation Accuracy: 61.78%
Epoch [6/10], Training Loss: 0.622, Validation Accuracy: 60.76%
Epoch [7/10], Training Loss: 0.599, Validation Accuracy: 61.21%
Epoch [8/10], Training Loss: 0.580, Validation Accuracy: 60.98%
Epoch [9/10], Training Loss: 0.555, Validation Accuracy: 61.16%
Epoch [10/10], Training Loss: 0.544, Validation Accuracy: 60.90%
Epoch [1/10], Training Loss: 0.858, Validation Accuracy: 60.63%
Epoch [2/10], Training Loss: 0.756, Validation Accuracy: 61.39%
Epoch [3/10], Training Loss: 0.703, Validation Accuracy: 61.05%
Epoch [4/10], Training Loss: 0.652, Validation Accuracy: 61.41%
Epoch [5/10], Training Loss: 0.618, Validation Accuracy: 61.70%
Epoch [6/10], Training Loss: 0.588, Validation Accuracy: 61.47%
Epoch [7/10], Training Loss: 0.568, Validation Accuracy: 61.66%
Epoch [8/10], Training Loss: 0.547, Validation Accuracy: 61.19%
Epoch [9/10], Training Loss: 0.530, Validation Accuracy: 61.06%
Epoch [10/10], Training Loss: 0.509, Validation Accuracy: 61.69%
Epoch [1/10], Training Loss: 0.859, Validation Accuracy: 61.57%
Epoch [2/10], Training Loss: 0.749, Validation Accuracy: 61.34%
Epoch [3/10], Training Loss: 0.695, Validation Accuracy: 61.22%
Epoch [4/10], Training Loss: 0.660, Validation Accuracy: 61.09%
Epoch [5/10], Training Loss: 0.618, Validation Accuracy: 61.11%
Epoch [6/10], Training Loss: 0.587, Validation Accuracy: 61.15%
Epoch [7/10], Training Loss: 0.555, Validation Accuracy: 61.30%
Epoch [8/10], Training Loss: 0.542, Validation Accuracy: 60.64%
Epoch [9/10], Training Loss: 0.530, Validation Accuracy: 60.77%
Epoch [10/10], Training Loss: 0.500, Validation Accuracy: 60.86%
Epoch [1/10], Training Loss: 0.828, Validation Accuracy: 61.12%
Epoch [2/10], Training Loss: 0.712, Validation Accuracy: 61.19%
Epoch [3/10], Training Loss: 0.645, Validation Accuracy: 60.97%
Epoch [4/10], Training Loss: 0.602, Validation Accuracy: 61.42%
Epoch [5/10], Training Loss: 0.555, Validation Accuracy: 61.11%
Epoch [6/10], Training Loss: 0.536, Validation Accuracy: 61.23%
Epoch [7/10], Training Loss: 0.511, Validation Accuracy: 60.96%
Epoch [8/10], Training Loss: 0.497, Validation Accuracy: 60.68%
Epoch [9/10], Training Loss: 0.471, Validation Accuracy: 61.48%
Epoch [10/10], Training Loss: 0.442, Validation Accuracy: 60.90%
Epoch [1/10], Training Loss: 0.797, Validation Accuracy: 60.49%
Epoch [2/10], Training Loss: 0.697, Validation Accuracy: 60.83%
Epoch [3/10], Training Loss: 0.621, Validation Accuracy: 61.30%
Epoch [4/10], Training Loss: 0.582, Validation Accuracy: 61.18%
Epoch [5/10], Training Loss: 0.550, Validation Accuracy: 61.41%
Epoch [6/10], Training Loss: 0.515, Validation Accuracy: 60.60%
Epoch [7/10], Training Loss: 0.495, Validation Accuracy: 61.07%
Epoch [8/10], Training Loss: 0.465, Validation Accuracy: 60.57%
Epoch [9/10], Training Loss: 0.450, Validation Accuracy: 61.08%
Epoch [10/10], Training Loss: 0.429, Validation Accuracy: 60.64%
Epoch [1/10], Training Loss: 0.848, Validation Accuracy: 60.84%
Epoch [2/10], Training Loss: 0.712, Validation Accuracy: 60.76%
Epoch [3/10], Training Loss: 0.658, Validation Accuracy: 61.16%
Epoch [4/10], Training Loss: 0.604, Validation Accuracy: 60.94%
Epoch [5/10], Training Loss: 0.568, Validation Accuracy: 60.87%
Epoch [6/10], Training Loss: 0.553, Validation Accuracy: 60.69%
Epoch [7/10], Training Loss: 0.516, Validation Accuracy: 61.60%
Epoch [8/10], Training Loss: 0.494, Validation Accuracy: 60.91%
Epoch [9/10], Training Loss: 0.481, Validation Accuracy: 60.83%
Epoch [10/10], Training Loss: 0.455, Validation Accuracy: 61.19%
Epoch [1/10], Training Loss: 0.821, Validation Accuracy: 61.22%
Epoch [2/10], Training Loss: 0.682, Validation Accuracy: 61.71%
Epoch [3/10], Training Loss: 0.623, Validation Accuracy: 60.64%
Epoch [4/10], Training Loss: 0.582, Validation Accuracy: 60.81%
Epoch [5/10], Training Loss: 0.545, Validation Accuracy: 61.38%
Epoch [6/10], Training Loss: 0.514, Validation Accuracy: 61.57%
Epoch [7/10], Training Loss: 0.500, Validation Accuracy: 61.52%
Epoch [8/10], Training Loss: 0.472, Validation Accuracy: 61.03%
Epoch [9/10], Training Loss: 0.447, Validation Accuracy: 60.76%
Epoch [10/10], Training Loss: 0.427, Validation Accuracy: 60.42%
Epoch [1/10], Training Loss: 0.817, Validation Accuracy: 60.97%
Epoch [2/10], Training Loss: 0.683, Validation Accuracy: 61.65%
Epoch [3/10], Training Loss: 0.618, Validation Accuracy: 61.12%
Epoch [4/10], Training Loss: 0.573, Validation Accuracy: 61.11%
Epoch [5/10], Training Loss: 0.537, Validation Accuracy: 60.95%
Epoch [6/10], Training Loss: 0.514, Validation Accuracy: 61.18%
Epoch [7/10], Training Loss: 0.474, Validation Accuracy: 60.96%
Epoch [8/10], Training Loss: 0.455, Validation Accuracy: 60.49%
Epoch [9/10], Training Loss: 0.433, Validation Accuracy: 60.64%
Epoch [10/10], Training Loss: 0.411, Validation Accuracy: 60.34%
Epoch [1/10], Training Loss: 0.791, Validation Accuracy: 60.47%
Epoch [2/10], Training Loss: 0.650, Validation Accuracy: 59.43%
Epoch [3/10], Training Loss: 0.584, Validation Accuracy: 60.48%
Epoch [4/10], Training Loss: 0.537, Validation Accuracy: 61.41%
Epoch [5/10], Training Loss: 0.482, Validation Accuracy: 61.16%
Epoch [6/10], Training Loss: 0.475, Validation Accuracy: 60.67%
Epoch [7/10], Training Loss: 0.430, Validation Accuracy: 60.86%
Epoch [8/10], Training Loss: 0.408, Validation Accuracy: 61.11%
Epoch [9/10], Training Loss: 0.385, Validation Accuracy: 60.45%
Epoch [10/10], Training Loss: 0.365, Validation Accuracy: 60.91%
"""

# Regular expression to find validation accuracies
accuracies = re.findall(r'Validation Accuracy: (\d+\.\d+)%', log)

# Convert accuracies from string to float
accuracies = [float(acc) for acc in accuracies]

# Print accuracies
print("Accuracies:", accuracies)

# Print size of the array
print("Size of array:", len(accuracies))


Accuracies: [9.09, 8.8, 8.83, 8.95, 9.52, 9.62, 9.65, 9.66, 9.66, 9.66, 9.66, 9.72, 9.78, 10.35, 11.38, 13.31, 14.51, 15.77, 16.55, 16.95, 17.66, 21.05, 23.12, 24.57, 26.33, 27.11, 28.22, 28.16, 29.41, 30.06, 30.18, 31.11, 31.47, 31.89, 32.67, 32.83, 33.57, 34.17, 34.26, 35.14, 35.42, 36.06, 36.02, 36.89, 37.26, 37.56, 38.08, 38.22, 38.43, 39.11, 39.45, 40.51, 39.55, 40.74, 41.25, 41.52, 40.7, 42.09, 42.08, 42.14, 43.02, 42.68, 43.28, 43.59, 43.75, 44.03, 43.75, 44.79, 44.87, 44.57, 45.16, 45.63, 45.46, 45.66, 46.49, 46.26, 45.98, 46.3, 46.2, 46.86, 46.61, 46.93, 47.17, 47.88, 47.29, 47.72, 48.24, 48.99, 48.79, 48.2, 48.34, 49.19, 49.54, 49.72, 49.56, 49.58, 49.54, 49.97, 50.01, 50.39, 49.71, 50.78, 50.1, 50.62, 49.97, 51.72, 51.33, 51.54, 51.24, 50.79, 51.82, 52.07, 52.53, 52.64, 52.13, 52.96, 52.76, 52.55, 52.14, 52.86, 51.95, 53.17, 53.91, 53.38, 53.74, 53.32, 52.78, 54.22, 53.65, 53.61, 52.64, 53.96, 53.72, 53.51, 54.83, 54.7, 54.61, 54.85, 54.05, 54.72, 54.6, 54.76, 55.29, 55.2, 5

In [None]:
import numpy as np
from scipy.stats import truncnorm
import torch
import torchvision
import torchvision.transforms as transforms
from torch import nn
from torch.utils.data import Subset, DataLoader
from torchvision.datasets import CIFAR10
from typing import Dict, Tuple
from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score
import random

# Define VAE model
class VAE(nn.Module):
    def __init__(self, x_dim, h_dim, z_dim):
        super(VAE, self).__init__()
        self.x_dim = x_dim
        self.h_dim = h_dim
        self.z_dim = z_dim

        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, 4, stride=2, padding=1),  # 16x16
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2, padding=1),  # 8x8
            nn.ReLU(),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),  # 4x4
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(2048, 256),
            nn.ReLU(),
            nn.Linear(256, 2*z_dim)
        )

        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(z_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 2048),
            nn.ReLU(),
            nn.Unflatten(1, (128, 4, 4)),
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),  # 8x8
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1),  # 16x16
            nn.ReLU(),
            nn.ConvTranspose2d(32, 3, 4, stride=2, padding=1),  # 32x32
            nn.Sigmoid()
        )

    def encode(self, x):
        h = self.encoder(x)
        mu, logvar = h[:, :self.z_dim], h[:, self.z_dim:]
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        return self.decoder(z)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar


# Define VAE training procedure
def vae_train(vae: VAE, trainloader: DataLoader, epochs: int) -> None:
    optimizer = torch.optim.Adam(vae.parameters(), lr=1e-3)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)

    for epoch in range(epochs):
        vae.train()
        for i, data in enumerate(trainloader, 0):
            inputs, _ = data
            optimizer.zero_grad()
            recon_x, mu, logvar = vae(inputs)
            loss = vae_loss(recon_x, inputs, mu, logvar)
            loss.backward()
            optimizer.step()

        scheduler.step()


# Define classification model
class Net(nn.Module):
    def __init__(self) -> None:
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.pool(nn.functional.relu(self.conv1(x)))
        x = self.pool(nn.functional.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = nn.functional.relu(self.fc1(x))
        x = nn.functional.relu(self.fc2(x))
        x = self.fc3(x)
        return x


# Load CIFAR10 dataset
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)

full_dataset = CIFAR10(root="./data", train=True, download=True, transform=transform)
test_set = CIFAR10(root="./data", train=False, download=True, transform=transform)

# Separate dataset by class
class_indices = {i: [] for i in range(10)}  # CIFAR-10 has 10 classes
for idx, (_, label) in enumerate(full_dataset):
    class_indices[label].append(idx)

# Define target count per class, summing to 60,000 with random distribution
class_counts = np.random.multinomial(60000, [0.1] * 10)  # Adjust probabilities if you want specific class biases
print("Random Images per Class:", class_counts)

# Sample indices based on the specified class counts
indices = []
for class_id, count in enumerate(class_counts):
    # Ensure count does not exceed available images
    count = min(count, len(class_indices[class_id]))
    selected_indices = random.sample(class_indices[class_id], count)
    indices.extend(selected_indices)

# Create a custom CIFAR-10 dataset with the sampled indices
custom_dataset = Subset(full_dataset, indices)
# Split dataset into training and validation sets
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_set, val_set = torch.utils.data.random_split(full_dataset, [train_size, val_size])

# Create training and validation loaders
trainloader = DataLoader(train_set, batch_size=128, shuffle=True, num_workers=2)
valloader = DataLoader(val_set, batch_size=128, shuffle=False, num_workers=2)
testloader = DataLoader(test_set, batch_size=128, shuffle=False, num_workers=2)



# Define VAE loss function
def vae_loss(recon_x, x, mu, logvar):
    BCE = nn.functional.mse_loss(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

# Define training procedure for classification model
def train(net: nn.Module, trainloader: DataLoader, valloader: DataLoader, epochs: int) -> None:
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)

    best_acc = 0.0

    for epoch in range(epochs):
        # Training loop
        net.train()
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        scheduler.step()

        # Validation loop
        net.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for data in valloader:
                images, labels = data
                outputs = net(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        val_acc = 100 * correct / total
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(net.state_dict(), 'best_model.pth')

        print(f"Epoch [{epoch+1}/{epochs}], Training Loss: {running_loss / (i+1):.3f}, Validation Accuracy: {val_acc:.2f}%")

# Define evaluation procedure
def evaluate(net: nn.Module, testloader: DataLoader) -> Tuple[float, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    net.eval()
    correct = 0
    total = 0
    all_labels = []
    all_predictions = []
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            all_labels.extend(labels.cpu().numpy())
            all_predictions.extend(predicted.cpu().numpy())

    accuracy = 100 * correct / total

    # Calculate confusion matrix
    cm = confusion_matrix(all_labels, all_predictions)
    print(f"Confusion Matrix:\n{cm}")

    # Calculate TP, FP, TN, FN for each class
    tp = np.diag(cm)
    fp = cm.sum(axis=0) - tp
    fn = cm.sum(axis=1) - tp
    tn = cm.sum() - (fp + fn + tp)

    # Calculate precision, recall, and F1 score for each class
    precision = precision_score(all_labels, all_predictions, average=None)
    recall = recall_score(all_labels, all_predictions, average=None)
    f1 = f1_score(all_labels, all_predictions, average=None)

    return accuracy, tp, fp, tn, fn, precision, recall, f1

# Initialize clients
def initialize_clients(trainset, transform, batch_size, num_clients):
    clients = {}
    for i in range(num_clients):
        client_trainset = torch.utils.data.Subset(trainset, range(i * len(trainset) // num_clients, (i + 1) * len(trainset) // num_clients))
        client_trainloader = torch.utils.data.DataLoader(client_trainset, batch_size=batch_size, shuffle=True)
        clients[f"client_{i}"] = client_trainloader
    return clients

def get_distribution_info(vae: VAE) -> Dict:
    # Implement the logic to extract distribution information from the VAE
    # This can involve computing statistics, parameters, or any other relevant information
    # that can be used to generate augmented data

    # Example implementation:
    distribution_info = {

        "truncated": {
            "mean": vae.encoder[-1].bias.data.cpu().numpy(),
            "std": torch.exp(0.5 * vae.encoder[-1].weight.data).cpu().numpy()
        }
    }

    return distribution_info

def send_distribution_info(distribution_info: Dict) -> None:
    # Implement the logic to send the distribution information to the global server
    # This can involve using a network protocol, a message queue, or any other communication mechanism
    # to send the information

    # Example implementation:
    # Send the distribution information to the global server using a network protocol
    # For example, you can use the `socket` module to send the information over a network
    # or use a message queue like `RabbitMQ` to send the information
    pass

def generate_augmented_data(vae: VAE,  distribution_info_truncated: Dict) -> torch.Tensor:
    # Generate augmented data using both uniform and truncated uniform distributions


    mean_truncated = distribution_info_truncated["mean"]
    std_truncated = distribution_info_truncated["std"]




    # Generate augmented data from truncated normal distribution
    a = (0 - mean_truncated) / std_truncated
    b = np.inf
    augmented_data_truncated = torch.from_numpy(truncnorm.rvs(a, b, loc=mean_truncated, scale=std_truncated, size=(64, vae.z_dim))).float()

    # Calculate the average of augmented data from both distributions
    augmented_data_average =  augmented_data_truncated

    return augmented_data_average

def federated_train(net: nn.Module, vae: VAE, trainloaders: Dict[str, DataLoader], trainloader: DataLoader, valloader: DataLoader, epochs: int) -> None:
    for epoch in range(epochs):
        for client_id, client_trainloader in trainloaders.items():
            # Train VAE on client data
            vae_train(vae, client_trainloader, epochs=10)

            # Share distribution information with global server
            distribution_info = get_distribution_info(vae)
            send_distribution_info(distribution_info)

            # Receive distribution information from other clients
            other_distribution_info = receive_distribution_info()

            # Generate augmented data using received distribution information
            augmented_data = generate_augmented_data(vae, other_distribution_info["truncated"])

            # Train classification model using local, augmented, and validation data
            train(net, client_trainloader, valloader, epochs=10)

            # Send model updates to global server
            send_model_update(client_id, net.state_dict())

# Define logic to receive distribution information from global server
def receive_distribution_info() -> Dict:
    # Receive distribution information logic
    distribution_info = {

        "truncated": {
            "mean": np.zeros(20),
            "std": np.ones(20)
        }
    }
    return distribution_info

def send_model_update(client_id: str, model_update: Dict) -> None:
    # Implement the logic to send the model update to the global server
    # This can involve using a network protocol, a message queue, or any other communication mechanism
    # to send the model update

    # Example implementation:
    # Send the model update to the global server using a network protocol
    # For example, you can use the `socket` module to send the information over a network
    # or use a message queue like `RabbitMQ` to send the information
    pass

# Define global server procedure
def global_server() -> None:
    net = Net()
    x_dim = 3 * 32 * 32  # CIFAR-10 input size
    h_dim = 400
    z_dim = 20
    vae = VAE(x_dim, h_dim, z_dim)  # Initialize VAE object with required arguments

    # Initialize clients
    num_clients = 5  # Define the number of clients
    clients = initialize_clients(train_set, transform, batch_size=128, num_clients=num_clients)

    # Train model using FedDIS
    federated_train(net, vae, clients, trainloader, valloader, epochs=10)

    # Evaluate final model
    test_accuracy, tp, fp, tn, fn, precision, recall, f1 = evaluate(net, testloader)
    print(f"Test Accuracy: {test_accuracy:.2f}%")

    print("True Positives (TP):", tp)
    print("False Positives (FP):", fp)
    print("True Negatives (TN):", tn)
    print("False Negatives (FN):", fn)
    print(":", fn+tn+tp+fp)
    print("Precision:", precision)
    print("Recall:", recall)
    print("F1 Score:", f1)

if __name__ == "__main__":
    global_server()

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:02<00:00, 80.0MB/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified
Random Images per Class: [5958 6147 6057 5950 6003 6035 5879 5965 6068 5938]
Epoch [1/10], Training Loss: 2.306, Validation Accuracy: 10.56%
Epoch [2/10], Training Loss: 2.305, Validation Accuracy: 10.56%
Epoch [3/10], Training Loss: 2.304, Validation Accuracy: 10.56%
Epoch [4/10], Training Loss: 2.303, Validation Accuracy: 10.56%
Epoch [5/10], Training Loss: 2.302, Validation Accuracy: 10.56%
Epoch [6/10], Training Loss: 2.301, Validation Accuracy: 10.65%
Epoch [7/10], Training Loss: 2.300, Validation Accuracy: 10.97%
Epoch [8/10], Training Loss: 2.299, Validation Accuracy: 12.26%
Epoch [9/10], Training Loss: 2.298, Validation Accuracy: 14.45%
Epoch [10/10], Training Loss: 2.296, Validation Accuracy: 16.50%
Epoch [1/10], Training Loss: 2.295, Validation Accuracy: 16.73%
Epoch [2/10], Training Loss: 2.293, Validation Accuracy: 16.87%
Epoch [3/10], Training Loss: 2.289, Validation Accuracy: 16.48%
E