In [12]:
import numpy as np
import numpy.random as rnd
from scipy.linalg import expm
from random import sample
import jax.numpy as jnp
from jax import jacobian
import jax
from jax import config
config.update("jax_enable_x64", True)

import scipy
import scipy.linalg as la

import time
import gc
import matplotlib.pyplot as plt
plt.rcParams.update({'axes.titlesize': 20,    # Font size of the axes title
                     'axes.labelsize': 20,    # Font size of the x and y labels
                     'xtick.labelsize': 16,   # Font size of the tick labels
                     'ytick.labelsize': 16,   # Font size of the tick labels
                     'legend.fontsize': 16    # Font size of the legend
                    })

$\textbf{Auxiliary Functions:}$

In [13]:
def cart_pole_dynamics(state, control, m_c = 1.0, m_p = 1.0, l = 1.0, g = 1.0):
    """
    Compute the continuous-time dynamics of the cart-pole system.

    Parameters:
    state (array): The state vector [x, x_dot, theta, theta_dot].
    control (array): The control vector [F], where F is the force applied to the cart.

    Returns:
    array: The derivative of the state vector [x_dot, x_ddot, theta_dot, theta_ddot].
    """
    # Unpack the state vector
    x, x_dot, theta, theta_dot = state

    # Unpack the control vector
    F = control[0]

    # Intermediate calculations
    sin_theta = jnp.sin(theta)
    cos_theta = jnp.cos(theta)
    total_mass = m_c + m_p
    d = total_mass * l - m_p * l * cos_theta**2

    a1 = jnp.array([[m_c + m_p, m_p*l*cos_theta], [m_p*cos_theta, m_p*l]])
    b1 = jnp.array([[m_p*l*theta_dot**2+F], [m_p*g*sin_theta]])
    c1 = jnp.linalg.inv(a1)@b1
    x_ddot = c1[0,0]
    theta_ddot = c1[1,0]
    return jnp.array([x_dot, x_ddot, theta_dot, theta_ddot])


def subspace_distance(Phi_1,Phi_2):
    distance = la.norm(Phi_1.T@Phi_2, ord = 2)
    return distance

def LQR(A,B,Q=None, R=None):
    if not Q:
        Q = np.eye(B.shape[0])
    if not R:
        R = np.eye(B.shape[1])
    P = la.solve_discrete_are(A,B,Q,R)
    return -la.inv(B.T@P@B+R)@B.T@P@A

$\textbf{Generate multiple tasks:}$

In [14]:
H = 100 #Number of tasks

#System dimensions
dx = 4
du = 1

#Linearize the cartpole system around the origin
x_0 = np.zeros(4)
u_0 = np.array([0.0])

dt = 0.25
g = 1.0

params_nominal = [(0.4, 1.0, 1.0), (1.6, 1.3, 0.3), (1.3, 0.7, 0.65), (0.2, 0.055, 1.36), (0.2, 0.47, 1.825)]

#params_nominal = [(0.1, 0.5, 1.36), (1.6, 1.3, 0.3), (1.3, 0.7, 0.65), (0.2, 0.055, 1.36), (0.2, 0.47, 1.825)]

nsys = len(params_nominal)

params = []
l=0
for i in range(H):
    if i%nsys == 0:
        l=0
    params.append(params_nominal[l])
    l+=1

for i in range(nsys,H):
    params[i] = tuple(np.array(params[i]) + 0.01*np.random.rand(1,3)[0])

In [15]:
#Setting the same cost matrices for all tasks
Q = np.diag(np.array([1,1,1,1]))
R = np.eye(1)
            
dynamics = []
K_star_tasks = []
K0_tasks = []
for param in params:
    m_c, m_p, ell = param
    Atemp = np.eye(4) + dt*np.array(jacobian(cart_pole_dynamics, 0)(x_0, u_0, m_c , m_p, ell, g=1.0))
    Btemp = dt*np.array(jacobian(cart_pole_dynamics, 1)(x_0, u_0, m_c, m_p, ell, g=1.0))
    #Computing the optimal controller for each generated cartpole system  
    K_star_tasks.append(LQR(Atemp,Btemp))
    #Scaling the task optimal controller to obtain a suboptimal initial controller
    K0_tasks.append(0.7*K_star_tasks[-1])
    dynamics.append((Atemp,Btemp))
    #Checking if K0h stabilizes the generated task
    assert np.max(np.abs(la.eigvals(Atemp+np.outer(Btemp,K0_tasks[-1])))) < 0.999


#Generating the ground truth representation Phi_star
A1, B1 = dynamics[0]
Atemp = np.copy(A1); Btemp = np.copy(B1)

bases = np.zeros((20,0))
for i in range(5):
    ei = np.eye(4,1,-i)
    Atemp[1,2] = ei[0]; Atemp[3,2] = ei[1]; Btemp[1,0] = ei[2]; Btemp[3,0] = ei[3]
    basis = np.vstack([Atemp.T, Btemp.T]).flatten()
    bases = np.hstack([bases, np.expand_dims(basis,1)])


U, S, Vh = la.svd(bases)
Phi_star = U[:,:bases.shape[1]]
Phi_star_perp = U[:,bases.shape[1]:]
    

init_dist = 0.6 
Phi_0 = np.copy(Phi_star)
Phi_0[:,-1] = Phi_0[:,-1] + np.sqrt(1/(1/init_dist**2 - 1))*Phi_star_perp[:,0]
Phi_0[:,-1] = Phi_0[:,-1]/la.norm(Phi_0[:,-1])



#Initial subspace distance 
print(f"Initial subspace distance {subspace_distance(Phi_0, Phi_star_perp)}")

Initial subspace distance 0.5999999999999998


$\textbf{Auxiliary Functions:}$

In [23]:
def state_update(x,u,w,A,B):
    
    return (A@x.T + B@u.T + np.sqrt(0.01)*w.T).T


def get_input(x, K, K_0, v, abort):
    
    N = x.shape[0]
    abort = abort.reshape(N,1)
    u = jnp.sum((abort*K_0 + (1-abort)*K)*x, axis=1) + (v*(1-abort)).squeeze()
    return u.reshape(N,1)


def step(carry, noise):
    
    x, abort, K, K_0, sigma, cost, zszs, zsxs, x_b, K_b, A, B = carry

    abort = jnp.maximum(abort, jnp.maximum(jnp.linalg.norm(x, axis=1) >= x_b, jnp.linalg.norm(K, axis=1) >= K_b))
    
    #process and input noise
    
    w = noise[:,:-1]; v = sigma*noise[:,-1:]

    u = get_input(x, K, K_0, v, abort)
    
    xnext = state_update(x,u,w,A,B)

    cost += jnp.linalg.norm(x, axis=1)**2 + jnp.linalg.norm(u,axis=1)**2

    z = jnp.hstack([x, u])
    zszs += jnp.expand_dims(z, axis = 1)*jnp.expand_dims(z, axis = 2)
    zsxs += (jnp.expand_dims(z, axis = 2)*jnp.expand_dims(xnext, axis = 1)).reshape(x.shape[0],20)

    carry = (xnext, abort, K, K_0, sigma, cost, zszs, zsxs, x_b, K_b, A, B)
    
    return carry, ()


def LS(Phi,  zszs, zsxs):
    Lambda = Phi.T@np.kron(zszs, np.eye(4))@Phi
    evals = la.eigvals(Lambda)
    #assert np.min(np.abs(evals)) > 1e-8, f"minimum eigevalue is: {np.min(np.abs(evals))} and all the eigenvalues are {evals}"
    H = Phi.T@zsxs
    theta = la.inv(Lambda)@H

    dyn_matrices = (Phi@theta).reshape(5, 4, order='C')
    return dyn_matrices[:4].T, dyn_matrices[4:].T



def LS_DFW_1(Phi,  zszs, zsxs):
    Lambda = Phi.T@np.kron(zszs, np.eye(4))@Phi
    evals = la.eigvals(Lambda)
    #assert np.min(np.abs(evals)) > 1e-8, f"minimum eigevalue is: {np.min(np.abs(evals))} and all the eigenvalues are {evals}"
    H = Phi.T@zsxs
    theta = la.inv(Lambda)@H
    return theta

def LS_DFW_2(zszs, zsxs):
    Lambda = np.kron(zszs, np.eye(4))
    evals = la.eigvals(Lambda)
    #assert np.min(np.abs(evals)) > 1e-8, f"minimum eigevalue is: {np.min(np.abs(evals))} and all the eigenvalues are {evals}"
    H = zsxs
    theta = la.inv(Lambda)@H
    return theta


def DFW(Phi_hat, DFW_N, eta, H, zszs1, zsxs1, zszs2, zsxs2, Phi_star_perp):
    
    N = len(Phi_hat)
    
    #print(f"DFW iteration {0}: {round(subspace_distance(Phi_hat[0],Phi_star_perp),3)}")
    for i in range(N):
    
        Phi = []
        thetas = []


        zszsh1=[]; zsxsh1=[]; zszsh2=[]; zsxsh2=[]
        
        for h in range(H):
            zszsh1.append(zszs1[h][i])
            zszsh2.append(zszs2[h][i])
            zsxsh1.append(zsxs1[h][i])
            zsxsh2.append(zsxs2[h][i])

        for n in range(DFW_N):

            thetas = np.zeros((0, 5))
            lstsq_sols_hs = np.zeros((0,20))

            for h in range(H):

                theta_h = LS_DFW_1(Phi_hat[i],  zszsh1[h], zsxsh1[h])
                lstsq_sol_h = LS_DFW_2(zszsh2[h], zsxsh2[h])
                thetas = np.vstack([thetas, np.expand_dims(theta_h,0)])
                lstsq_sols_hs = np.vstack([lstsq_sols_hs, np.expand_dims(lstsq_sol_h,0)])
                
            grad_Phi = (thetas.T@(thetas@Phi_hat[i].T - lstsq_sols_hs)).T
            Phi_hat[i]= Phi_hat[i] - (eta/H)*grad_Phi

            #Get a orthonormal representation
            Phi_hat[i] = la.svd(Phi_hat[i],False)[0]

    return Phi_hat

def play_controller(init,all_noise):
    carry, _ = jax.lax.scan(step, init, all_noise)
    return carry[0], carry[1], carry[5], carry[6], carry[7]

$\textbf{Shared Representation Certainty Equivalent Control with Continual Exploration:}$

In [24]:
def adaptive_control(Phi_0, Phi_star_perp, K0_tasks, tau_1, x_b, K_b, dx, du, exploration_sequence, DFW_N, H, dynamics, Q, R, eta, key):
    
    print(f"Initial subspace distance: {round(subspace_distance(Phi_0,Phi_star_perp),3)}")
    
    N = 1000 #Number of realizations to compute the expected regret
    
    kfin = len(exploration_sequence) #number of epochs

    tau_ = 0
    tau = tau_1 #Initial epoch length
    
    
    abort = jnp.zeros(N)
    cost = jnp.zeros(N)
    x = jnp.zeros((N, 4))
    
    
    aborth = []
    K_0h = []
    Kh =[] 
    costh = []
    xh = []
    
    for h in range(H):
        aborth.append(abort)
        K_0h.append(jnp.stack([K0_tasks[h].squeeze()]*N))
        Kh.append(jnp.stack([K0_tasks[h].squeeze()]*N))
        costh.append(cost)
        xh.append(x)
        
        
    #Representation 
    Phi_hat = []
    for i in range(N):
        Phi_hat.append(Phi_0)
        
    T = 0

    costs = []
    times = []

    costs.append(costh[0]) #We will evaluate the cost of the first task 
    times.append(T)

    for k in range(kfin):
        start_time = time.time()
        
        zszs = jnp.zeros((N, dx+du, dx+du))
        zsxs = jnp.zeros((N, dx*(dx+du)))
        
        zszsh = []
        zsxsh = []
        for h in range(H):
            zszsh.append(zszs)
            zsxsh.append(zsxs)
            
        
        zszsh1 = zszsh; zsxsh1 = zsxsh
        zszsh2 = zszsh; zsxsh2 = zsxsh
        zszsh3 = zszsh; zsxsh3 = zsxsh
        
    
        timestep = 0
        while timestep < tau - tau_:
            temp = min(1000, tau-tau_ - timestep)
            timestep += temp

            T += temp
            
            #Split the data for the DFW and LS
            temp1 = int(3*temp/8); temp2 = int(3*temp/8); temp3 = int(temp/4)

            for h in range(H):
                
                
                
                key = jax.random.PRNGKey(0)
                key, subkey = jax.random.split(key)
                
                
                all_noise1 = jax.random.normal(key, (temp1, N, 5))
                init1 = (xh[h], aborth[h], Kh[h], K_0h[h], jnp.sqrt(exploration_sequence[k]), costh[h], zszsh1[h], zsxsh1[h], x_b, K_b, dynamics[h][0], dynamics[h][1])
                xh[h], aborth[h], costh[h], zszsh1[h], zsxsh1[h] = play_controller(init1,all_noise1)
                
                key = jax.random.PRNGKey(0)
                key, subkey = jax.random.split(key)
                
                all_noise2 = jax.random.normal(key, (temp2, N, 5))
                init2 = (xh[h], aborth[h], Kh[h], K_0h[h], jnp.sqrt(exploration_sequence[k]), costh[h], zszsh2[h], zsxsh2[h], x_b, K_b, dynamics[h][0], dynamics[h][1])
                xh[h], aborth[h], costh[h], zszsh2[h], zsxsh2[h] = play_controller(init1,all_noise1)
                
                key = jax.random.PRNGKey(0)
                key, subkey = jax.random.split(key)
                
                all_noise3 = jax.random.normal(key, (temp3, N, 5))
                init3 = (xh[h], aborth[h], Kh[h], K_0h[h], jnp.sqrt(exploration_sequence[k]), costh[h], zszsh3[h], zsxsh3[h], x_b, K_b, dynamics[h][0], dynamics[h][1])
                xh[h], aborth[h], costh[h], zszsh3[h], zsxsh3[h] = play_controller(init1,all_noise1)
    
            #free up memory
            gc.collect()
                

            costs.append(costh[0])
            times.append(T)

            
            
        #DFW - Learning the representation with the multi-task dataset  
        
        #Improve this function - it is taking too much time to run
        Phi_hat = DFW(Phi_hat, DFW_N, eta, H, zszsh1, zsxsh1, zszsh2, zsxsh2, Phi_star_perp)  
        
        
        average_dist_error = 0
        
        for i in range(N):
            average_dist_error+= subspace_distance(Phi_hat[i],Phi_star_perp)
            
        average_est_error = 0 #for the first task
        Kh = []
        for h in range(H):
            K = []
            for i in range(N):
                if not aborth[h][i]:
                    
                    Ahat_h , Bhat_h  = LS(Phi_hat[i],  np.array(zszsh3[h][i]), np.array(zsxsh3[h][i]))
                    
                    if h == 0:
                        average_est_error +=  la.norm(Ahat_h - dynamics[0][0], 'fro')**2 + la.norm(Bhat_h - dynamics[0][1], 'fro')**2
                    Knew = LQR(Ahat_h,Bhat_h)
  
                else:
                    Knew = K_0h[h][i]
                
                K.append(Knew.squeeze())
            
            K = np.stack(K)
            
            Kh.append(K)
            
        
        end_time = time.time()
        
        print(f"Epoch: {k+1}, time: {round(end_time - start_time,2)}, avg dist error: {round(average_dist_error/N,3)}, avg est error: {round(average_est_error/N,5)}, # aborts: {int(jnp.sum(aborth[0]))}, cost: {int(jnp.mean(costh[0]))}")

        tau_ = tau
        tau = 2*tau

    return costs, times

In [25]:
#Algorithm parameters

#Initialization for the representation and task weights

x_b = 50 #state bound
K_b = 15 #controller bound
tau_1 = int(4*1024) #initial epoch length
kfin = 8 #number of epochs
DFW_N = 100 #number of DFW iterations
eta = 0.05 #step-size for the DFW iterations

key = jax.random.PRNGKey(0)

exploration_sequence = [1/(jnp.sqrt(2**k)) for k in range(kfin)]

In [None]:
H = 10

costs, times = adaptive_control(Phi_0, Phi_star_perp, K0_tasks, tau_1, x_b, K_b, dx, du, exploration_sequence, DFW_N, H, dynamics, Q, R, eta, key)

A_1 = dynamics[0][0]
B_1 = dynamics[0][1]

P = la.solve_discrete_are(A_1,B_1,Q,R)
regret_mean = [np.mean(cost - 0.01*T*np.trace(P)) for (cost,T) in zip(costs,times)]
regret_quantiles = [np.quantile(cost - 0.01*T*np.trace(P), [0.2, 0.8]) for (cost,T) in zip(costs,times)]

regret_quantiles_1 = []
regret_quantiles_2 = []
for i in range(len(regret_quantiles)):
    regret_quantiles_1.append(regret_quantiles[i][0])
    regret_quantiles_2.append(regret_quantiles[i][1])

Initial subspace distance: 0.6


In [None]:
np.save('regret_multi.npy', regret_points)
np.save('times_multi.npy', times)

In [None]:
H = 1

costs, times = adaptive_control(Phi_0, Phi_star_perp, K0_tasks, tau_1, x_b, K_b, dx, du, exploration_sequence, DFW_N, H, dynamics, Q, R, eta, key)

A_1 = dynamics[0][0]
B_1 = dynamics[0][1]

P = la.solve_discrete_are(A_1,B_1,Q,R)
regret_mean = [np.mean(cost - 0.01*T*np.trace(P)) for (cost,T) in zip(costs,times)]
regret_quantiles = [np.quantile(cost - 0.01*T*np.trace(P), [0.2, 0.8]) for (cost,T) in zip(costs,times)]

regret_quantiles_1 = []
regret_quantiles_2 = []
for i in range(len(regret_quantiles)):
    regret_quantiles_1.append(regret_quantiles[i][0])
    regret_quantiles_2.append(regret_quantiles[i][1])

In [None]:
np.save('regret_single.npy', regret_points)
np.save('times_single.npy', times)

$\textbf{Comparison between multi-task and single task:}$

In [None]:
regret_single = np.load('regret_single.npy')
regret_multi = np.load('regret_multi.npy')


#Regret single task
data_1 = regret_single

# Calculate mean and standard error of the mean (SEM)
mean_1 = np.mean(data_1, axis=1)
sem_1 = np.std(data_1, axis=1) / np.sqrt(data_1.shape[0])

# Calculate 95% confidence intervals
ci_1 = 1.95*sem_1


#Regret multi-task
data_2 = regret_multi

# Calculate mean and standard error of the mean (SEM)
mean_2 = np.mean(data_2, axis=1)
sem_2 = np.std(data_2, axis=1) / np.sqrt(data_2.shape[0])

# Calculate 95% confidence intervals
ci_2 = 1.95*sem_2



plt.figure(figsize=(16,10))

# Create a plot
plot_1=plt.plot(times, mean_1, label='Single task (H = 1)', color='red')
plot_2=plt.plot(times, mean_2,label=f"Multi-task (H = {H})", color='blue')


fill_1=plt.fill_between(times, mean_1-ci_1, mean_1+ci_1, color='red', alpha=0.2)
fill_2=plt.fill_between(times, mean_2-ci_2, mean_2+ci_2, color='blue', alpha=0.2)


plt.ylabel(r'Regret',fontsize=30)
plt.xlabel('T',fontsize=30)
plt.tick_params(axis='both', labelsize=25)

plt.legend()
# Exclude the shading elements from the legend
handles, labels = plt.gca().get_legend_handles_labels()
handles = handles[:2]  # Keep only the handles for the curve lines
labels = labels[:2]  # Keep only the labels for the curve lines
plt.legend(handles=handles, labels=labels,fontsize=20)
plt.subplots_adjust(bottom=0.20)
plt.grid()