In [None]:
'''This module contains a detailed implementation of the weight normalization algorithm'''
import torch
import torch.nn as nn
import numpy as np
from dataclasses import *
from functools import wraps
import matplotlib.pyplot as plt
from typing import Any, Callable, Dict, List, Tuple, Union, Optional

In [None]:
@dataclass
class ReLU:
    def __call__(self, x: torch.Tensor) -> torch.Tensor:
        return torch.relu(x)
    def derivative(self, x: torch.Tensor) -> torch.Tensor:
        return torch.where(x > 0, torch.tensor(1.0), torch.tensor(0.0))
@dataclass
class Tanh:
    def __call__(self, x: torch.Tensor) -> torch.Tensor:
        return torch.tanh(x)
    def derivative(self, x: torch.Tensor) -> torch.Tensor:
        return 1-x**2
@dataclass
class Null:
    def __call__(self, x: torch.Tensor) -> torch.Tensor:
        return x
    def derivative(self, x: torch.Tensor) -> torch.Tensor:
        return torch.ones_like(x)
def mse_grad(x, y):
    return torch.mean(-2*torch.mean(y-x))
def mse(x, y):
    return torch.mean(torch.mean(((x-y)**2)/2))
def LayerNormalization(InputData: torch.Tensor, epsilon= 0.00000000001):
    mean = InputData.mean(-1, keepdim = True)
    std = InputData.std(-1, keepdim = True, unbiased=False)
    NormalizedLayer= (InputData-mean)/(std+epsilon)
    return NormalizedLayer
def LayerNormalizationDerivative(x: torch.Tensor, epsilon= 0.00000000001):
    if x.ndim<2:
        x= x.unsqueeze(dim=0)
    N = x.shape[1]
    I = torch.eye(N)
    mean = x.mean(-1, keepdim = True)
    std = x.std(-1, keepdim = True, unbiased=False)
    return ((N * I - 1) / (N * std + epsilon)) - (( (x - mean)*((x - mean).t())) / (N * std**3 + epsilon))

In [None]:
@dataclass
class EnforceClassTyping:
    def __post_init__(self):
        for (name, field_type) in self.__annotations__.items():
            if not isinstance(self.__dict__[name], field_type):
                current_type = type(self.__dict__[name])
                raise TypeError(f"The field `{name}` was assigned by `{current_type}` instead of `{field_type}`")
        # print("Check is passed successfully")
def EnforceMethodTyping(func: Callable) -> Callable:
    'Enforces type annotation/hints for class mathods'
    arg_annotations = func.__annotations__
    if not arg_annotations:
        return func

    @wraps(func)
    def wrapper(self, *args: Tuple[Any], **kwargs: Dict[str, Any]) -> Any:
        for arg, annotation in zip(args, arg_annotations.values()):
            if not isinstance(arg, annotation):
                raise TypeError(f"Expected {annotation} for argument {arg}, got {type(arg)}.")

        for arg_name, arg_value in kwargs.items():
            if arg_name in arg_annotations:
                annotation = arg_annotations[arg_name]
                if not isinstance(arg_value, annotation):
                    raise TypeError(f"Expected {annotation} for keyword argument {arg_name}, got {type(arg_value)}.")

        return func(self, *args, **kwargs)

    return wrapper
def EnforceFunctionTyping(func: Callable) -> Callable:
    'Enforces type annotation/hints for other functions'
    @wraps(func)
    def wrapper(*args, **kwargs):
        # Check positional arguments
        for arg, annotation in zip(args, func.__annotations__.values()):
            if not isinstance(arg, annotation):
                raise TypeError(f"Expected {annotation} for {arg}, got {type(arg)}.")

        # Check keyword arguments
        for arg_name, arg_value in kwargs.items():
            if arg_name in func.__annotations__:
                annotation = func.__annotations__[arg_name]
                if not isinstance(arg_value, annotation):
                    raise TypeError(f"Expected {annotation} for {arg_name}, got {type(arg_value)}.")

        return func(*args, **kwargs)

    return wrapper

In [None]:
class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.input_layer = nn.Linear(8, 10)
        self.hidden_layer = nn.Linear(10, 1)
        self.weigh= torch._weight_norm
    def forward(self, x):
        x = torch.relu(self.input_layer(x))
        x = self.hidden_layer(x)
        return x


In [None]:
class CriticNetwork(EnforceClassTyping):
    '''This object represents the Value Function(Critic) used to estimate the expected value of a state-action pair.
    This value function is a neural network that will learn to more accurately predict the expected value given a state-action pair.
    This model uses weight normalization

        Parameters:
        - LayerSizes: List of integers representing the number of neurons in each layer.
        - LayerActivations: List of strings representing the activation function for each layer.
        - NormalizationLayers: List of layers to use normalization. Default is [BatchNormalization]

        Attributes:
        - WeightParameters:  A list containing all weights parameter matrices.
        - WeightCoefficient:  A list containing all weights magnitude matrices.
        - Bias:   A list containing all bias vectors (one per hidden layer)
        - Gamma:  A list of gamma values used by batch normalization layers
        - Beta:   A list of beta values used by batch normalization layers

        Methods:
        - forward:  Computes the output of the network given an input vector x.
        - compute_gradients: Compute the gradients of the loss with respect to the parameters of the network.
        - update_model: Update the model using gradient descent on the parameters.'''
    def __init__(self, LayerSizes: list[int], LayerActivations: list[Callable], NormalizationLayers: list[bool]= None):
        assert len(LayerSizes)-1 == len(LayerActivations)
        self.LayerSizes= LayerSizes
        self.WeightMagnitudes: list[torch.Tensor]= [torch.randn(1, LayerSizes[x+1]) for x in range(len(LayerSizes)-1)]
        self.WeightParameters: list[torch.Tensor]= [torch.randn(LayerSizes[x], LayerSizes[x+1]) for x in range(len(LayerSizes)-1)]
        self.Bias: list[torch.Tensor]= [torch.randn(1, LayerSizes[x+1]) for x in range(len(LayerSizes)-1)]
        self.LayerActivations= LayerActivations
        if  NormalizationLayers is None:
            self.NormalizationLayers= [1]*(len(LayerSizes)-2) +[0]
        else:
            assert len(LayerSizes)-1 == len(NormalizationLayers)
            self.NormalizationLayers= NormalizationLayers
        self.Gamma= [torch.ones(1, LayerSizes[x+1]) for x in range(len(LayerSizes)-1)]
        self.Beta= [torch.zeros(1, LayerSizes[x+1]) for x in range(len(LayerSizes)-1)]
    @EnforceMethodTyping
    def forward(self, 
                StateInput: torch.Tensor, 
                ActionInput: torch.Tensor, 
                full: bool= False)-> torch.Tensor | tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]]:
        """
        Takes State Parameters and Action Parameters to outputs the expected return of the state-action pair predicted by the Main critic network

        Input:
        - StateInput: Tensor of shape (batch_size, state_dimension)
        - ActionInput: Tensor of shape (batch_size, action_dimension)
        - full: 

        Output:
        - LayerConnections:
        - NormalizedLayer:   
        - ActivatedLayer: 
        - MeanShiftedLayer: 
        - LayerValue:
        """
        InputData: torch.Tensor = torch.cat([StateInput, ActionInput], dim=StateInput.ndim-1)
        if InputData.ndim < 2:
            InputData= InputData.unsqueeze(dim=0)
        LayerConnections=  [torch.zeros(1)]*len(self.WeightParameters)
        NormalizedLayer= [torch.zeros(1)]*len(self.WeightParameters)
        ActivatedLayer= [torch.zeros(1)]*len(self.WeightParameters)
        MeanShiftedLayer= [torch.zeros(1)]*len(self.WeightParameters)
        LayerValue= [InputData]
        def NormalizedReLULayerForward(LayerConnection: torch.Tensor, 
                                       ActivationFunction: Callable, 
                                       Gamma: torch.Tensor, 
                                       Beta: torch.Tensor)-> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
            """
            Sub-routine to perform forward propagation on the InputData.

            Parameters:
            - LayerConnection: Input tensor.
            - ActivationFunction:  ActivationFunction
            - Gamma: Gamma
            - Beta:

            Returns:
            - ActivatedLayer, NormalizedLayer, MeanShiftedLayer.
            """
            ActivatedLayer= ActivationFunction(LayerConnection)
            NormalizedLayer= LayerNormalization(ActivatedLayer)
            MeanShiftedLayer= (Gamma* NormalizedLayer)+ Beta
            return NormalizedLayer, MeanShiftedLayer, ActivatedLayer
        def NormalizedLayerForward(LayerConnection: torch.Tensor, 
                                       ActivationFunction: Callable, 
                                       Gamma: torch.Tensor, 
                                       Beta: torch.Tensor)-> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
            """
            Sub-routine to perform forward propagation on the InputData.

            Parameters:
            - LayerConnection: Input tensor.
            - ActivationFunction:  ActivationFunction
            - Gamma: Gamma
            - Beta:

            Returns:
            - ActivatedLayer, NormalizedLayer, MeanShiftedLayer.
            """
            NormalizedLayer= LayerNormalization(LayerConnection)
            MeanShiftedLayer= (Gamma* NormalizedLayer)+ Beta
            ActivatedLayer= ActivationFunction(MeanShiftedLayer)
            return NormalizedLayer, MeanShiftedLayer, ActivatedLayer
        for i in range(len(self.WeightParameters)):
            NormalizedWeight= self.WeightMagnitudes[i]* torch.div(self.WeightParameters[i], torch.norm(self.WeightParameters[i], dim=0, keepdim=True))
            LayerConnections[i]= torch.matmul(LayerValue[i], NormalizedWeight) + self.Bias[i]
            ActivationFunction: Callable= self.LayerActivations[i]()
            if self.NormalizationLayers[i]== 1:
                if self.LayerActivations[i]== ReLU:
                    NormalizedLayer[i], MeanShiftedLayer[i], ActivatedLayer[i]= NormalizedReLULayerForward(LayerConnections[i], ActivationFunction, self.Gamma[i], self.Beta[i])
                    LayerValue.append(MeanShiftedLayer[i])
                else:
                    NormalizedLayer[i], MeanShiftedLayer[i], ActivatedLayer[i]= NormalizedLayerForward(LayerConnections[i], ActivationFunction, self.Gamma[i], self.Beta[i])
                    LayerValue.append(ActivatedLayer[i])
            else:
                ActivatedLayer[i]= ActivationFunction(LayerConnections[i])
                LayerValue.append(ActivatedLayer[i])
        if full is False:
            return LayerValue[-1]
        else:
            return LayerConnections, ActivatedLayer, MeanShiftedLayer, NormalizedLayer, LayerValue
    @EnforceMethodTyping
    def compute_gradients(self, StateInput: torch.Tensor, 
                          ActionInput: torch.Tensor, 
                          OptimalReturn: torch.Tensor, 
                          LossDerivative: Callable):
        """
        This function computes the gradient of the Weights and Biases of the network using the given derivative of a loss functio, input data and target data

        Input:
        - StateInput:      Tensor of shape (batch_size, state_dimension)
        - ActionInput:       Tensor of shape (batch_size, action_dimension)
        - OptimalReturn:   Tensor of shape (batch_size).
        - LossDerivative:  Function that returns the derivative of loss with respect to network parameters.
        Output:
        - CummulativeWeightMagnitudesGradient:   Gradient of the network weights with respect to the loss function.
        - CummulativeWeightParametersGradient:   Gradient of the network weights with respect to the loss function.
        - CummulativeBiasGradient:     Gradient of the network biases with respect to the loss function.
        - CummulativeGammaGradient:    Gradient of gamma with respect to the loss function.
        - CummulativeBetaGradient:     Gradient of beta with respect to the loss function.
        """
        CummulativeBiasGradient: list[torch.Tensor] = [torch.zeros_like(bias) for bias in self.Bias]
        CummulativeWeightMagnitudesGradient: list[torch.Tensor] = [torch.zeros_like(weight) for weight in self.WeightMagnitudes]
        CummulativeWeightParametersGradient: list[torch.Tensor] = [torch.zeros_like(weight) for weight in self.WeightParameters]
        CummulativeGammaGradient: list[torch.Tensor] = [torch.zeros_like(gamma) for gamma in self.Gamma]
        CummulativeBetaGradient: list[torch.Tensor] = [torch.zeros_like(beta) for beta in self.Beta]
        NormalizedWeight: list[torch.Tensor]= [g* torch.div(v, torch.norm(v, dim=0, keepdim=True)) for v, g in zip(self.WeightParameters, self.WeightMagnitudes)]
        @EnforceFunctionTyping
        def NormalizedReLULayerBackward(dEdMS: torch.Tensor, 
                                        ActivationFunction: Callable, 
                                        NormalizedLayer: torch.Tensor, 
                                        Gamma: torch.Tensor, 
                                        ActivatedLayer: torch.Tensor, 
                                        LayerConnections: torch.Tensor, 
                                        LayerValue: torch.Tensor, 
                                        NormalizedWeight: torch.Tensor,
                                        WeightParameters: torch.Tensor,
                                        WeightMagnitudes: torch.Tensor)-> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
                """
                Sub-routine to perform backward propagation through a  normalized ReLU layer.

                Arguments:
                - dEdMS: Differential error with respect to the mean shifted layer.
                - ActivationFunction: 
                - NormalizedLayer: The output from the previous normalized ReLU layer.
                - Gamma: 
                - ActivatedLayer: 
                - LayerConnections: 
                - LayerValue: 
                - Weight:

                Returns:
                - ErrorGradient: The error gradient of the previous layer.
                - dEdWeight: The error gradient of the current layer's weights.
                - dEdb: The error gradient of the current layer's bias.
                - dEdGamma: The error gradient of the current layers gamma parameter.
                - dEdBeta: The error gradient of the current layers beta parameter.
                """
                dMSdGamma= NormalizedLayer
                dEdBeta=  dEdMS
                dEdGamma= torch.mul(dMSdGamma, dEdMS)

                dMSdN= Gamma
                dEdN= torch.mul(dEdMS, dMSdN)
                dNdA= LayerNormalizationDerivative(ActivatedLayer)
                dEdA= torch.matmul(dEdN, dNdA)
                dAdLC= ActivationFunction.derivative(LayerConnections)
                dEdLC= torch.mul(dEdA, dAdLC)
                dLCdw= LayerValue.t()
                dEdWeight= torch.mul(dEdLC, dLCdw)
                dEdb= dEdLC
                vnorm= torch.norm(WeightParameters, dim=0, keepdim=True)
                dEdg= torch.sum(dEdWeight*WeightParameters, dim=0, keepdim=True)/ vnorm
                dEdv= (WeightMagnitudes/ vnorm)*dEdWeight- (((WeightMagnitudes*dEdg)/vnorm**2)* WeightParameters)

                LayerInputGradient= NormalizedWeight.t()
                ErrorGradient= torch.matmul(dEdLC, LayerInputGradient)
                return ErrorGradient, dEdv, dEdg, dEdb, dEdGamma, dEdBeta
        def NormalizedLayerBackward(dEdA: torch.Tensor, 
                                        ActivationFunction: Callable, 
                                        NormalizedLayer: torch.Tensor, 
                                        Gamma: torch.Tensor, 
                                        MeanShiftedLayer: torch.Tensor, 
                                        LayerConnections: torch.Tensor, 
                                        LayerValue: torch.Tensor, 
                                        NormalizedWeight: torch.Tensor,
                                        WeightParameters: torch.Tensor,
                                        WeightMagnitudes: torch.Tensor)-> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
                """
                Sub-routine to perform backward propagation through a normalized layer in the network.

                Arguments:
                - dEdMS: Differential error with respect to the mean shifted layer.
                - ActivationFunction:  
                - NormalizedLayer: The output from the previous normalized ReLU layer.
                - Gamma: 
                - MeanShiftedLayer: 
                - LayerConnections: 
                - LayerValue: 
                - Weight: 
                Returns:
                - ErrorGradient: The error gradient of the previous layer.
                - dEdWeight: The error gradient of the current layer's weights.
                - dEdb: The error gradient of the current layer's bias.
                - dEdGamma: The error gradient of the current layers gamma parameter.
                - dEdBeta: The error gradient of the current layers beta parameter.
                """
                dAdMS= ActivationFunction.derivative(MeanShiftedLayer)
                dEdMS = torch.mul(dEdA, dAdMS)
                dMSdGamma= NormalizedLayer
                dEdbeta=  dEdMS
                dEdGamma= torch.mul(dEdMS, dMSdGamma)

                dMSdN= Gamma
                dEdN= torch.mul(dEdMS, dMSdN)
                dNdLC= LayerNormalizationDerivative(LayerConnections)
                dEdLC = torch.matmul(dEdN, dNdLC)
                dLCdw= LayerValue.t()
                dEdb = dEdLC
                dEdWeight = torch.mul(dLCdw, dEdLC)
                vnorm= torch.norm(WeightParameters, dim=0, keepdim=True)
                dEdg= torch.sum(dEdWeight*WeightParameters, dim=0, keepdim=True)/ vnorm
                dEdv= (WeightMagnitudes/ vnorm)*dEdWeight- (((WeightMagnitudes*dEdg)/vnorm**2)* WeightParameters)

                LayerInputGradient= NormalizedWeight.t()
                ErrorGradient= torch.matmul(dEdLC, LayerInputGradient)
                return ErrorGradient, dEdv, dEdg, dEdb, dEdGamma, dEdbeta
        @EnforceFunctionTyping
        def LayerBackward(ErrorGradient: torch.Tensor, 
                          ActivationFunction: Callable, 
                          LayerConnections: torch.Tensor, 
                          LayerValue: torch.Tensor, 
                          NormalizedWeight: torch.Tensor,
                          WeightParameters: torch.Tensor,
                          WeightMagnitudes: torch.Tensor)-> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
                """
                Sub-routine to perform backward propagation through a  single layer in the network.

                Arguments:
                - dEdMS: Differential error with respect to the mean shifted layer.
                - ActivationFunction:
                - LayerConnections: 
                - LayerValue: 
                - Weight: 
                Returns:
                - ErrorGradient: The error gradient of the previous layer.
                - dEdWeight: The error gradient of the current layer's weights.
                - dEdb: The error gradient of the current layer's bias.
                """
                ActivationGradient= ActivationFunction.derivative(LayerConnections)
                LayerErrorGradient = torch.mul(ErrorGradient, ActivationGradient)
                LayerWeightGradient= LayerValue.t()
                dEdb = LayerErrorGradient
                dEdWeight = torch.mul(LayerWeightGradient, LayerErrorGradient)
                vnorm= torch.norm(WeightParameters, dim=0, keepdim=True)
                dEdg= torch.sum(dEdWeight*WeightParameters, dim=0, keepdim=True)/ vnorm
                dEdv= (WeightMagnitudes/ vnorm)*dEdWeight- (((WeightMagnitudes*dEdg)/vnorm**2)* WeightParameters)
                LayerInputGradient= NormalizedWeight.t()
                ErrorGradient= torch.matmul(LayerErrorGradient, LayerInputGradient)
                return ErrorGradient, dEdv, dEdg, dEdb
        for State, Action, Return in zip(StateInput, ActionInput, OptimalReturn):
            dEdBias = [torch.zeros(1)]*len(self.Bias)
            dEdWeightParameters = [torch.zeros(1)]*len(self.WeightParameters)
            dEdWeightMagnitudes = [torch.zeros(1)]*len(self.WeightMagnitudes)
            dEdGamma = [torch.zeros(1)]*len(self.Gamma)
            dEdBeta = [torch.zeros(1)]*len(self.Beta)
            LayerConnections, ActivatedLayer, MeanShiftedLayer, NormalizedLayer, LayerValue= self.forward(State, Action, full= True)
            Output= LayerValue[-1]
            ErrorGradient= torch.tensor([[LossDerivative(Output, Return)]])
            for l in range(len(self.WeightParameters)): 
                ActivationFunction= self.LayerActivations[-l-1]()
                if self.NormalizationLayers[-l-1]==1 :
                    if self.LayerActivations[-l-1]== ReLU:
                        ErrorGradient, dEdWeightParameters[-l-1], dEdWeightMagnitudes[-l-1], dEdBias[-l-1], dEdGamma[-l-1], dEdBeta[-l-1]= NormalizedReLULayerBackward(ErrorGradient, 
                                                                                                                                      ActivationFunction, 
                                                                                                                                      NormalizedLayer[-l-1], 
                                                                                                                                      self.Gamma[-l-1], 
                                                                                                                                      ActivatedLayer[-l-1], 
                                                                                                                                      LayerConnections[-l-1], 
                                                                                                                                      LayerValue[-l-2], 
                                                                                                                                      NormalizedWeight[-l-1],
                                                                                                                                      self.WeightParameters[-l-1] ,
                                                                                                                                      self.WeightMagnitudes[-l-1])
                    else:
                        ErrorGradient, dEdWeightParameters[-l-1], dEdWeightMagnitudes[-l-1], dEdBias[-l-1], dEdGamma[-l-1], dEdBeta[-l-1]= NormalizedLayerBackward(NormalizedWeight, 
                                                                                                                                      ActivationFunction, 
                                                                                                                                      NormalizedLayer[-l-1], 
                                                                                                                                      self.Gamma[-l-1], 
                                                                                                                                      MeanShiftedLayer[-l-1], 
                                                                                                                                      LayerConnections[-l-1], 
                                                                                                                                      LayerValue[-l-2], 
                                                                                                                                      NormalizedWeight[-l-1],
                                                                                                                                      self.WeightParameters[-l-1] ,
                                                                                                                                      self.WeightMagnitudes[-l-1])
                else:    
                    ErrorGradient, dEdWeightParameters[-l-1], dEdWeightMagnitudes[-l-1], dEdBias[-l-1]= LayerBackward(ErrorGradient,
                                                                                                                        ActivationFunction, 
                                                                                                                        LayerConnections[-l-1], 
                                                                                                                        LayerValue[-l-2], 
                                                                                                                        NormalizedWeight[-l-1],
                                                                                                                        self.WeightParameters[-l-1],
                                                                                                                        self.WeightMagnitudes[-l-1])  
            CummulativeBiasGradient = [BiasGradient+dEdbias/len(StateInput) for BiasGradient, dEdbias in zip(CummulativeBiasGradient, dEdBias)]
            CummulativeWeightParametersGradient = [WeightParameterGradient+dEdweight/len(StateInput) for WeightParameterGradient, dEdweight in zip(CummulativeWeightParametersGradient, dEdWeightParameters)]
            CummulativeWeightMagnitudesGradient = [WeightMagnitudeGradient+dEdweight/len(StateInput) for WeightMagnitudeGradient, dEdweight in zip(CummulativeWeightMagnitudesGradient, dEdWeightMagnitudes)]
            CummulativeGammaGradient = [GammaGradient+dEdgamma/len(StateInput) for GammaGradient, dEdgamma in zip(CummulativeGammaGradient, dEdGamma)]
            CummulativeBetaGradient = [BetaGradient+dEdbeta/len(StateInput) for BetaGradient, dEdbeta in zip(CummulativeBetaGradient, dEdBeta)]
        return CummulativeWeightParametersGradient, CummulativeWeightMagnitudesGradient, CummulativeBiasGradient, CummulativeGammaGradient, CummulativeBetaGradient
    def update_model(self, 
                     CummulativeWeightParameterGradient: list[torch.Tensor], 
                     CummulativeWeightMagnitudeGradient: list[torch.Tensor], 
                     CummulativeBiasGradient: list[torch.Tensor], 
                     CummulativeGammaGradient: list[torch.Tensor], 
                     CummulativeBetaGradient: list[torch.Tensor], 
                     LearningRate=0.01):
        """
        Updates the parameters of the  neural network using the calculated gradients.
        Input:
           - CummulativeWeightGradient (list): list of gradient tensors corresponding to each weight in the network.
           - CummulativeBiasGradient (list): list of gradient tensors corresponding to each bias in the network.
           - CummulativeGammaGradient (list): List of gradient tensors corresponding to gamma parameter for normalization.
           - CummulativeBetaGradient (list): List of gradient tensors corresponding to beta parameter for normalization.
        """
        for i in range(len(self.WeightParameters)):
            self.WeightParameters[i]= self.WeightParameters[i].clone()- LearningRate * CummulativeWeightParameterGradient[i]
            self.WeightMagnitudes[i]= self.WeightMagnitudes[i].clone()- LearningRate * CummulativeWeightMagnitudeGradient[i]
            self.Bias[i] = self.Bias[i].clone() - LearningRate * CummulativeBiasGradient[i]
            self.Gamma[i] = self.Gamma[i].clone() - LearningRate * CummulativeGammaGradient[i]
            self.Beta[i] = self.Beta[i].clone() - LearningRate * CummulativeBetaGradient[i]

In [None]:
class NormalCriticNetwork:
    '''This object represents the Value Function(Critic) used to estimate the expected value of a state-action pair.
    This value function is a neural network that will learn to more accuately predict the expected value given a state-action pair.'''
    def __init__(self, layer_sizes: list, 
                 layer_activations: list, 
                 layer_activations_derivative: list):
        self.layer_sizes= layer_sizes
        self.weights= [2 * torch.randn(layer_sizes[x], layer_sizes[x+1])- 1 for x in range(len(layer_sizes)-1)]
        self.bias= [2 * torch.randn(1, layer_sizes[x+1])- 1 for x in range(len(layer_sizes)-1)]
        self.layer_activations= layer_activations
        self.layer_activations_derivative= layer_activations_derivative
    def forward(self, StateInput: torch.Tensor, 
                ActionInput: torch.Tensor, 
                full: bool= False)-> torch.Tensor:
        'Takes State Parameters and Action Parameters to outputs the expected return of the state-action pair predicted by the Main critic network'
        InputData = torch.cat([StateInput, ActionInput], dim=StateInput.ndim-1)
        LayerConnections= []
        ActivatedNeuronLayer= [InputData]
        for i in range(len(self.weights)):
            LayerConnections.append(torch.matmul(ActivatedNeuronLayer[i], self.weights[i]) + self.bias[i]) 
            ActivatedNeuronLayer.append(self.layer_activations[i](LayerConnections[i]))
        if full is False:
            return ActivatedNeuronLayer[-1]
        else:
            return LayerConnections, ActivatedNeuronLayer
    def compute_gradients(self, StateInput: torch.Tensor, 
                          ActionInput: torch.Tensor, 
                          OptimalReturn: torch.Tensor, 
                          loss_derivative: Callable):
        '''This function computes the gradient of the weights and biases of the network using the given derivative of a loss functio, input data and target data'''
        bias_grad = [torch.zeros_like(b) for b in self.bias]
        weight_grad = [torch.zeros_like(w) for w in self.weights]
        for x1, x2, y in zip(StateInput, ActionInput, OptimalReturn):
            dEdb = [0]*len(self.bias)
            dEdw = [0]*len(self.weights)
            LayerConnections, ActivatedNeuronLayer= self.forward(x1, x2, full= True)
            dEdA= torch.tensor([[loss_derivative(ActivatedNeuronLayer[-1], y)]])
            if ActivatedNeuronLayer[0].ndim < 2:
                ActivatedNeuronLayer[0]= ActivatedNeuronLayer[0].unsqueeze(dim=0)
            for l in range(len(self.weights)): 
                z = LayerConnections[-l-1]     
                dAdz= self.layer_activations_derivative[-l-1](z)
                dEdz = torch.mul(dEdA, dAdz)
                dzdw= ActivatedNeuronLayer[-l-2].t()
                dEdb[-l-1] = dEdz
                dEdw[-l-1] = torch.mul(dzdw, dEdz)
                dzdA= self.weights[-l-1].t()
                dEdA= torch.matmul(dEdz, dzdA)
            bias_grad = [nb+dnb/len(StateInput) for nb, dnb in zip(bias_grad, dEdb)]
            weight_grad = [nw+dnw/len(StateInput) for nw, dnw in zip(weight_grad, dEdw)]
        return weight_grad, bias_grad
    def update_model(self, weight_grad, bias_grad, learning_rate):
        for i in range(len(self.weights)):
            self.weights[i] -=  learning_rate * weight_grad[i]
            self.bias[i] -=  learning_rate * bias_grad[i]


In [None]:
StateInputs = torch.randn(3, 6)
ActionInputs = torch.randn(3, 2)
input_data = torch.cat([StateInputs, ActionInputs], dim=1)
target_data = torch.rand(3, 1)
loss_function = nn.MSELoss()
WNCriticModel = CriticNetwork([8, 10, 5, 1], [ReLU, ReLU, Null], [1, 1, 0])
CriticModel = NormalCriticNetwork([8, 10, 5, 1], [ReLU, ReLU, Null], [1, 1, 0])

In [None]:
tz= []
for epoch in range(2000):
    guess= WNCriticModel.forward(StateInputs, ActionInputs)
    loss = loss_function(guess, target_data)
    tz.append(loss.detach())
    a, b, c, d, e= WNCriticModel.compute_gradients(StateInputs, ActionInputs, target_data, mse_grad)
    WNCriticModel.update_model(a, b, c, d, e)
plt.plot(tz)
zt= WNCriticModel.forward(StateInputs, ActionInputs)
print('guess',zt)
print('target',target_data)
loss_function(zt, target_data)

In [None]:
tz= []
for epoch in range(2000):
    guess= CriticModel.forward(StateInputs, ActionInputs)
    loss = loss_function(guess, target_data)
    tz.append(loss.detach())
    a, b, c, d, e= CriticModel.compute_gradients(StateInputs, ActionInputs, target_data, mse_grad)
    CriticModel.update_model(a, b, c, d, e)
plt.plot(tz)
zt= CriticModel.forward(StateInputs, ActionInputs)
print('guess',zt)
print('target',target_data)
loss_function(zt, target_data)

In [None]:
# weightparameter_grad, g_grad, bias_grad, gamma_grad, beta_grad= CriticModel.compute_gradients(StateInputs, ActionInputs, target_data, mse_grad)
# print(weightparameter_grad)
# print(g_grad)
# print(bias_grad)
# print(gamma_grad)
# beta_grad