In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import matplotlib.cm as cm
import os 
import seaborn as sns
import glob
import natsort

In [None]:
import seaborn as sns
import copy

def load_data(save_path):
    datas = list()
    path_list = glob.glob(save_path)
    path_list = natsort.natsorted(path_list)
    data_temp = get_concat_glob(path_list)
    data_temp["a1"] = data_temp["a1"] * 100
    data_temp["a3"] = data_temp["a3"] * 100
    data_temp['Distribution'] = data_temp['distribution'] .map({0: "IID partition", 
                                                                0.3: "Non-IID partition (α=0.3)"})
    return data_temp

In [None]:
def get_concat_glob(paths):
    data_ = []
    for i, path in enumerate(paths):

        if "LN" in path: continue
        if "WD" in path: continue
        if "FCs" in path: continue
        if "G1" in path: continue
        if "Mean" in path: continue
        if "gs" in path: continue        
        if "hs" in path: continue
        if "Mask" in path: continue
        if "LR001" in path: continue
        
        data = pd.read_csv(path)
        data["Mode"] = data["mode"]
        data["Scaling"] = "vanilla"

        if "Output_GU" in path and "__GU" in path and "LR" in path:
            data["Scaling"] = path.split("__GU")[1].split("/")[0] 
            # continue
        elif "__GU" in path: 
            number = float(path.split("__GU")[1].split("/")[0])
            # print(number)
            data["Scaling"] = "GU" + str(number)
        elif "__GGU" in path: 
            number = float(path.split("__GGU")[1].split("/")[0])
            # print(number)
            data["Scaling"] = "GGU" + str(number)
        # else:
        #     print(path)

        if len(data_) == 0: 
            # print(path)
            data_ = data
        else: 
            # print(path)
            data_ = pd.concat([data_, data])
    return data_

## BarPlot

### CIFAR100, All

In [None]:
data_temp = load_data("../Output_GU_C100/CIFAR100-*-N100-*/s[1-9].csv")

for algo in ["fedavg"]:
    fig, axes = plt.subplots(1, 2, figsize=(20,7), dpi=200, sharey=True)
    sns.set_style("darkgrid")
    for index, name, iid in zip(range(2), [algo, algo], [0, 0.3]):        
        ax = axes[index]
        temp = copy.deepcopy(data_temp)
        temp = temp[temp["act_prob"] == 0.01]
        temp = temp[temp["epoch"] == 999]
        if iid == 0.: temp = temp[(temp["Distribution"] == "IID partition")]
        elif iid == 0.3: temp = temp[(temp["Distribution"] != "IID partition")]
        temp = temp[(temp["Mode"] == name)]
        
#         replace_icon =   {"GU111.01": "1", 
#                           "GU111.02": "2", 
#                           "GU111.04": "4", 
#                           "GU111.06": "6", 
#                           "GU111.08": "8", 
#                           "GU111.1": "10"
#                          }
        
#         sk = ["vanilla"] + list(replace_icon.keys())
#         temp = temp.set_index('Scaling').loc[sk].reset_index()
        
        ax = sns.barplot(data=temp, x="Scaling", y="a3", ax=ax, palette="tab10", errorbar="se", linewidth = 0.8)
        ax.set_ylim([10, 54])
        ax.set_ylabel("Test Accuracy (%)")
        # ax.set_xlabel("Prob of Routing to Global Model")
        ax.set_title("CIFAR100 (IID Label Distribution")
        ax.set_xticklabels(ax.get_xticklabels(), rotation = 40)
    fig.tight_layout()
    plt.pause(0.1)

### GU111

In [None]:
data_temp = load_data("../Output_GU_C100/CIFAR100-*-N100-*/s[1-9].csv")

for algo in ["fedavg"]:
    fig, axes = plt.subplots(1, 2, figsize=(15,6), dpi=200, sharey=True)
    sns.set_style("darkgrid")
    for index, name, iid in zip(range(2), [algo, algo], [0, 0.3]):        
        ax = axes[index]
        temp = copy.deepcopy(data_temp)
        temp = temp[temp["act_prob"] == 0.01]
        temp = temp[temp["epoch"] == 999]
        if iid == 0.: 
            temp = temp[(temp["Distribution"] == "IID partition")]
            title = "CIFAR100 (IID Label Distribution)"
        elif iid == 0.3: 
            temp = temp[(temp["Distribution"] != "IID partition")]
            title = "CIFAR100 (Non-IID Label Distribution)"

        temp = temp[(temp["Mode"] == name)]
        
        replace_icon =   {"GU111.01": "1", 
                          "GU111.02": "2", 
                          "GU111.04": "4", 
                          "GU111.06": "6", 
                          "GU111.08": "8", 
                          "GU111.1": "10"
                         }
        
        sk = ["vanilla"] + list(replace_icon.keys())
        temp = temp.set_index('Scaling').loc[sk].reset_index()
        temp = temp.replace({"Scaling": replace_icon})
        ax = sns.barplot(data=temp, x="Scaling", y="a3", ax=ax, palette="tab10", errorbar="se", linewidth = 0.8)
        ax.set_ylim([46, 54])
        ax.set_ylabel("Test Accuracy (%)")
        ax.set_xlabel("Number of local iters at each unfrozing stage")
        ax.set_title(title)
        # ax.set_xticklabels(ax.get_xticklabels(), rotation = 40)
    fig.tight_layout()
    plt.pause(0.1)

### GU 121

In [None]:
data_temp = load_data("../Output_GU_C100/CIFAR100-*-N100-*/s[1-9].csv")

for algo in ["fedavg"]:
    fig, axes = plt.subplots(1, 2, figsize=(15,6), dpi=200, sharey=True)
    sns.set_style("darkgrid")
    for index, name, iid in zip(range(2), [algo, algo], [0, 0.3]):        
        ax = axes[index]
        temp = copy.deepcopy(data_temp)
        temp = temp[temp["act_prob"] == 0.01]
        temp = temp[temp["epoch"] == 999]
        if iid == 0.: 
            temp = temp[(temp["Distribution"] == "IID partition")]
            title = "CIFAR100 (IID Label Distribution)"
        elif iid == 0.3: 
            temp = temp[(temp["Distribution"] != "IID partition")]
            title = "CIFAR100 (Non-IID Label Distribution)"
        temp = temp[(temp["Mode"] == name)]
        
        replace_icon =   {"GU112.01": "1", 
                          "GU112.02": "2", 
                          "GU112.04": "4", 
                          "GU112.06": "6", 
                          "GU112.08": "8", 
                          "GU112.1": "10"
                         }
        
        sk = ["vanilla"] + list(replace_icon.keys())
        temp = temp.set_index('Scaling').loc[sk].reset_index()
        temp = temp.replace({"Scaling": replace_icon})
        ax = sns.barplot(data=temp, x="Scaling", y="a3", ax=ax, palette="tab10", errorbar="se", linewidth = 0.8)
        ax.set_ylim([25, 50])
        ax.set_ylabel("Test Accuracy (%)")
        ax.set_xlabel("Number of local iters at each unfrozing stage")
        ax.set_title(title)
        # ax.set_xticklabels(ax.get_xticklabels(), rotation = 40)
    fig.tight_layout()
    plt.pause(0.1)

### GGU 111

In [None]:
data_temp = load_data("../Output_GU_C100/CIFAR100-*-N100-*/s[1-9].csv")

for algo in ["fedavg"]:
    fig, axes = plt.subplots(1, 2, figsize=(15,6), dpi=200, sharey=True)
    sns.set_style("darkgrid")
    for index, name, iid in zip(range(2), [algo, algo], [0, 0.3]):        
        ax = axes[index]
        temp = copy.deepcopy(data_temp)
        temp = temp[temp["act_prob"] == 0.01]
        temp = temp[temp["epoch"] == 999]
        if iid == 0.: temp = temp[(temp["Distribution"] == "IID partition")]
        elif iid == 0.3: temp = temp[(temp["Distribution"] != "IID partition")]
        temp = temp[(temp["Mode"] == name)]
        
        replace_icon =   {"GGU111.97": "0.97", 
                          "GGU111.95": "0.95", 
                          "GGU111.93": "0.93", 
                          "GGU111.91": "0.91",
                         }
        
        sk = ["vanilla"] + list(replace_icon.keys())
        temp = temp.set_index('Scaling').loc[sk].reset_index()
        temp = temp.replace({"Scaling": replace_icon})
        ax = sns.barplot(data=temp, x="Scaling", y="a3", ax=ax, palette="tab10", errorbar="se", linewidth = 0.8)
        ax.set_ylim([46, 54])
        ax.set_ylabel("Test Accuracy (%)")
        ax.set_xlabel("Number of local iters at each unfrozing stage")
        ax.set_title("CIFAR100 (IID Label Distribution")
        # ax.set_xticklabels(ax.get_xticklabels(), rotation = 40)
    fig.tight_layout()
    plt.pause(0.1)

### GGU 121

In [None]:
data_temp = load_data("../Output_GU_C100/CIFAR100-*-N100-*/s[1-9].csv")

for algo in ["fedavg"]:
    fig, axes = plt.subplots(1, 2, figsize=(15,6), dpi=200, sharey=True)
    sns.set_style("darkgrid")
    for index, name, iid in zip(range(2), [algo, algo], [0, 0.3]):        
        ax = axes[index]
        temp = copy.deepcopy(data_temp)
        temp = temp[temp["act_prob"] == 0.01]
        temp = temp[temp["epoch"] == 999]
        if iid == 0.: temp = temp[(temp["Distribution"] == "IID partition")]
        elif iid == 0.3: temp = temp[(temp["Distribution"] != "IID partition")]
        temp = temp[(temp["Mode"] == name)]
        
        replace_icon =   {"GGU121.03": "0.03", 
                          "GGU121.05": "0.05", 
                          "GGU121.07": "0.07", 
                          "GGU121.09": "0.09",
                         }
        
        sk = ["vanilla"] + list(replace_icon.keys())
        temp = temp.set_index('Scaling').loc[sk].reset_index()
        temp = temp.replace({"Scaling": replace_icon})
        ax = sns.barplot(data=temp, x="Scaling", y="a3", ax=ax, palette="tab10", errorbar="se", linewidth = 0.8)
        ax.set_ylim([46, 54])
        ax.set_ylabel("Test Accuracy (%)")
        ax.set_xlabel("Number of local iters at each unfrozing stage")
        ax.set_title("CIFAR100 (IID Label Distribution")
        # ax.set_xticklabels(ax.get_xticklabels(), rotation = 40)
    fig.tight_layout()
    plt.pause(0.1)

# CIFAR10

### All BarPlot

In [None]:
data_temp = load_data("../Output_GU_C10/CIFAR10-*-N100-*/s[1-9].csv")

for algo in ["fedavg"]:
    fig, axes = plt.subplots(1, 2, figsize=(20,7), dpi=200, sharey=True)
    sns.set_style("darkgrid")
    for index, name, iid in zip(range(2), [algo, algo], [0, 0.3]):        
        ax = axes[index]
        temp = copy.deepcopy(data_temp)
        temp = temp[temp["act_prob"] == 0.01]
        temp = temp[temp["epoch"] == 999]
        if iid == 0.: temp = temp[(temp["Distribution"] == "IID partition")]
        elif iid == 0.3: temp = temp[(temp["Distribution"] != "IID partition")]
        temp = temp[(temp["Mode"] == name)]
        
#         replace_icon =   {"GU111.01": "1", 
#                           "GU111.02": "2", 
#                           "GU111.04": "4", 
#                           "GU111.06": "6", 
#                           "GU111.08": "8", 
#                           "GU111.1": "10"
#                          }
        
#         sk = ["vanilla"] + list(replace_icon.keys())
#         temp = temp.set_index('Scaling').loc[sk].reset_index()
        
#         temp = temp.replace({"Scaling": replace_icon})
        
        ax = sns.barplot(data=temp, x="Scaling", y="a3", ax=ax, palette="tab10", errorbar="se", linewidth = 0.8)
        ax.set_ylim([79, 84])
        ax.set_ylabel("Test Accuracy (%)")
        # ax.set_xlabel("Prob of Routing to Global Model")
        ax.set_title("CIFAR100 (IID Label Distribution")
        ax.set_xticklabels(ax.get_xticklabels(), rotation = 40)
    fig.tight_layout()
    plt.pause(0.1)

### GU 111

In [None]:
data_temp = load_data("../Output_GU_C10/CIFAR10-*-N100-*/s[1-9].csv")

for algo in ["fedavg"]:
    fig, axes = plt.subplots(1, 2, figsize=(15,6), dpi=200, sharey=True)
    sns.set_style("darkgrid")
    for index, name, iid in zip(range(2), [algo, algo], [0, 0.3]):        
        ax = axes[index]
        temp = copy.deepcopy(data_temp)
        temp = temp[temp["act_prob"] == 0.01]
        temp = temp[temp["epoch"] == 999]
        if iid == 0.: temp = temp[(temp["Distribution"] == "IID partition")]
        elif iid == 0.3: temp = temp[(temp["Distribution"] != "IID partition")]
        temp = temp[(temp["Mode"] == name)]
        
        replace_icon =   {"GU111.01": "1", 
                          "GU111.02": "2", 
                          "GU111.04": "4", 
                          "GU111.06": "6", 
                          "GU111.08": "8", 
                          "GU111.1": "10"
                         }
        
        sk = ["vanilla"] + list(replace_icon.keys())
        temp = temp.set_index('Scaling').loc[sk].reset_index()
        
        temp = temp.replace({"Scaling": replace_icon})
        
        ax = sns.barplot(data=temp, x="Scaling", y="a3", ax=ax, palette="tab10", errorbar="se", linewidth = 0.8)
        ax.set_ylim([79, 84])
        ax.set_ylabel("Test Accuracy (%)")
        ax.set_xlabel("Number of local iters at each unfrozing stage")
        ax.set_title("CIFAR100 (IID Label Distribution")
        ax.set_xticklabels(ax.get_xticklabels(), rotation = 40)
    fig.tight_layout()
    plt.pause(0.1)

### GU112

In [None]:
data_temp = load_data("../Output_GU_C10/CIFAR10-*-N100-*/s[1-9].csv")

for algo in ["fedavg"]:
    fig, axes = plt.subplots(1, 2, figsize=(15,6), dpi=200, sharey=True)
    sns.set_style("darkgrid")
    for index, name, iid in zip(range(2), [algo, algo], [0, 0.3]):        
        ax = axes[index]
        temp = copy.deepcopy(data_temp)
        temp = temp[temp["act_prob"] == 0.01]
        temp = temp[temp["epoch"] == 999]
        if iid == 0.: 
            temp = temp[(temp["Distribution"] == "IID partition")]
            title = "CIFAR10 (IID Label Distribution)"
        elif iid == 0.3: 
            temp = temp[(temp["Distribution"] != "IID partition")]
            title = "CIFAR10 (Non-IID Label Distribution)"
        temp = temp[(temp["Mode"] == name)]
        
        replace_icon =   {"GU112.01": "1", 
                          "GU112.02": "2", 
                          "GU112.04": "4", 
                          "GU112.06": "6", 
                          "GU112.08": "8", 
                          "GU112.1": "10"
                         }
        
        sk = ["vanilla"] + list(replace_icon.keys())
        temp = temp.set_index('Scaling').loc[sk].reset_index()
        temp = temp.replace({"Scaling": replace_icon})
        ax = sns.barplot(data=temp, x="Scaling", y="a3", ax=ax, palette="tab10", errorbar="se", linewidth = 0.8)
        ax.set_ylim([75, 84])
        ax.set_ylabel("Test Accuracy (%)")
        ax.set_xlabel("Number of local iters at each unfrozing stage")
        ax.set_title(title)
        # ax.set_xticklabels(ax.get_xticklabels(), rotation = 40)
    fig.tight_layout()
    plt.pause(0.1)

### GU 211

In [None]:
data_temp = load_data("../Output_GU_C10/CIFAR10-*-N100-*/s[1-9].csv")

for algo in ["fedavg"]:
    fig, axes = plt.subplots(1, 2, figsize=(20,7), dpi=200, sharey=True)
    sns.set_style("darkgrid")
    for index, name, iid in zip(range(2), [algo, algo], [0, 0.3]):        
        ax = axes[index]
        temp = copy.deepcopy(data_temp)
        temp = temp[temp["act_prob"] == 0.01]
        temp = temp[temp["epoch"] == 999]
        if iid == 0.: temp = temp[(temp["Distribution"] == "IID partition")]
        elif iid == 0.3: temp = temp[(temp["Distribution"] != "IID partition")]
        temp = temp[(temp["Mode"] == name)]
        
        replace_icon =   {"GU211.3": "x 0.3", 
                          "GU211.1": "x 0.1", 
                          "GU211.03": "x 0.03", 
                          "GU211.01": "x 0.01", 
                        #   "GU211.0": "x 0.0", 
                         }
        
        sk = ["vanilla"] + list(replace_icon.keys())
        temp = temp.set_index('Scaling').loc[sk].reset_index()
        
        temp = temp.replace({"Scaling": replace_icon})
        
        ax = sns.barplot(data=temp, x="Scaling", y="a3", ax=ax, palette="tab10", errorbar="se", linewidth = 0.8)
        ax.set_ylim([79, 84])
        ax.set_ylabel("Test Accuracy (%)")
        # ax.set_xlabel("Prob of Routing to Global Model")
        ax.set_title("CIFAR100 (IID Label Distribution")
        ax.set_xticklabels(ax.get_xticklabels(), rotation = 40)
    fig.tight_layout()
    plt.pause(0.1)

### GU 112

In [None]:
data_temp = load_data("../Output_GU_C10/CIFAR10-*-N100-*/s[1-9].csv")

for algo in ["fedavg"]:
    fig, axes = plt.subplots(1, 2, figsize=(20,7), dpi=200, sharey=True)
    sns.set_style("darkgrid")
    for index, name, iid in zip(range(2), [algo, algo], [0, 0.3]):        
        ax = axes[index]
        temp = copy.deepcopy(data_temp)
        temp = temp[temp["act_prob"] == 0.01]
        temp = temp[temp["epoch"] == 999]
        if iid == 0.: temp = temp[(temp["Distribution"] == "IID partition")]
        elif iid == 0.3: temp = temp[(temp["Distribution"] != "IID partition")]
        temp = temp[(temp["Mode"] == name)]
        
        replace_icon =   {"GU112.01": "1", 
                          "GU112.02": "2", 
                          "GU112.04": "4", 
                          "GU112.06": "6", 
                          "GU112.08": "8", 
                        #   "GU112.1": "10"
                         }
        
        sk = ["vanilla"] + list(replace_icon.keys())
        temp = temp.set_index('Scaling').loc[sk].reset_index()
        
        temp = temp.replace({"Scaling": replace_icon})
        
        ax = sns.barplot(data=temp, x="Scaling", y="a3", ax=ax, palette="tab10", errorbar="se", linewidth = 0.8)
        ax.set_ylim([79, 84])
        ax.set_ylabel("Test Accuracy (%)")
        # ax.set_xlabel("Prob of Routing to Global Model")
        ax.set_title("CIFAR100 (IID Label Distribution")
        ax.set_xticklabels(ax.get_xticklabels(), rotation = 40)
    fig.tight_layout()
    plt.pause(0.1)

### GGU 121

In [None]:
data_temp = load_data("../Output_GU_C10/CIFAR10-*-N100-*/s[1-9].csv")

for algo in ["fedavg"]:
    fig, axes = plt.subplots(1, 2, figsize=(20,7), dpi=200, sharey=True)
    sns.set_style("darkgrid")
    for index, name, iid in zip(range(2), [algo, algo], [0, 0.3]):        
        ax = axes[index]
        temp = copy.deepcopy(data_temp)
        temp = temp[temp["act_prob"] == 0.01]
        temp = temp[temp["epoch"] == 999]
        if iid == 0.: temp = temp[(temp["Distribution"] == "IID partition")]
        elif iid == 0.3: temp = temp[(temp["Distribution"] != "IID partition")]
        temp = temp[(temp["Mode"] == name)]
        
        replace_icon =   {"GGU121.03": "0.03", 
                          "GGU121.05": "0.05", 
                          "GGU121.07": "0.07", 
                          "GGU121.09": "0.09",
                         }
        
        sk = ["vanilla"] + list(replace_icon.keys())
        temp = temp.set_index('Scaling').loc[sk].reset_index()
        
        temp = temp.replace({"Scaling": replace_icon})
        
        ax = sns.barplot(data=temp, x="Scaling", y="a3", ax=ax, palette="tab10", errorbar="se", linewidth = 0.8)
        ax.set_ylim([79, 84])
        ax.set_ylabel("Test Accuracy (%)")
        # ax.set_xlabel("Prob of Routing to Global Model")
        ax.set_title("CIFAR100 (IID Label Distribution")
        ax.set_xticklabels(ax.get_xticklabels(), rotation = 40)
    fig.tight_layout()
    plt.pause(0.1)