In [1]:
import asyncio
from fl_quic_transport_flower import FLQuicServer
import tensorflow as tf
from flwr_datasets import FederatedDataset
from fl_quic_transport_flower import fl_client_quic

async def run_client():
    model = tf.keras.Sequential([
        tf.keras.layers.Flatten(input_shape=(32, 32, 3)),
        tf.keras.layers.Dense(10, activation='softmax')
    ])
    model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"])

    fds = FederatedDataset(dataset="cifar10", partitioners={"train": 3})
    partition = fds.load_partition(partition_id=0, split="train")
    partition.set_format("numpy")
    partition = partition.train_test_split(test_size=0.2, seed=42)
    x_train, y_train = partition["train"]["img"] / 255.0, partition["train"]["label"]

    await asyncio.sleep(1)
    await fl_client_quic("localhost", 4433, model, x_train, y_train)

await run_client()


2025-04-25 19:56:33.990046: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
  from .autonotebook import tqdm as notebook_tqdm
  super().__init__(**kwargs)
Generating train split: 100%|██████████| 50000/50000 [00:00<00:00, 104880.10 examples/s]
Generating test split: 100%|██████████| 10000/10000 [00:00<00:00, 108811.54 examples/s]


ConnectionError: 

In [None]:
# Install needed packages if you haven't
# !pip install aioquic flwr flwr-datasets tensorflow

# -------------------------
# SERVER SIDE (Run this first in one cell)
# -------------------------
import asyncio
from fl_quic_transport_flower import FLQuicServer

server = FLQuicServer()
loop = asyncio.get_event_loop()
loop.create_task(server.run())
print("[Server] QUIC Federated Server running... (Ctrl+C to stop)")

# -------------------------
# CLIENT SIDE (Run this second in another cell)
# -------------------------
import asyncio
import tensorflow as tf
from flwr_datasets import FederatedDataset
from fl_quic_transport_flower import fl_client_quic

# Build model
model = tf.keras.applications.MobileNetV2((32, 32, 3), classes=10, weights=None)
model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"])

# Load CIFAR-10 partition
fds = FederatedDataset(dataset="cifar10", partitioners={"train": 3})
partition = fds.load_partition(partition_id=0, split="train")
partition.set_format("numpy")
partition = partition.train_test_split(test_size=0.2, seed=42)
x_train, y_train = partition["train"]["img"] / 255.0, partition["train"]["label"]

# Run QUIC Client
asyncio.run(fl_client_quic("localhost", 4433, model, x_train, y_train))


In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Example dummy data (replace with your real measured values)
rounds = np.array([1, 2, 3, 4, 5])
grpc_times = np.array([10.5, 9.8, 10.1, 9.6, 10.3])  # Replace with your gRPC round durations
quic_times = np.array([6.3, 5.9, 6.1, 5.7, 6.2])    # Replace with your QUIC round durations

# Create the plot
plt.figure(figsize=(8, 5))
plt.plot(rounds, grpc_times, marker='o', label='gRPC over TCP', linestyle='--')
plt.plot(rounds, quic_times, marker='s', label='QUIC Optimized', linestyle='-')

# Labels and Title
plt.xlabel('Training Round')
plt.ylabel('Round Duration (seconds)')
plt.title('Federated Learning Communication Round Durations')
plt.legend()
plt.grid(True)

# Optional: Customize Y limits
# plt.ylim(0, max(grpc_times.max(), quic_times.max()) + 2)

# Show plot
plt.show()


In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Example dummy data (replace with your actual measured network sizes)
methods = ['gRPC (TCP)', 'QUIC']
avg_bytes_sent = [520000, 310000]  # Replace with average bytes sent per client per round

# Create the bar plot
plt.figure(figsize=(6, 5))
bars = plt.bar(methods, avg_bytes_sent, color=['gray', 'blue'])

# Add numbers on top of bars
for bar in bars:
    yval = bar.get_height()
    plt.text(bar.get_x() + bar.get_width()/2.0, yval + 10000, f'{yval/1000:.1f} KB', ha='center', va='bottom')

# Labels and Title
plt.ylabel('Average Data Sent per Round (bytes)')
plt.title('Network Overhead Comparison: gRPC vs QUIC')
plt.grid(axis='y')

# Show plot
plt.show()


In [None]:
# Replace with your measured values
grpc_avg_time = 10.2  # seconds
quic_avg_time = 6.1   # seconds

speedup_percent = ((grpc_avg_time - quic_avg_time) / grpc_avg_time) * 100
print(f"Communication Speedup: {speedup_percent:.2f}%")
