In [None]:
import pickle as pkl
import os

import torch
import numpy as np
from matplotlib import pyplot as plt
import seaborn as sns
from sklearn.neighbors import NearestNeighbors

from data_heterogeneity.config import *

# Analysis

In [None]:

federated_results = pkl.load(open("/home/markhaoxiang/Projects/fl/florl/experiments/data_heterogeneity/iid/61279.pkl", "rb"))
# Analysis functions


def pendulum_dimensionality_reduction(states: torch.Tensor):
    # X: sin(theta)
    x = torch.atan2(states[:,0], states[:, 1]).sin()
    # Y: Min-max
    y = states[:, 2] / 8.0
    return x, y


def kld(p: np.ndarray, q: np.ndarray, k: int = 5):
    """ Estimates D(p||q)

    Implements
    Wang et al. Divergence Estimation for Multidimensional Densities Via Nearest-Neighbor Distances

    In bits. Distance is Euclidean.

    Args:
        p (np.ndarray): Samples from the true distribution.
        q (np.ndarray): Samples from the encoding distribution.
    """
    N, M = len(p), len(q)
    result = 0

    # Fit Nearest Neighbors
    nn_p = NearestNeighbors(n_neighbors=k).fit(p)
    nn_q = NearestNeighbors(n_neighbors=k).fit(q)
    # Calculate the distance to the k nearest neighbor
    distances_e, _ = nn_p.kneighbors(p)
    distances_v, _ = nn_q.kneighbors(p)
    distances_e = np.maximum(distances_e[:, -1], 0.001)
    distances_v = np.maximum(distances_v[:, -1], 0.001)
    # Summation
    result += np.log2(distances_v / distances_e).sum()
    # Final logarithmic term
    result = result * (len(p[0]) / N)
    result += np.log2(M/(N-1))
    return result

def jsd(batch_1: torch.Tensor, batch_2: torch.Tensor, k: int = 5):
    """ Estimates the Jensen-Shannon divergence between distributions whose samples form batch_1 and batch_2
    
    Args:
        batch_1 (torch.Tensor): Samples from the first distribution.
        batch_2 (torch.Tensor): Samples from the second distribution.
    """
    batch_1, batch_2 = np.array(batch_1), np.array(batch_2)
    return (kld(batch_1, batch_2, k) + kld(batch_2, batch_1, k)) / 2


# Plotting
N_COL = 3
N_ROW = 3

sns.set_theme()
fig, axs = plt.subplots(N_ROW, N_COL, sharey='row',  gridspec_kw={'hspace': 0.2, 'wspace': 0.1})
fig.set_size_inches(N_COL * 3, N_ROW * 3)

fig.suptitle("Dataset Heterogeneity (Pendulum, I.I.D Env)", weight="bold", fontsize=23)
colors = sns.color_palette("husl", NUM_CLIENTS)
markers = ["o", "s", "P", "^", "X"]

# Text annotations for each row
fig.text(0.0, 0.77, "All, Replay Buffer", ha='center', va='center', rotation='vertical', fontsize=12)
fig.text(0.0, 0.5, "Latest Episode, Replay Buffer", ha='center', va='center', rotation='vertical', fontsize=12)
fig.text(0.0, 0.23, "All, Divergence", ha='center', va='center', rotation='vertical', fontsize=12)

ids_ = federated_results.metrics_distributed_fit["id"]
id_to_index = {id_: i  for i, id_ in enumerate([x[1] for x in ids_[0][1]["all"]])}


# Precompute divergence maps
divergence_matrices = [None for _ in range(N_COL)]
for i in range(N_COL):
    round_number =  min(TOTAL_ROUNDS-1, int(TOTAL_ROUNDS * (i/(N_COL-1)) + (i/(N_COL-1)) // 2))
    transition_states_clients = [pkl.loads(x) for x in federated_results.metrics_distributed_fit['rb'][round_number][1]]
    order = [(j, id_to_index[ids_[round_number][1]["all"][j][1]]) for j in range(len(transition_states_clients))]
    order.sort(key=lambda x: x[1])
    n = int(federated_results.metrics_distributed_fit["rb_size"][round_number][1]['avg'])
    divergence_matrix = np.zeros((NUM_CLIENTS, NUM_CLIENTS))
    for j, _ in order:
        for k, _ in order:
            if k == j:
                continue
            data_j = transition_states_clients[j][:n]
            id_j = ids_[round_number][1]["all"][j][1]
            index_j = id_to_index[id_j]
            data_k = transition_states_clients[k][:n]
            id_k = ids_[round_number][1]["all"][k][1]
            index_k = id_to_index[id_k]
            divergence_matrix[index_j, index_k] = jsd(data_j, data_k)
    divergence_matrices[i] = divergence_matrix
maximum_divergence = max([x.max() for x in divergence_matrices])
minimum_divergence = min([x.min() for x in divergence_matrices])

for i in range(N_COL):
    round_number =  min(TOTAL_ROUNDS-1, int(TOTAL_ROUNDS * (i/(N_COL-1)) + (i/(N_COL-1)) // 2))
    transition_states_clients = [pkl.loads(x) for x in federated_results.metrics_distributed_fit['rb'][round_number][1]]
    new_transition_states_clients = [pkl.loads(x) for x in federated_results.metrics_distributed_fit['rb_new'][round_number][1]]
    
    # Total
    ax = axs[0][i]
    ax.set_title(f"Round {round_number}")

    order = [(j, id_to_index[ids_[round_number][1]["all"][j][1]]) for j in range(len(transition_states_clients))]
    order.sort(key=lambda x: x[1])

    n = int(federated_results.metrics_distributed_fit["rb_size"][round_number][1]['avg'])
    for j, _ in order:
        data = transition_states_clients[j]
        id_ = ids_[round_number][1]["all"][j][1]
        index = id_to_index[id_]
        x, y = pendulum_dimensionality_reduction(data[:n])
        ax.scatter(x, y, color=colors[index], s=5, marker=markers[index], label=f"Client {index}")
    ax.set_xlim(-1, 1)
    ax.set_ylim(-1, 1)

    # New
    ax = axs[1][i]
    for j, data in enumerate(new_transition_states_clients):
        id_ = ids_[round_number][1]["all"][j][1]
        index = id_to_index[id_]
        x, y = pendulum_dimensionality_reduction(data)
        ax.scatter(x, y, color=colors[index], s=5, marker=markers[index], label=f"Client {index}")
    ax.set_xlim(-1, 1)
    ax.set_ylim(-1, 1)

    # Similarity Matrix
    ax = axs[2][i]
    #sns.heatmap(divergence_matrices[i], ax=ax, vmax=maximum_divergence, vmin=minimum_divergence, cbar=True)
    sns.heatmap(divergence_matrices[i], ax=ax, cbar=True, annot=True)

for ax in axs.flat:# Analysis functions
    ax.label_outer()
for i in range(2):
    axs[i][0].set_ylabel("CCW Torque (Normalised)")
for j in range(N_COL):
    axs[0][j].set_xlabel('Sin(θ)')
    axs[1][j].set_xlabel('Sin(θ)')

handles, labels = axs[-2, -1].get_legend_handles_labels()
fig.legend(handles, labels, loc='upper center', bbox_to_anchor=(0.5, 1.05), ncol=NUM_CLIENTS, title="Clients", fontsize='small', title_fontsize='medium', markerscale=1)

fig.savefig("data_heterogeneity/plot_2.pdf", format="pdf", bbox_inches="tight")

In [None]:
fixed = False
name = "data_heterogeneity/fixed" if fixed else "data_heterogeneity/iid"

results = []

for fn in os.listdir(name):
    federated_results = pkl.load(open(os.path.join(name, fn), "rb"))
    mean_jsd_all = []
    for round_number in range(TOTAL_ROUNDS):
        transition_states_clients = [pkl.loads(x) for x in federated_results.metrics_distributed_fit['rb'][round_number][1]]
        #transition_states_clients = [pkl.loads(x) for x in federated_results.metrics_distributed_fit['rb_new'][round_number][1]]
        order = [(j, id_to_index[ids_[round_number][1]["all"][j][1]]) for j in range(len(transition_states_clients))]
        order.sort(key=lambda x: x[1])
        n = int(federated_results.metrics_distributed_fit["rb_size"][round_number][1]['avg'])
        total_jsd = 0
        for j, _ in order:
            for k, _ in order:
                if k == j:
                    continue
                data_j = transition_states_clients[j][:n]
                id_j = ids_[round_number][1]["all"][j][1]
                index_j = id_to_index[id_j]
                data_k = transition_states_clients[k][:n]
                id_k = ids_[round_number][1]["all"][k][1]
                index_k = id_to_index[id_k]
                total_jsd += jsd(data_j, data_k)
        mean_jsd = total_jsd / (NUM_CLIENTS * (NUM_CLIENTS-1))
        mean_jsd_all.append(mean_jsd)
    results.append(mean_jsd_all)

if fixed:
    fixed_results = np.array(results)
    fixed_results_mean = fixed_results.mean(axis=0)
    fixed_results_std = fixed_results.std(axis=0)
else:
    iid_results = np.array(results)
    iid_results_mean = iid_results.mean(axis=0)
    iid_results_std = iid_results.std(axis=0)

In [None]:
lower_bound = []
for i in range(100):
    A = torch.tensor(np.random.rand(TOTAL_ROUNDS * FRAMES_PER_ROUND, 2))
    B = torch.tensor(np.random.rand(TOTAL_ROUNDS * FRAMES_PER_ROUND, 2))
    lower_bound.append(jsd(A,B))
lower_bound = sum(lower_bound) / len(lower_bound)

In [None]:
fig, ax = plt.subplots()

x = range(TOTAL_ROUNDS)

ax.plot(x, fixed_results_mean, color="blue", linestyle="dashdot", label ="Fixed Reset (Heterogenous)")
ax.fill_between(
    x=x,
    y1=fixed_results_mean - fixed_results_std * 1.96,
    y2=fixed_results_mean + fixed_results_std * 1.96,
    alpha=0.4,
    color="white",
    facecolor="blue",
    hatch="/"
)

ax.plot(x, iid_results_mean, color="orange", label="I.I.D Environment")
ax.fill_between(
    x=x,
    y1=iid_results_mean - iid_results_std * 1.96,
    y2=iid_results_mean + iid_results_std * 1.96,
    alpha=0.4,
    color="white",
    facecolor="orange",
    hatch="\\"
)

ax.set_ylabel("Dataset Divergence (Bits)")
ax.set_xlabel("Round")
ax.legend()
ax.set_title("Reinforcement Learning Dataset Heterogeneity (Pendulum)", weight="bold", fontsize=14)

ax.hlines(y=lower_bound, xmin=0, xmax=TOTAL_ROUNDS, color="black", linestyles="--")
ax.text(40,0,f"I.I.D Data {round(lower_bound, 2)}")

fig.savefig("data_heterogeneity/plot.pdf", format="pdf", bbox_inches="tight")