In [None]:
# An agent trying to move through a field using the least amount of energy
'''This module contains a detailed implementation of the Deep Deterministic Policy Gradient (DDPG) algorithm, a model-free off-policy actor-critic reinforcement learning algorithm using pytorch's neural network tools.'''
import matplotlib.pyplot as plt
import numpy as np
from dataclasses import *
import torch
import torch.nn as nn
import torch.autograd
import torch.optim as optim
import torch.nn.functional as F 
import torch.autograd
from typing import Any, Callable, Dict, List, Tuple, Union, Optional
from functools import wraps
import random

@dataclass
class EnforceClassTyping:
    def __post_init__(self):
        for (name, field_type) in self.__annotations__.items():
            if not isinstance(self.__dict__[name], field_type):
                current_type = type(self.__dict__[name])
                raise TypeError(f"The field `{name}` was assigned by `{current_type}` instead of `{field_type}`")
        # print("Check is passed successfully")
def EnforceMethodTyping(func: Callable) -> Callable:
    'Enforces type annotation/hints for class mathods'
    arg_annotations = func.__annotations__
    if not arg_annotations:
        return func

    @wraps(func)
    def wrapper(self, *args: Tuple[Any], **kwargs: Dict[str, Any]) -> Any:
        for arg, annotation in zip(args, arg_annotations.values()):
            if not isinstance(arg, annotation):
                raise TypeError(f"Expected {annotation} for argument {arg}, got {type(arg)}.")

        for arg_name, arg_value in kwargs.items():
            if arg_name in arg_annotations:
                annotation = arg_annotations[arg_name]
                if not isinstance(arg_value, annotation):
                    raise TypeError(f"Expected {annotation} for keyword argument {arg_name}, got {type(arg_value)}.")

        return func(self, *args, **kwargs)

    return wrapper
def EnforceFunctionTyping(func: Callable) -> Callable:
    'Enforces type annotation/hints for other functions'
    @wraps(func)
    def wrapper(*args, **kwargs):
        # Check positional arguments
        for arg, annotation in zip(args, func.__annotations__.values()):
            if not isinstance(arg, annotation):
                raise TypeError(f"Expected {annotation} for {arg}, got {type(arg)}.")

        # Check keyword arguments
        for arg_name, arg_value in kwargs.items():
            if arg_name in func.__annotations__:
                annotation = func.__annotations__[arg_name]
                if not isinstance(arg_value, annotation):
                    raise TypeError(f"Expected {annotation} for {arg_name}, got {type(arg_value)}.")

        return func(*args, **kwargs)

    return wrapper

def tanh_derivative(x: torch.Tensor|int|float)-> torch.Tensor|int|float:
    return 1-x**2
def neg_relu(x):
    return min(0, x)
def f(x):
    return x
def f_grad(x):
    return torch.ones_like(x)
def relu_derivative(tensor):
    return torch.where(tensor > 0, torch.tensor(1.0), torch.tensor(0.0))
def mse_grad(x, y):
    return (-2* (x-y))/len(x)

@dataclass
class Source(EnforceClassTyping):
    'This class represents the electric field sources with its position in the field(Position) and the magnitude of the source(Charge)'
    Position: torch.Tensor # m
    Charge: float #C

@EnforceFunctionTyping
def ElectricField(FieldSources: list, ObservationPosition: torch.Tensor)->torch.Tensor:
    'This function takes a list of sources and outputs the field strength experienced at any given point(s). This determines the physics of the field(an electric field in this case)'
    CoulombConstant = 8.9875e9 #N*m^2/C^2
    for FieldSource in FieldSources:
        if type(FieldSource) != Source:
            raise TypeError("The input is not valid")
    if type(ObservationPosition[0]) != type(ObservationPosition[1]):
         raise TypeError("Incompatible Reference point data types")
    elif type(ObservationPosition[0]) != torch.Tensor:
        raise TypeError("Invalid Reference point data type")
    elif ObservationPosition[0].size()!=ObservationPosition[1].size():
        raise TypeError("Incompatible Reference point dimensions")
    else: 
        ElectricFieldVector = torch.zeros_like(ObservationPosition)
    for FieldSource in FieldSources:
        PositionMatrices= torch.stack([torch.ones_like(ObservationPosition[0])* FieldSource.Position[0].item(), 
                                        torch.ones_like(ObservationPosition[1])* FieldSource.Position[1].item()])
        DisplacemnetVector = ObservationPosition - PositionMatrices
        DisplacementMagnitude = torch.sqrt(DisplacemnetVector[0]**2 +DisplacemnetVector[1]**2)  # Magnitude of the displacement vector
        ElectricFieldVector += (CoulombConstant * FieldSource.Charge) / DisplacementMagnitude**3 * DisplacemnetVector
    return ElectricFieldVector #N/C or V/m

@EnforceFunctionTyping
def PlotField(Sources: list, ObservationPosition: torch.Tensor):
    'This funtion plots the 2D electric vector field'
    xd, yd = ElectricField(Sources, ObservationPosition)
    xd = xd / torch.sqrt(xd**2 + yd**2)
    yd = yd / torch.sqrt(xd**2 + yd**2)
    color_aara = torch.sqrt(xd**2+ yd**2)
    fig, ax = plt.subplots(1,1)
    cp = ax.quiver(ObservationPosition[0],ObservationPosition[1],xd,yd,color_aara)
    fig.colorbar(cp)
    plt.rcParams['figure.dpi'] = 150
    plt.show()

@dataclass 
class State(EnforceClassTyping):
    '''This class represents the state of the Agent with its Position, Momentum and the Field Strength if experiences at its Position. 
       These are parameters the agent is able to observe, they uniquely define the state of the agent.'''
    Position: torch.Tensor # m
    FieldStrength: torch.Tensor #N/C or V/m
    Momentum: torch.Tensor #kg*m/s
    def Unwrap(self)->torch.Tensor:
        '''This function converts the state parameters to a single tensor for processing. '''
        return torch.cat([self.Position, 
                          self.FieldStrength,
                          self.Momentum])

class CriticNetwork(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(CriticNetwork, self).__init__()
        self.linear1 = nn.Linear(input_size, hidden_size)
        self.linear2 = nn.Linear(hidden_size, hidden_size)
        self.linear3 = nn.Linear(hidden_size, output_size)

    def forward(self, state, action):

        x = torch.cat([state, action], 1)
        x = F.relu(self.linear1(x))
        x = F.relu(self.linear2(x))
        x = self.linear3(x)
        return x

class ActorNetwork(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(ActorNetwork, self).__init__()
        self.linear1 = nn.Linear(input_size, hidden_size)
        self.linear2 = nn.Linear(hidden_size, hidden_size)
        self.linear3 = nn.Linear(hidden_size, output_size)
        
    def forward(self, state):
        x = F.relu(self.linear1(state))
        x = F.relu(self.linear2(x))
        x = torch.tanh(self.linear3(x))
        return x

@dataclass 
class ReplayBuffer(EnforceClassTyping):
    '''This class represents the Replay buffer which stores state transitions(State, Action, NextState, Reward, Terminal Signal) which will be used to train the Actor and Critic Networks. 
    The replay buffer'''
    BufferSize: int
    Buffer: list = None
    def __post_init__(self):
        if self.Buffer is None:
            self.Buffer = []
    @EnforceMethodTyping
    def AddExperience(self, State: State, Action: torch.Tensor, NextState: State, Reward: float, TerminalState: bool):
        '''This method adds a state transition to the replay buffer'''
        if len(self.Buffer) < self.BufferSize:
            self.Buffer.append([State, Action, NextState, Reward, TerminalState])
        else:
            self.Buffer.pop(0)
            self.Buffer.append([State, Action, NextState, Reward, TerminalState])
    @EnforceMethodTyping
    def SampleBuffer(self, BatchSize: int):
        '''This method randomly samples the replay buffer to ouput a batches of state transition variables'''
        if len(self.Buffer) >= BatchSize:
            SampledBatch = random.sample(self.Buffer, BatchSize)
            SampledStates= [SampledState[0].Unwrap() for SampledState in SampledBatch]
            SampledActions= [SampledAction[1] for SampledAction in SampledBatch]
            SampledNextStates= [SampledNextState[2].Unwrap() for SampledNextState in SampledBatch]
            SampledRewards= [torch.Tensor([SampledReward[3]]) for SampledReward in SampledBatch]
            SampledTerminalSignals= [torch.Tensor([SampledTerminalSignal[4]]) for SampledTerminalSignal in SampledBatch]
            StateBatch= torch.stack(SampledStates)
            ActionBatch= torch.stack(SampledActions)
            NextStateBatch= torch.stack(SampledNextStates)
            RewardsBatch= torch.stack(SampledRewards)
            TerminalSignalsBatch= torch.stack(SampledTerminalSignals)
        else:
            raise ValueError('BatchSize too big')
        return StateBatch, ActionBatch, NextStateBatch, RewardsBatch, TerminalSignalsBatch

@dataclass
class Agent(EnforceClassTyping):
    '''This class represents the agent which will interact with the environment to create state state transitions which it will use to learn a good policy and value function.

    The Mass and Charge parameters deteremine how the interacts with its environment.
    The LearningRate, LossFunction, HiddenLayerSize, and MemorySize parameters determine its learning behaviour.'''
    Charge: float
    Mass: float
    LearningRate: float
    SoftUpdateRate: float
    DiscountRate: float
    MemorySize: int
    CurrentState: State 
    Memory: ReplayBuffer = field(init=False) 
    ActorModel: ActorNetwork = field(init=False)
    CriticModel: CriticNetwork = field(init=False)
    ActorTargetModel: ActorNetwork = field(init=False)
    CriticTargetModel: CriticNetwork = field(init=False)
    def __post_init__(self):
        self.Memory= ReplayBuffer(self.MemorySize)
        self.ActorModel= ActorNetwork(6, 10, 2)
        self.ActorTargetModel= self.ActorModel
        self.CriticModel= CriticNetwork(8, 10, 2)
        self.CriticTargetModel= self.CriticModel
        self.actor_optimizer  = optim.Adam(self.ActorModel.parameters(), lr=self.LearningRate)
        self.critic_optimizer = optim.Adam(self.CriticModel.parameters(), lr=self.LearningRate)
        self.critic_criterion= torch.nn.MSELoss()
    def ForceGenerator(self, Action: torch.Tensor)-> torch.Tensor:
        ForceVector= Action* 20
        return ForceVector
    @EnforceMethodTyping
    def UpdateCritic(self, StateBatch: torch.Tensor, ActionBatch: torch.Tensor, NextStateBatch: torch.Tensor, RewardBatch: torch.Tensor, TerminalSignalsbatch: torch.Tensor):
        'Updates the main critic network parameters by minimizing the difference between the bellman optimal expected return and the expected return predicted by the main critic network'
        ExpectedReturn= self.CriticModel.forward(StateBatch, ActionBatch)
        NextAction= self.ActorTargetModel.forward(NextStateBatch)
        BellmanOptimalReturn= RewardBatch+ (1-TerminalSignalsbatch)*self.DiscountRate*self.CriticTargetModel.forward(NextStateBatch, NextAction)
        critic_loss= self.critic_criterion(BellmanOptimalReturn, ExpectedReturn)

        self.CriticModel.zero_grad()
        self.critic_optimizer.zero_grad()
        critic_loss.backward(retain_graph=True) 
        self.critic_optimizer.step()
    @EnforceMethodTyping
    def UpdateActor(self, StateBatch: torch.Tensor):
        'Updates the main actor network parameters by maximizing the Expected Q-value predicted by the main critic network'
        policy_loss = -self.CriticModel.forward(StateBatch, self.ActorModel.forward(StateBatch)).mean()

        self.actor_optimizer.zero_grad()
        policy_loss.backward(retain_graph=True)
        self.actor_optimizer.step()
    @EnforceMethodTyping
    def UpdateTargetCritic(self):
        'Updates the target critic network parameters by making in it lag behind the main critic network updates'
        for target_param, param in zip(self.CriticTargetModel.parameters(), self.CriticModel.parameters()):
            target_param.data.copy_(param.data * self.SoftUpdateRate + target_param.data * (1.0 - self.SoftUpdateRate))
    @EnforceMethodTyping 
    def UpdateTargetActor(self):
        'Updates the target actor network parameters by making in it lag behind the main actor network updates'
        for target_param, param in zip(self.ActorTargetModel.parameters(), self.ActorModel.parameters()):
            target_param.data.copy_(param.data * self.SoftUpdateRate + target_param.data * (1.0 - self.SoftUpdateRate))

@dataclass
class Environment(EnforceClassTyping):
    '''This class represents the environment(i.e. the Space and Physics) the agent will learn from. 
    
    The UppperBoundX, LowerBoundX, UpperBoundY, and LowerBoundY determine the dimensions of the viable learning region of the environment.
    The FieldType determines the physics/dynamics of the environment
    The FieldSources shape the field '''
    UppperBoundX: float
    LowerBoundX: float
    UpperBoundY: float
    LowerBoundY: float
    FieldSources: list
    FieldType: Callable #[[list, torch.Tensor], torch.Tensor]
    def KineticEnergy(self, Mass: float, Velocity: float)-> float:
        return 0.5* Mass* Velocity**2
    @EnforceMethodTyping
    def ForceFieldStrength(self, Position: torch.Tensor)-> torch.Tensor:
        '''This method determines the field strength at any given position based the field type and field sources'''
        FieldStrengthVector = self.FieldType(self.FieldSources, Position)
        return FieldStrengthVector
    @EnforceMethodTyping
    def WorkDoneAgainstField(self, InitialPosition: torch.Tensor, FinalPosition: torch.Tensor, resolution: int= 5000)-> float:
        '''This method determines the amount of work required to get one position to another in the field'''
        XInterval= (FinalPosition[0] - InitialPosition[0]) / resolution
        YInterval= (FinalPosition[1] - InitialPosition[1]) / resolution
        XPositions = [InitialPosition[0] + i * XInterval for i in range(resolution + 1)]
        YPositions = [InitialPosition[1] + i * YInterval for i in range(resolution + 1)]
        WorkDone = 0
        for i in range(resolution):
            PositionFieldStrength = self.ForceFieldStrength(torch.Tensor([XPositions[i], YPositions[i]]))
            WorkDone += - (PositionFieldStrength[0]*XInterval + PositionFieldStrength[1]*YInterval)
        return WorkDone
    @EnforceMethodTyping
    def TransitionModel(self, LearningAgent: Agent, CurrentState: State, Action: torch.Tensor, TimeStep:float)-> State:
        '''This function determines how the state of the agent changes after a given period given the agents state and parameters'''
        InitialVelocity= CurrentState.Momentum/LearningAgent.Mass
        Acceleration= (Action- CurrentState.FieldStrength*LearningAgent.Charge)/LearningAgent.Mass
        FinalVelocity= InitialVelocity+ Acceleration*TimeStep
        NewPosition= InitialVelocity*TimeStep- (Acceleration*TimeStep**2)/2
        NewFieldForce= self.ForceFieldStrength(NewPosition)
        ResultantMomemntum= FinalVelocity*LearningAgent.Mass
        NewState= State(NewPosition, NewFieldForce, ResultantMomemntum)
        return NewState
    @EnforceMethodTyping
    def RewardModel(self, CurrentState: State, Action: torch.Tensor, NextState: State, Target: torch.Tensor, TerminalSignal: bool, DistanceSignificance: float, EnergySignificance: float, TerminalSignalSignificance: float, Resolution: int= 5000)-> float:
        '''This method determines how the agent is rewarded given a state transition. The reward determines the behaviour the agent should learn(i.e getting to the target and using the least amount of energy).'''
        DistanceGainedFromTarget= torch.norm(CurrentState.Position-Target)- torch.norm(NextState.Position-Target) 
        EnergyConsumed= self.WorkDoneAgainstField(CurrentState.Position, NextState.Position, Resolution)
        Cost= DistanceSignificance* DistanceGainedFromTarget+ EnergySignificance* EnergyConsumed+ TerminalSignalSignificance* TerminalSignal
        return -Cost.item()
    @EnforceMethodTyping
    def IsTerminalCondition(self, Position: torch.Tensor)-> bool:
        '''This method determines if a position is within the viable learning region of the environment'''
        if self.LowerBoundX <= Position[0] <= self.UppperBoundX or self.LowerBoundY <= Position[1] <= self.UpperBoundY:
            return False
        else:
            return True
    @EnforceMethodTyping
    def RandomState(self)->State:
        '''This method generates a random state within the viable learning region'''
        RandomPosition= torch.Tensor([random.uniform(self.LowerBoundX, self.UppperBoundX), random.uniform(self.LowerBoundY, self.UpperBoundY)])
        RandomFieldStrength= self.ForceFieldStrength(RandomPosition)
        RandomMomentum= torch.squeeze(torch.rand((1, 2)))
        return State(RandomPosition, RandomFieldStrength, RandomMomentum)

@dataclass
class DDPG(EnforceClassTyping):
    LearningAgent: Agent
    AgentEnvironment: Environment
    Target: torch.Tensor
    NumberOfEpisodes: int
    EpisodeDuration: int
    BatchSize: int
    alpha: float
    beta: float
    gamma: float
    TimeStep: float
    def __post_init__(self):
        if self.TimeStep< 0:
            raise ValueError('Time step cant be Negative')
    def CreateExperience(self):
        AgentAction= self.LearningAgent.ActorModel.forward(self.LearningAgent.CurrentState.Unwrap())
        NewState= self.AgentEnvironment.TransitionModel(self.LearningAgent, 
                                                      self.LearningAgent.CurrentState, 
                                                      AgentAction, 
                                                      self.TimeStep)
        TerminalSignal= self.AgentEnvironment.IsTerminalCondition(NewState.Position)
        Reward= self.AgentEnvironment.RewardModel(self.LearningAgent.CurrentState,
                                                    AgentAction,
                                                    NewState, 
                                                    self.Target, 
                                                    TerminalSignal, 
                                                    self.alpha,
                                                    self.beta, 
                                                    self.gamma) 
        return self.LearningAgent.CurrentState, AgentAction, NewState, Reward, TerminalSignal
    def ActionNoiseGenerator(self, Action: torch.Tensor, theta:float= 0.5, Mean: float= 0)-> torch.Tensor:
        OUNoise= -theta*Action+ Mean*torch.rand_like(Action)#np.random.randn
        NoisyAction= Action+ OUNoise*self.TimeStep
        return NoisyAction
    def TrainModel(self):
        '''This method runs the DDPG algorithm by letting it learn from the environment over the episodes'''
        InitialState= self.LearningAgent.CurrentState
        for _ in range(self.NumberOfEpisodes):
            self.LearningAgent.CurrentState = InitialState
            EpisodeReward = 0
            ReturnValues= []
            for _ in range(self.EpisodeDuration):
                Action = self.LearningAgent.ForceGenerator(self.LearningAgent.ActorModel.forward(self.LearningAgent.CurrentState.Unwrap()))
                NoisyAction = self.ActionNoiseGenerator(Action)
                NewState = self.AgentEnvironment.TransitionModel(self.LearningAgent, self.LearningAgent.CurrentState, NoisyAction, self.TimeStep) 
                TerminalSignal= self.AgentEnvironment.IsTerminalCondition(NewState.Position)
                Cost= self.AgentEnvironment.RewardModel(self.LearningAgent.CurrentState,
                                                    Action,
                                                    NewState, 
                                                    self.Target, 
                                                    TerminalSignal, 
                                                    self.alpha,
                                                    self.beta, 
                                                    self.gamma) 
                self.LearningAgent.Memory.AddExperience(self.LearningAgent.CurrentState, Action, NewState, Cost, TerminalSignal)
                if len(self.LearningAgent.Memory.Buffer) > self.BatchSize:
                    StateBatch, ActionBatch, NextStateBatch, RewardBatch, TerminalSignalsbatch= self.LearningAgent.Memory.SampleBuffer(self.BatchSize)
                    self.LearningAgent.UpdateCritic(StateBatch, ActionBatch, NextStateBatch, RewardBatch, TerminalSignalsbatch)
                    self.LearningAgent.UpdateActor(StateBatch)   
                    self.LearningAgent.UpdateTargetCritic() 
                    self.LearningAgent.UpdateTargetActor() 
                self.LearningAgent.CurrentState = NewState
                ReturnValues.append(Cost)
                # EpisodeReward  += Reward 
            plt.plot(ReturnValues)
            plt.show()

Charge1= Source(torch.tensor([-1, 0]), -1e-9)
Charge2= Source(torch.tensor([1, 0]), 1e-9)
ChargeSources= [Charge1, Charge2]
TestEnvironment= Environment(25.0, 
                             -25.0, 
                             25.0, 
                             -25.0, 
                             ChargeSources, 
                             ElectricField)
TestState= TestEnvironment.RandomState()
TestNextState= TestEnvironment.RandomState()
TestAgent= Agent(2.0, 2.0, 0.2, 0.2, 0.2, 64, TestState)
TestDDPG= DDPG(TestAgent, 
               TestEnvironment, 
               torch.Tensor([-10, 10]),
                10, 
                20, 
                10, 
                10.5, 
                10.2, 
                -100.0,
                0.5)
TestDDPG.TrainModel()