In [1]:
import torch.autograd
from torch.autograd import Variable
import matplotlib.pyplot as plt
import numpy as np
from dataclasses import *
import torch as T
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from typing import Any, Callable, Dict, List, Tuple, Union, Optional
from functools import wraps
import os
import random
from abc import ABC, abstractmethod
from collections import deque, namedtuple

In [None]:
@dataclass
class Environment(ABC):  

  class State:
      pass
  InitialState: State 
  CurrentState: State 

  @abstractmethod
  def TransitionModel(self, State: State, Action)-> State:
      ...

  @abstractmethod
  def RewardModel(self, State: State, Action, NextState: State, TerminalSignal: bool)-> float:
      ...

  @abstractmethod
  def IsTerminalCondition(self, State: State)-> bool:
      ...

  @abstractmethod
  def StateTransition(self, State: State, Action)-> tuple[float, State, bool]:
      ...

  @abstractmethod
  def SampleTrajectory(self, RunDuration: float, Policy: Optional[Callable])-> list[State]:
      ...

In [None]:
@dataclass
class MPController(ABC):
  EnvironmentModel: Environment
  InternalModel: Callable
  Policy: Callable

  @abstractmethod
  def Act(self, Observation: T.Tensor)-> T.Tensor:
      'Produces an action based on the observation of the current state of the environment'
      ...
  @abstractmethod
  def Plan(self)-> tuple[list[EnvironmentModel.State], list[T.Tensor]]:
      'Produces an sequence of actions and predicted states based on the observation of the current state of the environment'
      ...
  @abstractmethod
  def Observe(self)-> T.Tensor:
      'Produces a vector encoding the observable information of the observation of the current state of the environment'
      ... 
  @abstractmethod
  def Learn(self):
      'Improves the agent by updating its models'
      ...  
  @abstractmethod
  def LearningAlgorithm(self):
      ...
