In [None]:
import torch.autograd
from torch.autograd import Variable
import matplotlib.pyplot as plt
import numpy as np
from dataclasses import *
import torch as T
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from typing import Any, Callable, Dict, List, Tuple, Union, Optional
from functools import wraps
import os
import random
from abc import ABC, abstractmethod
from collections import deque, namedtuple

In [None]:
@dataclass
class Environment(ABC):  

  class State:
      pass
  InitialState: State 
  CurrentState: State 

  @abstractmethod
  def TransitionModel(self, State: State, Action)-> State:
      ...

  @abstractmethod
  def RewardModel(self, State: State, Action, NextState: State, TerminalSignal: bool)-> float:
      ...

  @abstractmethod
  def IsTerminalCondition(self, State: State)-> bool:
      ...

  @abstractmethod
  def StateTransition(self, State: State, Action)-> tuple[float, State, bool]:
      ...

  @abstractmethod
  def SampleTrajectory(self, RunDuration: float)-> list[State]:
      ...

In [None]:
@dataclass
class Agent(ABC):
  AgentEnvironment: Environment

  @abstractmethod
  def Act(self, Observation: T.Tensor)-> T.Tensor:
      ...
  @abstractmethod
  def Observe(self)-> T.Tensor:
      ...
  @abstractmethod
  def Learn(self):
      'Improves  the agent by updating its models'
      ...
  @abstractmethod
  def LearningAlgorithm(self):
      ...


In [None]:
class CriticNetwork(nn.Module):
    def __init__(self, learning_rate, state_dims, fc1_dims, fc2_dims, n_actions, name, chkpt_dir='tmp/ddpg'):
        super(CriticNetwork, self).__init__() 
        self.checkpoint_file = os.path.join(chkpt_dir,name+'_ddpg')

        self.fc1 = T.nn.utils.parametrizations.weight_norm(nn.Linear(state_dims+n_actions, fc1_dims)) 
        self.bn1 = nn.LayerNorm(fc1_dims)
        self.fc2 = T.nn.utils.parametrizations.weight_norm(nn.Linear(fc1_dims, fc2_dims))
        self.bn2 = nn.LayerNorm(fc2_dims)
        self.fc3 = T.nn.utils.parametrizations.weight_norm(nn.Linear(fc2_dims, 1))

        self.optimizer = optim.Adam(self.parameters(), lr=learning_rate)

        self.device = T.device('cuda:0' if T.cuda.is_available() else 'cpu')
        self.to(self.device)

    def forward(self, state, action):
        x = T.cat([state, action], dim=-1)
        x = T.relu(self.bn1(self.fc1(x)))
        x = T.relu(self.bn2(self.fc2(x)))
        x = self.fc3(x)
        return x

    def save_checkpoint(self):
        print('... saving checkpoint ...')
        T.save(self.state_dict(), self.checkpoint_file)

    def load_checkpoint(self):
        print('... loading checkpoint ...')
        self.load_state_dict(T.load(self.checkpoint_file))
 

In [None]:
class ActorNetwork(nn.Module):
    def __init__(self, learning_rate, state_dims, fc1_dims, fc2_dims, n_actions, name, chkpt_dir='tmp/ddpg'):
        super(ActorNetwork, self).__init__()
        self.checkpoint_file = os.path.join(chkpt_dir,name+'_ddpg')

        self.fc1 = T.nn.utils.parametrizations.weight_norm(nn.Linear(state_dims , fc1_dims)) 
        self.bn1 = nn.LayerNorm(fc1_dims)
        self.fc2 = T.nn.utils.parametrizations.weight_norm(nn.Linear(fc1_dims, fc2_dims))
        self.bn2 = nn.LayerNorm(fc2_dims)
        self.fc3 = T.nn.utils.parametrizations.weight_norm(nn.Linear(fc2_dims, n_actions))

        self.optimizer = optim.Adam(self.parameters(), lr=learning_rate)

        self.device = T.device('cuda:0' if T.cuda.is_available() else 'cpu')
        self.to(self.device)

    def forward(self, state):
        x = T.relu(self.bn1(self.fc1(state)))
        x = T.relu(self.bn2(self.fc2(x)))
        x = self.fc3(x)
        return x

    def save_checkpoint(self):
        print('... saving checkpoint ...')
        T.save(self.state_dict(), self.checkpoint_file)

    def load_checkpoint(self):
        print('... loading checkpoint ...')
        self.load_state_dict(T.load(self.checkpoint_file))


In [None]:
@dataclass(kw_only= True)
class DDPGAgent(Agent):
  AgentEnvironment: Environment= NotImplemented
  Actor: ActorNetwork = NotImplemented
  Critic: CriticNetwork = NotImplemented
  TargetActor: ActorNetwork = NotImplemented
  TargetCritic: CriticNetwork= NotImplemented
  ReplayBuffer: deque = NotImplemented
  BatchSize: int = 64
  SoftUpdateRate: float= 0.1
  Noise= NotImplemented
  DiscountRate: float = 0.67 
  def __post_init__(self):
    assert isinstance(self.AgentEnvironment, Environment), "Must be an instance of Environment"
    assert isinstance(self.Actor, ActorNetwork), "Must be an ActorNetwork"
    assert isinstance(self.Critic, CriticNetwork), "Must be a CriticNetwork"
    assert isinstance(self.TargetActor, ActorNetwork), "Must be an ActorNetwork"
    assert isinstance(self.TargetCritic, CriticNetwork), "Must be a CriticNetwork"
    assert isinstance(self.ReplayBuffer, deque), "Must be a deque"
    if self.Noise == NotImplemented: raise NotImplementedError("Must define AgentEnvironment") 

  def Observe(self, State)-> T.Tensor:  
    if isinstance(State, (tuple, list)):
        Observation= [] 
        for i in State:
          Observation.append(self.Observe(i)) 
        Observation= T.stack(Observation)
    elif isinstance(State, self.AgentEnvironment.State):
        Observation= T.cat([State.Position, State.Momentum])
    return Observation
  
  def Act(self, Observation: T.Tensor)-> T.Tensor:
    self.Actor.eval()
    Observation = T.tensor(Observation, dtype=T.float).to(self.Actor.device)
    Action = self.Actor.forward(Observation).to(self.Actor.device)
    NoisyAction = 1e-7* (Action + T.tensor(self.noise(), dtype=T.float).to(self.Actor.device))
    self.Actor.train()
    return NoisyAction.cpu().detach()

  def Learn(self):
  # Updates target network with online model parameters
    if len(self.ReplayBuffer) < self.BatchSize:
        return

    batch = random.sample(self.memory, self.batch_size)
    states, actions, next_states, rewards, dones = zip(*batch)

    state = T.tensor(states, dtype=T.float).to(self.Critic.device)
    action = T.stack(actions).to(self.Critic.device)
    reward = T.tensor(rewards, dtype=T.float).unsqueeze(1).to(self.Critic.device)
    new_state =  T.tensor(next_states, dtype=T.float).to(self.Critic.device)
    done = T.tensor(dones, dtype=T.float).unsqueeze(1).to(self.Critic.device)
    
    self.TargetActor.eval()
    self.TargetCritic.eval()
    self.Critic.eval()
    
    target_actions = self.TargetActor.forward(new_state)
    Critic_value_ = self.TargetCritic.forward(new_state, target_actions) 
    q_expected = self.Critic.forward(state, action)
    q_targets = reward + self.DiscountRate * Critic_value_ * (1 - done)

    Critic_loss = nn.MSELoss()(q_expected, q_targets.detach())
    self.Critic.train()
    self.Critic.optimizer.zero_grad()
    Critic_loss.backward()
    self.Critic.optimizer.step()

    self.Actor.eval()
    self.Critic.eval()

    mu = self.Actor.forward(state)
    Actor_loss = -self.Critic.forward(state, mu)

    Actor_loss = T.mean(Actor_loss)
    self.Actor.train()
    self.Actor.optimizer.zero_grad()
    Actor_loss.backward()
    self.Actor.optimizer.step()

    self.update_network_parameters()
  def RewardModel(self)-> float:
      ...
  def LearningAlgorithm(self):
      ...