In [None]:
'''This repository contains a detailed implementation of the Reinforcement Learning Enviroment class'''
import matplotlib.pyplot as plt
import numpy as np
from dataclasses import *
import torch as T
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from typing import Any, Callable, Dict, List, Tuple, Union, Optional
from functools import wraps
import os
# shoul obey newtons laws in Homogenous vector field
@dataclass
class EnforceClassTyping:
    def __post_init__(self):
        for (name, field_type) in self.__annotations__.items():
            if not isinstance(self.__dict__[name], field_type):
                current_type = type(self.__dict__[name])
                raise TypeError(f"The field `{name}` was assigned by `{current_type}` instead of `{field_type}`")
        # print("Check is passed successfully")
def EnforceMethodTyping(func: Callable) -> Callable:
    'Enforces type annotation/hints for class mathods'
    arg_annotations = func.__annotations__
    if not arg_annotations:
        return func

    @wraps(func)
    def wrapper(self, *args: Tuple[Any], **kwargs: Dict[str, Any]) -> Any:
        for arg, annotation in zip(args, arg_annotations.values()):
            if not isinstance(arg, annotation):
                raise TypeError(f"Expected {annotation} for argument {arg}, got {type(arg)}.")

        for arg_name, arg_value in kwargs.items():
            if arg_name in arg_annotations:
                annotation = arg_annotations[arg_name]
                if not isinstance(arg_value, annotation):
                    raise TypeError(f"Expected {annotation} for keyword argument {arg_name}, got {type(arg_value)}.")

        return func(self, *args, **kwargs)

    return wrapper
def EnforceFunctionTyping(func: Callable) -> Callable:
    'Enforces type annotation/hints for other functions'
    @wraps(func)
    def wrapper(*args, **kwargs):
        # Check positional arguments
        for arg, annotation in zip(args, func.__annotations__.values()):
            if not isinstance(arg, annotation):
                raise TypeError(f"Expected {annotation} for {arg}, got {type(arg)}.")

        # Check keyword arguments
        for arg_name, arg_value in kwargs.items():
            if arg_name in func.__annotations__:
                annotation = func.__annotations__[arg_name]
                if not isinstance(arg_value, annotation):
                    raise TypeError(f"Expected {annotation} for {arg_name}, got {type(arg_value)}.")

        return func(*args, **kwargs)

    return wrapper
 
@dataclass
class Particle(EnforceClassTyping):
    'This class represents the electric field sources with its position in the field(Position) and the magnitude of the source(Charge)'
    Mass: float # m
    Charge: float #C
    Position: T.Tensor # m
    Velocity: T.Tensor #kg*m/s

Electron= Particle(Mass=1.0, Charge= -1e-9, Position=T.tensor([1.0, 2.0]), Velocity=T.tensor([0.0, 0.0]))

@dataclass
class ElectricField:
    FieldSources: list
    FieldHighBound: float
    FieldLowBound: float
    def __call__(self, ObservationPosition: T.Tensor):
        return self.ElectricFieldStrength(ObservationPosition)
    def ElectricFieldStrength(self, ObservationPosition: T.Tensor)->T.Tensor:
        'This function takes a list of sources and outputs the field strength experienced at any given point(s). This determines the physics of the field(an electric field in this case)'
        CoulombConstant = 8.9875e9 #N*m^2/C^2
        for FieldSource in self.FieldSources:
            if type(FieldSource) != Particle:
                raise TypeError("The input is not valid")
        if type(ObservationPosition[0]) != type(ObservationPosition[1]):
            raise TypeError("Incompatible Reference point data types")
        elif type(ObservationPosition[0]) != T.Tensor:
            raise TypeError("Invalid Reference point data type")
        elif ObservationPosition[0].size()!=ObservationPosition[1].size():
            raise TypeError("Incompatible Reference point dimensions")
        else: 
            ElectricFieldVector = T.zeros_like(ObservationPosition)
        for FieldSource in self.FieldSources:
            PositionMatrices= T.stack([T.ones_like(ObservationPosition[0])* FieldSource.Position[0].item(), 
                                            T.ones_like(ObservationPosition[1])* FieldSource.Position[1].item()])
            DisplacemnetVector = ObservationPosition - PositionMatrices
            DisplacementMagnitude = T.sqrt(DisplacemnetVector[0]**2 +DisplacemnetVector[1]**2)  # Magnitude of the displacement vector
            ElectricFieldVector += (CoulombConstant * FieldSource.Charge) / DisplacementMagnitude**3 * DisplacemnetVector
        return ElectricFieldVector #N/C or V/m
    @EnforceMethodTyping
    def WorkDoneAgainstField(self, InitialPosition: T.Tensor, FinalPosition: T.Tensor, resolution: int= 5000)-> float:
        '''This method determines the amount of work required to get one position to another in the field'''
        XInterval= (FinalPosition[0] - InitialPosition[0]) / resolution
        YInterval= (FinalPosition[1] - InitialPosition[1]) / resolution
        XPositions = [InitialPosition[0] + i * XInterval for i in range(resolution + 1)]
        YPositions = [InitialPosition[1] + i * YInterval for i in range(resolution + 1)]
        WorkDone = 0
        for i in range(resolution):
            PositionFieldStrength = self.ForceFieldStrength(T.Tensor([XPositions[i], YPositions[i]]))
            WorkDone += - (PositionFieldStrength[0]*XInterval + PositionFieldStrength[1]*YInterval)
        return WorkDone
    def KineticEnergy(self, Mass: float, Velocity: float)-> float:
        return 0.5* Mass* Velocity**2

    def PlotField(self, ObservationPosition: T.Tensor):
        'This funtion plots the 2D electric vector field'
        xd, yd = self.ElectricFieldStrength(ObservationPosition)
        xd = xd / T.sqrt(xd**2 + yd**2)
        yd = yd / T.sqrt(xd**2 + yd**2)
        color_aara = T.sqrt(xd**2+ yd**2)
        fig, ax = plt.subplots(1,1)
        cp = ax.quiver(ObservationPosition[0],ObservationPosition[1],xd,yd,color_aara)
        fig.colorbar(cp)
        plt.rcParams['figure.dpi'] = 150
        plt.show()

class CriticNetwork(nn.Module):
    def __init__(self, beta, input_dims, fc1_dims, fc2_dims, n_actions, name,
                 chkpt_dir='tmp/ddpg'):
        super(CriticNetwork, self).__init__()
        self.input_dims = input_dims
        self.fc1_dims = fc1_dims
        self.fc2_dims = fc2_dims
        self.n_actions = n_actions
        self.checkpoint_file = os.path.join(chkpt_dir,name+'_ddpg')
        self.fc1 = nn.Linear(*self.input_dims, self.fc1_dims)
        f1 = 1./np.sqrt(self.fc1.weight.data.size()[0])
        T.nn.init.uniform_(self.fc1.weight.data, -f1, f1)
        T.nn.init.uniform_(self.fc1.bias.data, -f1, f1)
        #self.fc1.weight.data.uniform_(-f1, f1)
        #self.fc1.bias.data.uniform_(-f1, f1)
        self.bn1 = nn.LayerNorm(self.fc1_dims)

        self.fc2 = nn.Linear(self.fc1_dims, self.fc2_dims)
        f2 = 1./np.sqrt(self.fc2.weight.data.size()[0])
        #f2 = 0.002
        T.nn.init.uniform_(self.fc2.weight.data, -f2, f2)
        T.nn.init.uniform_(self.fc2.bias.data, -f2, f2)
        #self.fc2.weight.data.uniform_(-f2, f2)
        #self.fc2.bias.data.uniform_(-f2, f2)
        self.bn2 = nn.LayerNorm(self.fc2_dims)

        self.action_value = nn.Linear(self.n_actions, self.fc2_dims)
        f3 = 0.003
        self.q = nn.Linear(self.fc2_dims, 1)
        T.nn.init.uniform_(self.q.weight.data, -f3, f3)
        T.nn.init.uniform_(self.q.bias.data, -f3, f3)
        #self.q.weight.data.uniform_(-f3, f3)
        #self.q.bias.data.uniform_(-f3, f3)

        self.optimizer = optim.Adam(self.parameters(), lr=beta)
        self.device = T.device('cuda:0' if T.cuda.is_available() else 'cuda:1')

        self.to(self.device)

    def forward(self, state, action):
        state_value = self.fc1(state)
        state_value = self.bn1(state_value)
        state_value = F.relu(state_value)
        state_value = self.fc2(state_value)
        state_value = self.bn2(state_value)

        action_value = F.relu(self.action_value(action))
        state_action_value = F.relu(T.add(state_value, action_value))
        state_action_value = self.q(state_action_value)

        return state_action_value

    def save_checkpoint(self):
        print('... saving checkpoint ...')
        T.save(self.state_dict(), self.checkpoint_file)

    def load_checkpoint(self):
        print('... loading checkpoint ...')
        self.load_state_dict(T.load(self.checkpoint_file))

class ActorNetwork(nn.Module):
    def __init__(self, alpha, input_dims, fc1_dims, fc2_dims, n_actions, name,
                 chkpt_dir='tmp/ddpg'):
        super(ActorNetwork, self).__init__()
        self.input_dims = input_dims
        self.fc1_dims = fc1_dims
        self.fc2_dims = fc2_dims
        self.n_actions = n_actions
        self.checkpoint_file = os.path.join(chkpt_dir,name+'_ddpg')
        self.fc1 = nn.Linear(*self.input_dims, self.fc1_dims)
        f1 = 1./np.sqrt(self.fc1.weight.data.size()[0])
        T.nn.init.uniform_(self.fc1.weight.data, -f1, f1)
        T.nn.init.uniform_(self.fc1.bias.data, -f1, f1)
        #self.fc1.weight.data.uniform_(-f1, f1)
        #self.fc1.bias.data.uniform_(-f1, f1)
        self.bn1 = nn.LayerNorm(self.fc1_dims)

        self.fc2 = nn.Linear(self.fc1_dims, self.fc2_dims)
        #f2 = 0.002
        f2 = 1./np.sqrt(self.fc2.weight.data.size()[0])
        T.nn.init.uniform_(self.fc2.weight.data, -f2, f2)
        T.nn.init.uniform_(self.fc2.bias.data, -f2, f2)
        #self.fc2.weight.data.uniform_(-f2, f2)
        #self.fc2.bias.data.uniform_(-f2, f2)
        self.bn2 = nn.LayerNorm(self.fc2_dims)

        #f3 = 0.004
        f3 = 0.003
        self.mu = nn.Linear(self.fc2_dims, self.n_actions)
        T.nn.init.uniform_(self.mu.weight.data, -f3, f3)
        T.nn.init.uniform_(self.mu.bias.data, -f3, f3)
        #self.mu.weight.data.uniform_(-f3, f3)
        #self.mu.bias.data.uniform_(-f3, f3)

        self.optimizer = optim.Adam(self.parameters(), lr=alpha)
        self.device = T.device('cuda:0' if T.cuda.is_available() else 'cuda:1')

        self.to(self.device)

    def forward(self, state):
        x = self.fc1(state)
        x = self.bn1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = self.bn2(x)
        x = F.relu(x)
        x = T.tanh(self.mu(x))

        return x

    def save_checkpoint(self):
        print('... saving checkpoint ...')
        T.save(self.state_dict(), self.checkpoint_file)

    def load_checkpoint(self):
        print('... loading checkpoint ...')
        self.load_state_dict(T.load(self.checkpoint_file))

@dataclass 
class ReplayBuffer(EnforceClassTyping):
    '''This class represents the Replay buffer which stores state transitions(State, Action, NextState, Reward, Terminal Signal) which will be used to train the Actor and Critic Networks. 
    The replay buffer'''
    BufferSize: int
    Buffer: list = None
    def __post_init__(self):
        if self.Buffer is None:
            self.Buffer = []
    @EnforceMethodTyping
    def AddExperience(self, State: State, Action: T.Tensor, NextState: State, Reward: float, TerminalState: bool):
        '''This method adds a state transition to the replay buffer'''
        if len(self.Buffer) < self.BufferSize:
            self.Buffer.append([State, Action, NextState, Reward, TerminalState])
        else:
            self.Buffer.pop(0)
            self.Buffer.append([State, Action, NextState, Reward, TerminalState])
    @EnforceMethodTyping
    def SampleBuffer(self, BatchSize: int):
        '''This method randomly samples the replay buffer to ouput a batches of state transition variables'''
        if len(self.Buffer) >= BatchSize:
            SampledBatch = random.sample(self.Buffer, BatchSize)
            SampledStates= [SampledState[0].Unwrap() for SampledState in SampledBatch]
            SampledActions= [SampledAction[1] for SampledAction in SampledBatch]
            SampledNextStates= [SampledNextState[2].Unwrap() for SampledNextState in SampledBatch]
            SampledRewards= [T.Tensor([SampledReward[3]]) for SampledReward in SampledBatch]
            SampledTerminalSignals= [T.Tensor([SampledTerminalSignal[4]]) for SampledTerminalSignal in SampledBatch]
            StateBatch= T.stack(SampledStates)
            ActionBatch= T.stack(SampledActions)
            NextStateBatch= T.stack(SampledNextStates)
            RewardsBatch= T.stack(SampledRewards)
            TerminalSignalsBatch= T.stack(SampledTerminalSignals)
        else:
            raise ValueError('BatchSize too big')
        return StateBatch, ActionBatch, NextStateBatch, RewardsBatch, TerminalSignalsBatch

@dataclass
class ParticleInField(EnforceClassTyping):
    '''This class represents the environment(i.e. the Space and Physics) the agent will learn from. 
    
    The UppperBoundX, LowerBoundX, UpperBoundY, and LowerBoundY determine the dimensions of the viable learning region of the environment.
    The FieldType determines the physics/dynamics of the environment
    The FieldSources shape the field '''
    Field: ElectricField
    ChargedParticle: Particle
    @dataclass 
    class State(EnforceClassTyping):
        '''This class represents the state of the Agent with its Position, Momentum and the Field Strength if experiences at its Position. 
        These are parameters the agent is able to observe, they uniquely define the state of the agent.'''
        Position: T.Tensor # m
        Momentum: T.Tensor #kg*m/s
        Time: float  # s
        def Unwrap(self)->T.Tensor:
            '''This function converts the state parameters to a single tensor for processing. '''
            return T.cat([self.Position, 
                            self.Momentum])
    Target: T.Tensor
    DistanceWeight: float
    EnergyWeight: float
    TerminalSignalWeight: float
    CurrentState: State = None

    def __post_init__(self):
        if self.CurrentState is None:
            self.CurrentState= self.RandomState()

    @EnforceMethodTyping
    def TransitionModel(self, State: State, Action: T.Tensor, TimeStep:float)-> State:
        '''This function determines how the state of the agent changes after a given period given the agents state and parameters'''
        InitialVelocity= State.Momentum/ self.ChargedParticle.Mass
        Acceleration= (Action- self.CurrentState.FieldStrength* self.ChargedParticle.Charge)/self.ChargedParticle.Mass
        FinalVelocity= InitialVelocity+ Acceleration*TimeStep
        NewPosition= InitialVelocity*TimeStep- (Acceleration*TimeStep**2)/2
        NewFieldForce= self.ForceFieldStrength(NewPosition)
        ResultantMomemntum= FinalVelocity* self.ChargedParticle.Mass
        NewState= self.State(NewPosition, NewFieldForce, ResultantMomemntum)
        return NewState
    @EnforceMethodTyping
    def IsTerminalCondition(self, State: State)-> bool:
        '''This method determines if a position is within the viable learning region of the environment'''
        if self.LowerBoundX <= Position[0] <= self.UppperBoundX or self.LowerBoundY <= Position[1] <= self.UpperBoundY:
            return False
        else:
            return True
    @EnforceMethodTyping
    def RewardModel(self, State: State, Action: T.Tensor, NextState: State, TerminalSignal: bool, Resolution: int= 5000)-> float:
        '''This method determines how the agent is rewarded given a state transition. The reward determines the behaviour the agent should learn(i.e getting to the target and using the least amount of energy).'''
        DistanceGainedFromTarget= T.norm(self.CurrentState.Position-Target)- T.norm(NextState.Position-Target) 
        EnergyConsumed= self.WorkDoneAgainstField(self.CurrentState.Position, NextState.Position, Resolution)
        Cost= self.DistanceWeight* DistanceGainedFromTarget+ self.EnergyWeight* EnergyConsumed+ self.TerminalSignalWeight* TerminalSignal
        return -Cost.item()
    @EnforceMethodTyping
    def RandomState(self)->State:
        '''This method generates a random state within the viable learning region'''
        RandomPosition= T.Tensor([random.uniform(self.LowerBoundX, self.UppperBoundX), random.uniform(self.LowerBoundY, self.UpperBoundY)])
        RandomFieldStrength= self.ForceFieldStrength(RandomPosition)
        RandomMomentum= T.squeeze(T.rand((1, 2)))
        return State(RandomPosition, RandomFieldStrength, RandomMomentum)
    
    def Render(self):
        pass

    def Run(self, RunDuration: float):
        pass

class Agent(object):
    def __init__(self, 
                 alpha, 
                 beta, 
                 input_dims, 
                 tau, 
                 env, 
                 gamma=0.99,
                 n_actions=2, 
                 max_size=1000000, 
                 layer1_size=400,
                 layer2_size=300, 
                 batch_size=64):
        self.gamma = gamma
        self.tau = tau
        self.memory = ReplayBuffer(max_size, input_dims, n_actions)
        self.batch_size = batch_size

        self.actor = ActorNetwork(alpha, input_dims, layer1_size,
                                  layer2_size, n_actions=n_actions,
                                  name='Actor')
        self.critic = CriticNetwork(beta, input_dims, layer1_size,
                                    layer2_size, n_actions=n_actions,
                                    name='Critic')

        self.target_actor = ActorNetwork(alpha, input_dims, layer1_size,
                                         layer2_size, n_actions=n_actions,
                                         name='TargetActor')
        self.target_critic = CriticNetwork(beta, input_dims, layer1_size,
                                           layer2_size, n_actions=n_actions,
                                           name='TargetCritic')
        self.env= env
        self.noise = OUActionNoise(mu=np.zeros(n_actions))

        self.update_network_parameters(tau=1)

    def choose_action(self, observation):
        self.actor.eval()
        observation = T.tensor(observation, dtype=T.float).to(self.actor.device)
        mu = self.actor.forward(observation).to(self.actor.device)
        mu_prime = mu + T.tensor(self.noise(),
                                 dtype=T.float).to(self.actor.device)
        self.actor.train()
        return mu_prime.cpu().detach().numpy()

    def DDPGAlgorithm(self):
        score_history = []
        for i in range(1000):
            obs = self.env.reset()
            done = False
            score = 0
            while not done:
                act = agent.choose_action(obs)
                new_state, reward, done, info = self.env.step(act)
                agent.remember(obs, act, reward, new_state, int(done))
                agent.learn()
                score += reward
                obs = new_state
                #env.render()
            score_history.append(score)
    def remember(self, state, action, reward, new_state, done):
        self.memory.store_transition(state, action, reward, new_state, done)

    def learn(self):
        if self.memory.mem_cntr < self.batch_size:
            return
        state, action, reward, new_state, done = \
                                      self.memory.sample_buffer(self.batch_size)

        reward = T.tensor(reward, dtype=T.float).to(self.critic.device)
        done = T.tensor(done).to(self.critic.device)
        new_state = T.tensor(new_state, dtype=T.float).to(self.critic.device)
        action = T.tensor(action, dtype=T.float).to(self.critic.device)
        state = T.tensor(state, dtype=T.float).to(self.critic.device)

        self.target_actor.eval()
        self.target_critic.eval()
        self.critic.eval()
        target_actions = self.target_actor.forward(new_state)
        critic_value_ = self.target_critic.forward(new_state, target_actions)
        critic_value = self.critic.forward(state, action)

        target = []
        for j in range(self.batch_size):
            target.append(reward[j] + self.gamma*critic_value_[j]*done[j])
        target = T.tensor(target).to(self.critic.device)
        target = target.view(self.batch_size, 1)

        self.critic.train()
        self.critic.optimizer.zero_grad()
        critic_loss = F.mse_loss(target, critic_value)
        critic_loss.backward()
        self.critic.optimizer.step()

        self.critic.eval()
        self.actor.optimizer.zero_grad()
        mu = self.actor.forward(state)
        self.actor.train()
        actor_loss = -self.critic.forward(state, mu)
        actor_loss = T.mean(actor_loss)
        actor_loss.backward()
        self.actor.optimizer.step()

        self.update_network_parameters()

    def update_network_parameters(self, tau=None):
        if tau is None:
            tau = self.tau

        actor_params = self.actor.named_parameters()
        critic_params = self.critic.named_parameters()
        target_actor_params = self.target_actor.named_parameters()
        target_critic_params = self.target_critic.named_parameters()

        critic_state_dict = dict(critic_params)
        actor_state_dict = dict(actor_params)
        target_critic_dict = dict(target_critic_params)
        target_actor_dict = dict(target_actor_params)

        for name in critic_state_dict:
            critic_state_dict[name] = tau*critic_state_dict[name].clone() + \
                                      (1-tau)*target_critic_dict[name].clone()

        self.target_critic.load_state_dict(critic_state_dict)

        for name in actor_state_dict:
            actor_state_dict[name] = tau*actor_state_dict[name].clone() + \
                                      (1-tau)*target_actor_dict[name].clone()
        self.target_actor.load_state_dict(actor_state_dict)

        """
        #Verify that the copy assignment worked correctly
        target_actor_params = self.target_actor.named_parameters()
        target_critic_params = self.target_critic.named_parameters()

        critic_state_dict = dict(target_critic_params)
        actor_state_dict = dict(target_actor_params)
        print('\nActor Networks', tau)
        for name, param in self.actor.named_parameters():
            print(name, T.equal(param, actor_state_dict[name]))
        print('\nCritic Networks', tau)
        for name, param in self.critic.named_parameters():
            print(name, T.equal(param, critic_state_dict[name]))
        input()
        """
    def save_models(self):
        self.actor.save_checkpoint()
        self.target_actor.save_checkpoint()
        self.critic.save_checkpoint()
        self.target_critic.save_checkpoint()

    def load_models(self):
        self.actor.load_checkpoint()
        self.target_actor.load_checkpoint()
        self.critic.load_checkpoint()
        self.target_critic.load_checkpoint()

    def check_actor_params(self):
        current_actor_params = self.actor.named_parameters()
        current_actor_dict = dict(current_actor_params)
        original_actor_dict = dict(self.original_actor.named_parameters())
        original_critic_dict = dict(self.original_critic.named_parameters())
        current_critic_params = self.critic.named_parameters()
        current_critic_dict = dict(current_critic_params)
        print('Checking Actor parameters')

        for param in current_actor_dict:
            print(param, T.equal(original_actor_dict[param], current_actor_dict[param]))
        print('Checking critic parameters')
        for param in current_critic_dict:
            print(param, T.equal(original_critic_dict[param], current_critic_dict[param]))
        input()
   