In [None]:
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt

from environment import TrackingEnv

SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)

MAX_EPISODES = 1200
NUM_NEURONS = 128
LR_CRITIC = 0.0005
LR_ACTOR = 0.0001
GAMMA = 0.99
#ENTROPY_COEF = 0.01
MAX_STEP_EXPLORATION = 1e5
EARLY_STOPPING_EPISODES = 30

class PolicyNet(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(PolicyNet, self).__init__()
        
        self.fc1 = nn.Linear(state_dim, NUM_NEURONS)
        self.fc2 = nn.Linear(NUM_NEURONS, NUM_NEURONS)
        
        # Output per la media (mu)
        self.mu_layer = nn.Linear(NUM_NEURONS, action_dim)
        
        # Output per la deviazione standard (log_sigma)
        self.log_sigma_layer = nn.Linear(NUM_NEURONS, action_dim)

    def forward(self, state, exploration_term):
        x = F.relu(self.fc1(state))
        x = F.relu(self.fc2(x))
        
        # Media delle azioni
        mu = self.mu_layer(x)
        
        # Deviazione standard (softplus per garantire positività)
        log_sigma = self.log_sigma_layer(x)
        #sigma = F.softplus(log_sigma) + 1e-5 # 1e-5 per evitare log(0)
        sigma = F.softplus(log_sigma) + exploration_term + 1e-5 # 1e-5 per evitare log(0)

        return mu, sigma

class ValueNet(nn.Module):
    def __init__(self, num_inputs):
        super(ValueNet, self).__init__()
        self.fc1 = nn.Linear(num_inputs, NUM_NEURONS)
        self.fc2 = nn.Linear(NUM_NEURONS, NUM_NEURONS)
        self.fc3 = nn.Linear(NUM_NEURONS, 1)
        
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)

class Agent(nn.Module):
    def __init__(self, num_inputs, num_actions):
        super(Agent, self).__init__()
        self.target = []    # il target verrà dato dall'ambiente
        self.actor = PolicyNet(num_inputs, num_actions)
        self.critic = ValueNet(num_inputs)
        self.optimizer_actor = optim.Adam(self.actor.parameters(), lr=LR_ACTOR)
        self.optimizer_critic = optim.Adam(self.critic.parameters(), lr=LR_CRITIC)

        self.counter_avvicinamento = 0
        self.conuter_allontanamento = 0

    def sample_action(self, state, exploration_term):
        mu, sigma = self.actor.forward(state, exploration_term)
        dist = torch.distributions.Normal(mu, sigma)
        action = dist.rsample()  # Usa reparametrization trick per il backprop
        log_prob = dist.log_prob(action).sum(dim=-1)  # Somma log-prob per dimensione azione 
        return action, log_prob

    def reward_function(self, state, next_state):
        #state = torch.tensor(state, dtype=torch.float32)
        next_state = torch.tensor(next_state, dtype=torch.float32)
        diff_t = state[:2]-state[2:]    # posizione target è nello stato
        diff_t1 = next_state[:2]-next_state[2:]  # posizione target è nello stato

        if torch.norm(diff_t) > torch.norm(diff_t1):
            self.counter_avvicinamento += 1
        else:
            self.counter_allontanamento += 1

        return torch.norm(diff_t) - torch.norm(diff_t1)
        #reward = 0
        #if torch.norm(diff_t) >= torch.norm(diff_t1):
        #    reward = 1
        #elif torch.norm(diff_t1) < 0.1:
        #    reward = 100
        #elif torch.norm(diff_t) < torch.norm(diff_t1):
        #    reward = -1
        #return reward

    def get_exploration_term(self, current_step, max_steps):
        return max(0.0, 0.2 * (1 - current_step / max_steps))  # lineare decrescente

    def update(self, state, action, log_prob, next_state, reward, done):
        #state = torch.tensor(state, dtype=torch.float32)
        #action = torch.tensor(action, dtype=torch.int64)
        #next_state = torch.tensor(next_state, dtype=torch.float32)
        #reward = torch.tensor(reward, dtype=torch.float32)
        done = torch.tensor(done, dtype=torch.float32)

        value = self.critic(state).squeeze()
        target_value = reward + self.critic(next_state).squeeze() * GAMMA * (1 - done)
        advantage = target_value - value

        critic_loss = nn.MSELoss()(value, target_value.detach())
        
        actor_loss = -log_prob * advantage.detach()
        
        self.optimizer_critic.zero_grad()
        critic_loss.backward()
        self.optimizer_critic.step()

        self.optimizer_actor.zero_grad()
        actor_loss.backward()
        self.optimizer_actor.step()


def train_a2c(env=None, num_episodes=MAX_EPISODES):
    #target = torch.tensor(target, dtype=torch.float32)
    if env is None:
        env = TrackingEnv()
    inputs_dim = env.observation_space.shape[0]
    actions_dim = env.action_space.shape[0]
    #print(f"Num inputs: {inputs_dim}, Num actions: {actions_dim}")

    agent = Agent(inputs_dim, actions_dim)#, target)
    reward_history = []
    total_step = 0
    counter = 0
    tolerance = 0.05

    for episode in range(num_episodes):
        #print(f"Episode: {episode}")
        agent.counter_avvicinamento = 0
        agent.counter_allontanamento = 0
        state, _ = env.reset()
        done = False
        total_reward = 0
        #mean_reward = 0
        state = torch.tensor(state, dtype=torch.float32)
        step = 0
        actions = []
        
        while not done:
            step += 1
            total_step += 1
            exploration_term = agent.get_exploration_term(total_step, MAX_STEP_EXPLORATION)
            action, log_prob = agent.sample_action(state, exploration_term)
            actions.append(action)
            next_state, _, done, _, _ = env.step(action)
            #print(f"state: {state}, next_state: {next_state}")
            reward = agent.reward_function(state, next_state)
            next_state = torch.tensor(next_state, dtype=torch.float32)
            #print(f"state: {state}, next_state: {next_state}, reward: {reward}")
            if torch.norm(next_state[:2] - next_state[2:])>2 or torch.norm(next_state[:2] - next_state[2:])<tolerance:
                done = True
            #if torch.norm(next_state[:2] - target[:2])<tolerance:
            #    reward = 100*reward # se si avvicina abbastanza ottiene un reward molto alto
            #    print(reward)
            #    done = True
            #if torch.norm(next_state[:2] - target[:2])>5:
            #    done = True
            #reward = torch.tensor(reward, dtype=torch.float32)
            agent.update(state, action, log_prob, next_state, reward, done)
            state = next_state
            total_reward += reward
            #print(f"Episode: {episode}, Step: {step}, reward: {reward}")

        #mean_reward = total_reward/step
        print(f"mean action x: {action[0].mean()}, mean action y: {action[1].mean()}")

        print(f"Episode: {episode}, Step: {step}, Counter: {counter}, Avv: {agent.counter_avvicinamento} All: {agent.counter_allontanamento}, Total reward: {total_reward}, final state: {state}")
        if torch.norm(next_state[:2] - next_state[2:])<tolerance:
            counter += 1
        if counter % 100 == 0 and counter != 0:
            #counter = 0
            tolerance = round(max(0.1, tolerance-0.1),2)
        reward_history.append(total_reward)

        if len(reward_history)>EARLY_STOPPING_EPISODES and reward_history[EARLY_STOPPING_EPISODES:] > 1:
            break
        
        if episode % 10 == 0:
            print(f"Episode {episode}, Total Reward: {total_reward}")
    
    env.close()

    plt.plot(reward_history)
    plt.xlabel('Episode')
    plt.ylabel('Total Reward')
    plt.title('Training Progress Translation Agent')
    plt.show()

    return agent

if __name__ == "__main__":
    trained_agent = train_a2c()

mean action x: -0.2319922298192978
Episode: 0, Step: 5226, Counter: 0, Avv: 2016 All: 3210, Total reward: -1.5767444372177124, final state: tensor([-0.6237,  2.0750,  0.3000,  0.3000])
Episode 0, Total Reward: -1.5767444372177124
mean action x: -0.313232421875
Episode: 1, Step: 5068, Counter: 0, Avv: 1872 All: 3196, Total reward: -1.57728111743927, final state: tensor([-1.2523,  1.5636,  0.3000,  0.3000])
mean action x: 0.05774411931633949
Episode: 2, Step: 1889, Counter: 0, Avv: 504 All: 1385, Total reward: -1.5758062601089478, final state: tensor([-1.4504,  1.2677,  0.3000,  0.3000])


KeyboardInterrupt: 