In [24]:
import numpy as np
import jax
import jax.numpy as jnp
from jax.tree_util import tree_flatten, tree_unflatten, tree_structure
import jax_dataclasses as jdc
import chex
from functools import partial
from typing import Callable
import os
from exciting_environments import PMSM_Physical,PMSM

In [25]:
random_seed_DQN = np.random.randint(0, 2**31)
random_key_DQN = jax.random.PRNGKey(seed=random_seed_DQN)
random_seed_motor = np.random.randint(0, 2**31)
random_key_motor = jax.random.PRNGKey(seed=random_seed_motor)

random_seed = np.random.randint(0, 2**31)
random_key = jax.random.PRNGKey(seed=random_seed)

In [26]:
new_motor = PMSM_Physical(control_state="currents", deadtime=1,batch_size=1,saturated=True)

In [27]:
new_motor_env = PMSM(new_motor, gamma=0.85,batch_size=1)
#obs, motor_env_state = new_motor_env.reset(random_key_motor)

### Env2 (Darius)

In [28]:
import jax
import jax.numpy as jnp

t32 = jnp.array([   [1, 0], 
                    [-0.5, 0.5 * jnp.sqrt(3)],  # only for alpha/beta -> abc
                    [-0.5, -0.5 * jnp.sqrt(3)]])
t23 = 2/3 * jnp.array([ [1, 0], 
                        [-0.5, 0.5 * jnp.sqrt(3)],  # only for abc -> alpha/beta
                        [-0.5, -0.5 * jnp.sqrt(3)]]).T

inverter_t_abc = jnp.array([[-0.5, -0.5, -0.5],
                            [0.5, -0.5, -0.5],
                            [0.5, 0.5, -0.5],
                            [-0.5, 0.5, -0.5],
                            [-0.5, 0.5, 0.5],
                            [-0.5, -0.5, 0.5],
                            [0.5, -0.5, 0.5],
                            [0.5, 0.5, 0.5]])

def t_dq_alpha_beta(eps):
    cos = jnp.cos(eps)
    sin = jnp.sin(eps)
    return jnp.column_stack((cos, sin, -sin, cos)).reshape(2, 2)

def dq2abc(u_dq, eps):
    u_abc = t32 @ dq2albet(u_dq, eps).T
    return u_abc.T

def dq2albet(u_dq, eps):
    q = t_dq_alpha_beta(-eps)
    u_alpha_beta = q @ u_dq.T

    return u_alpha_beta.T

def albet2dq(u_albet, eps):
    q_inv = t_dq_alpha_beta(eps)
    u_dq =  q_inv @ u_albet.T

    return u_dq.T

def abc2dq(u_abc, eps):
    u_alpha_beta =  t23 @ u_abc.T
    u_dq = albet2dq(u_alpha_beta.T, eps)
    return u_dq

def step_eps(eps, omega_el, tau, tau_scale=1.):
    eps += omega_el * tau * tau_scale
    eps %= (2*jnp.pi)
    boolean = eps > jnp.pi
    summation_mask = boolean * -2*jnp.pi
    eps = eps + summation_mask
    return eps
    
def clip_in_abc_coordinates(u_dq, u_dc, omega_el, eps, tau):
    eps_advanced = step_eps(eps,omega_el,tau, 0.5)
    u_abc = dq2abc(u_dq, eps_advanced)
    # clip in abc coordinates
    u_abc = jnp.clip(u_abc,-u_dc/2.0, u_dc/2.0)
    u_dq = abc2dq(u_abc, eps)
    return u_dq

def switching_state_to_dq(switching_state, u_dc, eps):
    u_abc = inverter_t_abc[switching_state] * u_dc
    u_dq = abc2dq(u_abc, eps)
    return u_dq[0]

def currents_to_torque(i_d, i_q, p, psi_p, l_d, l_q ):
    torque = 1.5 * p * (psi_p + (l_d - l_q) * i_d) * i_q
    return torque

def calc_max_torque(l_d, l_q, i_n, psi_p, p):
    i_d = jnp.where(l_d == l_q, 0, -psi_p / (4 * (l_d - l_q)) - jnp.sqrt((psi_p / (4 * (l_d - l_q)))**2 + i_n**2 / 2))
    i_q = jnp.sqrt(i_n**2 - i_d**2)
    max_torque = 1.5*p*(psi_p + (l_d-l_q)*i_d)*i_q
    return max_torque


class PMSM_Physical:
    def __init__(self, 
                 params: dict = {
                                "p": jnp.array([3]),
                                "r_s": jnp.array([1.]),
                                "l_d": jnp.array([0.37e-3]),
                                "l_q": jnp.array([1.2e-3]),
                                "psi_p": jnp.array([65.6e-3]),
                                "u_dc":jnp.array([400.]),
                                "i_n": jnp.array([400.]),
                                "omega_el": jnp.array([100/60*2*jnp.pi]),
                                "tau": jnp.array([1e-4])},
                 deadtime: int = 1,
                 control_state = "torque",
                 control_set = "fcs"):
        self.switching_state_to_dq_vmap = jax.vmap(switching_state_to_dq)
        self.ode_exp_euler_step_vmap = jax.vmap(self.ode_exp_euler_step)
        self.clip_in_abc_coordinates_vmap = jax.vmap(clip_in_abc_coordinates)
        self.step_eps_vmap = jax.vmap(step_eps)

        self.control_state = control_state
        self.control_set = control_set
        self.params = params #TODO: In the future, params will be part of the state because they can change over time
        self.batch_size = params["p"].shape[0]

        if control_set == "ccs":
            self._action_description = ["u_d", "u_q"]
        elif control_set == "fcs":
            self._action_description = ["switching_state"]

        self.deadtime = deadtime
        if self.deadtime > 0:
            if control_set == "ccs":
                self.initial_action_buffer = jnp.zeros((self.batch_size, self.deadtime, 2))
            elif control_set == "fcs":
                initial_switching_state = jnp.zeros((self.batch_size, 1), dtype=int)
                initial_eps = jnp.zeros((self.batch_size, 1))
                initial_u_dq = self.switching_state_to_dq_vmap(initial_switching_state, self.params["u_dc"], initial_eps)
                initial_u_dq_expanded = initial_u_dq[:, None, :]  # Reshape to (8, 1, 2)
                self.initial_action_buffer = jnp.tile(initial_u_dq_expanded, (1, self.deadtime, 1))
        else:
            self.initial_action_buffer = None
    
    def reset(self):
        return {
            "action_buffer": self.initial_action_buffer,
            "epsilon": jnp.zeros((self.batch_size, 1)),
            "i_d": jnp.zeros((self.batch_size, 1)),
            "i_q": jnp.zeros((self.batch_size, 1)),
            "torque": jnp.zeros((self.batch_size, 1))
        }
    
    def ode_exp_euler_step(self, system_state, u_dq, params):
        
        u_d = u_dq[0]
        u_q = u_dq[1]

        omega_el = system_state["omega"]
        i_d = system_state["i_d"]
        i_q = system_state["i_q"]
        eps = system_state["epsilon"]

        tau = params["tau"]
        l_d = params["l_d"]
        l_q = params["l_q"]
        psi_p = params["psi_p"]
        r_s = params["r_s"]
        p = params["p"]

        # ODE
        i_d_diff = (u_d + omega_el*l_q*i_q - r_s*i_d) / l_d
        i_q_diff = (u_q - omega_el*(l_d*i_d + psi_p) - r_s*i_q) / l_q

        next_system_state = system_state.copy()
        next_system_state.update({
            "epsilon": step_eps(eps, omega_el, tau, 1.),
            "i_d": i_d + tau * i_d_diff,
            "i_q": i_q + tau * i_q_diff,
            "torque": currents_to_torque(i_d, i_q, p, psi_p, l_d, l_q),
        })

        return next_system_state
    
    def simulation_step(self, system_state, action):
        action_buffer = system_state["action_buffer"]
        eps = system_state["epsilon"]

        if self.deadtime > 0:


            if self.control_set == "fcs":
                advanced_eps = self.step_eps_vmap(eps, 
                                system_state["omega"], 
                                self.params["tau"], 
                                tau_scale=jnp.tile(jnp.array([self.deadtime]), (self.batch_size, 1)))

                future_u_dq = self.switching_state_to_dq_vmap(action,
                                                    self.params["u_dc"],
                                                    advanced_eps)
            else:
                future_u_dq = action

            updated_buffer = jnp.concatenate([action_buffer[:, 1:, :], future_u_dq[:,None,:]], axis=1)
            u_dq = action_buffer[:,0,:]
        else:
            updated_buffer = action_buffer
            if self.control_set == "fcs":
                u_dq = self.switching_state_to_dq_vmap(action,
                                                    self.params["u_dc"],
                                                    eps)
            else:
                u_dq = action

        if self.control_set == "ccs":
            u_dq = self.clip_in_abc_coordinates_vmap(
                u_dq = u_dq,
                u_dc = self.params["u_dc"],
                omega_el=system_state["omega"],
                eps = system_state["epsilon"],
                tau = self.params["tau"],
            )

        next_system_state = self.ode_exp_euler_step_vmap(system_state, u_dq, self.params)
        next_system_state.update({"action_buffer": updated_buffer})            
        return next_system_state
    
    @property
    def action_description(self):
        return self._action_description

class PMSM_Env:
    def __init__(self,
                 pmsm: PMSM_Physical,
                 gamma: float,
                 p_omega: float = 0.00005,
                 p_reference: float = 0.0002,
                 p_reset: float = 1.,
                 i_lim_multiplier: float = 1.2,
                 omega_ramp_min: int = 20000,
                 omega_ramp_max: int = 25000
                 ):
        self.pmsm = pmsm
        self.gamma = gamma
        self.p_omega = p_omega
        self.p_reference = p_reference
        self.p_reset = p_reset 
        self.batch_size = pmsm.batch_size
        self.deadtime = self.pmsm.deadtime
        self.update_omegas_vmap = jax.vmap(self.update_omegas, in_axes=(0,0,0,0,None))
        self.update_reference_vmap = jax.vmap(self.update_reference, in_axes=(0,0,None))
        self.generate_observation_vmap = jax.vmap(self.generate_observation)
        self.calculate_reward_vmap = jax.vmap(self.calculate_reward)
        self.i_lim_multiplier = i_lim_multiplier
        self.omega_ramp_min = omega_ramp_min
        self.omega_ramp_max = omega_ramp_max

        max_torque = jax.vmap(calc_max_torque)(l_d = self.pmsm.params["l_d"],
                                    l_q = self.pmsm.params["l_q"],
                                    i_n = self.pmsm.params["i_n"],
                                    psi_p = self.pmsm.params["psi_p"],
                                    p = self.pmsm.params["p"])
            
        self.state_normalizer = {
                "i_d": self.pmsm.params["i_n"]*self.i_lim_multiplier,
                "i_q": self.pmsm.params["i_n"]*self.i_lim_multiplier,
                "omega": self.pmsm.params["omega_el"],
                "torque": max_torque
            }
        
        self.action_normalizer = 2*self.pmsm.params["u_dc"]/3
        #self.action_normalizer = self.pmsm.params["u_dc"]/2 #TODO: Remove

        if self.control_state == "currents":
            self._obs_description = ["i_d", "i_q", "cos_eps", "sin_eps", "omega_el", "i_d_ref", "i_q_ref"]
            
        elif self.control_state == "torque":
            self._obs_description = ["i_d", "i_q", "cos_eps", "sin_eps", "omega_el", "torque_ref"]

    def reset(self, random_key):
        physical_state = self.pmsm.reset()

        #As the physical system is not actually updating omegas I will pretend they are part of the environment instead
        omegas = jnp.zeros((self.batch_size, 1))
        omegas_add = jnp.zeros((self.batch_size, 1))
        omegas_count = jnp.zeros((self.batch_size, 1))

        keys = jax.random.split(random_key, self.pmsm.batch_size)
        references = jnp.zeros((self.batch_size, 1))

        if self.control_state == "currents":
            i_d_ref, keys = self.update_reference_vmap(references, keys, self.p_reset)
            i_q_ref, keys = self.update_reference_vmap(references, keys, self.p_reset)
            references = jnp.hstack((i_d_ref, i_q_ref))
        else:
            references,keys = self.update_reference_vmap(references, keys, self.p_reset)

        omegas, omegas_add, omegas_count, keys = self.update_omegas_vmap(
                            omegas, 
                            omegas_add, 
                            omegas_count,
                            keys,
                            self.p_reset)
        
        physical_state.update({"omega": omegas*self.state_normalizer["omega"]})
        system_state = {
            "physical_state": physical_state,
            "omega_add": omegas_add,
            "omega_count": omegas_count,
            "keys": keys,
            "references": references,
            }
        
        observations = self.generate_observation_vmap(system_state, self.state_normalizer, self.action_normalizer)
        return observations, system_state

    def generate_observation(self,system_state, state_normalizer, action_normalizer):
        eps = system_state["physical_state"]["epsilon"]
        cos_eps = jnp.cos(eps)
        sin_eps = jnp.sin(eps)
        obs = jnp.hstack((
            system_state["physical_state"]["i_d"] / state_normalizer["i_d"],
            system_state["physical_state"]["i_q"] / state_normalizer["i_q"],
            system_state["physical_state"]["omega"] / state_normalizer["omega"],
            cos_eps,
            sin_eps,
            system_state["references"],
            system_state["physical_state"]["action_buffer"].reshape(-1) /  action_normalizer        
            ))
        return obs
 
    def update_reference(self, reference, key, p):
        random_bool = jax.random.bernoulli(key, p=p)
        key, subkey = jax.random.split(key)
        new_reference = jnp.where(random_bool, jax.random.uniform(subkey, minval=-1.0, maxval=1.0), reference)
        key, subkey = jax.random.split(subkey)
        return new_reference, subkey

    #TODO: Make omega ramp more realistic - otherwise RLS might not work
    def update_omegas(self,omegas, omegas_add, omegas_count, key, p):
        random_bool = jax.random.bernoulli(key, p=p)
        key, subkey = jax.random.split(key)

        # Add value to omegas
        omegas += omegas_add

        # If new target omega has been reached stop adding values in the future
        omegas_count = jnp.where(omegas_count > 0, omegas_count - 1, omegas_count)
        omegas_add = jnp.where(omegas_count == 0, 0., omegas_add)

        # Generate new omega targets and define the ramp
        key, subkey = jax.random.split(subkey)
        omegas_new = jnp.where(random_bool & (omegas_add == 0.),
                                    jax.random.uniform(subkey, minval=-1.0, maxval=1.0), 
                                    omegas)
        
        key, subkey = jax.random.split(subkey)
        omegas_count = jnp.where(omegas_new != omegas,
                                jax.random.choice(subkey, 
                                                  jnp.arange(self.omega_ramp_min, self.omega_ramp_max),
                                                  replace=True, 
                                                  axis=0),
                                omegas_count)
        
        omegas_add += jnp.where(omegas_new != omegas, (omegas_new - omegas) / omegas_count, 0.)

        key, subkey = jax.random.split(subkey)

        return omegas, omegas_add, omegas_count, subkey
    
    #@partial(jax.jit, static_argnums=0)
    def step(self, system_state, actions):
        if self.pmsm.control_set == "ccs":
            actions *= self.action_normalizer
        next_physical_state = self.pmsm.simulation_step(system_state["physical_state"],actions)
        omegas, omegas_add, omegas_count, keys = self.update_omegas_vmap(
                    system_state["physical_state"]["omega"] / self.state_normalizer["omega"], 
                    system_state["omega_add"], 
                    system_state["omega_count"],
                    system_state["keys"],
                    self.p_omega)
        #omegas = jnp.zeros((self.batch_size, 1)) + 0.2 #TODO: Remove
        if self.control_state == "currents":
            i_d_ref, keys = self.update_reference_vmap(system_state["references"][:,0].reshape(-1,1), keys, self.p_reference)
            i_q_ref, keys = self.update_reference_vmap(system_state["references"][:,1].reshape(-1,1), keys, self.p_reference)
            references = jnp.hstack((i_d_ref, i_q_ref))
        else:
            references,keys = self.update_reference_vmap(system_state["references"], keys, self.p_reference)

        next_physical_state.update({"omega": omegas * self.state_normalizer["omega"]})
        next_system_state = {
            "physical_state": next_physical_state,
            "omega_add": omegas_add,
            "omega_count": omegas_count,
            "keys": keys,
            "references": references,
            }

        observations = self.generate_observation_vmap(next_system_state, self.state_normalizer, self.action_normalizer)
        rewards = self.calculate_reward_vmap(next_physical_state, system_state["references"], self.state_normalizer)
        dones = self.identify_system_limit_violations(next_physical_state, self.state_normalizer)

        return next_system_state, observations, rewards, dones

    def identify_system_limit_violations(self, physical_state, state_normalizer):
        i_d_norm = physical_state["i_d"] / state_normalizer["i_d"]
        i_q_norm = physical_state["i_q"] / state_normalizer["i_q"]
        i_s = jnp.sqrt(i_d_norm**2 + i_q_norm**2)
        return jnp.where(i_s > 1, True, False)

    def calculate_reward(self, physical_state, references, state_nomalizer):
        if self.pmsm.control_state == "currents":
            reward = self.current_reward_func(
                physical_state["i_d"] / state_nomalizer["i_d"],
                physical_state["i_q"] / state_nomalizer["i_q"],
                references[0],
                references[1],
            )
        elif self.pmsm.control_state == "torque":
            reward = self.torque_reward_func(
                physical_state["i_d"] / state_nomalizer["i_d"],
                physical_state["i_q"] / state_nomalizer["i_q"],
                physical_state["torque"] / state_nomalizer["torque"],
                references
            )

        return reward

    def current_reward_func(self, i_d, i_q, i_d_ref, i_q_ref):
        mse = 0.5*(i_d - i_d_ref)**2 + 0.5*(i_q - i_q_ref)**2
        return -1*(mse * (1-self.gamma))

    def torque_reward_func(self, i_d, i_q, torque, torque_ref):
        i_s = jnp.sqrt(i_d**2 + i_q**2)
        i_n = 1/self.i_lim_multiplier
        i_d_plus = 0.2*i_n
        torque_tol = 0.01
        rew = jnp.zeros_like(torque_ref)
        rew = jnp.where(i_s > 1, -1*jnp.abs(i_s), rew)
        rew = jnp.where((i_s < 1.) & (i_s > i_n), 0.5*(1-(i_s - i_n)/(1 - i_n)) - 1, rew)
        rew = jnp.where((i_s < i_n) & (i_d > i_d_plus), -0.5*((i_d - i_d_plus)/(i_n - i_d_plus)), rew)
        rew = jnp.where((i_s < i_n) & (i_d < i_d_plus) & (jnp.abs(torque - torque_ref) > torque_tol), 0.5*(1- jnp.abs((torque_ref - torque)/2)), rew)
        rew = jnp.where((i_s < i_n) & (i_d < i_d_plus) & (jnp.abs(torque - torque_ref) < torque_tol), 1 - 0.5*i_s, rew)
        return rew * (1-self.gamma)
    
    @property
    def action_description(self):
        return self.pmsm._action_description
    @property
    def observation_description(self):
        return self._obs_description
    @property
    def control_state(self):
        return self.pmsm.control_state


In [29]:
new_motor2 = PMSM_Physical(control_state="currents", control_set="ccs", deadtime=1)

In [30]:
new_motor_env2 = PMSM_Env(new_motor2, gamma=0.85)
#obs, motor_env_state = new_motor_env.reset(random_key_motor)

### Comparison

In [31]:
obs, motor_env_state = new_motor_env.reset(random_key_motor)

In [32]:
obs2, motor_env_state2 = new_motor_env2.reset(random_key_motor)

In [33]:
motor_env_state.optional.references

Array([[ 0.22062254, -0.46504974]], dtype=float32)

In [34]:
motor_env_state2

{'physical_state': {'action_buffer': Array([[[0., 0.]]], dtype=float32),
  'epsilon': Array([[0.]], dtype=float32),
  'i_d': Array([[0.]], dtype=float32),
  'i_q': Array([[0.]], dtype=float32),
  'torque': Array([[0.]], dtype=float32),
  'omega': Array([[0.]], dtype=float32)},
 'omega_add': Array([[3.4061984e-05]], dtype=float32),
 'omega_count': Array([[23383.]], dtype=float32),
 'keys': Array([[1610414785,  728627179]], dtype=uint32),
 'references': Array([[ 0.22062254, -0.46504974]], dtype=float32)}

### Step comparison

In [47]:
next_state,obs_next,reward,done=jax.vmap(new_motor_env.step,in_axes=(0,0,new_motor_env.in_axes_env_properties))(next_state,jnp.array([[1,1]]),new_motor_env.env_properties)

Traced<ShapedArray(float32[2,1])>with<BatchTrace(level=1/0)> with
  val = Array([[[125350.805],
        [137448.14 ]]], dtype=float32)
  batch_dim = 0


In [50]:
obs_next

Array([[ 5.1705545e-01,  5.1705545e-01,  2.7249588e-04,  1.0000000e+00,
         9.9874944e-07,  2.2062254e-01, -4.6504974e-01,  1.0000000e+00,
         1.0000000e+00]], dtype=float32)

In [37]:
next_state2,obs_next2,reward2,done2=new_motor_env2.step(motor_env_state2,jnp.array([[0,1]]))

In [38]:
obs_next2

Array([[ 0.0000000e+00,  0.0000000e+00,  3.4061984e-05,  1.0000000e+00,
         0.0000000e+00,  2.2062254e-01, -4.6504974e-01,  0.0000000e+00,
         1.0000000e+00]], dtype=float32)

In [16]:
next_state

PMSM.States(physical_state=PMSM_Physical.PhysicalState(action_buffer=Array([[[  0.     , 266.66666]]], dtype=float32), epsilon=Array([0.], dtype=float32), i_d=Array([0.], dtype=float32), i_q=Array([0.], dtype=float32), torque=Array([0.], dtype=float32), omega=Array([-0.00032567], dtype=float32)), PRNGKey=Array([[ 520295942, 1650227225]], dtype=uint32), optional=PMSM.Optional(omega_add=Array([[-3.109891e-05]], dtype=float32), omega_count=Array([[24666.]], dtype=float32), references=Array([[-0.9234686 ,  0.00126839]], dtype=float32)))

In [17]:
next_state2

{'physical_state': {'action_buffer': Array([[[  0.     , 266.66666]]], dtype=float32),
  'epsilon': Array([[0.]], dtype=float32),
  'i_d': Array([[0.]], dtype=float32),
  'i_q': Array([[0.]], dtype=float32),
  'omega': Array([[-0.00032567]], dtype=float32),
  'torque': Array([[0.]], dtype=float32)},
 'omega_add': Array([[-3.109891e-05]], dtype=float32),
 'omega_count': Array([[24666.]], dtype=float32),
 'keys': Array([[ 520295942, 1650227225]], dtype=uint32),
 'references': Array([[-0.9234686 ,  0.00126839]], dtype=float32)}