<a href="https://colab.research.google.com/github/Rahad31/Kl-FedDis-Research-/blob/main/ordinary_truncated.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
import torch
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision import transforms
from torch import nn
from typing import Dict
from scipy.stats import truncnorm

In [None]:

# Define VAE model
class VAE(nn.Module):
    def __init__(self, x_dim, h_dim, z_dim):
        super(VAE, self).__init__()
        self.x_dim = x_dim
        self.h_dim = h_dim
        self.z_dim = z_dim

        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, 4, stride=2, padding=1),  # 16x16
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2, padding=1),  # 8x8
            nn.ReLU(),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),  # 4x4
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(2048, 256),
            nn.ReLU(),
            nn.Linear(256, 2*z_dim)
        )

        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(z_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 2048),
            nn.ReLU(),
            nn.Unflatten(1, (128, 4, 4)),
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),  # 8x8
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1),  # 16x16
            nn.ReLU(),
            nn.ConvTranspose2d(32, 3, 4, stride=2, padding=1),  # 32x32
            nn.Sigmoid()
        )

    def encode(self, x):
        h = self.encoder(x)
        mu, logvar = h[:, :self.z_dim], h[:, self.z_dim:]
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        return self.decoder(z)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar



In [None]:
# Define VAE training procedure
def vae_train(vae: VAE, trainloader: DataLoader, epochs: int) -> None:
    optimizer = torch.optim.Adam(vae.parameters(), lr=1e-3)
    for epoch in range(epochs):
        for i, data in enumerate(trainloader, 0):
            inputs, _ = data
            optimizer.zero_grad()
            recon_x, mu, logvar = vae(inputs)
            loss = vae_loss(recon_x, inputs, mu, logvar)
            loss.backward()
            optimizer.step()

In [None]:
# Define classification model
class Net(nn.Module):
    def __init__(self) -> None:
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.pool(nn.functional.relu(self.conv1(x)))
        x = self.pool(nn.functional.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = nn.functional.relu(self.fc1(x))
        x = nn.functional.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [None]:
# Load CIFAR10 dataset
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)

full_dataset = CIFAR10(root="./data", train=True, download=True, transform=transform)
test_set = CIFAR10(root="./data", train=False, download=True, transform=transform)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:02<00:00, 77306790.11it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [None]:
# Split dataset into training and validation sets
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_set, val_set = torch.utils.data.random_split(full_dataset, [train_size, val_size])


In [None]:
# Create training and validation loaders
trainloader = DataLoader(train_set, batch_size=128, shuffle=True, num_workers=2)
valloader = DataLoader(val_set, batch_size=128, shuffle=False, num_workers=2)
testloader = DataLoader(test_set, batch_size=128, shuffle=False, num_workers=2)


In [None]:
# Define VAE loss function
def vae_loss(recon_x, x, mu, logvar):
    BCE = nn.functional.mse_loss(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

In [None]:
# Define training procedure for classification model
def train(net: nn.Module, trainloader: DataLoader, valloader: DataLoader, epochs: int) -> None:
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

    best_acc = 0.0

    for epoch in range(epochs):
        # Training loop
        net.train()
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        # Validation loop
        net.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for data in valloader:
                images, labels = data
                outputs = net(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        val_acc = 100 * correct / total
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(net.state_dict(), 'best_model.pth')

        print(f"Epoch [{epoch+1}/{epochs}], Training Loss: {running_loss / (i+1):.3f}, Validation Accuracy: {val_acc:.2f}%")


In [None]:
# Define evaluation procedure
def evaluate(net: nn.Module, testloader: DataLoader) -> float:
    net.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    return accuracy

In [None]:
def initialize_clients(trainset, transform, batch_size, num_clients):
    clients = {}
    for i in range(num_clients):
        client_trainset = torch.utils.data.Subset(trainset, range(i * len(trainset) // num_clients, (i + 1) * len(trainset) // num_clients))
        client_trainloader = torch.utils.data.DataLoader(client_trainset, batch_size=batch_size, shuffle=True)
        clients[f"client_{i}"] = client_trainloader
    return clients

In [None]:
def get_distribution_info(vae: VAE) -> Dict:
    # Implement the logic to extract distribution information from the VAE
    # This can involve computing statistics, parameters, or any other relevant information
    # that can be used to generate augmented data

    # Example implementation:
    distribution_info = {
        "normal": {
            "mean": vae.encoder[-1].bias.data.cpu().numpy(),
            "std": torch.exp(0.5 * vae.encoder[-1].weight.data).cpu().numpy()
        },
        "truncated": {
            "mean": vae.encoder[-1].bias.data.cpu().numpy(),
            "std": torch.exp(0.5 * vae.encoder[-1].weight.data).cpu().numpy()
        }
    }

    return distribution_info

In [None]:
def send_distribution_info(distribution_info: Dict) -> None:
    # Implement the logic to send the distribution information to the global server
    # This can involve using a network protocol, a message queue, or any other communication mechanism
    # to send the information

    # Example implementation:
    # Send the distribution information to the global server using a network protocol
    # For example, you can use the `socket` module to send the information over a network
    # or use a message queue like `RabbitMQ` to send the information
    pass


In [None]:
def generate_augmented_data(vae: VAE, distribution_info_normal: Dict, distribution_info_truncated: Dict) -> torch.Tensor:
    # Generate augmented data using both normal and truncated normal distributions
    mean_normal = distribution_info_normal["mean"]
    std_normal = distribution_info_normal["std"]

    mean_truncated = distribution_info_truncated["mean"]
    std_truncated = distribution_info_truncated["std"]

    # Generate augmented data from normal distribution
    augmented_data_normal = torch.randn(64, vae.z_dim) * std_normal + mean_normal

    # Generate augmented data from truncated normal distribution
    a = (0 - mean_truncated) / std_truncated
    b = np.inf
    augmented_data_truncated = torch.from_numpy(truncnorm.rvs(a, b, loc=mean_truncated, scale=std_truncated, size=(64, vae.z_dim))).float()

    # Calculate the average of augmented data from both distributions
    augmented_data_average = (augmented_data_normal + augmented_data_truncated) / 2

    return augmented_data_average

In [None]:
# Define the federated training logic
def federated_train(net: nn.Module, vae: VAE, trainloaders: Dict[str, DataLoader], trainloader: DataLoader, valloader: DataLoader, epochs: int) -> None:
    for epoch in range(epochs):
        for client_id, client_trainloader in trainloaders.items():
            # Train VAE on client data
            vae_train(vae, client_trainloader, epochs=10)

            # Share distribution information with global server
            distribution_info = get_distribution_info(vae)
            send_distribution_info(distribution_info)

            # Receive distribution information from other clients
            other_distribution_info = receive_distribution_info()

            # Generate augmented data using received distribution information
            augmented_data = generate_augmented_data(vae, other_distribution_info["normal"], other_distribution_info["truncated"])

            # Train classification model using local, augmented, and validation data
            train(net, client_trainloader, valloader, epochs=10)

            # Send model updates to global server
            send_model_update(client_id, net.state_dict())

In [None]:
# Define logic to receive distribution information from global server
def receive_distribution_info() -> Dict:
    # Receive distribution information logic
    distribution_info = {
        "normal": {
            "mean": np.zeros(20),  # Adjust the size based on your latent space dimension
            "std": np.ones(20)
        },
        "truncated": {
            "mean": np.zeros(20),
            "std": np.ones(20)
        }
    }
    return distribution_info

In [None]:
def send_model_update(client_id: str, model_update: Dict) -> None:
    # Implement the logic to send the model update to the global server
    # This can involve using a network protocol, a message queue, or any other communication mechanism
    # to send the model update

    # Example implementation:
    # Send the model update to the global server using a network protocol
    # For example, you can use the `socket` module to send the information over a network
    # or use a message queue like `RabbitMQ` to send the information
    pass

In [None]:
# Define global server procedure
def global_server() -> None:
    net = Net()
    x_dim = 3 * 32 * 32  # CIFAR-10 input size
    h_dim = 400
    z_dim = 20
    vae = VAE(x_dim, h_dim, z_dim)  # Initialize VAE object with required arguments

    # Initialize clients
    num_clients = 5  # Define the number of clients
    clients = initialize_clients(train_set, transform, batch_size=128, num_clients=num_clients)

    # Train model using FedDIS
    federated_train(net, vae, clients, trainloader, valloader, epochs=10)

    # Evaluate final model
    test_accuracy = evaluate(net, testloader)
    print(f"Test Accuracy: {test_accuracy:.2f}%")

if __name__ == "__main__":
    global_server()

Epoch [1/10], Training Loss: 2.304, Validation Accuracy: 9.91%
Epoch [2/10], Training Loss: 2.303, Validation Accuracy: 9.95%
Epoch [3/10], Training Loss: 2.302, Validation Accuracy: 10.01%
Epoch [4/10], Training Loss: 2.301, Validation Accuracy: 10.32%
Epoch [5/10], Training Loss: 2.299, Validation Accuracy: 11.35%
Epoch [6/10], Training Loss: 2.298, Validation Accuracy: 13.96%
Epoch [7/10], Training Loss: 2.296, Validation Accuracy: 14.94%
Epoch [8/10], Training Loss: 2.294, Validation Accuracy: 14.61%
Epoch [9/10], Training Loss: 2.290, Validation Accuracy: 14.43%
Epoch [10/10], Training Loss: 2.286, Validation Accuracy: 14.36%
Epoch [1/10], Training Loss: 2.280, Validation Accuracy: 14.73%
Epoch [2/10], Training Loss: 2.270, Validation Accuracy: 13.47%
Epoch [3/10], Training Loss: 2.256, Validation Accuracy: 12.90%
Epoch [4/10], Training Loss: 2.240, Validation Accuracy: 13.76%
Epoch [5/10], Training Loss: 2.225, Validation Accuracy: 14.83%
Epoch [6/10], Training Loss: 2.209, Valid

In [None]:
import numpy as np
from scipy.stats import truncnorm
import torch
import torchvision
import torchvision.transforms as transforms
from torch import nn
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from typing import Dict, Tuple
from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score

# Define VAE model
class VAE(nn.Module):
    def __init__(self, x_dim, h_dim, z_dim):
        super(VAE, self).__init__()
        self.x_dim = x_dim
        self.h_dim = h_dim
        self.z_dim = z_dim

        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, 4, stride=2, padding=1),  # 16x16
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2, padding=1),  # 8x8
            nn.ReLU(),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),  # 4x4
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(2048, 256),
            nn.ReLU(),
            nn.Linear(256, 2*z_dim)
        )

        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(z_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 2048),
            nn.ReLU(),
            nn.Unflatten(1, (128, 4, 4)),
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),  # 8x8
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1),  # 16x16
            nn.ReLU(),
            nn.ConvTranspose2d(32, 3, 4, stride=2, padding=1),  # 32x32
            nn.Sigmoid()
        )

    def encode(self, x):
        h = self.encoder(x)
        mu, logvar = h[:, :self.z_dim], h[:, self.z_dim:]
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        return self.decoder(z)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar


# Define VAE training procedure
def vae_train(vae: VAE, trainloader: DataLoader, epochs: int) -> None:
    optimizer = torch.optim.Adam(vae.parameters(), lr=1e-3)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)

    for epoch in range(epochs):
        vae.train()
        for i, data in enumerate(trainloader, 0):
            inputs, _ = data
            optimizer.zero_grad()
            recon_x, mu, logvar = vae(inputs)
            loss = vae_loss(recon_x, inputs, mu, logvar)
            loss.backward()
            optimizer.step()

        scheduler.step()


# Define classification model
class Net(nn.Module):
    def __init__(self) -> None:
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.pool(nn.functional.relu(self.conv1(x)))
        x = self.pool(nn.functional.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = nn.functional.relu(self.fc1(x))
        x = nn.functional.relu(self.fc2(x))
        x = self.fc3(x)
        return x


# Load CIFAR10 dataset
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)

full_dataset = CIFAR10(root="./data", train=True, download=True, transform=transform)
test_set = CIFAR10(root="./data", train=False, download=True, transform=transform)

# Split dataset into training and validation sets
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_set, val_set = torch.utils.data.random_split(full_dataset, [train_size, val_size])

# Create training and validation loaders
trainloader = DataLoader(train_set, batch_size=128, shuffle=True, num_workers=2)
valloader = DataLoader(val_set, batch_size=128, shuffle=False, num_workers=2)
testloader = DataLoader(test_set, batch_size=128, shuffle=False, num_workers=2)



# Define VAE loss function
def vae_loss(recon_x, x, mu, logvar):
    BCE = nn.functional.mse_loss(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

# Define training procedure for classification model
def train(net: nn.Module, trainloader: DataLoader, valloader: DataLoader, epochs: int) -> None:
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)

    best_acc = 0.0

    for epoch in range(epochs):
        # Training loop
        net.train()
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        scheduler.step()

        # Validation loop
        net.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for data in valloader:
                images, labels = data
                outputs = net(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        val_acc = 100 * correct / total
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(net.state_dict(), 'best_model.pth')

        print(f"Epoch [{epoch+1}/{epochs}], Training Loss: {running_loss / (i+1):.3f}, Validation Accuracy: {val_acc:.2f}%")

# Define evaluation procedure
def evaluate(net: nn.Module, testloader: DataLoader) -> Tuple[float, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    net.eval()
    correct = 0
    total = 0
    all_labels = []
    all_predictions = []
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            all_labels.extend(labels.cpu().numpy())
            all_predictions.extend(predicted.cpu().numpy())

    accuracy = 100 * correct / total

    # Calculate confusion matrix
    cm = confusion_matrix(all_labels, all_predictions)
    print(f"Confusion Matrix:\n{cm}")

    # Calculate TP, FP, TN, FN for each class
    tp = np.diag(cm)
    fp = cm.sum(axis=0) - tp
    fn = cm.sum(axis=1) - tp
    tn = cm.sum() - (fp + fn + tp)

    # Calculate precision, recall, and F1 score for each class
    precision = precision_score(all_labels, all_predictions, average=None)
    recall = recall_score(all_labels, all_predictions, average=None)
    f1 = f1_score(all_labels, all_predictions, average=None)

    return accuracy, tp, fp, tn, fn, precision, recall, f1

# Initialize clients
def initialize_clients(trainset, transform, batch_size, num_clients):
    clients = {}
    for i in range(num_clients):
        client_trainset = torch.utils.data.Subset(trainset, range(i * len(trainset) // num_clients, (i + 1) * len(trainset) // num_clients))
        client_trainloader = torch.utils.data.DataLoader(client_trainset, batch_size=batch_size, shuffle=True)
        clients[f"client_{i}"] = client_trainloader
    return clients

def get_distribution_info(vae: VAE) -> Dict:
    # Implement the logic to extract distribution information from the VAE
    # This can involve computing statistics, parameters, or any other relevant information
    # that can be used to generate augmented data

    # Example implementation:
    distribution_info = {
        "normal": {
            "mean": vae.encoder[-1].bias.data.cpu().numpy(),
            "std": torch.exp(0.5 * vae.encoder[-1].weight.data).cpu().numpy()
        },
        "truncated": {
            "mean": vae.encoder[-1].bias.data.cpu().numpy(),
            "std": torch.exp(0.5 * vae.encoder[-1].weight.data).cpu().numpy()
        }
    }

    return distribution_info

def send_distribution_info(distribution_info: Dict) -> None:
    # Implement the logic to send the distribution information to the global server
    # This can involve using a network protocol, a message queue, or any other communication mechanism
    # to send the information

    # Example implementation:
    # Send the distribution information to the global server using a network protocol
    # For example, you can use the `socket` module to send the information over a network
    # or use a message queue like `RabbitMQ` to send the information
    pass

def generate_augmented_data(vae: VAE, distribution_info_normal: Dict, distribution_info_truncated: Dict) -> torch.Tensor:
    # Generate augmented data using both normal and truncated normal distributions
    mean_normal = distribution_info_normal["mean"]
    std_normal = distribution_info_normal["std"]

    mean_truncated = distribution_info_truncated["mean"]
    std_truncated = distribution_info_truncated["std"]

    # Generate augmented data from normal distribution
    augmented_data_normal = torch.randn(64, vae.z_dim) * std_normal + mean_normal

    # Generate augmented data from truncated normal distribution
    a = (0 - mean_truncated) / std_truncated
    b = np.inf
    augmented_data_truncated = torch.from_numpy(truncnorm.rvs(a, b, loc=mean_truncated, scale=std_truncated, size=(64, vae.z_dim))).float()

    # Calculate the average of augmented data from both distributions
    augmented_data_average = (augmented_data_normal + augmented_data_truncated) / 2

    return augmented_data_average

def federated_train(net: nn.Module, vae: VAE, trainloaders: Dict[str, DataLoader], trainloader: DataLoader, valloader: DataLoader, epochs: int) -> None:
    for epoch in range(epochs):
        for client_id, client_trainloader in trainloaders.items():
            # Train VAE on client data
            vae_train(vae, client_trainloader, epochs=10)

            # Share distribution information with global server
            distribution_info = get_distribution_info(vae)
            send_distribution_info(distribution_info)

            # Receive distribution information from other clients
            other_distribution_info = receive_distribution_info()

            # Generate augmented data using received distribution information
            augmented_data = generate_augmented_data(vae, other_distribution_info["normal"], other_distribution_info["truncated"])

            # Train classification model using local, augmented, and validation data
            train(net, client_trainloader, valloader, epochs=10)

            # Send model updates to global server
            send_model_update(client_id, net.state_dict())

# Define logic to receive distribution information from global server
def receive_distribution_info() -> Dict:
    # Receive distribution information logic
    distribution_info = {
        "normal": {
            "mean": np.zeros(20),  # Adjust the size based on your latent space dimension
            "std": np.ones(20)
        },
        "truncated": {
            "mean": np.zeros(20),
            "std": np.ones(20)
        }
    }
    return distribution_info

def send_model_update(client_id: str, model_update: Dict) -> None:
    # Implement the logic to send the model update to the global server
    # This can involve using a network protocol, a message queue, or any other communication mechanism
    # to send the model update

    # Example implementation:
    # Send the model update to the global server using a network protocol
    # For example, you can use the `socket` module to send the information over a network
    # or use a message queue like `RabbitMQ` to send the information
    pass

# Define global server procedure
def global_server() -> None:
    net = Net()
    x_dim = 3 * 32 * 32  # CIFAR-10 input size
    h_dim = 400
    z_dim = 20
    vae = VAE(x_dim, h_dim, z_dim)  # Initialize VAE object with required arguments

    # Initialize clients
    num_clients = 5  # Define the number of clients
    clients = initialize_clients(train_set, transform, batch_size=128, num_clients=num_clients)

    # Train model using FedDIS
    federated_train(net, vae, clients, trainloader, valloader, epochs=10)

    # Evaluate final model
    test_accuracy, tp, fp, tn, fn, precision, recall, f1 = evaluate(net, testloader)
    print(f"Test Accuracy: {test_accuracy:.2f}%")

    print("True Positives (TP):", tp)
    print("False Positives (FP):", fp)
    print("True Negatives (TN):", tn)
    print("False Negatives (FN):", fn)
    print(":", fn+tn+tp+fp)
    print("Precision:", precision)
    print("Recall:", recall)
    print("F1 Score:", f1)

if __name__ == "__main__":
    global_server()


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:03<00:00, 43758282.68it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified
Epoch [1/10], Training Loss: 2.302, Validation Accuracy: 9.65%
Epoch [2/10], Training Loss: 2.301, Validation Accuracy: 9.72%
Epoch [3/10], Training Loss: 2.300, Validation Accuracy: 9.99%
Epoch [4/10], Training Loss: 2.299, Validation Accuracy: 10.53%
Epoch [5/10], Training Loss: 2.297, Validation Accuracy: 11.09%
Epoch [6/10], Training Loss: 2.295, Validation Accuracy: 12.15%
Epoch [7/10], Training Loss: 2.293, Validation Accuracy: 12.91%
Epoch [8/10], Training Loss: 2.290, Validation Accuracy: 14.11%
Epoch [9/10], Training Loss: 2.286, Validation Accuracy: 15.58%
Epoch [10/10], Training Loss: 2.281, Validation Accuracy: 16.29%
Epoch [1/10], Training Loss: 2.275, Validation Accuracy: 16.97%
Epoch [2/10], Training Loss: 2.265, Validation Accuracy: 19.59%
Epoch [3/10], Training Loss: 2.251, Validation Accuracy: 20.96%
Epoch [4/10], Training Loss: 2.230, Validation Accuracy: 20.48%
Epoch [5/10], Tra

In [None]:
import re

# Your provided text
log = """
Epoch [1/10], Training Loss: 2.302, Validation Accuracy: 9.65%
Epoch [2/10], Training Loss: 2.301, Validation Accuracy: 9.72%
Epoch [3/10], Training Loss: 2.300, Validation Accuracy: 9.99%
Epoch [4/10], Training Loss: 2.299, Validation Accuracy: 10.53%
Epoch [5/10], Training Loss: 2.297, Validation Accuracy: 11.09%
Epoch [6/10], Training Loss: 2.295, Validation Accuracy: 12.15%
Epoch [7/10], Training Loss: 2.293, Validation Accuracy: 12.91%
Epoch [8/10], Training Loss: 2.290, Validation Accuracy: 14.11%
Epoch [9/10], Training Loss: 2.286, Validation Accuracy: 15.58%
Epoch [10/10], Training Loss: 2.281, Validation Accuracy: 16.29%
Epoch [1/10], Training Loss: 2.275, Validation Accuracy: 16.97%
Epoch [2/10], Training Loss: 2.265, Validation Accuracy: 19.59%
Epoch [3/10], Training Loss: 2.251, Validation Accuracy: 20.96%
Epoch [4/10], Training Loss: 2.230, Validation Accuracy: 20.48%
Epoch [5/10], Training Loss: 2.203, Validation Accuracy: 20.64%
Epoch [6/10], Training Loss: 2.166, Validation Accuracy: 21.06%
Epoch [7/10], Training Loss: 2.123, Validation Accuracy: 22.06%
Epoch [8/10], Training Loss: 2.088, Validation Accuracy: 23.46%
Epoch [9/10], Training Loss: 2.057, Validation Accuracy: 24.85%
Epoch [10/10], Training Loss: 2.031, Validation Accuracy: 27.06%
Epoch [1/10], Training Loss: 2.020, Validation Accuracy: 27.46%
Epoch [2/10], Training Loss: 2.002, Validation Accuracy: 28.43%
Epoch [3/10], Training Loss: 1.983, Validation Accuracy: 28.91%
Epoch [4/10], Training Loss: 1.967, Validation Accuracy: 29.96%
Epoch [5/10], Training Loss: 1.948, Validation Accuracy: 30.19%
Epoch [6/10], Training Loss: 1.928, Validation Accuracy: 31.15%
Epoch [7/10], Training Loss: 1.912, Validation Accuracy: 31.49%
Epoch [8/10], Training Loss: 1.897, Validation Accuracy: 31.84%
Epoch [9/10], Training Loss: 1.880, Validation Accuracy: 33.01%
Epoch [10/10], Training Loss: 1.869, Validation Accuracy: 33.31%
Epoch [1/10], Training Loss: 1.865, Validation Accuracy: 32.68%
Epoch [2/10], Training Loss: 1.851, Validation Accuracy: 34.10%
Epoch [3/10], Training Loss: 1.835, Validation Accuracy: 34.41%
Epoch [4/10], Training Loss: 1.822, Validation Accuracy: 33.65%
Epoch [5/10], Training Loss: 1.806, Validation Accuracy: 34.98%
Epoch [6/10], Training Loss: 1.792, Validation Accuracy: 35.97%
Epoch [7/10], Training Loss: 1.776, Validation Accuracy: 36.18%
Epoch [8/10], Training Loss: 1.762, Validation Accuracy: 35.61%
Epoch [9/10], Training Loss: 1.747, Validation Accuracy: 36.83%
Epoch [10/10], Training Loss: 1.731, Validation Accuracy: 36.36%
Epoch [1/10], Training Loss: 1.736, Validation Accuracy: 37.58%
Epoch [2/10], Training Loss: 1.719, Validation Accuracy: 38.09%
Epoch [3/10], Training Loss: 1.707, Validation Accuracy: 38.32%
Epoch [4/10], Training Loss: 1.696, Validation Accuracy: 38.69%
Epoch [5/10], Training Loss: 1.685, Validation Accuracy: 38.91%
Epoch [6/10], Training Loss: 1.673, Validation Accuracy: 39.31%
Epoch [7/10], Training Loss: 1.662, Validation Accuracy: 39.59%
Epoch [8/10], Training Loss: 1.648, Validation Accuracy: 39.36%
Epoch [9/10], Training Loss: 1.648, Validation Accuracy: 40.28%
Epoch [10/10], Training Loss: 1.628, Validation Accuracy: 39.35%
Epoch [1/10], Training Loss: 1.639, Validation Accuracy: 41.04%
Epoch [2/10], Training Loss: 1.624, Validation Accuracy: 40.83%
Epoch [3/10], Training Loss: 1.616, Validation Accuracy: 41.26%
Epoch [4/10], Training Loss: 1.605, Validation Accuracy: 41.38%
Epoch [5/10], Training Loss: 1.595, Validation Accuracy: 42.07%
Epoch [6/10], Training Loss: 1.587, Validation Accuracy: 42.00%
Epoch [7/10], Training Loss: 1.575, Validation Accuracy: 41.79%
Epoch [8/10], Training Loss: 1.568, Validation Accuracy: 41.98%
Epoch [9/10], Training Loss: 1.553, Validation Accuracy: 42.78%
Epoch [10/10], Training Loss: 1.544, Validation Accuracy: 42.13%
Epoch [1/10], Training Loss: 1.573, Validation Accuracy: 43.33%
Epoch [2/10], Training Loss: 1.556, Validation Accuracy: 43.97%
Epoch [3/10], Training Loss: 1.545, Validation Accuracy: 43.27%
Epoch [4/10], Training Loss: 1.537, Validation Accuracy: 44.94%
Epoch [5/10], Training Loss: 1.526, Validation Accuracy: 44.13%
Epoch [6/10], Training Loss: 1.516, Validation Accuracy: 45.28%
Epoch [7/10], Training Loss: 1.503, Validation Accuracy: 45.46%
Epoch [8/10], Training Loss: 1.498, Validation Accuracy: 45.07%
Epoch [9/10], Training Loss: 1.484, Validation Accuracy: 44.97%
Epoch [10/10], Training Loss: 1.482, Validation Accuracy: 45.27%
Epoch [1/10], Training Loss: 1.508, Validation Accuracy: 46.60%
Epoch [2/10], Training Loss: 1.486, Validation Accuracy: 46.94%
Epoch [3/10], Training Loss: 1.477, Validation Accuracy: 46.27%
Epoch [4/10], Training Loss: 1.474, Validation Accuracy: 47.17%
Epoch [5/10], Training Loss: 1.464, Validation Accuracy: 47.20%
Epoch [6/10], Training Loss: 1.451, Validation Accuracy: 47.48%
Epoch [7/10], Training Loss: 1.440, Validation Accuracy: 47.21%
Epoch [8/10], Training Loss: 1.433, Validation Accuracy: 47.16%
Epoch [9/10], Training Loss: 1.423, Validation Accuracy: 47.92%
Epoch [10/10], Training Loss: 1.413, Validation Accuracy: 47.82%
Epoch [1/10], Training Loss: 1.447, Validation Accuracy: 48.27%
Epoch [2/10], Training Loss: 1.426, Validation Accuracy: 48.11%
Epoch [3/10], Training Loss: 1.416, Validation Accuracy: 48.06%
Epoch [4/10], Training Loss: 1.405, Validation Accuracy: 47.73%
Epoch [5/10], Training Loss: 1.402, Validation Accuracy: 48.28%
Epoch [6/10], Training Loss: 1.387, Validation Accuracy: 48.30%
Epoch [7/10], Training Loss: 1.382, Validation Accuracy: 49.37%
Epoch [8/10], Training Loss: 1.372, Validation Accuracy: 48.91%
Epoch [9/10], Training Loss: 1.364, Validation Accuracy: 49.06%
Epoch [10/10], Training Loss: 1.356, Validation Accuracy: 49.25%
Epoch [1/10], Training Loss: 1.413, Validation Accuracy: 48.93%
Epoch [2/10], Training Loss: 1.392, Validation Accuracy: 49.74%
Epoch [3/10], Training Loss: 1.383, Validation Accuracy: 49.70%
Epoch [4/10], Training Loss: 1.373, Validation Accuracy: 49.69%
Epoch [5/10], Training Loss: 1.363, Validation Accuracy: 48.81%
Epoch [6/10], Training Loss: 1.352, Validation Accuracy: 49.43%
Epoch [7/10], Training Loss: 1.345, Validation Accuracy: 48.99%
Epoch [8/10], Training Loss: 1.340, Validation Accuracy: 49.84%
Epoch [9/10], Training Loss: 1.326, Validation Accuracy: 50.29%
Epoch [10/10], Training Loss: 1.327, Validation Accuracy: 49.53%
Epoch [1/10], Training Loss: 1.380, Validation Accuracy: 50.18%
Epoch [2/10], Training Loss: 1.356, Validation Accuracy: 50.52%
Epoch [3/10], Training Loss: 1.344, Validation Accuracy: 50.58%
Epoch [4/10], Training Loss: 1.330, Validation Accuracy: 50.80%
Epoch [5/10], Training Loss: 1.325, Validation Accuracy: 51.28%
Epoch [6/10], Training Loss: 1.311, Validation Accuracy: 50.94%
Epoch [7/10], Training Loss: 1.307, Validation Accuracy: 51.00%
Epoch [8/10], Training Loss: 1.293, Validation Accuracy: 50.60%
Epoch [9/10], Training Loss: 1.295, Validation Accuracy: 51.47%
Epoch [10/10], Training Loss: 1.280, Validation Accuracy: 51.58%
Epoch [1/10], Training Loss: 1.338, Validation Accuracy: 52.23%
Epoch [2/10], Training Loss: 1.324, Validation Accuracy: 51.29%
Epoch [3/10], Training Loss: 1.313, Validation Accuracy: 52.52%
Epoch [4/10], Training Loss: 1.304, Validation Accuracy: 52.00%
Epoch [5/10], Training Loss: 1.291, Validation Accuracy: 52.04%
Epoch [6/10], Training Loss: 1.280, Validation Accuracy: 52.78%
Epoch [7/10], Training Loss: 1.265, Validation Accuracy: 52.07%
Epoch [8/10], Training Loss: 1.269, Validation Accuracy: 52.58%
Epoch [9/10], Training Loss: 1.262, Validation Accuracy: 52.15%
Epoch [10/10], Training Loss: 1.246, Validation Accuracy: 52.18%
Epoch [1/10], Training Loss: 1.305, Validation Accuracy: 53.37%
Epoch [2/10], Training Loss: 1.285, Validation Accuracy: 53.04%
Epoch [3/10], Training Loss: 1.273, Validation Accuracy: 53.16%
Epoch [4/10], Training Loss: 1.259, Validation Accuracy: 53.66%
Epoch [5/10], Training Loss: 1.252, Validation Accuracy: 53.26%
Epoch [6/10], Training Loss: 1.239, Validation Accuracy: 53.11%
Epoch [7/10], Training Loss: 1.230, Validation Accuracy: 53.84%
Epoch [8/10], Training Loss: 1.228, Validation Accuracy: 53.29%
Epoch [9/10], Training Loss: 1.216, Validation Accuracy: 53.92%
Epoch [10/10], Training Loss: 1.204, Validation Accuracy: 53.05%
Epoch [1/10], Training Loss: 1.279, Validation Accuracy: 53.26%
Epoch [2/10], Training Loss: 1.256, Validation Accuracy: 54.23%
Epoch [3/10], Training Loss: 1.231, Validation Accuracy: 53.36%
Epoch [4/10], Training Loss: 1.220, Validation Accuracy: 54.48%
Epoch [5/10], Training Loss: 1.210, Validation Accuracy: 53.55%
Epoch [6/10], Training Loss: 1.200, Validation Accuracy: 54.20%
Epoch [7/10], Training Loss: 1.190, Validation Accuracy: 54.21%
Epoch [8/10], Training Loss: 1.194, Validation Accuracy: 54.25%
Epoch [9/10], Training Loss: 1.171, Validation Accuracy: 54.02%
Epoch [10/10], Training Loss: 1.159, Validation Accuracy: 54.46%
Epoch [1/10], Training Loss: 1.253, Validation Accuracy: 54.78%
Epoch [2/10], Training Loss: 1.230, Validation Accuracy: 54.96%
Epoch [3/10], Training Loss: 1.209, Validation Accuracy: 54.55%
Epoch [4/10], Training Loss: 1.203, Validation Accuracy: 53.36%
Epoch [5/10], Training Loss: 1.195, Validation Accuracy: 54.76%
Epoch [6/10], Training Loss: 1.170, Validation Accuracy: 55.04%
Epoch [7/10], Training Loss: 1.158, Validation Accuracy: 54.21%
Epoch [8/10], Training Loss: 1.162, Validation Accuracy: 54.30%
Epoch [9/10], Training Loss: 1.145, Validation Accuracy: 54.42%
Epoch [10/10], Training Loss: 1.135, Validation Accuracy: 53.80%
Epoch [1/10], Training Loss: 1.227, Validation Accuracy: 55.64%
Epoch [2/10], Training Loss: 1.191, Validation Accuracy: 55.67%
Epoch [3/10], Training Loss: 1.182, Validation Accuracy: 56.61%
Epoch [4/10], Training Loss: 1.170, Validation Accuracy: 56.60%
Epoch [5/10], Training Loss: 1.154, Validation Accuracy: 56.44%
Epoch [6/10], Training Loss: 1.141, Validation Accuracy: 56.56%
Epoch [7/10], Training Loss: 1.124, Validation Accuracy: 55.45%
Epoch [8/10], Training Loss: 1.121, Validation Accuracy: 56.26%
Epoch [9/10], Training Loss: 1.106, Validation Accuracy: 56.11%
Epoch [10/10], Training Loss: 1.107, Validation Accuracy: 55.40%
Epoch [1/10], Training Loss: 1.208, Validation Accuracy: 56.75%
Epoch [2/10], Training Loss: 1.176, Validation Accuracy: 55.52%
Epoch [3/10], Training Loss: 1.169, Validation Accuracy: 56.40%
Epoch [4/10], Training Loss: 1.149, Validation Accuracy: 56.51%
Epoch [5/10], Training Loss: 1.129, Validation Accuracy: 56.61%
Epoch [6/10], Training Loss: 1.129, Validation Accuracy: 55.53%
Epoch [7/10], Training Loss: 1.117, Validation Accuracy: 56.29%
Epoch [8/10], Training Loss: 1.104, Validation Accuracy: 56.82%
Epoch [9/10], Training Loss: 1.093, Validation Accuracy: 56.93%
Epoch [10/10], Training Loss: 1.083, Validation Accuracy: 57.18%
Epoch [1/10], Training Loss: 1.170, Validation Accuracy: 57.39%
Epoch [2/10], Training Loss: 1.139, Validation Accuracy: 57.96%
Epoch [3/10], Training Loss: 1.128, Validation Accuracy: 57.39%
Epoch [4/10], Training Loss: 1.114, Validation Accuracy: 57.70%
Epoch [5/10], Training Loss: 1.099, Validation Accuracy: 57.86%
Epoch [6/10], Training Loss: 1.089, Validation Accuracy: 58.02%
Epoch [7/10], Training Loss: 1.074, Validation Accuracy: 57.12%
Epoch [8/10], Training Loss: 1.070, Validation Accuracy: 56.72%
Epoch [9/10], Training Loss: 1.061, Validation Accuracy: 56.92%
Epoch [10/10], Training Loss: 1.045, Validation Accuracy: 57.34%
Epoch [1/10], Training Loss: 1.149, Validation Accuracy: 56.76%
Epoch [2/10], Training Loss: 1.116, Validation Accuracy: 58.08%
Epoch [3/10], Training Loss: 1.110, Validation Accuracy: 57.29%
Epoch [4/10], Training Loss: 1.095, Validation Accuracy: 56.76%
Epoch [5/10], Training Loss: 1.073, Validation Accuracy: 58.35%
Epoch [6/10], Training Loss: 1.064, Validation Accuracy: 58.15%
Epoch [7/10], Training Loss: 1.049, Validation Accuracy: 57.56%
Epoch [8/10], Training Loss: 1.036, Validation Accuracy: 56.66%
Epoch [9/10], Training Loss: 1.023, Validation Accuracy: 58.04%
Epoch [10/10], Training Loss: 1.008, Validation Accuracy: 57.81%
Epoch [1/10], Training Loss: 1.136, Validation Accuracy: 58.21%
Epoch [2/10], Training Loss: 1.103, Validation Accuracy: 58.41%
Epoch [3/10], Training Loss: 1.083, Validation Accuracy: 58.51%
Epoch [4/10], Training Loss: 1.065, Validation Accuracy: 57.82%
Epoch [5/10], Training Loss: 1.049, Validation Accuracy: 58.73%
Epoch [6/10], Training Loss: 1.036, Validation Accuracy: 59.10%
Epoch [7/10], Training Loss: 1.017, Validation Accuracy: 58.10%
Epoch [8/10], Training Loss: 1.014, Validation Accuracy: 58.57%
Epoch [9/10], Training Loss: 0.998, Validation Accuracy: 58.00%
Epoch [10/10], Training Loss: 0.994, Validation Accuracy: 58.37%
Epoch [1/10], Training Loss: 1.116, Validation Accuracy: 58.84%
Epoch [2/10], Training Loss: 1.087, Validation Accuracy: 59.51%
Epoch [3/10], Training Loss: 1.061, Validation Accuracy: 59.44%
Epoch [4/10], Training Loss: 1.047, Validation Accuracy: 59.23%
Epoch [5/10], Training Loss: 1.028, Validation Accuracy: 59.30%
Epoch [6/10], Training Loss: 1.011, Validation Accuracy: 59.28%
Epoch [7/10], Training Loss: 1.008, Validation Accuracy: 59.24%
Epoch [8/10], Training Loss: 1.000, Validation Accuracy: 58.72%
Epoch [9/10], Training Loss: 0.981, Validation Accuracy: 58.46%
Epoch [10/10], Training Loss: 0.972, Validation Accuracy: 59.08%
Epoch [1/10], Training Loss: 1.104, Validation Accuracy: 59.30%
Epoch [2/10], Training Loss: 1.073, Validation Accuracy: 59.92%
Epoch [3/10], Training Loss: 1.057, Validation Accuracy: 59.25%
Epoch [4/10], Training Loss: 1.035, Validation Accuracy: 59.12%
Epoch [5/10], Training Loss: 1.020, Validation Accuracy: 59.53%
Epoch [6/10], Training Loss: 1.007, Validation Accuracy: 59.65%
Epoch [7/10], Training Loss: 0.991, Validation Accuracy: 59.22%
Epoch [8/10], Training Loss: 0.980, Validation Accuracy: 59.36%
Epoch [9/10], Training Loss: 0.965, Validation Accuracy: 59.32%
Epoch [10/10], Training Loss: 0.955, Validation Accuracy: 59.73%
Epoch [1/10], Training Loss: 1.067, Validation Accuracy: 60.06%
Epoch [2/10], Training Loss: 1.038, Validation Accuracy: 60.30%
Epoch [3/10], Training Loss: 1.016, Validation Accuracy: 59.78%
Epoch [4/10], Training Loss: 1.000, Validation Accuracy: 59.94%
Epoch [5/10], Training Loss: 0.983, Validation Accuracy: 59.53%
Epoch [6/10], Training Loss: 0.972, Validation Accuracy: 59.63%
Epoch [7/10], Training Loss: 0.959, Validation Accuracy: 60.03%
Epoch [8/10], Training Loss: 0.949, Validation Accuracy: 59.79%
Epoch [9/10], Training Loss: 0.932, Validation Accuracy: 59.75%
Epoch [10/10], Training Loss: 0.922, Validation Accuracy: 60.07%
Epoch [1/10], Training Loss: 1.068, Validation Accuracy: 58.22%
Epoch [2/10], Training Loss: 1.029, Validation Accuracy: 59.66%
Epoch [3/10], Training Loss: 1.000, Validation Accuracy: 60.26%
Epoch [4/10], Training Loss: 0.980, Validation Accuracy: 60.31%
Epoch [5/10], Training Loss: 0.961, Validation Accuracy: 60.03%
Epoch [6/10], Training Loss: 0.948, Validation Accuracy: 58.87%
Epoch [7/10], Training Loss: 0.939, Validation Accuracy: 59.67%
Epoch [8/10], Training Loss: 0.920, Validation Accuracy: 59.86%
Epoch [9/10], Training Loss: 0.914, Validation Accuracy: 60.17%
Epoch [10/10], Training Loss: 0.895, Validation Accuracy: 60.09%
Epoch [1/10], Training Loss: 1.046, Validation Accuracy: 60.58%
Epoch [2/10], Training Loss: 1.009, Validation Accuracy: 58.70%
Epoch [3/10], Training Loss: 0.991, Validation Accuracy: 60.30%
Epoch [4/10], Training Loss: 0.960, Validation Accuracy: 60.36%
Epoch [5/10], Training Loss: 0.951, Validation Accuracy: 60.19%
Epoch [6/10], Training Loss: 0.939, Validation Accuracy: 60.19%
Epoch [7/10], Training Loss: 0.920, Validation Accuracy: 60.44%
Epoch [8/10], Training Loss: 0.904, Validation Accuracy: 60.63%
Epoch [9/10], Training Loss: 0.894, Validation Accuracy: 60.86%
Epoch [10/10], Training Loss: 0.890, Validation Accuracy: 60.43%
Epoch [1/10], Training Loss: 1.042, Validation Accuracy: 61.05%
Epoch [2/10], Training Loss: 0.999, Validation Accuracy: 60.01%
Epoch [3/10], Training Loss: 0.969, Validation Accuracy: 60.66%
Epoch [4/10], Training Loss: 0.951, Validation Accuracy: 61.15%
Epoch [5/10], Training Loss: 0.935, Validation Accuracy: 60.88%
Epoch [6/10], Training Loss: 0.925, Validation Accuracy: 60.61%
Epoch [7/10], Training Loss: 0.898, Validation Accuracy: 60.50%
Epoch [8/10], Training Loss: 0.894, Validation Accuracy: 60.82%
Epoch [9/10], Training Loss: 0.867, Validation Accuracy: 60.23%
Epoch [10/10], Training Loss: 0.868, Validation Accuracy: 60.84%
Epoch [1/10], Training Loss: 1.022, Validation Accuracy: 61.02%
Epoch [2/10], Training Loss: 0.986, Validation Accuracy: 61.76%
Epoch [3/10], Training Loss: 0.957, Validation Accuracy: 60.99%
Epoch [4/10], Training Loss: 0.932, Validation Accuracy: 61.56%
Epoch [5/10], Training Loss: 0.917, Validation Accuracy: 61.18%
Epoch [6/10], Training Loss: 0.903, Validation Accuracy: 61.07%
Epoch [7/10], Training Loss: 0.886, Validation Accuracy: 60.96%
Epoch [8/10], Training Loss: 0.866, Validation Accuracy: 61.23%
Epoch [9/10], Training Loss: 0.848, Validation Accuracy: 61.04%
Epoch [10/10], Training Loss: 0.842, Validation Accuracy: 60.94%
Epoch [1/10], Training Loss: 1.007, Validation Accuracy: 61.79%
Epoch [2/10], Training Loss: 0.965, Validation Accuracy: 61.94%
Epoch [3/10], Training Loss: 0.934, Validation Accuracy: 61.75%
Epoch [4/10], Training Loss: 0.905, Validation Accuracy: 61.20%
Epoch [5/10], Training Loss: 0.893, Validation Accuracy: 61.19%
Epoch [6/10], Training Loss: 0.876, Validation Accuracy: 61.53%
Epoch [7/10], Training Loss: 0.861, Validation Accuracy: 61.08%
Epoch [8/10], Training Loss: 0.849, Validation Accuracy: 61.62%
Epoch [9/10], Training Loss: 0.833, Validation Accuracy: 61.17%
Epoch [10/10], Training Loss: 0.817, Validation Accuracy: 61.46%
Epoch [1/10], Training Loss: 0.984, Validation Accuracy: 61.31%
Epoch [2/10], Training Loss: 0.937, Validation Accuracy: 61.33%
Epoch [3/10], Training Loss: 0.913, Validation Accuracy: 61.52%
Epoch [4/10], Training Loss: 0.887, Validation Accuracy: 61.49%
Epoch [5/10], Training Loss: 0.871, Validation Accuracy: 61.80%
Epoch [6/10], Training Loss: 0.855, Validation Accuracy: 61.20%
Epoch [7/10], Training Loss: 0.837, Validation Accuracy: 61.64%
Epoch [8/10], Training Loss: 0.820, Validation Accuracy: 61.08%
Epoch [9/10], Training Loss: 0.809, Validation Accuracy: 61.51%
Epoch [10/10], Training Loss: 0.796, Validation Accuracy: 61.53%
Epoch [1/10], Training Loss: 0.992, Validation Accuracy: 61.19%
Epoch [2/10], Training Loss: 0.937, Validation Accuracy: 60.96%
Epoch [3/10], Training Loss: 0.900, Validation Accuracy: 61.89%
Epoch [4/10], Training Loss: 0.873, Validation Accuracy: 61.31%
Epoch [5/10], Training Loss: 0.848, Validation Accuracy: 61.80%
Epoch [6/10], Training Loss: 0.843, Validation Accuracy: 60.63%
Epoch [7/10], Training Loss: 0.825, Validation Accuracy: 61.86%
Epoch [8/10], Training Loss: 0.803, Validation Accuracy: 61.07%
Epoch [9/10], Training Loss: 0.789, Validation Accuracy: 61.63%
Epoch [10/10], Training Loss: 0.782, Validation Accuracy: 61.71%
Epoch [1/10], Training Loss: 0.976, Validation Accuracy: 61.88%
Epoch [2/10], Training Loss: 0.915, Validation Accuracy: 62.13%
Epoch [3/10], Training Loss: 0.891, Validation Accuracy: 62.07%
Epoch [4/10], Training Loss: 0.868, Validation Accuracy: 61.67%
Epoch [5/10], Training Loss: 0.844, Validation Accuracy: 61.94%
Epoch [6/10], Training Loss: 0.830, Validation Accuracy: 61.81%
Epoch [7/10], Training Loss: 0.817, Validation Accuracy: 61.68%
Epoch [8/10], Training Loss: 0.796, Validation Accuracy: 61.41%
Epoch [9/10], Training Loss: 0.784, Validation Accuracy: 61.92%
Epoch [10/10], Training Loss: 0.784, Validation Accuracy: 60.92%
Epoch [1/10], Training Loss: 0.956, Validation Accuracy: 62.38%
Epoch [2/10], Training Loss: 0.898, Validation Accuracy: 62.08%
Epoch [3/10], Training Loss: 0.871, Validation Accuracy: 61.94%
Epoch [4/10], Training Loss: 0.848, Validation Accuracy: 61.75%
Epoch [5/10], Training Loss: 0.824, Validation Accuracy: 62.29%
Epoch [6/10], Training Loss: 0.797, Validation Accuracy: 62.54%
Epoch [7/10], Training Loss: 0.782, Validation Accuracy: 62.46%
Epoch [8/10], Training Loss: 0.768, Validation Accuracy: 62.27%
Epoch [9/10], Training Loss: 0.743, Validation Accuracy: 62.26%
Epoch [10/10], Training Loss: 0.730, Validation Accuracy: 61.57%
Epoch [1/10], Training Loss: 0.948, Validation Accuracy: 62.37%
Epoch [2/10], Training Loss: 0.887, Validation Accuracy: 61.94%
Epoch [3/10], Training Loss: 0.850, Validation Accuracy: 61.35%
Epoch [4/10], Training Loss: 0.833, Validation Accuracy: 61.88%
Epoch [5/10], Training Loss: 0.809, Validation Accuracy: 62.12%
Epoch [6/10], Training Loss: 0.788, Validation Accuracy: 62.57%
Epoch [7/10], Training Loss: 0.772, Validation Accuracy: 62.52%
Epoch [8/10], Training Loss: 0.754, Validation Accuracy: 61.99%
Epoch [9/10], Training Loss: 0.743, Validation Accuracy: 62.04%
Epoch [10/10], Training Loss: 0.724, Validation Accuracy: 61.86%
Epoch [1/10], Training Loss: 0.928, Validation Accuracy: 61.95%
Epoch [2/10], Training Loss: 0.873, Validation Accuracy: 62.03%
Epoch [3/10], Training Loss: 0.833, Validation Accuracy: 62.37%
Epoch [4/10], Training Loss: 0.813, Validation Accuracy: 62.62%
Epoch [5/10], Training Loss: 0.783, Validation Accuracy: 61.91%
Epoch [6/10], Training Loss: 0.768, Validation Accuracy: 61.50%
Epoch [7/10], Training Loss: 0.746, Validation Accuracy: 62.77%
Epoch [8/10], Training Loss: 0.731, Validation Accuracy: 62.27%
Epoch [9/10], Training Loss: 0.706, Validation Accuracy: 62.72%
Epoch [10/10], Training Loss: 0.704, Validation Accuracy: 61.55%
Epoch [1/10], Training Loss: 0.926, Validation Accuracy: 62.24%
Epoch [2/10], Training Loss: 0.859, Validation Accuracy: 61.45%
Epoch [3/10], Training Loss: 0.816, Validation Accuracy: 62.71%
Epoch [4/10], Training Loss: 0.792, Validation Accuracy: 62.24%
Epoch [5/10], Training Loss: 0.772, Validation Accuracy: 62.08%
Epoch [6/10], Training Loss: 0.748, Validation Accuracy: 62.62%
Epoch [7/10], Training Loss: 0.732, Validation Accuracy: 61.65%
Epoch [8/10], Training Loss: 0.720, Validation Accuracy: 62.03%
Epoch [9/10], Training Loss: 0.700, Validation Accuracy: 61.96%
Epoch [10/10], Training Loss: 0.680, Validation Accuracy: 62.29%
Epoch [1/10], Training Loss: 0.910, Validation Accuracy: 62.44%
Epoch [2/10], Training Loss: 0.853, Validation Accuracy: 61.82%
Epoch [3/10], Training Loss: 0.820, Validation Accuracy: 62.24%
Epoch [4/10], Training Loss: 0.782, Validation Accuracy: 62.53%
Epoch [5/10], Training Loss: 0.761, Validation Accuracy: 62.82%
Epoch [6/10], Training Loss: 0.749, Validation Accuracy: 62.40%
Epoch [7/10], Training Loss: 0.729, Validation Accuracy: 62.65%
Epoch [8/10], Training Loss: 0.707, Validation Accuracy: 62.41%
Epoch [9/10], Training Loss: 0.691, Validation Accuracy: 62.49%
Epoch [10/10], Training Loss: 0.667, Validation Accuracy: 62.99%
Epoch [1/10], Training Loss: 0.905, Validation Accuracy: 63.02%
Epoch [2/10], Training Loss: 0.835, Validation Accuracy: 62.59%
Epoch [3/10], Training Loss: 0.797, Validation Accuracy: 62.71%
Epoch [4/10], Training Loss: 0.764, Validation Accuracy: 63.44%
Epoch [5/10], Training Loss: 0.728, Validation Accuracy: 62.98%
Epoch [6/10], Training Loss: 0.708, Validation Accuracy: 62.90%
Epoch [7/10], Training Loss: 0.687, Validation Accuracy: 63.08%
Epoch [8/10], Training Loss: 0.663, Validation Accuracy: 62.41%
Epoch [9/10], Training Loss: 0.654, Validation Accuracy: 62.70%
Epoch [10/10], Training Loss: 0.639, Validation Accuracy: 61.65%
Epoch [1/10], Training Loss: 0.894, Validation Accuracy: 62.00%
Epoch [2/10], Training Loss: 0.816, Validation Accuracy: 62.58%
Epoch [3/10], Training Loss: 0.775, Validation Accuracy: 62.46%
Epoch [4/10], Training Loss: 0.750, Validation Accuracy: 62.39%
Epoch [5/10], Training Loss: 0.724, Validation Accuracy: 62.35%
Epoch [6/10], Training Loss: 0.699, Validation Accuracy: 62.62%
Epoch [7/10], Training Loss: 0.677, Validation Accuracy: 62.63%
Epoch [8/10], Training Loss: 0.661, Validation Accuracy: 63.19%
Epoch [9/10], Training Loss: 0.631, Validation Accuracy: 62.78%
Epoch [10/10], Training Loss: 0.621, Validation Accuracy: 62.04%
Epoch [1/10], Training Loss: 0.887, Validation Accuracy: 61.19%
Epoch [2/10], Training Loss: 0.812, Validation Accuracy: 62.81%
Epoch [3/10], Training Loss: 0.763, Validation Accuracy: 62.67%
Epoch [4/10], Training Loss: 0.726, Validation Accuracy: 62.79%
Epoch [5/10], Training Loss: 0.701, Validation Accuracy: 62.49%
Epoch [6/10], Training Loss: 0.679, Validation Accuracy: 62.94%
Epoch [7/10], Training Loss: 0.664, Validation Accuracy: 62.55%
Epoch [8/10], Training Loss: 0.629, Validation Accuracy: 62.49%
Epoch [9/10], Training Loss: 0.619, Validation Accuracy: 62.80%
Epoch [10/10], Training Loss: 0.607, Validation Accuracy: 62.28%
Epoch [1/10], Training Loss: 0.880, Validation Accuracy: 61.54%
Epoch [2/10], Training Loss: 0.802, Validation Accuracy: 62.33%
Epoch [3/10], Training Loss: 0.746, Validation Accuracy: 62.69%
Epoch [4/10], Training Loss: 0.718, Validation Accuracy: 62.61%
Epoch [5/10], Training Loss: 0.693, Validation Accuracy: 62.84%
Epoch [6/10], Training Loss: 0.664, Validation Accuracy: 61.88%
Epoch [7/10], Training Loss: 0.641, Validation Accuracy: 62.23%
Epoch [8/10], Training Loss: 0.616, Validation Accuracy: 62.87%
Epoch [9/10], Training Loss: 0.611, Validation Accuracy: 62.40%
Epoch [10/10], Training Loss: 0.601, Validation Accuracy: 61.45%
Epoch [1/10], Training Loss: 0.874, Validation Accuracy: 62.34%
Epoch [2/10], Training Loss: 0.791, Validation Accuracy: 62.28%
Epoch [3/10], Training Loss: 0.745, Validation Accuracy: 63.27%
Epoch [4/10], Training Loss: 0.713, Validation Accuracy: 63.31%
Epoch [5/10], Training Loss: 0.684, Validation Accuracy: 62.79%
Epoch [6/10], Training Loss: 0.655, Validation Accuracy: 62.99%
Epoch [7/10], Training Loss: 0.636, Validation Accuracy: 62.70%
Epoch [8/10], Training Loss: 0.613, Validation Accuracy: 61.98%
Epoch [9/10], Training Loss: 0.603, Validation Accuracy: 62.41%
Epoch [10/10], Training Loss: 0.582, Validation Accuracy: 62.69%
Epoch [1/10], Training Loss: 0.863, Validation Accuracy: 63.13%
Epoch [2/10], Training Loss: 0.765, Validation Accuracy: 63.02%
Epoch [3/10], Training Loss: 0.708, Validation Accuracy: 63.64%
Epoch [4/10], Training Loss: 0.684, Validation Accuracy: 62.23%
Epoch [5/10], Training Loss: 0.656, Validation Accuracy: 62.80%
Epoch [6/10], Training Loss: 0.624, Validation Accuracy: 63.19%
Epoch [7/10], Training Loss: 0.608, Validation Accuracy: 62.38%
Epoch [8/10], Training Loss: 0.576, Validation Accuracy: 63.21%
Epoch [9/10], Training Loss: 0.561, Validation Accuracy: 62.47%
Epoch [10/10], Training Loss: 0.540, Validation Accuracy: 63.02%
Epoch [1/10], Training Loss: 0.843, Validation Accuracy: 62.64%
Epoch [2/10], Training Loss: 0.754, Validation Accuracy: 63.33%
Epoch [3/10], Training Loss: 0.697, Validation Accuracy: 62.39%
Epoch [4/10], Training Loss: 0.668, Validation Accuracy: 63.34%
Epoch [5/10], Training Loss: 0.637, Validation Accuracy: 62.04%
Epoch [6/10], Training Loss: 0.607, Validation Accuracy: 63.05%
Epoch [7/10], Training Loss: 0.589, Validation Accuracy: 62.75%
Epoch [8/10], Training Loss: 0.567, Validation Accuracy: 62.76%
Epoch [9/10], Training Loss: 0.553, Validation Accuracy: 62.30%
Epoch [10/10], Training Loss: 0.541, Validation Accuracy: 62.54%
Epoch [1/10], Training Loss: 0.852, Validation Accuracy: 61.45%
Epoch [2/10], Training Loss: 0.753, Validation Accuracy: 62.70%
Epoch [3/10], Training Loss: 0.690, Validation Accuracy: 62.71%
Epoch [4/10], Training Loss: 0.658, Validation Accuracy: 62.89%
Epoch [5/10], Training Loss: 0.621, Validation Accuracy: 62.91%
Epoch [6/10], Training Loss: 0.593, Validation Accuracy: 62.95%
Epoch [7/10], Training Loss: 0.571, Validation Accuracy: 62.50%
Epoch [8/10], Training Loss: 0.549, Validation Accuracy: 62.74%
Epoch [9/10], Training Loss: 0.534, Validation Accuracy: 62.73%
Epoch [10/10], Training Loss: 0.514, Validation Accuracy: 61.72%
Epoch [1/10], Training Loss: 0.826, Validation Accuracy: 62.88%
Epoch [2/10], Training Loss: 0.722, Validation Accuracy: 61.72%
Epoch [3/10], Training Loss: 0.671, Validation Accuracy: 62.66%
Epoch [4/10], Training Loss: 0.635, Validation Accuracy: 62.83%
Epoch [5/10], Training Loss: 0.599, Validation Accuracy: 61.92%
Epoch [6/10], Training Loss: 0.581, Validation Accuracy: 62.69%
Epoch [7/10], Training Loss: 0.556, Validation Accuracy: 62.14%
Epoch [8/10], Training Loss: 0.530, Validation Accuracy: 62.60%
Epoch [9/10], Training Loss: 0.506, Validation Accuracy: 62.09%
Epoch [10/10], Training Loss: 0.495, Validation Accuracy: 62.12%
Epoch [1/10], Training Loss: 0.828, Validation Accuracy: 61.96%
Epoch [2/10], Training Loss: 0.742, Validation Accuracy: 62.92%
Epoch [3/10], Training Loss: 0.671, Validation Accuracy: 62.43%
Epoch [4/10], Training Loss: 0.639, Validation Accuracy: 63.06%
Epoch [5/10], Training Loss: 0.610, Validation Accuracy: 63.48%
Epoch [6/10], Training Loss: 0.580, Validation Accuracy: 62.87%
Epoch [7/10], Training Loss: 0.565, Validation Accuracy: 62.87%
Epoch [8/10], Training Loss: 0.538, Validation Accuracy: 63.04%
Epoch [9/10], Training Loss: 0.521, Validation Accuracy: 63.10%
Epoch [10/10], Training Loss: 0.497, Validation Accuracy: 62.88%
Epoch [1/10], Training Loss: 0.812, Validation Accuracy: 62.43%
Epoch [2/10], Training Loss: 0.717, Validation Accuracy: 63.10%
Epoch [3/10], Training Loss: 0.645, Validation Accuracy: 63.42%
Epoch [4/10], Training Loss: 0.602, Validation Accuracy: 62.68%
Epoch [5/10], Training Loss: 0.575, Validation Accuracy: 62.68%
Epoch [6/10], Training Loss: 0.549, Validation Accuracy: 61.72%
Epoch [7/10], Training Loss: 0.526, Validation Accuracy: 62.96%
Epoch [8/10], Training Loss: 0.504, Validation Accuracy: 62.72%
Epoch [9/10], Training Loss: 0.490, Validation Accuracy: 62.31%
Epoch [10/10], Training Loss: 0.462, Validation Accuracy: 63.42%
Epoch [1/10], Training Loss: 0.809, Validation Accuracy: 62.49%
Epoch [2/10], Training Loss: 0.689, Validation Accuracy: 62.23%
Epoch [3/10], Training Loss: 0.631, Validation Accuracy: 63.21%
Epoch [4/10], Training Loss: 0.597, Validation Accuracy: 63.10%
Epoch [5/10], Training Loss: 0.557, Validation Accuracy: 62.29%
Epoch [6/10], Training Loss: 0.535, Validation Accuracy: 62.66%
Epoch [7/10], Training Loss: 0.512, Validation Accuracy: 61.56%
Epoch [8/10], Training Loss: 0.488, Validation Accuracy: 62.85%
Epoch [9/10], Training Loss: 0.465, Validation Accuracy: 62.05%
Epoch [10/10], Training Loss: 0.444, Validation Accuracy: 62.21%
Epoch [1/10], Training Loss: 0.804, Validation Accuracy: 62.04%
Epoch [2/10], Training Loss: 0.688, Validation Accuracy: 61.94%
Epoch [3/10], Training Loss: 0.630, Validation Accuracy: 62.26%
Epoch [4/10], Training Loss: 0.580, Validation Accuracy: 62.46%
Epoch [5/10], Training Loss: 0.550, Validation Accuracy: 62.04%
Epoch [6/10], Training Loss: 0.533, Validation Accuracy: 62.39%
Epoch [7/10], Training Loss: 0.494, Validation Accuracy: 62.28%
Epoch [8/10], Training Loss: 0.466, Validation Accuracy: 62.51%
Epoch [9/10], Training Loss: 0.450, Validation Accuracy: 62.01%
Epoch [10/10], Training Loss: 0.437, Validation Accuracy: 62.61%
Epoch [1/10], Training Loss: 0.796, Validation Accuracy: 61.25%
Epoch [2/10], Training Loss: 0.674, Validation Accuracy: 62.05%
Epoch [3/10], Training Loss: 0.607, Validation Accuracy: 62.85%
Epoch [4/10], Training Loss: 0.562, Validation Accuracy: 62.50%
Epoch [5/10], Training Loss: 0.526, Validation Accuracy: 62.43%
Epoch [6/10], Training Loss: 0.497, Validation Accuracy: 62.46%
Epoch [7/10], Training Loss: 0.470, Validation Accuracy: 62.12%
Epoch [8/10], Training Loss: 0.453, Validation Accuracy: 62.23%
Epoch [9/10], Training Loss: 0.434, Validation Accuracy: 61.99%
Epoch [10/10], Training Loss: 0.411, Validation Accuracy: 62.12%
"""

# Regular expression to find validation accuracies
accuracies = re.findall(r'Validation Accuracy: (\d+\.\d+)%', log)

# Convert accuracies from string to float
accuracies = [float(acc) for acc in accuracies]

# Print accuracies
print("Accuracies:", accuracies)

# Print size of the array
print("Size of array:", len(accuracies))


Accuracies: [9.65, 9.72, 9.99, 10.53, 11.09, 12.15, 12.91, 14.11, 15.58, 16.29, 16.97, 19.59, 20.96, 20.48, 20.64, 21.06, 22.06, 23.46, 24.85, 27.06, 27.46, 28.43, 28.91, 29.96, 30.19, 31.15, 31.49, 31.84, 33.01, 33.31, 32.68, 34.1, 34.41, 33.65, 34.98, 35.97, 36.18, 35.61, 36.83, 36.36, 37.58, 38.09, 38.32, 38.69, 38.91, 39.31, 39.59, 39.36, 40.28, 39.35, 41.04, 40.83, 41.26, 41.38, 42.07, 42.0, 41.79, 41.98, 42.78, 42.13, 43.33, 43.97, 43.27, 44.94, 44.13, 45.28, 45.46, 45.07, 44.97, 45.27, 46.6, 46.94, 46.27, 47.17, 47.2, 47.48, 47.21, 47.16, 47.92, 47.82, 48.27, 48.11, 48.06, 47.73, 48.28, 48.3, 49.37, 48.91, 49.06, 49.25, 48.93, 49.74, 49.7, 49.69, 48.81, 49.43, 48.99, 49.84, 50.29, 49.53, 50.18, 50.52, 50.58, 50.8, 51.28, 50.94, 51.0, 50.6, 51.47, 51.58, 52.23, 51.29, 52.52, 52.0, 52.04, 52.78, 52.07, 52.58, 52.15, 52.18, 53.37, 53.04, 53.16, 53.66, 53.26, 53.11, 53.84, 53.29, 53.92, 53.05, 53.26, 54.23, 53.36, 54.48, 53.55, 54.2, 54.21, 54.25, 54.02, 54.46, 54.78, 54.96, 54.55, 

In [None]:
import numpy as np
import torch
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision import transforms
from torch import nn
from typing import Dict
from scipy.stats import truncnorm

In [None]:

# Define VAE model
class VAE(nn.Module):
    def __init__(self, x_dim, h_dim, z_dim):
        super(VAE, self).__init__()
        self.x_dim = x_dim
        self.h_dim = h_dim
        self.z_dim = z_dim

        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, 4, stride=2, padding=1),  # 16x16
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2, padding=1),  # 8x8
            nn.ReLU(),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),  # 4x4
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(2048, 256),
            nn.ReLU(),
            nn.Linear(256, 2*z_dim)
        )

        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(z_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 2048),
            nn.ReLU(),
            nn.Unflatten(1, (128, 4, 4)),
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),  # 8x8
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1),  # 16x16
            nn.ReLU(),
            nn.ConvTranspose2d(32, 3, 4, stride=2, padding=1),  # 32x32
            nn.Sigmoid()
        )

    def encode(self, x):
        h = self.encoder(x)
        mu, logvar = h[:, :self.z_dim], h[:, self.z_dim:]
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        return self.decoder(z)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar



In [None]:
# Define VAE training procedure
def vae_train(vae: VAE, trainloader: DataLoader, epochs: int) -> None:
    optimizer = torch.optim.Adam(vae.parameters(), lr=1e-3)
    for epoch in range(epochs):
        for i, data in enumerate(trainloader, 0):
            inputs, _ = data
            optimizer.zero_grad()
            recon_x, mu, logvar = vae(inputs)
            loss = vae_loss(recon_x, inputs, mu, logvar)
            loss.backward()
            optimizer.step()

In [None]:
# Define classification model
class Net(nn.Module):
    def __init__(self) -> None:
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.pool(nn.functional.relu(self.conv1(x)))
        x = self.pool(nn.functional.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = nn.functional.relu(self.fc1(x))
        x = nn.functional.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [None]:
# Load CIFAR10 dataset
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)

full_dataset = CIFAR10(root="./data", train=True, download=True, transform=transform)
test_set = CIFAR10(root="./data", train=False, download=True, transform=transform)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:02<00:00, 77306790.11it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [None]:
# Split dataset into training and validation sets
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_set, val_set = torch.utils.data.random_split(full_dataset, [train_size, val_size])


In [None]:
# Create training and validation loaders
trainloader = DataLoader(train_set, batch_size=128, shuffle=True, num_workers=2)
valloader = DataLoader(val_set, batch_size=128, shuffle=False, num_workers=2)
testloader = DataLoader(test_set, batch_size=128, shuffle=False, num_workers=2)


In [None]:
# Define VAE loss function
def vae_loss(recon_x, x, mu, logvar):
    BCE = nn.functional.mse_loss(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

In [None]:
# Define training procedure for classification model
def train(net: nn.Module, trainloader: DataLoader, valloader: DataLoader, epochs: int) -> None:
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

    best_acc = 0.0

    for epoch in range(epochs):
        # Training loop
        net.train()
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        # Validation loop
        net.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for data in valloader:
                images, labels = data
                outputs = net(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        val_acc = 100 * correct / total
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(net.state_dict(), 'best_model.pth')

        print(f"Epoch [{epoch+1}/{epochs}], Training Loss: {running_loss / (i+1):.3f}, Validation Accuracy: {val_acc:.2f}%")


In [None]:
# Define evaluation procedure
def evaluate(net: nn.Module, testloader: DataLoader) -> float:
    net.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    return accuracy

In [None]:
def initialize_clients(trainset, transform, batch_size, num_clients):
    clients = {}
    for i in range(num_clients):
        client_trainset = torch.utils.data.Subset(trainset, range(i * len(trainset) // num_clients, (i + 1) * len(trainset) // num_clients))
        client_trainloader = torch.utils.data.DataLoader(client_trainset, batch_size=batch_size, shuffle=True)
        clients[f"client_{i}"] = client_trainloader
    return clients

In [None]:
def get_distribution_info(vae: VAE) -> Dict:
    # Implement the logic to extract distribution information from the VAE
    # This can involve computing statistics, parameters, or any other relevant information
    # that can be used to generate augmented data

    # Example implementation:
    distribution_info = {
        "normal": {
            "mean": vae.encoder[-1].bias.data.cpu().numpy(),
            "std": torch.exp(0.5 * vae.encoder[-1].weight.data).cpu().numpy()
        },
        "truncated": {
            "mean": vae.encoder[-1].bias.data.cpu().numpy(),
            "std": torch.exp(0.5 * vae.encoder[-1].weight.data).cpu().numpy()
        }
    }

    return distribution_info

In [None]:
def send_distribution_info(distribution_info: Dict) -> None:
    # Implement the logic to send the distribution information to the global server
    # This can involve using a network protocol, a message queue, or any other communication mechanism
    # to send the information

    # Example implementation:
    # Send the distribution information to the global server using a network protocol
    # For example, you can use the `socket` module to send the information over a network
    # or use a message queue like `RabbitMQ` to send the information
    pass


In [None]:
def generate_augmented_data(vae: VAE, distribution_info_normal: Dict, distribution_info_truncated: Dict) -> torch.Tensor:
    # Generate augmented data using both normal and truncated normal distributions
    mean_normal = distribution_info_normal["mean"]
    std_normal = distribution_info_normal["std"]

    mean_truncated = distribution_info_truncated["mean"]
    std_truncated = distribution_info_truncated["std"]

    # Generate augmented data from normal distribution
    augmented_data_normal = torch.randn(64, vae.z_dim) * std_normal + mean_normal

    # Generate augmented data from truncated normal distribution
    a = (0 - mean_truncated) / std_truncated
    b = np.inf
    augmented_data_truncated = torch.from_numpy(truncnorm.rvs(a, b, loc=mean_truncated, scale=std_truncated, size=(64, vae.z_dim))).float()

    # Calculate the average of augmented data from both distributions
    augmented_data_average = (augmented_data_normal + augmented_data_truncated) / 2

    return augmented_data_average

In [None]:
# Define the federated training logic
def federated_train(net: nn.Module, vae: VAE, trainloaders: Dict[str, DataLoader], trainloader: DataLoader, valloader: DataLoader, epochs: int) -> None:
    for epoch in range(epochs):
        for client_id, client_trainloader in trainloaders.items():
            # Train VAE on client data
            vae_train(vae, client_trainloader, epochs=10)

            # Share distribution information with global server
            distribution_info = get_distribution_info(vae)
            send_distribution_info(distribution_info)

            # Receive distribution information from other clients
            other_distribution_info = receive_distribution_info()

            # Generate augmented data using received distribution information
            augmented_data = generate_augmented_data(vae, other_distribution_info["normal"], other_distribution_info["truncated"])

            # Train classification model using local, augmented, and validation data
            train(net, client_trainloader, valloader, epochs=10)

            # Send model updates to global server
            send_model_update(client_id, net.state_dict())

In [None]:
# Define logic to receive distribution information from global server
def receive_distribution_info() -> Dict:
    # Receive distribution information logic
    distribution_info = {
        "normal": {
            "mean": np.zeros(20),  # Adjust the size based on your latent space dimension
            "std": np.ones(20)
        },
        "truncated": {
            "mean": np.zeros(20),
            "std": np.ones(20)
        }
    }
    return distribution_info

In [None]:
def send_model_update(client_id: str, model_update: Dict) -> None:
    # Implement the logic to send the model update to the global server
    # This can involve using a network protocol, a message queue, or any other communication mechanism
    # to send the model update

    # Example implementation:
    # Send the model update to the global server using a network protocol
    # For example, you can use the `socket` module to send the information over a network
    # or use a message queue like `RabbitMQ` to send the information
    pass

In [None]:
# Define global server procedure
def global_server() -> None:
    net = Net()
    x_dim = 3 * 32 * 32  # CIFAR-10 input size
    h_dim = 400
    z_dim = 20
    vae = VAE(x_dim, h_dim, z_dim)  # Initialize VAE object with required arguments

    # Initialize clients
    num_clients = 5  # Define the number of clients
    clients = initialize_clients(train_set, transform, batch_size=128, num_clients=num_clients)

    # Train model using FedDIS
    federated_train(net, vae, clients, trainloader, valloader, epochs=10)

    # Evaluate final model
    test_accuracy = evaluate(net, testloader)
    print(f"Test Accuracy: {test_accuracy:.2f}%")

if __name__ == "__main__":
    global_server()

Epoch [1/10], Training Loss: 2.304, Validation Accuracy: 9.91%
Epoch [2/10], Training Loss: 2.303, Validation Accuracy: 9.95%
Epoch [3/10], Training Loss: 2.302, Validation Accuracy: 10.01%
Epoch [4/10], Training Loss: 2.301, Validation Accuracy: 10.32%
Epoch [5/10], Training Loss: 2.299, Validation Accuracy: 11.35%
Epoch [6/10], Training Loss: 2.298, Validation Accuracy: 13.96%
Epoch [7/10], Training Loss: 2.296, Validation Accuracy: 14.94%
Epoch [8/10], Training Loss: 2.294, Validation Accuracy: 14.61%
Epoch [9/10], Training Loss: 2.290, Validation Accuracy: 14.43%
Epoch [10/10], Training Loss: 2.286, Validation Accuracy: 14.36%
Epoch [1/10], Training Loss: 2.280, Validation Accuracy: 14.73%
Epoch [2/10], Training Loss: 2.270, Validation Accuracy: 13.47%
Epoch [3/10], Training Loss: 2.256, Validation Accuracy: 12.90%
Epoch [4/10], Training Loss: 2.240, Validation Accuracy: 13.76%
Epoch [5/10], Training Loss: 2.225, Validation Accuracy: 14.83%
Epoch [6/10], Training Loss: 2.209, Valid

In [None]:
from torchvision.transforms import Compose, ToTensor, Normalize
from flwr.datasets import emnist
# Load and preprocess the CFEMNIST dataset
transform = Compose([ToTensor(), Normalize((0.5,), (0.5,))])

# Load EMNIST in a federated way
def load_federated_emnist(num_clients: int):
    # Download Federated EMNIST dataset
    train, test = emnist.load_data()
    train_clients, train_data = train
    test_clients, test_data = test

    # Map the data to PyTorch DataLoader for each client
    train_loaders = {
        client_id: DataLoader(
            [(transform(image), label) for image, label in data],
            batch_size=32,
            shuffle=True,
        )
        for client_id, data in train_data.items()
    }

    test_loaders = {
        client_id: DataLoader(
            [(transform(image), label) for image, label in data],
            batch_size=32,
            shuffle=False,
        )
        for client_id, data in test_data.items()
    }

    return train_loaders, test_loaders

# Specify the number of clients
num_clients = 10
train_loaders, test_loaders = load_federated_emnist(num_clients)

def federated_train(global_model, train_loaders, num_clients, epochs):
    for epoch in range(epochs):
        client_models = {}
        for client_id, train_loader in train_loaders.items():
            # Clone the global model for this client
            client_model = deepcopy(global_model)
            optimizer = torch.optim.Adam(client_model.parameters(), lr=1e-3)

            # Train on client's local data
            client_model.train()
            for images, labels in train_loader:
                optimizer.zero_grad()
                outputs = client_model(images)
                loss = nn.CrossEntropyLoss()(outputs, labels)
                loss.backward()
                optimizer.step()

            # Save the updated client model
            client_models[client_id] = client_model.state_dict()

        # Aggregate client models to update the global model
        global_model.load_state_dict(aggregate(client_models))
        print(f"Epoch {epoch + 1}/{epochs} completed.")

def aggregate(client_models):
    """Aggregate model weights from all clients (simple average)."""
    global_model_state = deepcopy(list(client_models.values())[0])
    for key in global_model_state.keys():
        global_model_state[key] = torch.stack([model[key] for model in client_models.values()], dim=0).mean(dim=0)
    return global_model_state
def evaluate_federated(global_model, test_loaders):
    global_model.eval()
    overall_accuracy = 0
    total_samples = 0

    for client_id, test_loader in test_loaders.items():
        correct = 0
        total = 0
        with torch.no_grad():
            for images, labels in test_loader:
                outputs = global_model(images)
                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        accuracy = 100 * correct / total
        overall_accuracy += accuracy * total
        total_samples += total
        print(f"Client {client_id} Accuracy: {accuracy:.2f}%")

    # Weighted average accuracy across clients
    overall_accuracy /= total_samples
    print(f"Overall Accuracy: {overall_accuracy:.2f}%")
    return overall_accuracy
if __name__ == "__main__":
    # Load federated EMNIST dataset
    num_clients = 10
    train_loaders, test_loaders = load_federated_emnist(num_clients)

    # Initialize global model (e.g., VAE or classifier)
    global_model = Net()  # Replace with VAE if needed

    # Train the model using federated learning
    federated_train(global_model, train_loaders, num_clients, epochs=10)

    # Evaluate the global model
    evaluate_federated(global_model, test_loaders)



ModuleNotFoundError: No module named 'flwr'

In [None]:
from torchvision.datasets import EMNIST
from torch.utils.data import DataLoader, random_split
from torchvision.transforms import Compose, ToTensor, Normalize

# Load the EMNIST dataset
transform = Compose([ToTensor(), Normalize((0.5,), (0.5,))])
full_dataset = EMNIST(root="./data", split="balanced", train=True, download=True, transform=transform)
test_dataset = EMNIST(root="./data", split="balanced", train=False, download=True, transform=transform)

# Split dataset into federated subsets
def create_federated_datasets(dataset, num_clients):
    data_per_client = len(dataset) // num_clients
    subsets = random_split(dataset, [data_per_client] * num_clients)
    return {
        f"client_{i}": DataLoader(subsets[i], batch_size=32, shuffle=True)
        for i in range(num_clients)
    }

# Create federated training and testing datasets
num_clients = 10
train_loaders = create_federated_datasets(full_dataset, num_clients)
test_loaders = create_federated_datasets(test_dataset, num_clients)


Downloading https://biometrics.nist.gov/cs_links/EMNIST/gzip.zip to ./data/EMNIST/raw/gzip.zip


100%|██████████| 562M/562M [00:15<00:00, 35.5MB/s]


Extracting ./data/EMNIST/raw/gzip.zip to ./data/EMNIST/raw


In [None]:
import numpy as np

# Confusion matrix from your data
confusion_matrix = np.array([
    [597, 32, 77, 22, 39, 22, 17, 31, 100, 63],
    [14, 760, 19, 6, 6, 13, 18, 10, 33, 121],
    [50, 18, 482, 58, 104, 115, 85, 61, 15, 12],
    [15, 16, 77, 329, 61, 282, 108, 68, 21, 23],
    [18, 9, 76, 47, 528, 85, 95, 119, 14, 9],
    [7, 6, 52, 136, 43, 620, 33, 87, 9, 7],
    [2, 17, 40, 46, 51, 49, 740, 28, 13, 14],
    [10, 4, 31, 37, 52, 100, 13, 733, 1, 19],
    [68, 46, 20, 13, 10, 15, 13, 19, 736, 60],
    [23, 117, 8, 15, 15, 20, 19, 44, 53, 686]
])

# Function to calculate errors for each class
def calculate_errors(conf_matrix):
    class_names = ['Airplane', 'Automobile', 'Bird', 'Cat', 'Deer', 'Dog', 'Frog', 'Horse', 'Ship', 'Truck']

    errors_per_class = {}
    num_classes = conf_matrix.shape[0]

    for i in range(num_classes):
        TP = conf_matrix[i, i]  # True Positives
        FP = np.sum(conf_matrix[:, i]) - TP  # False Positives (wrongly predicted as this class)
        FN = np.sum(conf_matrix[i, :]) - TP  # False Negatives (missed this class)
        TN = np.sum(conf_matrix) - (TP + FP + FN)  # True Negatives (correctly classified all other classes)

        errors_per_class[class_names[i]] = {
            'TP': TP,
            'FP': FP,
            'FN': FN,
            'TN': TN,
            'Total Errors': FP + FN  # Total errors = False Positives + False Negatives
        }

    return errors_per_class

# Calculate errors for each class
errors = calculate_errors(confusion_matrix)

# Display the errors for each class
for class_name, error_info in errors.items():
    print(f"{class_name}:")
    print(f"  True Positives (TP): {error_info['TP']}")
    print(f"  False Positives (FP): {error_info['FP']}")
    print(f"  False Negatives (FN): {error_info['FN']}")
    print(f"  True Negatives (TN): {error_info['TN']}")
    print(f"  Total Errors (FP + FN): {error_info['Total Errors']}")
    print()


Airplane:
  True Positives (TP): 597
  False Positives (FP): 207
  False Negatives (FN): 403
  True Negatives (TN): 8793
  Total Errors (FP + FN): 610

Automobile:
  True Positives (TP): 760
  False Positives (FP): 265
  False Negatives (FN): 240
  True Negatives (TN): 8735
  Total Errors (FP + FN): 505

Bird:
  True Positives (TP): 482
  False Positives (FP): 400
  False Negatives (FN): 518
  True Negatives (TN): 8600
  Total Errors (FP + FN): 918

Cat:
  True Positives (TP): 329
  False Positives (FP): 380
  False Negatives (FN): 671
  True Negatives (TN): 8620
  Total Errors (FP + FN): 1051

Deer:
  True Positives (TP): 528
  False Positives (FP): 381
  False Negatives (FN): 472
  True Negatives (TN): 8619
  Total Errors (FP + FN): 853

Dog:
  True Positives (TP): 620
  False Positives (FP): 701
  False Negatives (FN): 380
  True Negatives (TN): 8299
  Total Errors (FP + FN): 1081

Frog:
  True Positives (TP): 740
  False Positives (FP): 401
  False Negatives (FN): 260
  True Negati

In [None]:
import numpy as np

# Your confusion matrix
confusion_matrix = np.array([[
    [597, 32, 77, 22, 39, 22, 17, 31, 100, 63],
    [14, 760, 19, 6, 6, 13, 18, 10, 33, 121],
    [50, 18, 482, 58, 104, 115, 85, 61, 15, 12],
    [15, 16, 77, 329, 61, 282, 108, 68, 21, 23],
    [18, 9, 76, 47, 528, 85, 95, 119, 14, 9],
    [7, 6, 52, 136, 43, 620, 33, 87, 9, 7],
    [2, 17, 40, 46, 51, 49, 740, 28, 13, 14],
    [10, 4, 31, 37, 52, 100, 13, 733, 1, 19],
    [68, 46, 20, 13, 10, 15, 13, 19, 736, 60],
    [23, 117, 8, 15, 15, 20, 19, 44, 53, 686]
]])

# Calculate errors for each class
errors_per_class = {}
for i in range(confusion_matrix.shape[0]):
    # False Positives (sum of column i excluding diagonal element)
    FP = np.sum(confusion_matrix[:, i]) - confusion_matrix[i, i]
    # False Negatives (sum of row i excluding diagonal element)
    FN = np.sum(confusion_matrix[i, :]) - confusion_matrix[i, i]
    # Total errors for class i
    total_errors = FP + FN
    errors_per_class[i] = total_errors

# Print errors for each class
for class_label, error_count in errors_per_class.items():
    print(f"Class {class_label}: {error_count} errors")


Class 0: [ 9806 10936 10846 10956 10922 10956 10966 10938 10800 10874] errors


In [None]:
import numpy as np
from scipy.stats import truncnorm
import torch
import torchvision
import torchvision.transforms as transforms
from torch import nn
from torch.utils.data import Subset, DataLoader
from torchvision.datasets import CIFAR10
from typing import Dict, Tuple
from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score
import random

# Define VAE model
class VAE(nn.Module):
    def __init__(self, x_dim, h_dim, z_dim):
        super(VAE, self).__init__()
        self.x_dim = x_dim
        self.h_dim = h_dim
        self.z_dim = z_dim

        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, 4, stride=2, padding=1),  # 16x16
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2, padding=1),  # 8x8
            nn.ReLU(),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),  # 4x4
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(2048, 256),
            nn.ReLU(),
            nn.Linear(256, 2*z_dim)
        )

        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(z_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 2048),
            nn.ReLU(),
            nn.Unflatten(1, (128, 4, 4)),
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),  # 8x8
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1),  # 16x16
            nn.ReLU(),
            nn.ConvTranspose2d(32, 3, 4, stride=2, padding=1),  # 32x32
            nn.Sigmoid()
        )

    def encode(self, x):
        h = self.encoder(x)
        mu, logvar = h[:, :self.z_dim], h[:, self.z_dim:]
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        return self.decoder(z)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar


# Define VAE training procedure
def vae_train(vae: VAE, trainloader: DataLoader, epochs: int) -> None:
    optimizer = torch.optim.Adam(vae.parameters(), lr=1e-3)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)

    for epoch in range(epochs):
        vae.train()
        for i, data in enumerate(trainloader, 0):
            inputs, _ = data
            optimizer.zero_grad()
            recon_x, mu, logvar = vae(inputs)
            loss = vae_loss(recon_x, inputs, mu, logvar)
            loss.backward()
            optimizer.step()

        scheduler.step()


# Define classification model
class Net(nn.Module):
    def __init__(self) -> None:
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.pool(nn.functional.relu(self.conv1(x)))
        x = self.pool(nn.functional.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = nn.functional.relu(self.fc1(x))
        x = nn.functional.relu(self.fc2(x))
        x = self.fc3(x)
        return x


# Load CIFAR10 dataset
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)

full_dataset = CIFAR10(root="./data", train=True, download=True, transform=transform)
test_set = CIFAR10(root="./data", train=False, download=True, transform=transform)

# Separate dataset by class
class_indices = {i: [] for i in range(10)}  # CIFAR-10 has 10 classes
for idx, (_, label) in enumerate(full_dataset):
    class_indices[label].append(idx)

# Define target count per class, summing to 60,000 with random distribution
class_counts = np.random.multinomial(60000, [0.1] * 10)  # Adjust probabilities if you want specific class biases
print("Random Images per Class:", class_counts)

# Sample indices based on the specified class counts
indices = []
for class_id, count in enumerate(class_counts):
    # Ensure count does not exceed available images
    count = min(count, len(class_indices[class_id]))
    selected_indices = random.sample(class_indices[class_id], count)
    indices.extend(selected_indices)

# Create a custom CIFAR-10 dataset with the sampled indices
custom_dataset = Subset(full_dataset, indices)
# Split dataset into training and validation sets
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_set, val_set = torch.utils.data.random_split(full_dataset, [train_size, val_size])

# Create training and validation loaders
trainloader = DataLoader(train_set, batch_size=128, shuffle=True, num_workers=2)
valloader = DataLoader(val_set, batch_size=128, shuffle=False, num_workers=2)
testloader = DataLoader(test_set, batch_size=128, shuffle=False, num_workers=2)



# Define VAE loss function
def vae_loss(recon_x, x, mu, logvar):
    BCE = nn.functional.mse_loss(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

# Define training procedure for classification model
def train(net: nn.Module, trainloader: DataLoader, valloader: DataLoader, epochs: int) -> None:
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)

    best_acc = 0.0

    for epoch in range(epochs):
        # Training loop
        net.train()
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        scheduler.step()

        # Validation loop
        net.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for data in valloader:
                images, labels = data
                outputs = net(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        val_acc = 100 * correct / total
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(net.state_dict(), 'best_model.pth')

        print(f"Epoch [{epoch+1}/{epochs}], Training Loss: {running_loss / (i+1):.3f}, Validation Accuracy: {val_acc:.2f}%")

# Define evaluation procedure
def evaluate(net: nn.Module, testloader: DataLoader) -> Tuple[float, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    net.eval()
    correct = 0
    total = 0
    all_labels = []
    all_predictions = []
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            all_labels.extend(labels.cpu().numpy())
            all_predictions.extend(predicted.cpu().numpy())

    accuracy = 100 * correct / total

    # Calculate confusion matrix
    cm = confusion_matrix(all_labels, all_predictions)
    print(f"Confusion Matrix:\n{cm}")

    # Calculate TP, FP, TN, FN for each class
    tp = np.diag(cm)
    fp = cm.sum(axis=0) - tp
    fn = cm.sum(axis=1) - tp
    tn = cm.sum() - (fp + fn + tp)

    # Calculate precision, recall, and F1 score for each class
    precision = precision_score(all_labels, all_predictions, average=None)
    recall = recall_score(all_labels, all_predictions, average=None)
    f1 = f1_score(all_labels, all_predictions, average=None)

    return accuracy, tp, fp, tn, fn, precision, recall, f1

# Initialize clients
def initialize_clients(trainset, transform, batch_size, num_clients):
    clients = {}
    for i in range(num_clients):
        client_trainset = torch.utils.data.Subset(trainset, range(i * len(trainset) // num_clients, (i + 1) * len(trainset) // num_clients))
        client_trainloader = torch.utils.data.DataLoader(client_trainset, batch_size=batch_size, shuffle=True)
        clients[f"client_{i}"] = client_trainloader
    return clients

def get_distribution_info(vae: VAE) -> Dict:
    # Implement the logic to extract distribution information from the VAE
    # This can involve computing statistics, parameters, or any other relevant information
    # that can be used to generate augmented data

    # Example implementation:
    distribution_info = {

        "normal": {
            "mean": vae.encoder[-1].bias.data.cpu().numpy(),
            "std": torch.exp(0.5 * vae.encoder[-1].weight.data).cpu().numpy()
        },
        "truncated": {
            "mean": vae.encoder[-1].bias.data.cpu().numpy(),
            "std": torch.exp(0.5 * vae.encoder[-1].weight.data).cpu().numpy()
        }
    }

    return distribution_info

def send_distribution_info(distribution_info: Dict) -> None:
    # Implement the logic to send the distribution information to the global server
    # This can involve using a network protocol, a message queue, or any other communication mechanism
    # to send the information

    # Example implementation:
    # Send the distribution information to the global server using a network protocol
    # For example, you can use the `socket` module to send the information over a network
    # or use a message queue like `RabbitMQ` to send the information
    pass

def generate_augmented_data(vae: VAE, distribution_info_normal: Dict, distribution_info_truncated: Dict) -> torch.Tensor:
    # Generate augmented data using both uniform and truncated uniform distributions


    mean_truncated = distribution_info_truncated["mean"]
    std_truncated = distribution_info_truncated["std"]

    mean_normal = distribution_info_normal["mean"]
    std_normal = distribution_info_normal["std"]

    augmented_data_normal = torch.randn(64, vae.z_dim) * std_normal + mean_normal


    # Generate augmented data from truncated normal distribution
    a = (0 - mean_truncated) / std_truncated
    b = np.inf
    augmented_data_truncated = torch.from_numpy(truncnorm.rvs(a, b, loc=mean_truncated, scale=std_truncated, size=(64, vae.z_dim))).float()

    # Calculate the average of augmented data from both distributions
    augmented_data_average = ( augmented_data_normal + augmented_data_truncated) / 2

    return augmented_data_average

def federated_train(net: nn.Module, vae: VAE, trainloaders: Dict[str, DataLoader], trainloader: DataLoader, valloader: DataLoader, epochs: int) -> None:
    for epoch in range(epochs):
        for client_id, client_trainloader in trainloaders.items():
            # Train VAE on client data
            vae_train(vae, client_trainloader, epochs=10)

            # Share distribution information with global server
            distribution_info = get_distribution_info(vae)
            send_distribution_info(distribution_info)

            # Receive distribution information from other clients
            other_distribution_info = receive_distribution_info()

            # Generate augmented data using received distribution information
            augmented_data = generate_augmented_data(vae, other_distribution_info["normal"], other_distribution_info["truncated"])

            # Train classification model using local, augmented, and validation data
            train(net, client_trainloader, valloader, epochs=10)

            # Send model updates to global server
            send_model_update(client_id, net.state_dict())

# Define logic to receive distribution information from global server
def receive_distribution_info() -> Dict:
    # Receive distribution information logic
    distribution_info = {

        "normal": {
            "mean": np.zeros(20),  # Adjust the size based on your latent space dimension
            "std": np.ones(20)
        },
        "truncated": {
            "mean": np.zeros(20),
            "std": np.ones(20)
        }
    }
    return distribution_info

def send_model_update(client_id: str, model_update: Dict) -> None:
    # Implement the logic to send the model update to the global server
    # This can involve using a network protocol, a message queue, or any other communication mechanism
    # to send the model update

    # Example implementation:
    # Send the model update to the global server using a network protocol
    # For example, you can use the `socket` module to send the information over a network
    # or use a message queue like `RabbitMQ` to send the information
    pass

# Define global server procedure
def global_server() -> None:
    net = Net()
    x_dim = 3 * 32 * 32  # CIFAR-10 input size
    h_dim = 400
    z_dim = 20
    vae = VAE(x_dim, h_dim, z_dim)  # Initialize VAE object with required arguments

    # Initialize clients
    num_clients = 5  # Define the number of clients
    clients = initialize_clients(train_set, transform, batch_size=128, num_clients=num_clients)

    # Train model using FedDIS
    federated_train(net, vae, clients, trainloader, valloader, epochs=10)

    # Evaluate final model
    test_accuracy, tp, fp, tn, fn, precision, recall, f1 = evaluate(net, testloader)
    print(f"Test Accuracy: {test_accuracy:.2f}%")

    print("True Positives (TP):", tp)
    print("False Positives (FP):", fp)
    print("True Negatives (TN):", tn)
    print("False Negatives (FN):", fn)
    print(":", fn+tn+tp+fp)
    print("Precision:", precision)
    print("Recall:", recall)
    print("F1 Score:", f1)

if __name__ == "__main__":
    global_server()

Files already downloaded and verified
Files already downloaded and verified
Random Images per Class: [6008 6065 6074 5856 6147 5986 6008 5962 5904 5990]
Epoch [1/10], Training Loss: 2.306, Validation Accuracy: 10.17%
Epoch [2/10], Training Loss: 2.304, Validation Accuracy: 10.24%
Epoch [3/10], Training Loss: 2.301, Validation Accuracy: 10.16%
Epoch [4/10], Training Loss: 2.299, Validation Accuracy: 10.14%
Epoch [5/10], Training Loss: 2.295, Validation Accuracy: 10.12%
Epoch [6/10], Training Loss: 2.291, Validation Accuracy: 10.24%
Epoch [7/10], Training Loss: 2.284, Validation Accuracy: 11.09%
Epoch [8/10], Training Loss: 2.274, Validation Accuracy: 12.03%
Epoch [9/10], Training Loss: 2.259, Validation Accuracy: 12.78%
Epoch [10/10], Training Loss: 2.243, Validation Accuracy: 14.08%
Epoch [1/10], Training Loss: 2.232, Validation Accuracy: 15.69%
Epoch [2/10], Training Loss: 2.217, Validation Accuracy: 17.64%
Epoch [3/10], Training Loss: 2.201, Validation Accuracy: 19.24%
Epoch [4/10], 

In [None]:
import numpy as np
from scipy.stats import truncnorm
import torch
import torchvision
import torchvision.transforms as transforms
from torch import nn
from torch.utils.data import Subset, DataLoader
from torchvision.datasets import CIFAR10
from typing import Dict, Tuple
from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score
import random

# Define VAE model
class VAE(nn.Module):
    def __init__(self, x_dim, h_dim, z_dim):
        super(VAE, self).__init__()
        self.x_dim = x_dim
        self.h_dim = h_dim
        self.z_dim = z_dim

        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, 4, stride=2, padding=1),  # 16x16
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2, padding=1),  # 8x8
            nn.ReLU(),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),  # 4x4
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(2048, 256),
            nn.ReLU(),
            nn.Linear(256, 2*z_dim)
        )

        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(z_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 2048),
            nn.ReLU(),
            nn.Unflatten(1, (128, 4, 4)),
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),  # 8x8
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1),  # 16x16
            nn.ReLU(),
            nn.ConvTranspose2d(32, 3, 4, stride=2, padding=1),  # 32x32
            nn.Sigmoid()
        )

    def encode(self, x):
        h = self.encoder(x)
        mu, logvar = h[:, :self.z_dim], h[:, self.z_dim:]
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        return self.decoder(z)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar


# Define VAE training procedure
def vae_train(vae: VAE, trainloader: DataLoader, epochs: int) -> None:
    optimizer = torch.optim.Adam(vae.parameters(), lr=1e-3)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)

    for epoch in range(epochs):
        vae.train()
        for i, data in enumerate(trainloader, 0):
            inputs, _ = data
            optimizer.zero_grad()
            recon_x, mu, logvar = vae(inputs)
            loss = vae_loss(recon_x, inputs, mu, logvar)
            loss.backward()
            optimizer.step()

        scheduler.step()


# Define classification model
class Net(nn.Module):
    def __init__(self) -> None:
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.pool(nn.functional.relu(self.conv1(x)))
        x = self.pool(nn.functional.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = nn.functional.relu(self.fc1(x))
        x = nn.functional.relu(self.fc2(x))
        x = self.fc3(x)
        return x


# Load CIFAR10 dataset
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)

full_dataset = CIFAR10(root="./data", train=True, download=True, transform=transform)
test_set = CIFAR10(root="./data", train=False, download=True, transform=transform)

# Separate dataset by class
class_indices = {i: [] for i in range(10)}  # CIFAR-10 has 10 classes
for idx, (_, label) in enumerate(full_dataset):
    class_indices[label].append(idx)

# Define target count per class, summing to 60,000 with random distribution
class_counts = np.random.multinomial(60000, [0.1] * 10)  # Adjust probabilities if you want specific class biases
print("Random Images per Class:", class_counts)

# Sample indices based on the specified class counts
indices = []
for class_id, count in enumerate(class_counts):
    # Ensure count does not exceed available images
    count = min(count, len(class_indices[class_id]))
    selected_indices = random.sample(class_indices[class_id], count)
    indices.extend(selected_indices)

# Create a custom CIFAR-10 dataset with the sampled indices
custom_dataset = Subset(full_dataset, indices)
# Split dataset into training and validation sets
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_set, val_set = torch.utils.data.random_split(full_dataset, [train_size, val_size])

# Create training and validation loaders
trainloader = DataLoader(train_set, batch_size=128, shuffle=True, num_workers=2)
valloader = DataLoader(val_set, batch_size=128, shuffle=False, num_workers=2)
testloader = DataLoader(test_set, batch_size=128, shuffle=False, num_workers=2)



# Define VAE loss function
def vae_loss(recon_x, x, mu, logvar):
    BCE = nn.functional.mse_loss(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

# Define training procedure for classification model
def train(net: nn.Module, trainloader: DataLoader, valloader: DataLoader, epochs: int) -> None:
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)

    best_acc = 0.0

    for epoch in range(epochs):
        # Training loop
        net.train()
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        scheduler.step()

        # Validation loop
        net.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for data in valloader:
                images, labels = data
                outputs = net(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        val_acc = 100 * correct / total
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(net.state_dict(), 'best_model.pth')

        print(f"Epoch [{epoch+1}/{epochs}], Training Loss: {running_loss / (i+1):.3f}, Validation Accuracy: {val_acc:.2f}%")

# Define evaluation procedure
def evaluate(net: nn.Module, testloader: DataLoader) -> Tuple[float, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    net.eval()
    correct = 0
    total = 0
    all_labels = []
    all_predictions = []
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            all_labels.extend(labels.cpu().numpy())
            all_predictions.extend(predicted.cpu().numpy())

    accuracy = 100 * correct / total

    # Calculate confusion matrix
    cm = confusion_matrix(all_labels, all_predictions)
    print(f"Confusion Matrix:\n{cm}")

    # Calculate TP, FP, TN, FN for each class
    tp = np.diag(cm)
    fp = cm.sum(axis=0) - tp
    fn = cm.sum(axis=1) - tp
    tn = cm.sum() - (fp + fn + tp)

    # Calculate precision, recall, and F1 score for each class
    precision = precision_score(all_labels, all_predictions, average=None)
    recall = recall_score(all_labels, all_predictions, average=None)
    f1 = f1_score(all_labels, all_predictions, average=None)

    return accuracy, tp, fp, tn, fn, precision, recall, f1

# Initialize clients
def initialize_clients(trainset, transform, batch_size, num_clients):
    clients = {}
    for i in range(num_clients):
        client_trainset = torch.utils.data.Subset(trainset, range(i * len(trainset) // num_clients, (i + 1) * len(trainset) // num_clients))
        client_trainloader = torch.utils.data.DataLoader(client_trainset, batch_size=batch_size, shuffle=True)
        clients[f"client_{i}"] = client_trainloader
    return clients

def get_distribution_info(vae: VAE) -> Dict:
    # Implement the logic to extract distribution information from the VAE
    # This can involve computing statistics, parameters, or any other relevant information
    # that can be used to generate augmented data

    # Example implementation:
    distribution_info = {

        "normal": {
            "mean": vae.encoder[-1].bias.data.cpu().numpy(),
            "std": torch.exp(0.5 * vae.encoder[-1].weight.data).cpu().numpy()
        },
        "truncated": {
            "mean": vae.encoder[-1].bias.data.cpu().numpy(),
            "std": torch.exp(0.5 * vae.encoder[-1].weight.data).cpu().numpy()
        }
    }

    return distribution_info

def send_distribution_info(distribution_info: Dict) -> None:
    # Implement the logic to send the distribution information to the global server
    # This can involve using a network protocol, a message queue, or any other communication mechanism
    # to send the information

    # Example implementation:
    # Send the distribution information to the global server using a network protocol
    # For example, you can use the `socket` module to send the information over a network
    # or use a message queue like `RabbitMQ` to send the information
    pass

def generate_augmented_data(vae: VAE, distribution_info_normal: Dict, distribution_info_truncated: Dict) -> torch.Tensor:
    # Generate augmented data using both uniform and truncated uniform distributions


    mean_truncated = distribution_info_truncated["mean"]
    std_truncated = distribution_info_truncated["std"]

    mean_normal = distribution_info_normal["mean"]
    std_normal = distribution_info_normal["std"]

    augmented_data_normal = torch.randn(64, vae.z_dim) * std_normal + mean_normal


    # Generate augmented data from truncated normal distribution
    a = (0 - mean_truncated) / std_truncated
    b = np.inf
    augmented_data_truncated = torch.from_numpy(truncnorm.rvs(a, b, loc=mean_truncated, scale=std_truncated, size=(64, vae.z_dim))).float()

    # Calculate the average of augmented data from both distributions
    augmented_data_average = ( augmented_data_normal + augmented_data_truncated) / 2

    return augmented_data_average

def federated_train(net: nn.Module, vae: VAE, trainloaders: Dict[str, DataLoader], trainloader: DataLoader, valloader: DataLoader, epochs: int) -> None:
    for epoch in range(epochs):
        for client_id, client_trainloader in trainloaders.items():
            # Train VAE on client data
            vae_train(vae, client_trainloader, epochs=10)

            # Share distribution information with global server
            distribution_info = get_distribution_info(vae)
            send_distribution_info(distribution_info)

            # Receive distribution information from other clients
            other_distribution_info = receive_distribution_info()

            # Generate augmented data using received distribution information
            augmented_data = generate_augmented_data(vae, other_distribution_info["normal"], other_distribution_info["truncated"])

            # Train classification model using local, augmented, and validation data
            train(net, client_trainloader, valloader, epochs=10)

            # Send model updates to global server
            send_model_update(client_id, net.state_dict())

# Define logic to receive distribution information from global server
def receive_distribution_info() -> Dict:
    # Receive distribution information logic
    distribution_info = {

        "normal": {
            "mean": np.zeros(20),  # Adjust the size based on your latent space dimension
            "std": np.ones(20)
        },
        "truncated": {
            "mean": np.zeros(20),
            "std": np.ones(20)
        }
    }
    return distribution_info

def send_model_update(client_id: str, model_update: Dict) -> None:
    # Implement the logic to send the model update to the global server
    # This can involve using a network protocol, a message queue, or any other communication mechanism
    # to send the model update

    # Example implementation:
    # Send the model update to the global server using a network protocol
    # For example, you can use the `socket` module to send the information over a network
    # or use a message queue like `RabbitMQ` to send the information
    pass

# Define global server procedure
def global_server() -> None:
    net = Net()
    x_dim = 3 * 32 * 32  # CIFAR-10 input size
    h_dim = 400
    z_dim = 20
    vae = VAE(x_dim, h_dim, z_dim)  # Initialize VAE object with required arguments

    # Initialize clients
    num_clients = 5  # Define the number of clients
    clients = initialize_clients(train_set, transform, batch_size=128, num_clients=num_clients)

    # Train model using FedDIS
    federated_train(net, vae, clients, trainloader, valloader, epochs=10)

    # Evaluate final model
    test_accuracy, tp, fp, tn, fn, precision, recall, f1 = evaluate(net, testloader)
    print(f"Test Accuracy: {test_accuracy:.2f}%")

    print("True Positives (TP):", tp)
    print("False Positives (FP):", fp)
    print("True Negatives (TN):", tn)
    print("False Negatives (FN):", fn)
    print(":", fn+tn+tp+fp)
    print("Precision:", precision)
    print("Recall:", recall)
    print("F1 Score:", f1)

if __name__ == "__main__":
    global_server()

Files already downloaded and verified
Files already downloaded and verified
Random Images per Class: [6049 5959 5932 6018 5974 6043 5971 6058 6001 5995]
Epoch [1/10], Training Loss: 2.303, Validation Accuracy: 10.39%
Epoch [2/10], Training Loss: 2.302, Validation Accuracy: 10.79%
Epoch [3/10], Training Loss: 2.300, Validation Accuracy: 10.93%
Epoch [4/10], Training Loss: 2.298, Validation Accuracy: 11.68%
Epoch [5/10], Training Loss: 2.296, Validation Accuracy: 12.45%
Epoch [6/10], Training Loss: 2.293, Validation Accuracy: 12.85%
Epoch [7/10], Training Loss: 2.290, Validation Accuracy: 13.24%
Epoch [8/10], Training Loss: 2.285, Validation Accuracy: 13.36%
Epoch [9/10], Training Loss: 2.278, Validation Accuracy: 13.75%
Epoch [10/10], Training Loss: 2.267, Validation Accuracy: 12.47%
Epoch [1/10], Training Loss: 2.259, Validation Accuracy: 14.27%
Epoch [2/10], Training Loss: 2.244, Validation Accuracy: 16.91%
Epoch [3/10], Training Loss: 2.225, Validation Accuracy: 17.81%
Epoch [4/10], 

In [None]:
import re

# Your provided text
log = """
Epoch [1/10], Training Loss: 2.306, Validation Accuracy: 10.17%
Epoch [2/10], Training Loss: 2.304, Validation Accuracy: 10.24%
Epoch [3/10], Training Loss: 2.301, Validation Accuracy: 10.16%
Epoch [4/10], Training Loss: 2.299, Validation Accuracy: 10.14%
Epoch [5/10], Training Loss: 2.295, Validation Accuracy: 10.12%
Epoch [6/10], Training Loss: 2.291, Validation Accuracy: 10.24%
Epoch [7/10], Training Loss: 2.284, Validation Accuracy: 11.09%
Epoch [8/10], Training Loss: 2.274, Validation Accuracy: 12.03%
Epoch [9/10], Training Loss: 2.259, Validation Accuracy: 12.78%
Epoch [10/10], Training Loss: 2.243, Validation Accuracy: 14.08%
Epoch [1/10], Training Loss: 2.232, Validation Accuracy: 15.69%
Epoch [2/10], Training Loss: 2.217, Validation Accuracy: 17.64%
Epoch [3/10], Training Loss: 2.201, Validation Accuracy: 19.24%
Epoch [4/10], Training Loss: 2.184, Validation Accuracy: 19.88%
Epoch [5/10], Training Loss: 2.168, Validation Accuracy: 20.73%
Epoch [6/10], Training Loss: 2.152, Validation Accuracy: 22.25%
Epoch [7/10], Training Loss: 2.137, Validation Accuracy: 21.60%
Epoch [8/10], Training Loss: 2.121, Validation Accuracy: 22.45%
Epoch [9/10], Training Loss: 2.105, Validation Accuracy: 23.39%
Epoch [10/10], Training Loss: 2.089, Validation Accuracy: 24.34%
Epoch [1/10], Training Loss: 2.082, Validation Accuracy: 27.71%
Epoch [2/10], Training Loss: 2.066, Validation Accuracy: 28.13%
Epoch [3/10], Training Loss: 2.048, Validation Accuracy: 27.49%
Epoch [4/10], Training Loss: 2.028, Validation Accuracy: 28.12%
Epoch [5/10], Training Loss: 2.006, Validation Accuracy: 28.80%
Epoch [6/10], Training Loss: 1.984, Validation Accuracy: 29.36%
Epoch [7/10], Training Loss: 1.962, Validation Accuracy: 29.68%
Epoch [8/10], Training Loss: 1.939, Validation Accuracy: 30.34%
Epoch [9/10], Training Loss: 1.920, Validation Accuracy: 30.88%
Epoch [10/10], Training Loss: 1.902, Validation Accuracy: 31.75%
Epoch [1/10], Training Loss: 1.890, Validation Accuracy: 32.49%
Epoch [2/10], Training Loss: 1.877, Validation Accuracy: 32.30%
Epoch [3/10], Training Loss: 1.863, Validation Accuracy: 32.76%
Epoch [4/10], Training Loss: 1.845, Validation Accuracy: 33.83%
Epoch [5/10], Training Loss: 1.830, Validation Accuracy: 34.41%
Epoch [6/10], Training Loss: 1.815, Validation Accuracy: 34.35%
Epoch [7/10], Training Loss: 1.794, Validation Accuracy: 35.59%
Epoch [8/10], Training Loss: 1.779, Validation Accuracy: 36.21%
Epoch [9/10], Training Loss: 1.759, Validation Accuracy: 36.27%
Epoch [10/10], Training Loss: 1.745, Validation Accuracy: 37.12%
Epoch [1/10], Training Loss: 1.734, Validation Accuracy: 37.81%
Epoch [2/10], Training Loss: 1.715, Validation Accuracy: 38.31%
Epoch [3/10], Training Loss: 1.697, Validation Accuracy: 37.80%
Epoch [4/10], Training Loss: 1.679, Validation Accuracy: 38.92%
Epoch [5/10], Training Loss: 1.661, Validation Accuracy: 38.53%
Epoch [6/10], Training Loss: 1.648, Validation Accuracy: 38.86%
Epoch [7/10], Training Loss: 1.634, Validation Accuracy: 39.55%
Epoch [8/10], Training Loss: 1.623, Validation Accuracy: 39.75%
Epoch [9/10], Training Loss: 1.607, Validation Accuracy: 40.67%
Epoch [10/10], Training Loss: 1.594, Validation Accuracy: 40.43%
Epoch [1/10], Training Loss: 1.639, Validation Accuracy: 40.93%
Epoch [2/10], Training Loss: 1.618, Validation Accuracy: 41.43%
Epoch [3/10], Training Loss: 1.608, Validation Accuracy: 41.32%
Epoch [4/10], Training Loss: 1.596, Validation Accuracy: 41.16%
Epoch [5/10], Training Loss: 1.586, Validation Accuracy: 41.67%
Epoch [6/10], Training Loss: 1.580, Validation Accuracy: 42.35%
Epoch [7/10], Training Loss: 1.570, Validation Accuracy: 42.56%
Epoch [8/10], Training Loss: 1.561, Validation Accuracy: 42.47%
Epoch [9/10], Training Loss: 1.554, Validation Accuracy: 42.86%
Epoch [10/10], Training Loss: 1.543, Validation Accuracy: 42.42%
Epoch [1/10], Training Loss: 1.550, Validation Accuracy: 43.98%
Epoch [2/10], Training Loss: 1.539, Validation Accuracy: 43.37%
Epoch [3/10], Training Loss: 1.531, Validation Accuracy: 43.02%
Epoch [4/10], Training Loss: 1.519, Validation Accuracy: 44.81%
Epoch [5/10], Training Loss: 1.506, Validation Accuracy: 44.49%
Epoch [6/10], Training Loss: 1.501, Validation Accuracy: 44.85%
Epoch [7/10], Training Loss: 1.493, Validation Accuracy: 44.86%
Epoch [8/10], Training Loss: 1.486, Validation Accuracy: 44.82%
Epoch [9/10], Training Loss: 1.473, Validation Accuracy: 45.09%
Epoch [10/10], Training Loss: 1.469, Validation Accuracy: 45.14%
Epoch [1/10], Training Loss: 1.506, Validation Accuracy: 45.68%
Epoch [2/10], Training Loss: 1.497, Validation Accuracy: 46.31%
Epoch [3/10], Training Loss: 1.485, Validation Accuracy: 44.64%
Epoch [4/10], Training Loss: 1.481, Validation Accuracy: 45.50%
Epoch [5/10], Training Loss: 1.468, Validation Accuracy: 46.86%
Epoch [6/10], Training Loss: 1.461, Validation Accuracy: 46.66%
Epoch [7/10], Training Loss: 1.457, Validation Accuracy: 46.80%
Epoch [8/10], Training Loss: 1.442, Validation Accuracy: 46.50%
Epoch [9/10], Training Loss: 1.445, Validation Accuracy: 47.59%
Epoch [10/10], Training Loss: 1.429, Validation Accuracy: 47.65%
Epoch [1/10], Training Loss: 1.452, Validation Accuracy: 46.99%
Epoch [2/10], Training Loss: 1.434, Validation Accuracy: 47.80%
Epoch [3/10], Training Loss: 1.425, Validation Accuracy: 47.92%
Epoch [4/10], Training Loss: 1.417, Validation Accuracy: 48.09%
Epoch [5/10], Training Loss: 1.408, Validation Accuracy: 47.34%
Epoch [6/10], Training Loss: 1.395, Validation Accuracy: 47.69%
Epoch [7/10], Training Loss: 1.398, Validation Accuracy: 47.07%
Epoch [8/10], Training Loss: 1.382, Validation Accuracy: 48.60%
Epoch [9/10], Training Loss: 1.385, Validation Accuracy: 46.95%
Epoch [10/10], Training Loss: 1.379, Validation Accuracy: 48.71%
Epoch [1/10], Training Loss: 1.395, Validation Accuracy: 47.52%
Epoch [2/10], Training Loss: 1.381, Validation Accuracy: 48.75%
Epoch [3/10], Training Loss: 1.370, Validation Accuracy: 48.72%
Epoch [4/10], Training Loss: 1.358, Validation Accuracy: 50.37%
Epoch [5/10], Training Loss: 1.344, Validation Accuracy: 49.66%
Epoch [6/10], Training Loss: 1.342, Validation Accuracy: 50.00%
Epoch [7/10], Training Loss: 1.328, Validation Accuracy: 49.62%
Epoch [8/10], Training Loss: 1.330, Validation Accuracy: 49.95%
Epoch [9/10], Training Loss: 1.311, Validation Accuracy: 49.79%
Epoch [10/10], Training Loss: 1.313, Validation Accuracy: 49.17%
Epoch [1/10], Training Loss: 1.401, Validation Accuracy: 50.98%
Epoch [2/10], Training Loss: 1.384, Validation Accuracy: 51.37%
Epoch [3/10], Training Loss: 1.361, Validation Accuracy: 51.11%
Epoch [4/10], Training Loss: 1.352, Validation Accuracy: 51.07%
Epoch [5/10], Training Loss: 1.350, Validation Accuracy: 50.90%
Epoch [6/10], Training Loss: 1.338, Validation Accuracy: 51.97%
Epoch [7/10], Training Loss: 1.331, Validation Accuracy: 50.99%
Epoch [8/10], Training Loss: 1.322, Validation Accuracy: 51.29%
Epoch [9/10], Training Loss: 1.313, Validation Accuracy: 51.66%
Epoch [10/10], Training Loss: 1.319, Validation Accuracy: 49.10%
Epoch [1/10], Training Loss: 1.344, Validation Accuracy: 51.42%
Epoch [2/10], Training Loss: 1.325, Validation Accuracy: 52.03%
Epoch [3/10], Training Loss: 1.305, Validation Accuracy: 52.39%
Epoch [4/10], Training Loss: 1.295, Validation Accuracy: 52.85%
Epoch [5/10], Training Loss: 1.286, Validation Accuracy: 52.78%
Epoch [6/10], Training Loss: 1.284, Validation Accuracy: 52.30%
Epoch [7/10], Training Loss: 1.277, Validation Accuracy: 51.87%
Epoch [8/10], Training Loss: 1.267, Validation Accuracy: 50.89%
Epoch [9/10], Training Loss: 1.251, Validation Accuracy: 52.01%
Epoch [10/10], Training Loss: 1.249, Validation Accuracy: 52.60%
Epoch [1/10], Training Loss: 1.326, Validation Accuracy: 52.29%
Epoch [2/10], Training Loss: 1.317, Validation Accuracy: 52.48%
Epoch [3/10], Training Loss: 1.302, Validation Accuracy: 52.77%
Epoch [4/10], Training Loss: 1.289, Validation Accuracy: 52.62%
Epoch [5/10], Training Loss: 1.284, Validation Accuracy: 52.61%
Epoch [6/10], Training Loss: 1.271, Validation Accuracy: 52.93%
Epoch [7/10], Training Loss: 1.265, Validation Accuracy: 53.08%
Epoch [8/10], Training Loss: 1.257, Validation Accuracy: 53.41%
Epoch [9/10], Training Loss: 1.248, Validation Accuracy: 52.70%
Epoch [10/10], Training Loss: 1.238, Validation Accuracy: 53.77%
Epoch [1/10], Training Loss: 1.300, Validation Accuracy: 52.18%
Epoch [2/10], Training Loss: 1.274, Validation Accuracy: 53.35%
Epoch [3/10], Training Loss: 1.256, Validation Accuracy: 53.93%
Epoch [4/10], Training Loss: 1.247, Validation Accuracy: 52.25%
Epoch [5/10], Training Loss: 1.235, Validation Accuracy: 53.51%
Epoch [6/10], Training Loss: 1.226, Validation Accuracy: 53.24%
Epoch [7/10], Training Loss: 1.227, Validation Accuracy: 54.00%
Epoch [8/10], Training Loss: 1.207, Validation Accuracy: 54.61%
Epoch [9/10], Training Loss: 1.200, Validation Accuracy: 53.91%
Epoch [10/10], Training Loss: 1.193, Validation Accuracy: 54.12%
Epoch [1/10], Training Loss: 1.245, Validation Accuracy: 54.70%
Epoch [2/10], Training Loss: 1.223, Validation Accuracy: 54.88%
Epoch [3/10], Training Loss: 1.212, Validation Accuracy: 54.29%
Epoch [4/10], Training Loss: 1.194, Validation Accuracy: 55.00%
Epoch [5/10], Training Loss: 1.189, Validation Accuracy: 54.32%
Epoch [6/10], Training Loss: 1.173, Validation Accuracy: 53.69%
Epoch [7/10], Training Loss: 1.165, Validation Accuracy: 53.73%
Epoch [8/10], Training Loss: 1.155, Validation Accuracy: 54.49%
Epoch [9/10], Training Loss: 1.145, Validation Accuracy: 55.24%
Epoch [10/10], Training Loss: 1.142, Validation Accuracy: 54.82%
Epoch [1/10], Training Loss: 1.268, Validation Accuracy: 55.58%
Epoch [2/10], Training Loss: 1.238, Validation Accuracy: 55.25%
Epoch [3/10], Training Loss: 1.217, Validation Accuracy: 55.81%
Epoch [4/10], Training Loss: 1.204, Validation Accuracy: 55.75%
Epoch [5/10], Training Loss: 1.190, Validation Accuracy: 55.15%
Epoch [6/10], Training Loss: 1.183, Validation Accuracy: 56.25%
Epoch [7/10], Training Loss: 1.171, Validation Accuracy: 55.15%
Epoch [8/10], Training Loss: 1.162, Validation Accuracy: 55.55%
Epoch [9/10], Training Loss: 1.154, Validation Accuracy: 55.53%
Epoch [10/10], Training Loss: 1.146, Validation Accuracy: 56.14%
Epoch [1/10], Training Loss: 1.214, Validation Accuracy: 55.82%
Epoch [2/10], Training Loss: 1.192, Validation Accuracy: 56.40%
Epoch [3/10], Training Loss: 1.173, Validation Accuracy: 56.66%
Epoch [4/10], Training Loss: 1.151, Validation Accuracy: 55.91%
Epoch [5/10], Training Loss: 1.143, Validation Accuracy: 56.48%
Epoch [6/10], Training Loss: 1.139, Validation Accuracy: 56.37%
Epoch [7/10], Training Loss: 1.126, Validation Accuracy: 56.08%
Epoch [8/10], Training Loss: 1.111, Validation Accuracy: 56.79%
Epoch [9/10], Training Loss: 1.107, Validation Accuracy: 56.20%
Epoch [10/10], Training Loss: 1.091, Validation Accuracy: 56.60%
Epoch [1/10], Training Loss: 1.214, Validation Accuracy: 56.96%
Epoch [2/10], Training Loss: 1.191, Validation Accuracy: 56.07%
Epoch [3/10], Training Loss: 1.170, Validation Accuracy: 57.41%
Epoch [4/10], Training Loss: 1.152, Validation Accuracy: 56.71%
Epoch [5/10], Training Loss: 1.142, Validation Accuracy: 56.14%
Epoch [6/10], Training Loss: 1.135, Validation Accuracy: 57.20%
Epoch [7/10], Training Loss: 1.119, Validation Accuracy: 57.25%
Epoch [8/10], Training Loss: 1.108, Validation Accuracy: 57.28%
Epoch [9/10], Training Loss: 1.101, Validation Accuracy: 57.02%
Epoch [10/10], Training Loss: 1.086, Validation Accuracy: 57.43%
Epoch [1/10], Training Loss: 1.189, Validation Accuracy: 57.45%
Epoch [2/10], Training Loss: 1.158, Validation Accuracy: 58.15%
Epoch [3/10], Training Loss: 1.143, Validation Accuracy: 58.04%
Epoch [4/10], Training Loss: 1.120, Validation Accuracy: 57.35%
Epoch [5/10], Training Loss: 1.109, Validation Accuracy: 57.40%
Epoch [6/10], Training Loss: 1.096, Validation Accuracy: 57.20%
Epoch [7/10], Training Loss: 1.082, Validation Accuracy: 56.95%
Epoch [8/10], Training Loss: 1.069, Validation Accuracy: 56.99%
Epoch [9/10], Training Loss: 1.069, Validation Accuracy: 57.48%
Epoch [10/10], Training Loss: 1.064, Validation Accuracy: 57.74%
Epoch [1/10], Training Loss: 1.149, Validation Accuracy: 57.60%
Epoch [2/10], Training Loss: 1.109, Validation Accuracy: 57.95%
Epoch [3/10], Training Loss: 1.086, Validation Accuracy: 58.11%
Epoch [4/10], Training Loss: 1.082, Validation Accuracy: 58.06%
Epoch [5/10], Training Loss: 1.063, Validation Accuracy: 57.10%
Epoch [6/10], Training Loss: 1.052, Validation Accuracy: 57.33%
Epoch [7/10], Training Loss: 1.036, Validation Accuracy: 58.53%
Epoch [8/10], Training Loss: 1.021, Validation Accuracy: 58.43%
Epoch [9/10], Training Loss: 1.012, Validation Accuracy: 57.59%
Epoch [10/10], Training Loss: 1.001, Validation Accuracy: 57.31%
Epoch [1/10], Training Loss: 1.161, Validation Accuracy: 58.03%
Epoch [2/10], Training Loss: 1.128, Validation Accuracy: 58.59%
Epoch [3/10], Training Loss: 1.106, Validation Accuracy: 58.09%
Epoch [4/10], Training Loss: 1.083, Validation Accuracy: 58.70%
Epoch [5/10], Training Loss: 1.070, Validation Accuracy: 58.12%
Epoch [6/10], Training Loss: 1.064, Validation Accuracy: 58.97%
Epoch [7/10], Training Loss: 1.049, Validation Accuracy: 58.49%
Epoch [8/10], Training Loss: 1.035, Validation Accuracy: 58.37%
Epoch [9/10], Training Loss: 1.028, Validation Accuracy: 58.39%
Epoch [10/10], Training Loss: 1.016, Validation Accuracy: 58.40%
Epoch [1/10], Training Loss: 1.121, Validation Accuracy: 58.49%
Epoch [2/10], Training Loss: 1.085, Validation Accuracy: 58.50%
Epoch [3/10], Training Loss: 1.059, Validation Accuracy: 58.46%
Epoch [4/10], Training Loss: 1.053, Validation Accuracy: 58.79%
Epoch [5/10], Training Loss: 1.025, Validation Accuracy: 59.09%
Epoch [6/10], Training Loss: 1.018, Validation Accuracy: 58.97%
Epoch [7/10], Training Loss: 0.996, Validation Accuracy: 58.70%
Epoch [8/10], Training Loss: 0.987, Validation Accuracy: 58.42%
Epoch [9/10], Training Loss: 0.972, Validation Accuracy: 58.51%
Epoch [10/10], Training Loss: 0.963, Validation Accuracy: 58.35%
Epoch [1/10], Training Loss: 1.132, Validation Accuracy: 57.56%
Epoch [2/10], Training Loss: 1.095, Validation Accuracy: 59.33%
Epoch [3/10], Training Loss: 1.061, Validation Accuracy: 58.23%
Epoch [4/10], Training Loss: 1.038, Validation Accuracy: 58.99%
Epoch [5/10], Training Loss: 1.024, Validation Accuracy: 59.32%
Epoch [6/10], Training Loss: 1.019, Validation Accuracy: 59.17%
Epoch [7/10], Training Loss: 1.000, Validation Accuracy: 58.38%
Epoch [8/10], Training Loss: 0.992, Validation Accuracy: 59.06%
Epoch [9/10], Training Loss: 0.979, Validation Accuracy: 59.38%
Epoch [10/10], Training Loss: 0.968, Validation Accuracy: 59.12%
Epoch [1/10], Training Loss: 1.108, Validation Accuracy: 59.91%
Epoch [2/10], Training Loss: 1.058, Validation Accuracy: 59.51%
Epoch [3/10], Training Loss: 1.048, Validation Accuracy: 59.60%
Epoch [4/10], Training Loss: 1.010, Validation Accuracy: 59.51%
Epoch [5/10], Training Loss: 0.998, Validation Accuracy: 59.51%
Epoch [6/10], Training Loss: 0.990, Validation Accuracy: 59.49%
Epoch [7/10], Training Loss: 0.970, Validation Accuracy: 59.00%
Epoch [8/10], Training Loss: 0.951, Validation Accuracy: 59.56%
Epoch [9/10], Training Loss: 0.939, Validation Accuracy: 59.66%
Epoch [10/10], Training Loss: 0.933, Validation Accuracy: 58.87%
Epoch [1/10], Training Loss: 1.054, Validation Accuracy: 59.66%
Epoch [2/10], Training Loss: 1.020, Validation Accuracy: 59.64%
Epoch [3/10], Training Loss: 0.989, Validation Accuracy: 59.59%
Epoch [4/10], Training Loss: 0.967, Validation Accuracy: 60.24%
Epoch [5/10], Training Loss: 0.941, Validation Accuracy: 59.93%
Epoch [6/10], Training Loss: 0.930, Validation Accuracy: 59.69%
Epoch [7/10], Training Loss: 0.919, Validation Accuracy: 59.87%
Epoch [8/10], Training Loss: 0.900, Validation Accuracy: 58.85%
Epoch [9/10], Training Loss: 0.889, Validation Accuracy: 59.52%
Epoch [10/10], Training Loss: 0.866, Validation Accuracy: 59.45%
Epoch [1/10], Training Loss: 1.083, Validation Accuracy: 59.64%
Epoch [2/10], Training Loss: 1.044, Validation Accuracy: 59.64%
Epoch [3/10], Training Loss: 1.016, Validation Accuracy: 59.82%
Epoch [4/10], Training Loss: 0.996, Validation Accuracy: 59.09%
Epoch [5/10], Training Loss: 0.975, Validation Accuracy: 59.20%
Epoch [6/10], Training Loss: 0.960, Validation Accuracy: 60.20%
Epoch [7/10], Training Loss: 0.945, Validation Accuracy: 59.93%
Epoch [8/10], Training Loss: 0.926, Validation Accuracy: 59.52%
Epoch [9/10], Training Loss: 0.917, Validation Accuracy: 59.53%
Epoch [10/10], Training Loss: 0.901, Validation Accuracy: 59.75%
Epoch [1/10], Training Loss: 1.039, Validation Accuracy: 60.36%
Epoch [2/10], Training Loss: 0.989, Validation Accuracy: 60.13%
Epoch [3/10], Training Loss: 0.964, Validation Accuracy: 59.62%
Epoch [4/10], Training Loss: 0.941, Validation Accuracy: 60.02%
Epoch [5/10], Training Loss: 0.922, Validation Accuracy: 59.46%
Epoch [6/10], Training Loss: 0.908, Validation Accuracy: 59.33%
Epoch [7/10], Training Loss: 0.899, Validation Accuracy: 59.47%
Epoch [8/10], Training Loss: 0.880, Validation Accuracy: 60.00%
Epoch [9/10], Training Loss: 0.855, Validation Accuracy: 59.58%
Epoch [10/10], Training Loss: 0.843, Validation Accuracy: 59.94%
Epoch [1/10], Training Loss: 1.058, Validation Accuracy: 59.44%
Epoch [2/10], Training Loss: 1.006, Validation Accuracy: 60.68%
Epoch [3/10], Training Loss: 0.969, Validation Accuracy: 60.58%
Epoch [4/10], Training Loss: 0.944, Validation Accuracy: 60.16%
Epoch [5/10], Training Loss: 0.931, Validation Accuracy: 59.64%
Epoch [6/10], Training Loss: 0.914, Validation Accuracy: 60.08%
Epoch [7/10], Training Loss: 0.888, Validation Accuracy: 60.49%
Epoch [8/10], Training Loss: 0.879, Validation Accuracy: 60.26%
Epoch [9/10], Training Loss: 0.876, Validation Accuracy: 60.08%
Epoch [10/10], Training Loss: 0.847, Validation Accuracy: 60.45%
Epoch [1/10], Training Loss: 1.037, Validation Accuracy: 60.23%
Epoch [2/10], Training Loss: 0.982, Validation Accuracy: 60.01%
Epoch [3/10], Training Loss: 0.944, Validation Accuracy: 60.26%
Epoch [4/10], Training Loss: 0.928, Validation Accuracy: 60.41%
Epoch [5/10], Training Loss: 0.910, Validation Accuracy: 60.20%
Epoch [6/10], Training Loss: 0.888, Validation Accuracy: 59.89%
Epoch [7/10], Training Loss: 0.864, Validation Accuracy: 60.42%
Epoch [8/10], Training Loss: 0.855, Validation Accuracy: 59.98%
Epoch [9/10], Training Loss: 0.841, Validation Accuracy: 60.05%
Epoch [10/10], Training Loss: 0.824, Validation Accuracy: 59.50%
Epoch [1/10], Training Loss: 0.986, Validation Accuracy: 59.41%
Epoch [2/10], Training Loss: 0.938, Validation Accuracy: 60.51%
Epoch [3/10], Training Loss: 0.899, Validation Accuracy: 60.24%
Epoch [4/10], Training Loss: 0.875, Validation Accuracy: 59.89%
Epoch [5/10], Training Loss: 0.851, Validation Accuracy: 60.46%
Epoch [6/10], Training Loss: 0.834, Validation Accuracy: 60.45%
Epoch [7/10], Training Loss: 0.817, Validation Accuracy: 59.87%
Epoch [8/10], Training Loss: 0.794, Validation Accuracy: 59.05%
Epoch [9/10], Training Loss: 0.781, Validation Accuracy: 60.17%
Epoch [10/10], Training Loss: 0.767, Validation Accuracy: 60.05%
Epoch [1/10], Training Loss: 1.030, Validation Accuracy: 60.51%
Epoch [2/10], Training Loss: 0.966, Validation Accuracy: 60.02%
Epoch [3/10], Training Loss: 0.924, Validation Accuracy: 60.54%
Epoch [4/10], Training Loss: 0.907, Validation Accuracy: 59.31%
Epoch [5/10], Training Loss: 0.881, Validation Accuracy: 59.63%
Epoch [6/10], Training Loss: 0.859, Validation Accuracy: 59.81%
Epoch [7/10], Training Loss: 0.842, Validation Accuracy: 59.72%
Epoch [8/10], Training Loss: 0.822, Validation Accuracy: 60.25%
Epoch [9/10], Training Loss: 0.808, Validation Accuracy: 60.20%
Epoch [10/10], Training Loss: 0.793, Validation Accuracy: 60.43%
Epoch [1/10], Training Loss: 0.977, Validation Accuracy: 60.45%
Epoch [2/10], Training Loss: 0.913, Validation Accuracy: 61.05%
Epoch [3/10], Training Loss: 0.877, Validation Accuracy: 60.51%
Epoch [4/10], Training Loss: 0.846, Validation Accuracy: 60.68%
Epoch [5/10], Training Loss: 0.836, Validation Accuracy: 60.13%
Epoch [6/10], Training Loss: 0.799, Validation Accuracy: 60.87%
Epoch [7/10], Training Loss: 0.795, Validation Accuracy: 60.48%
Epoch [8/10], Training Loss: 0.768, Validation Accuracy: 59.92%
Epoch [9/10], Training Loss: 0.756, Validation Accuracy: 59.54%
Epoch [10/10], Training Loss: 0.741, Validation Accuracy: 60.07%
Epoch [1/10], Training Loss: 1.000, Validation Accuracy: 59.77%
Epoch [2/10], Training Loss: 0.930, Validation Accuracy: 60.16%
Epoch [3/10], Training Loss: 0.893, Validation Accuracy: 60.95%
Epoch [4/10], Training Loss: 0.864, Validation Accuracy: 61.28%
Epoch [5/10], Training Loss: 0.834, Validation Accuracy: 60.79%
Epoch [6/10], Training Loss: 0.808, Validation Accuracy: 60.95%
Epoch [7/10], Training Loss: 0.797, Validation Accuracy: 61.02%
Epoch [8/10], Training Loss: 0.778, Validation Accuracy: 60.99%
Epoch [9/10], Training Loss: 0.763, Validation Accuracy: 60.78%
Epoch [10/10], Training Loss: 0.740, Validation Accuracy: 60.62%
Epoch [1/10], Training Loss: 0.979, Validation Accuracy: 60.17%
Epoch [2/10], Training Loss: 0.907, Validation Accuracy: 60.89%
Epoch [3/10], Training Loss: 0.875, Validation Accuracy: 61.28%
Epoch [4/10], Training Loss: 0.843, Validation Accuracy: 60.99%
Epoch [5/10], Training Loss: 0.816, Validation Accuracy: 60.89%
Epoch [6/10], Training Loss: 0.795, Validation Accuracy: 61.01%
Epoch [7/10], Training Loss: 0.776, Validation Accuracy: 61.29%
Epoch [8/10], Training Loss: 0.752, Validation Accuracy: 60.45%
Epoch [9/10], Training Loss: 0.735, Validation Accuracy: 60.28%
Epoch [10/10], Training Loss: 0.718, Validation Accuracy: 59.78%
Epoch [1/10], Training Loss: 0.939, Validation Accuracy: 60.89%
Epoch [2/10], Training Loss: 0.875, Validation Accuracy: 61.15%
Epoch [3/10], Training Loss: 0.821, Validation Accuracy: 60.44%
Epoch [4/10], Training Loss: 0.788, Validation Accuracy: 61.15%
Epoch [5/10], Training Loss: 0.769, Validation Accuracy: 61.36%
Epoch [6/10], Training Loss: 0.738, Validation Accuracy: 60.91%
Epoch [7/10], Training Loss: 0.724, Validation Accuracy: 60.19%
Epoch [8/10], Training Loss: 0.697, Validation Accuracy: 60.69%
Epoch [9/10], Training Loss: 0.673, Validation Accuracy: 60.23%
Epoch [10/10], Training Loss: 0.661, Validation Accuracy: 60.16%
Epoch [1/10], Training Loss: 0.975, Validation Accuracy: 60.62%
Epoch [2/10], Training Loss: 0.899, Validation Accuracy: 60.32%
Epoch [3/10], Training Loss: 0.854, Validation Accuracy: 61.46%
Epoch [4/10], Training Loss: 0.820, Validation Accuracy: 60.97%
Epoch [5/10], Training Loss: 0.794, Validation Accuracy: 60.20%
Epoch [6/10], Training Loss: 0.768, Validation Accuracy: 60.92%
Epoch [7/10], Training Loss: 0.750, Validation Accuracy: 60.07%
Epoch [8/10], Training Loss: 0.725, Validation Accuracy: 60.51%
Epoch [9/10], Training Loss: 0.707, Validation Accuracy: 59.95%
Epoch [10/10], Training Loss: 0.696, Validation Accuracy: 60.05%
Epoch [1/10], Training Loss: 0.924, Validation Accuracy: 59.99%
Epoch [2/10], Training Loss: 0.862, Validation Accuracy: 60.82%
Epoch [3/10], Training Loss: 0.804, Validation Accuracy: 60.53%
Epoch [4/10], Training Loss: 0.771, Validation Accuracy: 60.60%
Epoch [5/10], Training Loss: 0.744, Validation Accuracy: 60.34%
Epoch [6/10], Training Loss: 0.717, Validation Accuracy: 60.87%
Epoch [7/10], Training Loss: 0.695, Validation Accuracy: 60.95%
Epoch [8/10], Training Loss: 0.669, Validation Accuracy: 60.52%
Epoch [9/10], Training Loss: 0.654, Validation Accuracy: 60.11%
Epoch [10/10], Training Loss: 0.634, Validation Accuracy: 60.18%
Epoch [1/10], Training Loss: 0.938, Validation Accuracy: 60.37%
Epoch [2/10], Training Loss: 0.851, Validation Accuracy: 60.88%
Epoch [3/10], Training Loss: 0.805, Validation Accuracy: 60.58%
Epoch [4/10], Training Loss: 0.770, Validation Accuracy: 61.33%
Epoch [5/10], Training Loss: 0.739, Validation Accuracy: 60.35%
Epoch [6/10], Training Loss: 0.715, Validation Accuracy: 60.46%
Epoch [7/10], Training Loss: 0.688, Validation Accuracy: 61.35%
Epoch [8/10], Training Loss: 0.665, Validation Accuracy: 61.09%
Epoch [9/10], Training Loss: 0.652, Validation Accuracy: 61.31%
Epoch [10/10], Training Loss: 0.629, Validation Accuracy: 61.21%
Epoch [1/10], Training Loss: 0.927, Validation Accuracy: 61.01%
Epoch [2/10], Training Loss: 0.837, Validation Accuracy: 61.52%
Epoch [3/10], Training Loss: 0.783, Validation Accuracy: 61.39%
Epoch [4/10], Training Loss: 0.755, Validation Accuracy: 60.52%
Epoch [5/10], Training Loss: 0.723, Validation Accuracy: 61.21%
Epoch [6/10], Training Loss: 0.703, Validation Accuracy: 61.30%
Epoch [7/10], Training Loss: 0.673, Validation Accuracy: 61.16%
Epoch [8/10], Training Loss: 0.650, Validation Accuracy: 60.22%
Epoch [9/10], Training Loss: 0.633, Validation Accuracy: 60.57%
Epoch [10/10], Training Loss: 0.613, Validation Accuracy: 60.86%
Epoch [1/10], Training Loss: 0.897, Validation Accuracy: 60.27%
Epoch [2/10], Training Loss: 0.799, Validation Accuracy: 60.16%
Epoch [3/10], Training Loss: 0.742, Validation Accuracy: 61.33%
Epoch [4/10], Training Loss: 0.696, Validation Accuracy: 60.50%
Epoch [5/10], Training Loss: 0.686, Validation Accuracy: 60.48%
Epoch [6/10], Training Loss: 0.647, Validation Accuracy: 60.93%
Epoch [7/10], Training Loss: 0.620, Validation Accuracy: 60.92%
Epoch [8/10], Training Loss: 0.601, Validation Accuracy: 60.26%
Epoch [9/10], Training Loss: 0.592, Validation Accuracy: 59.47%
Epoch [10/10], Training Loss: 0.557, Validation Accuracy: 59.89%
Epoch [1/10], Training Loss: 0.940, Validation Accuracy: 60.66%
Epoch [2/10], Training Loss: 0.835, Validation Accuracy: 60.76%
Epoch [3/10], Training Loss: 0.784, Validation Accuracy: 61.05%
Epoch [4/10], Training Loss: 0.732, Validation Accuracy: 61.23%
Epoch [5/10], Training Loss: 0.701, Validation Accuracy: 61.02%
Epoch [6/10], Training Loss: 0.677, Validation Accuracy: 60.75%
Epoch [7/10], Training Loss: 0.647, Validation Accuracy: 60.71%
Epoch [8/10], Training Loss: 0.627, Validation Accuracy: 61.12%
Epoch [9/10], Training Loss: 0.607, Validation Accuracy: 60.59%
Epoch [10/10], Training Loss: 0.590, Validation Accuracy: 60.99%
Epoch [1/10], Training Loss: 0.871, Validation Accuracy: 61.10%
Epoch [2/10], Training Loss: 0.777, Validation Accuracy: 61.34%
Epoch [3/10], Training Loss: 0.721, Validation Accuracy: 61.01%
Epoch [4/10], Training Loss: 0.686, Validation Accuracy: 60.64%
Epoch [5/10], Training Loss: 0.652, Validation Accuracy: 60.99%
Epoch [6/10], Training Loss: 0.625, Validation Accuracy: 60.59%
Epoch [7/10], Training Loss: 0.605, Validation Accuracy: 60.79%
Epoch [8/10], Training Loss: 0.578, Validation Accuracy: 60.60%
Epoch [9/10], Training Loss: 0.557, Validation Accuracy: 60.78%
Epoch [10/10], Training Loss: 0.541, Validation Accuracy: 60.75%
Epoch [1/10], Training Loss: 0.902, Validation Accuracy: 60.82%
Epoch [2/10], Training Loss: 0.783, Validation Accuracy: 61.00%
Epoch [3/10], Training Loss: 0.736, Validation Accuracy: 60.98%
Epoch [4/10], Training Loss: 0.691, Validation Accuracy: 61.54%
Epoch [5/10], Training Loss: 0.649, Validation Accuracy: 61.91%
Epoch [6/10], Training Loss: 0.623, Validation Accuracy: 61.32%
Epoch [7/10], Training Loss: 0.603, Validation Accuracy: 61.24%
Epoch [8/10], Training Loss: 0.573, Validation Accuracy: 61.27%
Epoch [9/10], Training Loss: 0.548, Validation Accuracy: 61.04%
Epoch [10/10], Training Loss: 0.530, Validation Accuracy: 60.88%
Epoch [1/10], Training Loss: 0.894, Validation Accuracy: 61.79%
Epoch [2/10], Training Loss: 0.769, Validation Accuracy: 61.70%
Epoch [3/10], Training Loss: 0.707, Validation Accuracy: 61.22%
Epoch [4/10], Training Loss: 0.665, Validation Accuracy: 60.32%
Epoch [5/10], Training Loss: 0.634, Validation Accuracy: 61.22%
Epoch [6/10], Training Loss: 0.603, Validation Accuracy: 61.16%
Epoch [7/10], Training Loss: 0.576, Validation Accuracy: 61.02%
Epoch [8/10], Training Loss: 0.555, Validation Accuracy: 61.05%
Epoch [9/10], Training Loss: 0.530, Validation Accuracy: 60.82%
Epoch [10/10], Training Loss: 0.518, Validation Accuracy: 60.19%
Epoch [1/10], Training Loss: 0.856, Validation Accuracy: 60.07%
Epoch [2/10], Training Loss: 0.739, Validation Accuracy: 60.81%
Epoch [3/10], Training Loss: 0.668, Validation Accuracy: 60.41%
Epoch [4/10], Training Loss: 0.625, Validation Accuracy: 61.07%
Epoch [5/10], Training Loss: 0.588, Validation Accuracy: 60.78%
Epoch [6/10], Training Loss: 0.555, Validation Accuracy: 60.13%
Epoch [7/10], Training Loss: 0.524, Validation Accuracy: 60.87%
Epoch [8/10], Training Loss: 0.505, Validation Accuracy: 61.03%
Epoch [9/10], Training Loss: 0.485, Validation Accuracy: 60.60%
Epoch [10/10], Training Loss: 0.463, Validation Accuracy: 59.93%
Epoch [1/10], Training Loss: 0.914, Validation Accuracy: 60.92%
Epoch [2/10], Training Loss: 0.782, Validation Accuracy: 61.35%
Epoch [3/10], Training Loss: 0.701, Validation Accuracy: 61.00%
Epoch [4/10], Training Loss: 0.647, Validation Accuracy: 60.66%
Epoch [5/10], Training Loss: 0.613, Validation Accuracy: 61.26%
Epoch [6/10], Training Loss: 0.597, Validation Accuracy: 60.98%
Epoch [7/10], Training Loss: 0.553, Validation Accuracy: 61.01%
Epoch [8/10], Training Loss: 0.529, Validation Accuracy: 60.61%
Epoch [9/10], Training Loss: 0.504, Validation Accuracy: 60.75%
Epoch [10/10], Training Loss: 0.486, Validation Accuracy: 60.77%
Epoch [1/10], Training Loss: 0.836, Validation Accuracy: 60.59%
Epoch [2/10], Training Loss: 0.719, Validation Accuracy: 60.36%
Epoch [3/10], Training Loss: 0.645, Validation Accuracy: 60.84%
Epoch [4/10], Training Loss: 0.608, Validation Accuracy: 60.31%
Epoch [5/10], Training Loss: 0.573, Validation Accuracy: 61.11%
Epoch [6/10], Training Loss: 0.538, Validation Accuracy: 60.81%
Epoch [7/10], Training Loss: 0.508, Validation Accuracy: 60.49%
Epoch [8/10], Training Loss: 0.490, Validation Accuracy: 60.93%
Epoch [9/10], Training Loss: 0.459, Validation Accuracy: 60.15%
Epoch [10/10], Training Loss: 0.435, Validation Accuracy: 60.76%
Epoch [1/10], Training Loss: 0.842, Validation Accuracy: 60.50%
Epoch [2/10], Training Loss: 0.705, Validation Accuracy: 60.87%
Epoch [3/10], Training Loss: 0.643, Validation Accuracy: 61.42%
Epoch [4/10], Training Loss: 0.599, Validation Accuracy: 59.88%
Epoch [5/10], Training Loss: 0.565, Validation Accuracy: 60.83%
Epoch [6/10], Training Loss: 0.530, Validation Accuracy: 61.29%
Epoch [7/10], Training Loss: 0.501, Validation Accuracy: 61.10%
Epoch [8/10], Training Loss: 0.478, Validation Accuracy: 60.61%
Epoch [9/10], Training Loss: 0.459, Validation Accuracy: 60.17%
Epoch [10/10], Training Loss: 0.427, Validation Accuracy: 61.18%
Epoch [1/10], Training Loss: 0.840, Validation Accuracy: 61.18%
Epoch [2/10], Training Loss: 0.698, Validation Accuracy: 61.01%
Epoch [3/10], Training Loss: 0.629, Validation Accuracy: 61.28%
Epoch [4/10], Training Loss: 0.582, Validation Accuracy: 61.35%
Epoch [5/10], Training Loss: 0.538, Validation Accuracy: 60.65%
Epoch [6/10], Training Loss: 0.509, Validation Accuracy: 60.96%
Epoch [7/10], Training Loss: 0.484, Validation Accuracy: 60.31%
Epoch [8/10], Training Loss: 0.457, Validation Accuracy: 59.92%
Epoch [9/10], Training Loss: 0.439, Validation Accuracy: 60.74%
Epoch [10/10], Training Loss: 0.411, Validation Accuracy: 60.59%
Epoch [1/10], Training Loss: 0.828, Validation Accuracy: 59.25%
Epoch [2/10], Training Loss: 0.673, Validation Accuracy: 60.12%
Epoch [3/10], Training Loss: 0.597, Validation Accuracy: 60.68%
Epoch [4/10], Training Loss: 0.550, Validation Accuracy: 60.16%
Epoch [5/10], Training Loss: 0.503, Validation Accuracy: 60.91%
Epoch [6/10], Training Loss: 0.472, Validation Accuracy: 60.43%
Epoch [7/10], Training Loss: 0.444, Validation Accuracy: 60.29%
Epoch [8/10], Training Loss: 0.419, Validation Accuracy: 59.90%
Epoch [9/10], Training Loss: 0.392, Validation Accuracy: 60.46%
Epoch [10/10], Training Loss: 0.374, Validation Accuracy: 60.34%
"""

# Regular expression to find validation accuracies
accuracies = re.findall(r'Validation Accuracy: (\d+\.\d+)%', log)

# Convert accuracies from string to float
accuracies = [float(acc) for acc in accuracies]

# Print accuracies
print("Accuracies:", accuracies)

# Print size of the array
print("Size of array:", len(accuracies))


Accuracies: [10.17, 10.24, 10.16, 10.14, 10.12, 10.24, 11.09, 12.03, 12.78, 14.08, 15.69, 17.64, 19.24, 19.88, 20.73, 22.25, 21.6, 22.45, 23.39, 24.34, 27.71, 28.13, 27.49, 28.12, 28.8, 29.36, 29.68, 30.34, 30.88, 31.75, 32.49, 32.3, 32.76, 33.83, 34.41, 34.35, 35.59, 36.21, 36.27, 37.12, 37.81, 38.31, 37.8, 38.92, 38.53, 38.86, 39.55, 39.75, 40.67, 40.43, 40.93, 41.43, 41.32, 41.16, 41.67, 42.35, 42.56, 42.47, 42.86, 42.42, 43.98, 43.37, 43.02, 44.81, 44.49, 44.85, 44.86, 44.82, 45.09, 45.14, 45.68, 46.31, 44.64, 45.5, 46.86, 46.66, 46.8, 46.5, 47.59, 47.65, 46.99, 47.8, 47.92, 48.09, 47.34, 47.69, 47.07, 48.6, 46.95, 48.71, 47.52, 48.75, 48.72, 50.37, 49.66, 50.0, 49.62, 49.95, 49.79, 49.17, 50.98, 51.37, 51.11, 51.07, 50.9, 51.97, 50.99, 51.29, 51.66, 49.1, 51.42, 52.03, 52.39, 52.85, 52.78, 52.3, 51.87, 50.89, 52.01, 52.6, 52.29, 52.48, 52.77, 52.62, 52.61, 52.93, 53.08, 53.41, 52.7, 53.77, 52.18, 53.35, 53.93, 52.25, 53.51, 53.24, 54.0, 54.61, 53.91, 54.12, 54.7, 54.88, 54.29, 55.