In [None]:
import numpy as np
import gym
from gym import spaces
import csv
# import drl_utils as dr
from collections import OrderedDict
import os
import select
import torch as T
from torch.distributions.utils import logits_to_probs
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Normal, Categorical

: 

In [88]:
class CustomObservationSpace(spaces.Dict):
    def __init__(self, total_nodes):
        self.total_nodes = total_nodes
        # Define the min and max values for each feature
        low_freq = 50.0
        high_freq = 300.0
        low_rate = 150.0
        high_rate = 1000.0
        min_honesty = -500.0
        max_honesty = 500.0
        min_data = 100.0
        max_data = 1000.0
        feature_min_values = np.array([0,0.0, min_honesty, min_data, low_freq, low_rate, 0.0,0.0], dtype=np.float32)
        feature_max_values = np.array([total_nodes,1.0, max_honesty, max_data, high_freq, high_rate,1.0 ,100.0], dtype=np.float32)
        observation_low = np.tile(feature_min_values, (total_nodes, 1))
        observation_high = np.tile(feature_max_values, (total_nodes, 1))

        current_state_space = spaces.Box(low=observation_low, high=observation_high, dtype=np.float32)

        observation_space_dict = spaces.Dict(
            {
                "current_state":current_state_space, # tableau de noeuds with features
                "FL_accuracy": spaces.Box(low=0.0, high=1.0, dtype=np.float32),
            }
        )
        super().__init__(observation_space_dict) # type : ignore
    def sample(self):
        obs= super().sample()
        indexes = np.arange(obs["current_state"].shape[0])

        # Replace the first column of current_state with the indexes
        obs["current_state"][:, 0] = indexes
        return obs
    def preprocess_observation(self, current_state, fl_accuracy):
        flattened_current_state = current_state.flatten()
        processed_observation = np.hstack((flattened_current_state, fl_accuracy))

        return processed_observation

class CustomActionSpace(spaces.Space):
    def __init__(self, total_nodes, num_selected):
        self.total_nodes = total_nodes
        self.num_selected = num_selected
        self.high = np.ones(self.num_selected)

    @property
    def shape(self):
        return (self.total_nodes,)

    def sample(self):
        action = np.random.choice(self.total_nodes, size=self.num_selected, replace=False)
        result = np.zeros(self.total_nodes)
        for i in action :
            result[i] = 1
        return result, action

In [None]:
class FLNodeSelectionEnv(gym.Env):
    def __init__(self,total_nodes,num_selected , num_features,target,max_rounds,aggregator_ratio=0.3):
        super().__init__()
        self.total_nodes = total_nodes
        self.num_selected = num_selected
        self.num_features = num_features
        self.aggregator_ratio = aggregator_ratio
        self.target_accuracy = target
        # Calculate the number of aggregators and trainers based on the ratio
        num_aggregators = int(num_selected * aggregator_ratio)
        num_trainers = num_selected - num_aggregators
        self.num_aggregators = num_aggregators
        self.num_trainers = num_trainers
        self.current_state =  np.zeros((total_nodes, num_features+1)) # room for id and node accuracy
        self.current_state[:, 0] = np.arange(total_nodes)  # Set node IDs
        self.observation_space = CustomObservationSpace(total_nodes)
        self.action_space = CustomActionSpace(total_nodes, num_selected)
        # setting the initial state
        self.fl_accuracy = 0.0
        self.current_observation = {
                "current_state":self.current_state,
                "FL_accuracy": self.fl_accuracy}

        self.current_round=0
        self.max_rounds: int =max_rounds

    def set_act(self, act):
        self.act = act
        
    def reset(self): #CALLED TO INITIATE NEW EPISODE
        #should return the observation of the initial state
        # We need the following line to seed self.np_random
        csv_filename = "generated_nodes.csv"  # Replace with your CSV file name
        with open(csv_filename, "r") as csv_file:
            csv_reader = csv.reader(csv_file)
            rows = list(csv_reader)
        # Extract the observations from the CSV rows
        current_state_rows = rows[1:self.total_nodes+1] # has true and false in it we remove first row having the name of the features
        #prepocessing the data from csv to change Bool to int
        current_state_preprocessed = []
        for row in current_state_rows:
            preprocessed_row = []
            for value in row[:self.num_features]:
                if value == "true":
                    preprocessed_row.append(1)
                elif value == "false":
                    preprocessed_row.append(0)
                else:
                    preprocessed_row.append(value)  # Keep other values unchanged
            current_state_preprocessed.append(preprocessed_row)
        # Convert the preprocessed rows to a NumPy array
        new_current_state = np.array(current_state_preprocessed, dtype=np.float32)[:, :self.num_features+1]
        new_column = np.zeros((new_current_state.shape[0], 1))
        # Append the new column to the existing array
        current_state = np.append(new_current_state, new_column, axis=1) #added the accuracy of local model column

        self.current_state[:self.total_nodes] = current_state
        # Create initial values for other parts of the observation
        current_observation = {
            "current_state": current_state,
            "FL_accuracy": 0.0
        }
        self._observation = current_observation
        info = {"msg" : "success"}
        return current_observation , info

    def step(self, action, accuracies, nodes, losses, fl_accuracy):
        updated_nodes =dr.get_nodes_withaccuracy(nodes, self.total_nodes,accuracies)
        updated_fl_accuracy =fl_accuracy

        # Update the state of the environment with received updates
        next_observation = self.update_environment_state_with_network_updates(updated_nodes, fl_accuracy)

        self.current_observation = next_observation
        # Simulate FL round and get rewards
        node_rewards = self.calculate_reward(action,accuracies)
        agent_reward = sum(node_rewards)# or agent_reward = self.agent_reward(node_rewards) in case we change the way we calcultae the agent reward
        # Update the state of the environment
        self.current_round += 1
        # Check if the maximum number of rounds is reached or the target accuracy is achieved
        done = self.current_round >= self.max_rounds or self.target_accuracy_achieved(updated_fl_accuracy)

        return next_observation, agent_reward,done, node_rewards


    def update_environment_state_with_network_updates(self,nodes,FL_accuracy):
        # Update the state of the environment with received updates
        #check the shape of nodes if it includes accuracy
        nodes= nodes.astype(np.float32) # cast the array to accept npfloat
        obs= {
            "current_state": nodes,
            "FL_accuracy": FL_accuracy
        }
        self.current_observation = obs
        return obs


    def agent_reward(self, node_rewards):
        return sum(node_rewards)

    def calculate_reward(self, selected_nodes, updated_accuracies):
        #todo to change to accuracy
        # print("nodes in reward", self.current_observation)
        # print("selected nodes",  selected_nodes)
        state = self.current_observation["current_state"]
        node_rewards = np.zeros(self.total_nodes)
        # print ("updated_losses", updated_losses)
        for node_index in selected_nodes:
            node_rewards[node_index] = state[node_index][2]
            if (updated_accuracies[node_index] != 0) :
                node_rewards[node_index] *= updated_accuracies[node_index]  # Use loss as a simple example because loss is positive
            # print("node reward", node_rewards[node_index])
        return node_rewards
    def target_accuracy_achieved(self, updated_accuracy):
        return updated_accuracy >= self.target_accuracy

    # def __render__(self, mode="human"):
    #     #should render the environment
    #     pass

    # def __close__(self):
    #     #should close the environment
    #     pass

    # def __seed__(self, seed=None):
    #     #should set the seed for this env's random number generator(s)
    #     pass

    def _get_obs(self):
        return self.current_observation

: 

In [90]:

class ReplayBuffer():
    #max_size is max memory,n_actions number of actions, input_shape is observation shape
    def __init__(self,max_size,input_shape,n_actions, max_action):
        self.mem_size = max_size
        self.mem_cntr = 0 # memory counter to keep track

        self.state_memory = np.zeros((self.mem_size,*input_shape)) #state memory
        self.new_state_memory = np.zeros((self.mem_size,*input_shape)) #new state memory
        self.action_memory = np.zeros((self.mem_size,max_action)) #action memory
        self.reward_memory = np.zeros(self.mem_size) #reward memory
        self.terminal_memory = np.zeros(self.mem_size,dtype=bool) #terminal memory we need it to store the done flags

    def preprocess_observation(self, observation):
        current_state = observation['current_state']
        fl_accuracy = observation['FL_accuracy']

        # Flatten the current_state component
        flattened_current_state = current_state.flatten()

        # Concatenate the flattened current_state and the FL_accuracy
        processed_observation = np.concatenate((flattened_current_state, [fl_accuracy]))

        return processed_observation

    def store_transition(self,state,action,reward,state_,done):
        index = self.mem_cntr % self.mem_size
        self.state_memory[index] = state
        self.new_state_memory[index] = state_
        self.action_memory[index] = action
        self.reward_memory[index] = reward
        self.terminal_memory[index] = done
        self.mem_cntr += 1

    def sample_buffer(self,batch_size):
        max_mem = min(self.mem_cntr,self.mem_size)
        batch = np.random.choice(max_mem,batch_size) #randomly choose the batch size from the memory
        states = self.state_memory[batch]#get the states
        states_ = self.new_state_memory[batch]#get the new states
        actions = self.action_memory[batch] #get the actions
        rewards = self.reward_memory[batch] #get the rewards
        dones = self.terminal_memory[batch] #get the done flags
        return states[0], actions[0], rewards[0], states_[0], dones[0]


In [12]:
# Create a list to store the global accuracy values
global_accuracy = []

# Iterate through the transactions of all blocks
for json_obj in data:
    block_id = json_obj['BlockId']

    # Find the transaction with message_type equal to 2 and type equal to 2
    accuracy_transaction = None
    for transaction in json_obj['Transactions']:
        content = transaction['Content']
        if content.get('message_type') == 2 and content.get('type') == 2:
            accuracy_transaction = transaction
            break

    if accuracy_transaction is not None:
        accuracy = accuracy_transaction['Content'].get('accuracy', 0.0)
        global_accuracy.append(accuracy)

# Print or use global_accuracy as needed
print("Global Accuracy Values:", global_accuracy)

Global Accuracy Values: [60.0, 73.53, 74.71, 83.0, 83.0, 85.0, 86.0, 87.0, 85.0, 90.0, 88.0, 89.0, 87.0, 89.69, 89.84, 92.0, 92.0, 93.0, 93.0, 91.88, 92.0, 93.0, 94.24, 91.21, 93.27, 93.0, 91.76, 91.0, 93.45]


In [13]:
# Create a list to store the concatenated lists for each block
selected_nodes = []

# Iterate through the transactions of all blocks
for json_obj in data:
    block_id = json_obj['BlockId']

    # Find the transaction with message_type equal to 4
    aggregator_transaction = None
    for transaction in json_obj['Transactions']:
        if transaction['Content'].get('message_type') == 4:
            aggregator_transaction = transaction
            break

    if aggregator_transaction is not None:
        aggregators = aggregator_transaction['Content'].get('aggregators', [])
        trainers = aggregator_transaction['Content'].get('trainers', [])

        selected_nodes.append(aggregators + trainers)

# Print or use selected_nodes as needed
for block_id, selected_node_list in zip([json_obj['BlockId'] for json_obj in data], selected_nodes):
    print(f"Block {block_id}: {selected_node_list}")

Block 1/0: [32, 1, 7, 30, 35, 13, 42, 6, 4, 27, 14, 36, 16, 38, 3, 15, 23]
Block 1/1: [37, 1, 7, 30, 35, 46, 19, 8, 13, 42, 43, 6, 4, 49, 27, 36, 16]
Block 1/2: [37, 30, 35, 19, 42, 49, 27, 1, 36, 16, 7, 38, 3, 12, 15, 23, 9]
Block 1/3: [32, 37, 30, 35, 19, 42, 13, 43, 27, 4, 36, 12, 15, 23, 9, 26, 5]
Block 1/4: [32, 37, 30, 35, 19, 42, 46, 13, 43, 27, 49, 4, 36, 16, 38, 12, 15]
Block 1/5: [32, 30, 19, 46, 13, 43, 42, 49, 36, 16, 38, 15, 26, 29, 2, 45, 48]
Block 1/6: [32, 30, 19, 35, 46, 13, 42, 49, 27, 38, 15, 23, 26, 29, 45, 48, 6]
Block 1/7: [32, 30, 37, 19, 35, 46, 13, 42, 43, 49, 27, 16, 38, 15, 23, 26, 29]
Block 1/8: [32, 30, 37, 19, 35, 46, 13, 42, 43, 49, 27, 36, 16, 38, 15, 12, 23]
Block 1/9: [32, 37, 30, 19, 35, 46, 13, 42, 43, 27, 36, 16, 38, 15, 12, 23, 26]
Block 1/10: [32, 37, 30, 19, 35, 46, 13, 43, 27, 16, 38, 15, 12, 23, 26, 29, 45]
Block 1/11: [32, 37, 30, 46, 13, 43, 42, 49, 16, 38, 36, 12, 23, 26, 29, 45, 48]
Block 1/12: [32, 30, 19, 46, 13, 42, 27, 16, 38, 36, 15, 1

In [14]:
round_reward = []
transitions = []

for block_id in range(len(node_info_by_block) - 1):  # Fixed the range
    # print(len(node_info_by_block) - 1)
    # print(f"Block {block_id}:")
    nodes_rewards = []

    for node_index in selected_nodes[block_id]:
        # print("Node index: ", node_index)
        node_data = nodes_by_block[block_id][node_index]
        reward = node_data[-1]
        # print("Original reward: ", reward)

        if reward != 0.0:
            reward *= node_data[2]
        else :
            reward = node_data[2]
        # print("Modified reward: ", reward)

        nodes_rewards.append(reward)

    # print("Nodes rewards:", nodes_rewards)
    cumulated_reward = sum(nodes_rewards)
    round_reward.append(cumulated_reward)

    transition = {
        "obs": nodes_by_block[block_id],
        "obs_": nodes_by_block[block_id + 1],
        "action": selected_nodes[block_id],
        "reward": cumulated_reward,
        "done": block_id == len(node_info_by_block) - 1  # Corrected the condition
    }
    transitions.append(transition)

for transition in transitions :
    print(transition)

{'obs': [[0, 1, 0.0, 320, 190.0, 260.0, 0.0, 0.0], [1, 1, 0.0, 360, 290.0, 770.0, 0.0, 0.0], [2, 1, 0.0, 510, 100.0, 730.0, 0.0, 0.0], [3, 1, 0.0, 450, 180.0, 450.0, 0.0, 56.67], [4, 1, 0.0, 800, 280.0, 640.0, 0.0, 68.75], [5, 1, 0.0, 130, 190.0, 270.0, 0.0, 0.0], [6, 1, 0.0, 760, 240.0, 830.0, 0.0, 69.74], [7, 1, 0.0, 160, 120.0, 670.0, 0.0, 0.0], [8, 0, 0.0, 380, 260.0, 540.0, 0.0, 0.0], [9, 1, 0.0, 300, 60.0, 990.0, 0.0, 0.0], [10, 1, 0.0, 610, 50.0, 700.0, 0.0, 0.0], [11, 1, 0.0, 410, 250.0, 180.0, 0.0, 0.0], [12, 0, 0.0, 520, 200.0, 460.0, 0.0, 0.0], [13, 1, 0.0, 840, 300.0, 750.0, 0.0, 0.0], [14, 1, 0.0, 700, 180.0, 890.0, 0.0, 72.86], [15, 1, 0.0, 830, 270.0, 470.0, 0.0, 63.25], [16, 1, 0.0, 450, 150.0, 550.0, 0.0, 63.33], [17, 0, 0.0, 460, 70.0, 360.0, 0.0, 0.0], [18, 1, 0.0, 650, 160.0, 240.0, 0.0, 0.0], [19, 0, 0.0, 550, 240.0, 970.0, 0.0, 0.0], [20, 1, 0.0, 210, 70.0, 160.0, 0.0, 0.0], [21, 1, 0.0, 580, 290.0, 180.0, 0.0, 0.0], [22, 1, 0.0, 500, 220.0, 190.0, 0.0, 0.0], [23,

In [91]:

class CriticNetwork(nn.Module):
    #beta learning rate, number of input dimensions from the environment
    #
    def __init__ (self,beta,input_shape,n_actions,
    fc1_dims=256 , fc2_dims = 256, name='critic',chkpt_dir='tmp/sac' ):
        super(CriticNetwork,self).__init__()
        self.input_dims = np.prod(input_shape)
        self.n_actions = n_actions
        self.fc1_dims = fc1_dims
        self.fc2_dims = fc2_dims
        self.name = name
        self.checkpoint_dir = chkpt_dir
        self.checkpoint_file = os.path.join(self.checkpoint_dir,self.name+'_sac')
        #layer 1
        self.fc1 = nn.Linear(self.input_dims+n_actions,self.fc1_dims)
        #layer 2
        self.fc2 = nn.Linear(self.fc1_dims,self.fc2_dims)
        #layer 3
        self.q = nn.Linear(self.fc2_dims,1)
        #optimizer
        self.optimizer = optim.Adam(self.parameters(),lr=beta)
        #device
        self.device = T.device('cuda:0' if T.cuda.is_available() else 'cpu')
        #to device
        self.to(self.device)

    def forward(self,state,action):
        #layer 1
        input = T.cat([state,action],dim=0)
        action_value = self.fc1(input)
        action_value = F.relu(action_value)
        #layer 2
        action_value = self.fc2(action_value)
        action_value = F.relu(action_value)
        #layer 3
        q = self.q(action_value)

        return q

    def save_checkpoint(self):
        print('...saving checkpoint...')
        T.save(self.state_dict(),self.checkpoint_file)

    def load_checkpoint(self):
        print('...loading checkpoint...')
        self.load_state_dict(T.load(self.checkpoint_file))

class ValueNetwork(nn.Module):
    def __init__(self,beta,input_shape,fc1_dims=256,fc2_dims=256,name='value',chkpt_dir='tmp/sac'):
        super(ValueNetwork,self).__init__()
        # print("in init value networks")
        self.input_dims = np.prod(input_shape)

        self.fc1_dims = fc1_dims
        self.fc2_dims = fc2_dims
        self.name = name
        self.checkpoint_dir = chkpt_dir
        self.checkpoint_file = os.path.join(self.checkpoint_dir,self.name+'_sac')
        #layer 1
        # print("value network input dimensions", self.input_dims)
        self.fc1 = nn.Linear(self.input_dims, self.fc1_dims)
        #layer 2
        self.fc2 = nn.Linear(self.fc1_dims,self.fc2_dims)
        #layer 3
        self.v = nn.Linear(self.fc2_dims,1)
        #optimizer
        self.optimizer = optim.Adam(self.parameters(),lr=beta)
        #device
        self.device = T.device('cuda:0' if T.cuda.is_available() else 'cpu')
        #to device
        self.to(self.device)

    def forward(self,state):
        #layer 1
        state_value = self.fc1(state)
        state_value = F.relu(state_value)
        #layer 2
        state_value = self.fc2(state_value)
        state_value = F.relu(state_value)
        #layer 3
        v = self.v(state_value)

        return v

    def save_checkpoint(self):
        print('...saving checkpoint...')
        T.save(self.state_dict(),self.checkpoint_file)

    def load_checkpoint(self):
        print('...loading checkpoint...')
        self.load_state_dict(T.load(self.checkpoint_file))


class ActorNetwork(nn.Module):
    def __init__(self,alpha,input_shape,max_actions,fc1_dims=256,fc2_dims=256,n_actions=2,name='actor',chkpt_dir='tmp/sac'):
        super(ActorNetwork,self).__init__()
        self.input_dims = np.prod(input_shape)
        self.input_shape = input_shape
        self.max_actions = max_actions # number of selected nodes
        self.fc1_dims = fc1_dims
        self.fc2_dims = fc2_dims
        self.n_actions = n_actions
        self.name = name
        self.checkpoint_dir = chkpt_dir
        self.checkpoint_file = os.path.join(self.checkpoint_dir,self.name+'_sac')
        # print("checpoint file", self.checkpoint_file)
        self.reparam_noise = 1e-6
        #layer 1
        self.fc1 = nn.Linear(self.input_dims, self.fc1_dims)
        # self.fc1 = nn.Linear(*self.input_dims,out_features=self.fc1_dims)
        #layer 2
        self.fc2 = nn.Linear(self.fc1_dims,self.fc2_dims)
        #layer 3


        self.mu = nn.Linear(self.fc2_dims,self.max_actions)
        #layer 4
        self.sigma = nn.Linear(self.fc2_dims,self.max_actions)
        #optimizer
        self.optimizer = optim.Adam(self.parameters(),lr=alpha)
        #device
        self.device = T.device('cuda:0' if T.cuda.is_available() else 'cpu')
        #to device
        self.to(self.device)
        #normal distribution
        # self.distribution = Normal

    def sample_normal(self, state, num_selected_nodes, exploration_noise=0.025):
        # print("in sample_normal")
        state = dr.flatten_nodes(state)
        # print("numselected in sample normal", num_selected_nodes)
        action_probs, action_mean, action_log_std = self.forward(state)
        # print("action probs in sample normazle",action_probs.shape)
        action_std = action_log_std.exp()
        state = dr.array_to_state(state, 8)
        # print("after array of array ",state)
        availability_mask = state[:, 1] != 0
        availability_mask = availability_mask.long()
        # print("availability mask", availability_mask)
        # Add exploration noise to logits
        noisy_logits = action_mean + exploration_noise * action_std * T.randn_like(action_mean)
        action_probs_clean = T.tensor(np.nan_to_num(action_probs.T.detach().numpy(), nan=0.0))
        # Create a Categorical distribution using the noisy logits

        action_dist = Categorical(action_probs_clean)
        # print("action_dist from categorical ", action_dist)
        # Sample actions from the Categorical distribution
        sampled_actions = action_dist.sample()
        print('sampled actions from categorical', sampled_actions)

        # Apply availability mask
        # print('sampled actions size', sampled_actions)
        # print('availability mask shape', availability_mask)
        # sampled_actions = sampled_actions * availability_mask


        # Get indices of selected nodes
        selected_indices = sampled_actions.nonzero().squeeze()
        for i in range(10):
            new_samples = action_dist.sample()
            new_samples = new_samples*availability_mask  # Apply availability mask
            new_indices = new_samples.nonzero().squeeze()

            # Convert the selected_indices list back to a tensor
            selected_indices_tensor = T.tensor(selected_indices)

            # Combine the selected indices and new indices while removing duplicates
            combined_indices = T.cat((selected_indices_tensor, new_indices))
            unique_combined_indices = T.unique(combined_indices)

            # Convert the unique indices back to a Python list
            selected_indices = unique_combined_indices
        # print(selected_indices)
        selected_indices = np.array(selected_indices)
        # print(selected_indices)
        if len(selected_indices) < num_selected_nodes:
            additional_indices = np.random.choice(availability_mask.nonzero().squeeze(), size=num_selected_nodes - len(selected_indices), replace=False)
            print(additional_indices)
            selected_indices = np.concatenate((selected_indices, additional_indices))
            # print("finished sample_normal")
        # Calculate log probabilities for the new sampled actions
        log_probs = action_dist.log_prob(new_samples)
        # print("log probs", log_probs)
        return selected_indices, log_probs



    def forward(self, state):
      # Layer 1
      prob = self.fc1(state)
      prob = F.relu(prob)
      prob = self.fc2(prob)
      prob = F.relu(prob)

      # Layer 3
      mu = self.mu(prob)
      sigma = self.sigma(prob)

      # Apply availability mask
      state = dr.array_to_state(state, 8)
      availability_mask = state[:, 1] != 0
      availability_mask = availability_mask.float()

      # Apply ReLU activation to mu to ensure non-negative values
      mu = F.relu(mu)

      # Scale the mu values to the range [0, 1]
      scaled_actions = mu / self.input_shape[0]

      # Calculate the probability of not selecting the node
      prob_not_selected = 1 - scaled_actions

      # Normalize the probabilities
      total_probabilities = prob_not_selected + scaled_actions
      selection_probabilities = scaled_actions / total_probabilities  # Only the probability of selecting the node

      return selection_probabilities, mu, sigma





    def save_checkpoint(self):
        print('...saving checkpoint...')
        T.save(self.state_dict(),self.checkpoint_file)

    def load_checkpoint(self):
        print('...loading checkpoint...')
        self.load_state_dict(T.load(self.checkpoint_file))



In [92]:

ALPHA_INITIAL = 1.
DISCOUNT_RATE = 0.99
LEARNING_RATE = 10 ** -4
SOFT_UPDATE_INTERPOLATION_FACTOR = 0.01
class Agent ():
    def __init__(self,env,alpha=ALPHA_INITIAL,beta=LEARNING_RATE,input_shape=[8],gamma = DISCOUNT_RATE,n_actions=2,max_actions=1,max_size=200,tau=SOFT_UPDATE_INTERPOLATION_FACTOR,
    layer1_size=256,layer2_size=256,batch_size=1,reward_scale=2):
        os.makedirs('./tmp/sac', exist_ok=True)
        self.gamma = gamma
        self.tau = tau
        self.memory = ReplayBuffer(max_size, input_shape=input_shape, n_actions=n_actions,max_action=max_actions)
        self.batch_size = batch_size
        self.n_actions = n_actions
        self.max_actions= max_actions
        self.scale = reward_scale

        self.actor = ActorNetwork(alpha, input_shape,  n_actions=n_actions, name='actor', max_actions=self.max_actions)

        self.critic_1 = CriticNetwork(beta, input_shape , n_actions=max_actions, name='critic_1')

        self.critic_2 = CriticNetwork(beta, input_shape, n_actions=max_actions, name='critic_2')

        self.value = ValueNetwork(beta, input_shape, name='value')

        self.target_value = ValueNetwork(beta, input_shape, name='target_value')

        self.update_network_parameters(tau=1)

    def choose_action(self, observation):
        state = T.tensor(observation, dtype=T.float).to(self.actor.device)
        actions, log_probs = self.actor.sample_normal(state, self.max_actions)
        return actions

    def remember(self,state,action,reward,next_state,done):
        self.memory.store_transition(state,action,reward,next_state,done)

    def update_network_parameters(self, tau=None):
        if tau is None:
            tau = self.tau

        target_value_params = self.target_value.named_parameters()
        value_params = self.value.named_parameters()

        target_value_state_dict = dict(target_value_params)
        value_state_dict = dict(value_params)

        for name in value_state_dict:
            value_state_dict[name] = tau*value_state_dict[name].clone() + (1-tau)*target_value_state_dict[name].clone()

        self.target_value.load_state_dict(value_state_dict)

    def save_models(self):
        print('.... saving models ....')
        self.actor.save_checkpoint()
        self.value.save_checkpoint()
        self.target_value.save_checkpoint()
        self.critic_1.save_checkpoint()
        self.critic_2.save_checkpoint()


    def load_models(self):
        print('.... loading models ....')
        self.actor.load_checkpoint()
        self.value.load_checkpoint()
        self.target_value.load_checkpoint()
        self.critic_1.load_checkpoint()
        self.critic_2.load_checkpoint()

    def learn(self):
        if self.memory.mem_cntr < self.batch_size:
            print('not enough memories to learn from!')
            return

        state, action, reward, new_state, done = self.memory.sample_buffer(self.batch_size)
        reward = T.tensor(reward, dtype=T.float).to(self.actor.device)
        done = T.tensor(done).to(self.actor.device)
        state_ = T.tensor(new_state, dtype=T.float).to(self.actor.device)
        state = T.tensor(state, dtype=T.float).to(self.actor.device)
        action = T.tensor(action, dtype=T.float).to(self.actor.device)
        flat_state = dr.flatten_nodes(state)
        flat_state_= dr.flatten_nodes(state_)


        value = self.value(flat_state).view(-1)
        value_ = self.target_value(flat_state_).view(-1)
        value_[done] = 0.0
        actions, log_probs = self.actor.sample_normal(state, self.max_actions)
        actions = actions[:self.max_actions]
        actions = T.tensor(actions)
        q1_new_policy = self.critic_1.forward(flat_state, actions)
        q2_new_policy = self.critic_2.forward(flat_state, actions)
        critic_value = T.min(q1_new_policy, q2_new_policy)
        critic_value = critic_value.view(-1)
        self.value.optimizer.zero_grad()
        value_target = critic_value - log_probs
        value_loss = 0.5 * F.mse_loss(value, value_target)
        value_loss.backward(retain_graph=True)
        self.value.optimizer.step()

        log_probs = log_probs.view(-1)
        q1_new_policy = self.critic_1.forward(flat_state, actions)
        q2_new_policy = self.critic_2.forward(flat_state, actions)
        critic_value = T.min(q1_new_policy, q2_new_policy)
        critic_value = critic_value.view(-1)

        actor_loss = log_probs - critic_value
        actor_loss = T.mean(actor_loss)
        self.actor.optimizer.zero_grad()
        actor_loss.backward(retain_graph=True)
        self.actor.optimizer.step()

        self.critic_1.optimizer.zero_grad()
        self.critic_2.optimizer.zero_grad()
        q_hat = self.scale * reward + self.gamma * value_
        q1_old_policy = self.critic_1.forward(flat_state, action).view(-1)
        q2_old_policy = self.critic_2.forward(flat_state, action).view(-1)
        critic_1_loss = 0.5 * F.mse_loss(q1_old_policy, q_hat)
        critic_2_loss = 0.5 * F.mse_loss(q2_old_policy, q_hat)
        critic_loss = critic_1_loss + critic_2_loss
        critic_loss.backward()
        self.critic_1.optimizer.step()
        self.critic_2.optimizer.step()

        self.update_network_parameters()
        self.save_models()
        print('updated the networks')

In [86]:
# import sys
# sys.path.insert(0, './')  # Add the directory containing "DRL" to the path

# from agent import node_selection_agent as ag
# from environment import node_selection_env as nds

total_nodes = 50
max_actions = 17
num_features = 8
target_accuracy = 99.9
max_rounds= 30
#create the nodes selection environment accoding to the observation
envNodeSelect = FLNodeSelectionEnv(total_nodes,max_actions, num_features=num_features, target=target_accuracy, max_rounds=max_rounds)
obs_shape = envNodeSelect.observation_space.sample()["current_state"].shape
agent = Agent(input_shape=obs_shape ,n_actions=obs_shape[0], env=envNodeSelect,max_actions=max_actions)

# train the agent with the transitions store the trasitions then make the agent learn
for transition in transitions :
    agent.remember(transition["obs"],transition["action"], transition["reward"], transition["obs_"], transition["done"])
    agent.learn()

checpoint file tmp/sac/actor_sac
in init value networks
value network input dimensions 400
in init value networks
value network input dimensions 400
added the transition 1
learning agent in learn
batch_size 1
maxmeme in sample buffer 1
the batch in sample ? [0]
action from batch [[32.  1.  7. 30. 35. 13. 42.  6.  4. 27. 14. 36. 16. 38.  3. 15. 23.]]
rewards from batch [0.]
dones from batch [False]
Sampled transitions: (50, 8) (17,) () (50, 8) ()
done False
reward 0.0
made them tensors in leanr
passed value
passed other value
value_done? tensor([], size=(0, 1), grad_fn=<IndexBackward0>)
in sample_normal
numselected in sample normal 17
action probs in sample normazle torch.Size([17])
action_dist from categorical  Categorical(probs: torch.Size([17]))
sampled actions from categorical tensor(14)
log probs tensor([ -0.6938,  -0.6938,  -0.6938,  -0.6938,  -0.6938,  -0.6938,  -0.6938,
         -0.6938, -15.9424,  -0.6938,  -0.6938,  -0.6938, -15.9424,  -0.6938,
         -0.6938,  -0.6938,  -0.

  done = T.tensor(done).to(self.actor.device)
  selected_indices_tensor = T.tensor(selected_indices)
  value_loss = 0.5 * F.mse_loss(value, value_target)


.... saving models ....
...saving checkpoint...
...saving checkpoint...
...saving checkpoint...
...saving checkpoint...
...saving checkpoint...
updated the networks
added the transition 4
learning agent in learn
batch_size 1
maxmeme in sample buffer 4
the batch in sample ? [0]
action from batch [[32.  1.  7. 30. 35. 13. 42.  6.  4. 27. 14. 36. 16. 38.  3. 15. 23.]]
rewards from batch [0.]
dones from batch [False]
Sampled transitions: (50, 8) (17,) () (50, 8) ()
done False
reward 0.0
made them tensors in leanr
passed value
passed other value
value_done? tensor([], size=(0, 1), grad_fn=<IndexBackward0>)
in sample_normal
numselected in sample normal 17
action probs in sample normazle torch.Size([17])
action_dist from categorical  Categorical(probs: torch.Size([17]))
sampled actions from categorical tensor(10)
log probs tensor([ -0.6938,  -0.6938,  -0.6938,  -0.6938,  -0.6938,  -0.6938,  -0.6938,
         -0.6938, -15.9424,  -0.6938,  -0.6938,  -0.6938, -15.9424,  -0.6938,
         -0.6938