In [1]:
import numpy as np
import os
import csv
import matplotlib.pyplot as plt
import argparse

In [2]:
dict_IDDPG = {'update_period_1': ['IDDPG-20230222-200936',
                      'IDDPG-20230223-004434',
                      'IDDPG-20230223-052433'],
              'update_period_3': ['IDDPG-20230223-100613',
                      'IDDPG-20230223-143859',
                      'IDDPG-20230223-191829'],
              'update_period_5': ['IDDPG-20230223-234842',
                      'IDDPG-20230224-045714',
                      'IDDPG-20230224-102218']}
dict_SNDDPG = {'update_period_1': ['SNDDPG-20230222-214715',
                       'SNDDPG-20230223-022142',
                       'SNDDPG-20230223-070429'],
              'update_period_3': ['SNDDPG-20230223-113915',
                      'SNDDPG-20230223-161720',
                      'SNDDPG-20230223-205630'],
              'update_period_5': ['SNDDPG-20230224-012708',
                      'SNDDPG-20230224-065046',
                      'SNDDPG-20230224-121620']}
dict_FLDDPG = {'update_period_1': ['FLDDPG-20230222-230953',
                       'FLDDPG-20230223-034437',
                       'FLDDPG-20230223-082832'],
              'update_period_3': ['FLDDPG-20230223-130117',
                      'FLDDPG-20230223-174100',
                      'FLDDPG-20230223-221115'],
              'update_period_5': ['FLDDPG-20230224-030532',
                      'FLDDPG-20230224-082944']}
dict_local_update = {'update_period_1': ['SwarmDDPG-20230222-201113',
                             'SwarmDDPG-20230222-232657',
                             'SwarmDDPG-20230223-023904'],
              'update_period_3': ['SwarmDDPG-20230223-055714',
                      'SwarmDDPG-20230223-091550',
                      'SwarmDDPG-20230223-122657'],
              'update_period_5': ['SwarmDDPG-20230223-154414',
                      'SwarmDDPG-20230223-185850',
                      'SwarmDDPG-20230223-221257']}
dict_round_update = {'update_period_1': ['SwarmDDPG-20230222-214845',
                             'SwarmDDPG-20230223-010202',
                             'SwarmDDPG-20230223-041703'],
              'update_period_3': ['SwarmDDPG-20230223-073531',
                     'SwarmDDPG-20230223-105045',
                     'SwarmDDPG-20230223-140444'],
              'update_period_5': ['SwarmDDPG-20230223-172029',
                      'SwarmDDPG-20230223-203537',
                      'SwarmDDPG-20230223-221257']}
dict_pair_update = {'update_period_1': ['SwarmDDPG-20230224-161632',
                            'SwarmDDPG-20230224-180631',
                            'SwarmDDPG-20230224-195743'],
              'update_period_3': ['SwarmDDPG-20230224-214856',
                      'SwarmDDPG-20230224-233835', 
                      'SwarmDDPG-20230225-012700'],
              'update_period_5': ['SwarmDDPG-20230225-030542',
                      'SwarmDDPG-20230225-044001',
                      'SwarmDDPG-20230225-061518']}

dict_exp = {'IDDPG': dict_IDDPG, 'SNDDPG': dict_SNDDPG, 'FLDDPG': dict_FLDDPG,
            'local_update': dict_local_update, 'round_update': dict_round_update,
            'pair_update': dict_pair_update}
dict_algorithm_up = {'IDDPG': 'update_period_1', 'SNDDPG': 'update_period_3',
                     'FLDDPG': 'update_period_1', 'local_update': 'update_period_1',
                     'round_update': 'update_period_1', 'pair_update': 'update_period_1'}


In [20]:
def calculate_num_catastrophic_interference(array_reward, threshold_ci):
    '''
    input array of rewards (4 x 80)
    output - array of number of catastrophic interference (ci) for each agent (4 x 1)
    '''
    print(f"array_reward.shape: {array_reward.shape}")
    list_range_max = [np.max(agent_reward) for agent_reward in array_reward]
    list_range_min = [np.min(agent_reward) for agent_reward in array_reward]
    
    list_total_abs_change = []
    list_num_ci = []
    for agent_id, agent_reward in enumerate(array_reward):
        list_abs_change = []
        for episode in range(len(agent_reward)-1):
            abs_change = abs(agent_reward[episode] - agent_reward[episode+1])/(list_range_max[agent_id]-list_range_min[agent_id])
            list_abs_change.append(abs_change)
        array_abs_change = np.array(list_abs_change)
        list_total_abs_change.append(list_abs_change)

        num_ci = len(array_abs_change[array_abs_change > threshold_ci])
        list_num_ci.append(num_ci)
    return list_num_ci

In [21]:
HOME = os.environ['HOME']
path_data = HOME + '/catkin_ws/src/fl4sr/src/data'
dir_name = 'experiment_20230222'
path_parent = os.path.join(path_data, dir_name)
dict_key_value = {'algorithm': ['IDDPG', 'SNDDPG', 'FLDDPG', 'local_update', 'round_update', 'pair_update'], 
                  'update_period': [1,3,5],
                  'random_seed': [101, 102, 103]}

In [28]:
def save_individual_reward_plot(path_parent):
    list_exp = os.listdir(path_parent)

    # Using a directory to read the file names.
    for exp_name in list_exp:
        path_log = get_path_log(path_parent, exp_name)
        path_reward = get_path_reward(path_parent, exp_name)
        is_path_exist = os.path.exists(path_reward)
        if is_path_exist == True:
            print(f"data is found at: {path_reward}")
            data_reward = np.load(path_reward)
            list_num_ci = calculate_num_catastrophic_interference(data_reward.T, 0.3)
            print(list_num_ci)
        else: print(f"No data is found at: {path_reward}")
    


In [29]:
def get_path_log(path_data: str, 
    exp_name: str
    ) -> str:
    path_exp = os.path.join(path_data, exp_name)
    path_log = os.path.join(path_exp, 'log')
    return path_log

def get_path_reward(path_data: str, 
    exp_name: str
    ) -> str:
    path_exp = os.path.join(path_data, exp_name)
    path_log = os.path.join(path_exp, 'log')
    path_reward = os.path.join(path_log, 'rewards.npy')
    return path_reward

In [30]:
save_individual_reward_plot(path_parent)

data is found at: /home/swacil/catkin_ws/src/fl4sr/src/data/experiment_20230222/IDDPG-20230224-102218/log/rewards.npy
array_reward.shape: (4, 80)
[0, 0, 0, 0]
data is found at: /home/swacil/catkin_ws/src/fl4sr/src/data/experiment_20230222/IDDPG-20230222-200936/log/rewards.npy
array_reward.shape: (4, 80)
[0, 0, 0, 0]
data is found at: /home/swacil/catkin_ws/src/fl4sr/src/data/experiment_20230222/IDDPG-20230223-234842/log/rewards.npy
array_reward.shape: (4, 80)
[0, 0, 0, 0]
No data is found at: /home/swacil/catkin_ws/src/fl4sr/src/data/experiment_20230222/FLDDPG/log/rewards.npy
data is found at: /home/swacil/catkin_ws/src/fl4sr/src/data/experiment_20230222/SNDDPG-20230224-012708/log/rewards.npy
array_reward.shape: (4, 80)
[0, 0, 0, 0]
No data is found at: /home/swacil/catkin_ws/src/fl4sr/src/data/experiment_20230222/total_reward_plot.png/log/rewards.npy
data is found at: /home/swacil/catkin_ws/src/fl4sr/src/data/experiment_20230222/SwarmDDPG-20230224-233835/log/rewards.npy
array_reward.s

In [31]:
dict_test = {}

In [37]:
wow = dict.fromkeys(['23','24'])

{'23': None, '24': None}