In [None]:
'''This repository contains a detailed implementation of the Reinforcement Learning Replay Buffer'''
import matplotlib.pyplot as plt
import numpy as np
from dataclasses import *
import torch 
from typing import Any, Callable, Dict, List, Tuple, Union, Optional
from functools import wraps
import random


In [None]:
@dataclass
class EnforceClassTyping:
    def __post_init__(self):
        for (name, field_type) in self.__annotations__.items():
            if not isinstance(self.__dict__[name], field_type):
                current_type = type(self.__dict__[name])
                raise TypeError(f"The field `{name}` was assigned by `{current_type}` instead of `{field_type}`")
        # print("Check is passed successfully")
def EnforceMethodTyping(func: Callable) -> Callable:
    'Enforces type annotation/hints for class mathods'
    arg_annotations = func.__annotations__
    if not arg_annotations:
        return func

    @wraps(func)
    def wrapper(self, *args: Tuple[Any], **kwargs: Dict[str, Any]) -> Any:
        for arg, annotation in zip(args, arg_annotations.values()):
            if not isinstance(arg, annotation):
                raise TypeError(f"Expected {annotation} for argument {arg}, got {type(arg)}.")

        for arg_name, arg_value in kwargs.items():
            if arg_name in arg_annotations:
                annotation = arg_annotations[arg_name]
                if not isinstance(arg_value, annotation):
                    raise TypeError(f"Expected {annotation} for keyword argument {arg_name}, got {type(arg_value)}.")

        return func(self, *args, **kwargs)

    return wrapper
def EnforceFunctionTyping(func: Callable) -> Callable:
    'Enforces type annotation/hints for other functions'
    @wraps(func)
    def wrapper(*args, **kwargs):
        # Check positional arguments
        for arg, annotation in zip(args, func.__annotations__.values()):
            if not isinstance(arg, annotation):
                raise TypeError(f"Expected {annotation} for {arg}, got {type(arg)}.")

        # Check keyword arguments
        for arg_name, arg_value in kwargs.items():
            if arg_name in func.__annotations__:
                annotation = func.__annotations__[arg_name]
                if not isinstance(arg_value, annotation):
                    raise TypeError(f"Expected {annotation} for {arg_name}, got {type(arg_value)}.")

        return func(*args, **kwargs)

    return wrapper


In [None]:
@dataclass
class Source(EnforceClassTyping):
    'This class represents the electric field sources with its position in the field(Position) and the magnitude of the source(Charge)'
    Position: torch.Tensor # m
    Charge: float #C

@EnforceFunctionTyping
def ElectricField(FieldSources: list, ObservationPosition: torch.Tensor)->torch.Tensor:
    'This function takes a list of sources and outputs the field strength experienced at any given point(s). This determines the physics of the field(an electric field in this case)'
    CoulombConstant = 8.9875e9 #N*m^2/C^2
    for FieldSource in FieldSources:
        if type(FieldSource) != Source:
            raise TypeError("The input is not valid")
    if type(ObservationPosition[0]) != type(ObservationPosition[1]):
         raise TypeError("Incompatible Reference point data types")
    elif type(ObservationPosition[0]) != torch.Tensor:
        raise TypeError("Invalid Reference point data type")
    elif ObservationPosition[0].size()!=ObservationPosition[1].size():
        raise TypeError("Incompatible Reference point dimensions")
    else: 
        ElectricFieldVector = torch.zeros_like(ObservationPosition)
    for FieldSource in FieldSources:
        PositionMatrices= torch.stack([torch.ones_like(ObservationPosition[0])* FieldSource.Position[0].item(), 
                                        torch.ones_like(ObservationPosition[1])* FieldSource.Position[1].item()])
        DisplacemnetVector = ObservationPosition - PositionMatrices
        DisplacementMagnitude = torch.sqrt(DisplacemnetVector[0]**2 +DisplacemnetVector[1]**2)  # Magnitude of the displacement vector
        ElectricFieldVector += (CoulombConstant * FieldSource.Charge) / DisplacementMagnitude**3 * DisplacemnetVector
    return ElectricFieldVector #N/C or V/m

@EnforceFunctionTyping
def PlotField(Sources: list, ObservationPosition: torch.Tensor):
    'This funtion plots the 2D electric vector field'
    xd, yd = ElectricField(Sources, ObservationPosition)
    xd = xd / torch.sqrt(xd**2 + yd**2)
    yd = yd / torch.sqrt(xd**2 + yd**2)
    color_aara = torch.sqrt(xd**2+ yd**2)
    fig, ax = plt.subplots(1,1)
    cp = ax.quiver(ObservationPosition[0],ObservationPosition[1],xd,yd,color_aara)
    fig.colorbar(cp)
    plt.rcParams['figure.dpi'] = 150
    plt.show()

@dataclass 
class State(EnforceClassTyping):
    '''This class represents the state of the Agent with its Position, Momentum and the Field Strength if experiences at its Position. 
       These are parameters the agent is able to observe, they uniquely define the state of the agent.'''
    Position: torch.Tensor # m
    FieldStrength: torch.Tensor #N/C or V/m
    Momentum: torch.Tensor #kg*m/s
    def Unwrap(self)->torch.Tensor:
        '''This function converts the state parameters to a single tensor for processing. '''
        return torch.cat([self.Position, 
                          self.FieldStrength,
                          self.Momentum])

@dataclass 
class ReplayBuffer(EnforceClassTyping):
    '''This class represents the Replay buffer which stores state transitions(State, Action, NextState, Reward, Terminal Signal) which will be used to train the Actor and Critic Networks. 
    The replay buffer'''
    BufferSize: int
    Buffer: list = None
    def __post_init__(self):
        if self.Buffer is None:
            self.Buffer = []
    @EnforceMethodTyping
    def AddExperience(self, State: State, Action: torch.Tensor, NextState: State, Reward: float, TerminalState: bool):
        '''This method adds a state transition to the replay buffer'''
        if len(self.Buffer) < self.BufferSize:
            self.Buffer.append([State, Action, NextState, Reward, TerminalState])
        else:
            self.Buffer.pop(0)
            self.Buffer.append([State, Action, NextState, Reward, TerminalState])
    @EnforceMethodTyping
    def SampleBuffer(self, BatchSize: int):
        '''This method randomly samples the replay buffer to ouput a batches of state transition variables'''
        if len(self.Buffer) >= BatchSize:
            SampledBatch = random.sample(self.Buffer, BatchSize)
            SampledStates= [SampledState[0].Unwrap() for SampledState in SampledBatch]
            SampledActions= [SampledAction[1] for SampledAction in SampledBatch]
            SampledNextStates= [SampledNextState[2].Unwrap() for SampledNextState in SampledBatch]
            SampledRewards= [torch.Tensor([SampledReward[3]]) for SampledReward in SampledBatch]
            SampledTerminalSignals= [torch.Tensor([SampledTerminalSignal[4]]) for SampledTerminalSignal in SampledBatch]
            StateBatch= torch.stack(SampledStates)
            ActionBatch= torch.stack(SampledActions)
            NextStateBatch= torch.stack(SampledNextStates)
            RewardsBatch= torch.stack(SampledRewards)
            TerminalSignalsBatch= torch.stack(SampledTerminalSignals)
        else:
            raise ValueError('BatchSize too big')
        return StateBatch, ActionBatch, NextStateBatch, RewardsBatch, TerminalSignalsBatch