In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

# Check and set the device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cpu')

In [2]:
# Federated learning hyperparameters
num_epochs, learning_rate = 5, 0.001
batch_sizes, communication_rounds = [64, 128], 100
overlap_percentage = 25 # Dmeta hyperparameter

In [4]:
# Define a deep learning model (CNN)
class CNNModel(nn.Module):
    def __init__(self):
        super(CNNModel, self).__init__()
        # Modify the architecture based on your requirements
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.relu1 = nn.ReLU()
        self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.relu2 = nn.ReLU()
        self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(64 * 7 * 7, 512)
        self.relu3 = nn.ReLU()
        self.fc2 = nn.Linear(512, 10)  # Assuming 10 classes

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.maxpool1(x)
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.maxpool2(x)
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        x = self.relu3(x)
        x = self.fc2(x)
        return x

In [5]:
# Dataset preparation
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

# Assuming you have the FEMNIST dataset downloaded and split for each client
# Replace 'path_to_client_data' with the actual path to the client data
client_dataset = torchvision.datasets.FashionMNIST(root='path_to_client_data', train=True, download=True, transform=transform)

# Split the client dataset into train and validation sets
train_size = int(0.8 * len(client_dataset))
val_size = len(client_dataset) - train_size
client_train_set, client_val_set = random_split(client_dataset, [train_size, val_size])

# DataLoader for training and validation
train_loaders = [DataLoader(client_train_set, batch_size=batch_size, shuffle=True, num_workers=4) for batch_size in batch_sizes]
val_loader = DataLoader(client_val_set, batch_size=batch_sizes[0], shuffle=False, num_workers=4)

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to path_to_client_data\FashionMNIST\raw\train-images-idx3-ubyte.gz


  0%|          | 0/26421880 [00:00<?, ?it/s]

Extracting path_to_client_data\FashionMNIST\raw\train-images-idx3-ubyte.gz to path_to_client_data\FashionMNIST\raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to path_to_client_data\FashionMNIST\raw\train-labels-idx1-ubyte.gz


  0%|          | 0/29515 [00:00<?, ?it/s]

Extracting path_to_client_data\FashionMNIST\raw\train-labels-idx1-ubyte.gz to path_to_client_data\FashionMNIST\raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to path_to_client_data\FashionMNIST\raw\t10k-images-idx3-ubyte.gz


  0%|          | 0/4422102 [00:00<?, ?it/s]

Extracting path_to_client_data\FashionMNIST\raw\t10k-images-idx3-ubyte.gz to path_to_client_data\FashionMNIST\raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to path_to_client_data\FashionMNIST\raw\t10k-labels-idx1-ubyte.gz


  0%|          | 0/5148 [00:00<?, ?it/s]

Extracting path_to_client_data\FashionMNIST\raw\t10k-labels-idx1-ubyte.gz to path_to_client_data\FashionMNIST\raw



In [10]:
# Function to train the model for a given federated learning algorithm
def train_federated_learning_with_dmeta(model, train_loader, dmeta_loader, algorithm_name):
    model.train()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    performance_metrics = {'accuracy': [], 'loss': []}

    for epoch in range(num_epochs):
        running_loss = 0.0

        # Train on D
        for images, labels in tqdm(train_loader, desc=f'Epoch {epoch + 1}/{num_epochs} - {algorithm_name} - D'):
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        # Train on Dmeta
        for images, labels in tqdm(dmeta_loader, desc=f'Epoch {epoch + 1}/{num_epochs} - {algorithm_name} - Dmeta'):
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        # Print average loss at the end of each epoch
        average_loss = running_loss / (len(train_loader) + len(dmeta_loader))
        print(f'Epoch [{epoch + 1}/{num_epochs}] - {algorithm_name}, Loss: {average_loss:.4f}')

        # Evaluate on validation set
        accuracy = evaluate_model(model, val_loader)
        print(f'Epoch [{epoch + 1}/{num_epochs}] - {algorithm_name}, Validation Accuracy: {accuracy * 100:.2f}%')

        performance_metrics['accuracy'].append(accuracy)
        performance_metrics['loss'].append(average_loss)

    return performance_metrics

# Function to evaluate the model on the validation set
def evaluate_model(model, val_loader):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = correct / total
    return accuracy

In [7]:
# def train_federated_learningAdj(model, train_loader, algorithm_name, times):
#     model.train()
#     criterion = nn.CrossEntropyLoss()
#     optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    
#     # Initialize the convolutional layer outside the loop
#     in_channels = train_loader.dataset[0][0].shape[0]  # Number of input channels
#     out_channels = max(1, in_channels - times) 
#     conv_reduce_channels = nn.Conv2d(in_channels, out_channels, kernel_size=1)
    
#     performance_metrics = {'accuracy': [], 'loss': []}

#     for epoch in range(num_epochs):
#         running_loss = 0.0
#         for images, labels in tqdm(train_loader, desc=f'Epoch {epoch + 1}/{num_epochs} - {algorithm_name}'):
#             images, labels = images.to(device), labels.to(device)
            
#             # Apply the convolutional layer to input images
#             data = conv_reduce_channels(images)
            
#             optimizer.zero_grad()
#             outputs = model(data)
#             loss = criterion(outputs, labels)
#             loss.backward()
#             optimizer.step()

#             running_loss += loss.item()

#         # Print average loss at the end of each epoch
#         average_loss = running_loss / len(train_loader)
#         print(f'Epoch [{epoch + 1}/{num_epochs}] - {algorithm_name}, Loss: {average_loss:.4f}')

#         # Evaluate on validation set
#         accuracy = evaluate_model(model, val_loader)
#         print(f'Epoch [{epoch + 1}/{num_epochs}] - {algorithm_name}, Validation Accuracy: {accuracy * 100:.2f}%')

#         performance_metrics['accuracy'].append(accuracy)
#         performance_metrics['loss'].append(average_loss)

#     return performance_metrics

# Function to train the model for a given federated learning algorithm
def train_federated_learning_with_dmetaAdj(model, train_loader, dmeta_loader, algorithm_name, times):
    model.train()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
     # Initialize the convolutional layer outside the loop
    in_channels = train_loader.dataset[0][0].shape[0]  # Number of input channels
    out_channels = max(1, in_channels - times) 
    conv_reduce_channels = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    performance_metrics = {'accuracy': [], 'loss': []}

    for epoch in range(num_epochs):
        running_loss = 0.0

        # Train on D
        for images, labels in tqdm(train_loader, desc=f'Epoch {epoch + 1}/{num_epochs} - {algorithm_name} - D'):
            images, labels = images.to(device), labels.to(device)

            data = conv_reduce_channels(images)
            
            optimizer.zero_grad()
            outputs = model(data)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        # Train on Dmeta
        for images, labels in tqdm(dmeta_loader, desc=f'Epoch {epoch + 1}/{num_epochs} - {algorithm_name} - Dmeta'):
            images, labels = images.to(device), labels.to(device)

            data = conv_reduce_channels(images)
            
            optimizer.zero_grad()
            outputs = model(data)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        # Print average loss at the end of each epoch
        average_loss = running_loss / (len(train_loader) + len(dmeta_loader))
        print(f'Epoch [{epoch + 1}/{num_epochs}] - {algorithm_name}, Loss: {average_loss:.4f}')

        # Evaluate on validation set
        accuracy = evaluate_model(model, val_loader)
        print(f'Epoch [{epoch + 1}/{num_epochs}] - {algorithm_name}, Validation Accuracy: {accuracy * 100:.2f}%')

        performance_metrics['accuracy'].append(accuracy)
        performance_metrics['loss'].append(average_loss)

    return performance_metrics

# Function to evaluate the model on the validation set
def evaluate_model(model, val_loader):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = correct / total
    return accuracy



In [17]:
# Number of writers in Dmeta
num_writers_dmeta = int(overlap_percentage / 100 * len(client_dataset))

# Select writers from D and the auxiliary dataset
writers_d = client_dataset.targets.unique().tolist()
writers_auxiliary = list(range(100))  # Assuming 100 writers in the auxiliary dataset

# Select a certain proportion of writers from D and the rest from the auxiliary dataset
selected_writers_d = writers_d[:len(writers_d) - num_writers_dmeta]
selected_writers_auxiliary = writers_auxiliary[:num_writers_dmeta]

# Combine selected writers to form Dmeta writers
writers_dmeta = selected_writers_d + selected_writers_auxiliary

# Sample 1% examples to form Dmeta
num_samples_dmeta = int(len(client_dataset))
indices_dmeta = torch.randperm(len(client_dataset))[:num_samples_dmeta]

# Create Dmeta by selecting examples from D and auxiliary dataset
data_dmeta = [client_dataset[i] for i in indices_dmeta]
dmeta_loader = DataLoader(data_dmeta, batch_size=batch_sizes[0], shuffle=True, num_workers=4)

In [22]:
fedmeta_model = CNNModel().to(device)
# Train & Evaluate FedMeta w/ UGA
fedmeta_uga_performance_dmeta = train_federated_learning_with_dmeta(fedmeta_model, train_loaders[0], dmeta_loader, 'FedMeta w/ UGA with Dmeta')

Epoch 1/5 - FedMeta w/ UGA with Dmeta - D:   0%|          | 0/750 [00:00<?, ?it/s]

Epoch 1/5 - FedMeta w/ UGA with Dmeta - D: 100%|██████████| 750/750 [00:36<00:00, 20.74it/s]
Epoch 1/5 - FedMeta w/ UGA with Dmeta - Dmeta: 100%|██████████| 938/938 [03:00<00:00,  5.21it/s]


Epoch [1/5] - FedMeta w/ UGA with Dmeta, Loss: 0.3344
Epoch [1/5] - FedMeta w/ UGA with Dmeta, Validation Accuracy: 92.25%


Epoch 2/5 - FedMeta w/ UGA with Dmeta - D: 100%|██████████| 750/750 [01:49<00:00,  6.88it/s]
Epoch 2/5 - FedMeta w/ UGA with Dmeta - Dmeta: 100%|██████████| 938/938 [01:37<00:00,  9.63it/s] 


Epoch [2/5] - FedMeta w/ UGA with Dmeta, Loss: 0.1959
Epoch [2/5] - FedMeta w/ UGA with Dmeta, Validation Accuracy: 94.05%


Epoch 3/5 - FedMeta w/ UGA with Dmeta - D: 100%|██████████| 750/750 [00:47<00:00, 15.84it/s]
Epoch 3/5 - FedMeta w/ UGA with Dmeta - Dmeta: 100%|██████████| 938/938 [01:10<00:00, 13.29it/s]

Epoch [3/5] - FedMeta w/ UGA with Dmeta, Loss: 0.1352





Epoch [3/5] - FedMeta w/ UGA with Dmeta, Validation Accuracy: 96.08%


Epoch 4/5 - FedMeta w/ UGA with Dmeta - D: 100%|██████████| 750/750 [00:47<00:00, 15.88it/s]
Epoch 4/5 - FedMeta w/ UGA with Dmeta - Dmeta: 100%|██████████| 938/938 [01:11<00:00, 13.05it/s]


Epoch [4/5] - FedMeta w/ UGA with Dmeta, Loss: 0.0881
Epoch [4/5] - FedMeta w/ UGA with Dmeta, Validation Accuracy: 96.99%


Epoch 5/5 - FedMeta w/ UGA with Dmeta - D: 100%|██████████| 750/750 [00:48<00:00, 15.58it/s]
Epoch 5/5 - FedMeta w/ UGA with Dmeta - Dmeta: 100%|██████████| 938/938 [02:18<00:00,  6.76it/s]


Epoch [5/5] - FedMeta w/ UGA with Dmeta, Loss: 0.0566
Epoch [5/5] - FedMeta w/ UGA with Dmeta, Validation Accuracy: 98.41%


In [24]:
fedmeta_model1 = CNNModel().to(device)
fedmeta_uga_performance_dmeta1 = train_federated_learning_with_dmetaAdj(fedmeta_model1, train_loaders[0], dmeta_loader, 'FedMeta w/ UGA with Dmeta(50 Removed Nodes)', 50)

Epoch 1/5 - FedMeta w/ UGA with Dmeta(50 Removed Nodes) - Dmeta:   0%|          | 0/938 [00:00<?, ?it/s]Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x0000029ADF054700>
Traceback (most recent call last):
  File "c:\Users\vinji\AppData\Local\Programs\Python\Python38\lib\site-packages\torch\utils\data\dataloader.py", line 1466, in __del__
    self._shutdown_workers()
  File "c:\Users\vinji\AppData\Local\Programs\Python\Python38\lib\site-packages\torch\utils\data\dataloader.py", line 1424, in _shutdown_workers
    if self._persistent_workers or self._workers_status[worker_id]:
AttributeError: '_MultiProcessingDataLoaderIter' object has no attribute '_workers_status'
Epoch 1/5 - FedMeta w/ UGA with Dmeta(50 Removed Nodes) - Dmeta: 100%|██████████| 938/938 [02:16<00:00,  6.85it/s]


Epoch [1/5] - FedMeta w/ UGA with Dmeta(50 Removed Nodes), Loss: 0.2144
Epoch [1/5] - FedMeta w/ UGA with Dmeta(50 Removed Nodes), Validation Accuracy: 80.03%


Epoch 2/5 - FedMeta w/ UGA with Dmeta(50 Removed Nodes) - Dmeta: 100%|██████████| 938/938 [03:45<00:00,  4.16it/s]


Epoch [2/5] - FedMeta w/ UGA with Dmeta(50 Removed Nodes), Loss: 0.1378
Epoch [2/5] - FedMeta w/ UGA with Dmeta(50 Removed Nodes), Validation Accuracy: 82.66%


Epoch 3/5 - FedMeta w/ UGA with Dmeta(50 Removed Nodes) - Dmeta: 100%|██████████| 938/938 [02:12<00:00,  7.07it/s]


Epoch [3/5] - FedMeta w/ UGA with Dmeta(50 Removed Nodes), Loss: 0.1138
Epoch [3/5] - FedMeta w/ UGA with Dmeta(50 Removed Nodes), Validation Accuracy: 85.12%


Epoch 4/5 - FedMeta w/ UGA with Dmeta(50 Removed Nodes) - Dmeta: 100%|██████████| 938/938 [04:26<00:00,  3.52it/s]


Epoch [4/5] - FedMeta w/ UGA with Dmeta(50 Removed Nodes), Loss: 0.0957
Epoch [4/5] - FedMeta w/ UGA with Dmeta(50 Removed Nodes), Validation Accuracy: 88.22%


Epoch 5/5 - FedMeta w/ UGA with Dmeta(50 Removed Nodes) - Dmeta: 100%|██████████| 938/938 [03:12<00:00,  4.86it/s]


Epoch [5/5] - FedMeta w/ UGA with Dmeta(50 Removed Nodes), Loss: 0.0787
Epoch [5/5] - FedMeta w/ UGA with Dmeta(50 Removed Nodes), Validation Accuracy: 88.63%


In [25]:
fedmeta_model2 = CNNModel().to(device)
fedmeta_uga_performance_dmeta2 = train_federated_learning_with_dmetaAdj(fedmeta_model2, train_loaders[0], dmeta_loader, 'FedMeta w/ UGA with Dmeta(100 Removed Nodes)', 100)

Epoch 1/5 - FedMeta w/ UGA with Dmeta(100 Removed Nodes) - Dmeta: 100%|██████████| 938/938 [01:29<00:00, 10.51it/s]


Epoch [1/5] - FedMeta w/ UGA with Dmeta(100 Removed Nodes), Loss: 0.2872
Epoch [1/5] - FedMeta w/ UGA with Dmeta(100 Removed Nodes), Validation Accuracy: 9.12%


Epoch 2/5 - FedMeta w/ UGA with Dmeta(100 Removed Nodes) - Dmeta: 100%|██████████| 938/938 [02:44<00:00,  5.70it/s]


Epoch [2/5] - FedMeta w/ UGA with Dmeta(100 Removed Nodes), Loss: 0.1812
Epoch [2/5] - FedMeta w/ UGA with Dmeta(100 Removed Nodes), Validation Accuracy: 19.40%


Epoch 3/5 - FedMeta w/ UGA with Dmeta(100 Removed Nodes) - Dmeta: 100%|██████████| 938/938 [02:07<00:00,  7.37it/s]


Epoch [3/5] - FedMeta w/ UGA with Dmeta(100 Removed Nodes), Loss: 0.1529
Epoch [3/5] - FedMeta w/ UGA with Dmeta(100 Removed Nodes), Validation Accuracy: 23.42%


Epoch 4/5 - FedMeta w/ UGA with Dmeta(100 Removed Nodes) - Dmeta: 100%|██████████| 938/938 [01:13<00:00, 12.78it/s]


Epoch [4/5] - FedMeta w/ UGA with Dmeta(100 Removed Nodes), Loss: 0.1360
Epoch [4/5] - FedMeta w/ UGA with Dmeta(100 Removed Nodes), Validation Accuracy: 19.38%


Epoch 5/5 - FedMeta w/ UGA with Dmeta(100 Removed Nodes) - Dmeta: 100%|██████████| 938/938 [01:58<00:00,  7.93it/s]


Epoch [5/5] - FedMeta w/ UGA with Dmeta(100 Removed Nodes), Loss: 0.1234
Epoch [5/5] - FedMeta w/ UGA with Dmeta(100 Removed Nodes), Validation Accuracy: 27.45%


In [26]:
fedmeta_model3 = CNNModel().to(device)
fedmeta_uga_performance_dmeta3 = train_federated_learning_with_dmetaAdj(fedmeta_model3, train_loaders[0], dmeta_loader, 'FedMeta w/ UGA with Dmeta(200 Removed Nodes)', 200)

Epoch 1/5 - FedMeta w/ UGA with Dmeta(200 Removed Nodes) - Dmeta: 100%|██████████| 938/938 [01:35<00:00,  9.83it/s]


Epoch [1/5] - FedMeta w/ UGA with Dmeta(200 Removed Nodes), Loss: 0.2433
Epoch [1/5] - FedMeta w/ UGA with Dmeta(200 Removed Nodes), Validation Accuracy: 77.87%


Epoch 2/5 - FedMeta w/ UGA with Dmeta(200 Removed Nodes) - Dmeta: 100%|██████████| 938/938 [03:55<00:00,  3.99it/s]


Epoch [2/5] - FedMeta w/ UGA with Dmeta(200 Removed Nodes), Loss: 0.1540
Epoch [2/5] - FedMeta w/ UGA with Dmeta(200 Removed Nodes), Validation Accuracy: 78.96%


Epoch 3/5 - FedMeta w/ UGA with Dmeta(200 Removed Nodes) - Dmeta: 100%|██████████| 938/938 [03:14<00:00,  4.83it/s]


Epoch [3/5] - FedMeta w/ UGA with Dmeta(200 Removed Nodes), Loss: 0.1294
Epoch [3/5] - FedMeta w/ UGA with Dmeta(200 Removed Nodes), Validation Accuracy: 78.88%


Epoch 4/5 - FedMeta w/ UGA with Dmeta(200 Removed Nodes) - Dmeta: 100%|██████████| 938/938 [01:55<00:00,  8.12it/s]


Epoch [4/5] - FedMeta w/ UGA with Dmeta(200 Removed Nodes), Loss: 0.1097
Epoch [4/5] - FedMeta w/ UGA with Dmeta(200 Removed Nodes), Validation Accuracy: 78.13%


Epoch 5/5 - FedMeta w/ UGA with Dmeta(200 Removed Nodes) - Dmeta:  98%|█████████▊| 923/938 [01:18<00:01, 11.72it/s]


KeyboardInterrupt: 

In [None]:
fedmeta_model4 = CNNModel().to(device)
fedmeta_uga_performance_dmeta4 = train_federated_learning_with_dmetaAdj(fedmeta_model4, train_loaders[0], dmeta_loader, 'FedMeta w/ UGA with Dmeta(400 Removed Nodes)', 400)


In [None]:
fedmeta_model5 = CNNModel().to(device)
fedmeta_uga_performance_dmeta5 = train_federated_learning_with_dmetaAdj(fedmeta_model5, train_loaders[0], dmeta_loader, 'FedMeta w/ UGA with Dmeta(800 Removed Nodes)', 800)

In [None]:
# Create comparison plots for FEMNIST
plt.figure(figsize=(12, 8))

# Test accuracy comparison
plt.subplot(2, 1, 1)
plt.title("Accuracy Comparison")
sns.lineplot(x=range(0, communication_rounds, 20), y=fedmeta_uga_performance_dmeta['accuracy'], label='FedMeta w/ UGA')
sns.lineplot(x=range(0, communication_rounds, 20), y=fedmeta_uga_performance_dmeta1['accuracy'], label='FedMeta w/ UGA(50 Removed)')
sns.lineplot(x=range(0, communication_rounds, 20), y=fedmeta_uga_performance_dmeta2['accuracy'], label='FedMeta w/ UGA(100 Removed)')
sns.lineplot(x=range(0, communication_rounds, 20), y=fedmeta_uga_performance_dmeta3['accuracy'], label='FedMeta w/ UGA(200 Removed)')
sns.lineplot(x=range(0, communication_rounds, 20), y=fedmeta_uga_performance_dmeta4['accuracy'], label='FedMeta w/ UGA(400 Removed)')
sns.lineplot(x=range(0, communication_rounds, 20), y=fedmeta_uga_performance_dmeta5['accuracy'], label='FedMeta w/ UGA(800 Removed)')


# Loss comparison
plt.subplot(2, 1, 2)
plt.title("Loss Comparison")
sns.lineplot(x=range(0, communication_rounds, 20), y=fedmeta_uga_performance_dmeta['loss'], label='FedMeta w/ UGA')
sns.lineplot(x=range(0, communication_rounds, 20), y=fedmeta_uga_performance_dmeta1['loss'], label='edMeta w/ UGA(50 Removed)')
sns.lineplot(x=range(0, communication_rounds, 20), y=fedmeta_uga_performance_dmeta2['loss'], label='edMeta w/ UGA(100 Removed)')
sns.lineplot(x=range(0, communication_rounds, 20), y=fedmeta_uga_performance_dmeta3['loss'], label='FedMeta w/ UGA(200 Removed)')
sns.lineplot(x=range(0, communication_rounds, 20), y=fedmeta_uga_performance_dmeta4['loss'], label='FedMeta w/ UGA(400 Removed)')
sns.lineplot(x=range(0, communication_rounds, 20), y=fedmeta_uga_performance_dmeta5['loss'], label='FedMeta w/ UGA(800 Removed)')

plt.show()