https://epubs.siam.org/doi/10.1137/20M1382386

In [None]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from itertools import chain
from tqdm import tqdm

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
seed = 0
torch.manual_seed(seed)
np.random.seed(seed)

Parameters

In [None]:
A = torch.tensor([[0.5, 0.05, 0.1, 0.2], [0., 0.2, 0.3, 0.1], [0.06, 0.1, 0.2, 0.4], [0.05, 0.2, 0.15, 0.1]])
B = torch.tensor([[-0.05, -0.01], [-0.005, -0.01], [-1, -0.01], [-0.01, -0.9]])
Q_t = torch.tensor([[1., 0.5, -0.01, 0.], [-0.1, 1.1, 0.2, 0.], [0., 0.1, 0.9, -0.06], [0.03, 0., -0.1, 0.88]])
R_t = torch.tensor([[0.4, -0.2], [-0.3, 0.7]])
W = torch.diag(torch.tensor([0.1, 0.5, 0.2, 0.3]))
# print(A, B, Q_t, R_t, W)
Q_T = Q_t

T = 10 # finite time horizon
# r = 1 # smoothing parameter
d = 4
k = 2

Q_t_block_diag = torch.block_diag(*[Q_t for _ in range(T)])
R_t_block_diag = torch.block_diag(*[R_t for _ in range(T)])
# Q_t_TXT = Q_t.repeat(T, T)
# R_t_TXT = R_t.repeat(T, T)

epochs = 1000
num_samples = 25000 # expectation
# print(Q_t_block_diag)

Initialization

In [None]:
def gen_init_state(size:int) -> torch.tensor:
    # x_0_1 = torch.normal(mean=torch.tensor([5.0]), std=torch.tensor([0.1**0.5]))  
    # x_0_2 = torch.normal(mean=torch.tensor([2.0]), std=torch.tensor([0.3**0.5]))  
    # x_0_3 = torch.normal(mean=torch.tensor([8.0]), std=torch.tensor([1.**0.5]))  
    # x_0_4 = torch.normal(mean=torch.tensor([5.0]), std=torch.tensor([0.5**0.5]))
    # mean = torch.tensor([5.0, 2.0, 8.0, 5.0])
    size = (1, size)
    x_0_1 = torch.normal(mean=5.0, std=0.1**0.5, size=size)  
    x_0_2 = torch.normal(mean=2.0, std=0.3**0.5, size=size)  
    x_0_3 = torch.normal(mean=8.0, std=1.**0.5, size=size)  
    x_0_4 = torch.normal(mean=5.0, std=0.5**0.5, size=size)
    
    x_0 = torch.stack((x_0_1, x_0_2, x_0_3, x_0_4), dim=1)                                  
    x_0 = torch.squeeze(x_0, dim=0)
    # x_0 = torch.unsqueeze(x_0, 1)
    # x_0 = x_0.to(device)|
    return x_0

print(gen_init_state(num_samples).shape)
# print(gen_init_state(num_samples))

noise $w_t$

In [None]:
w = torch.sqrt(torch.diag(W)).view(-1,1)
print(w)
w_noise = torch.randn((d, num_samples))
w_stack = w * w_noise
# print((w_stack**2).mean(dim=1))
# print(w.shape)
# w_stack = w.repeat(1, num_samples)
# print(w_stack.shape)

# L = torch.cholesky(W)
# standard_normal_noise = torch.randn(L.size(0), num_samples)
# w_stack = torch.matmul(L, standard_normal_noise)
# print(w_stack.shape)

Neural Network of Control

In [None]:
class ControlNN(nn.Module):
    def __init__(self):
        super(ControlNN, self).__init__()
        
        self.input_layer = nn.Linear(d, 100)
        self.hidden_layer1 = nn.Linear(100, 100)
        self.hidden_layer2 = nn.Linear(100, 100)
        self.output_layer = nn.Linear(100, k*T)

        # self.act = nn.ReLU()
        # self.act1 = nn.Softplus()
        # self.act2 = nn.Tanh()
        # self.act = nn.CELU()
        # self.act1 = nn.LeakyReLU()
        self.act = nn.Tanh()

        # self.relu = nn.ReLU()
        # self.softplus = nn.Softplus()
        # self.tanh = nn.Tanh()
        # self.celu = nn.CELU()
        # self.leakyrelu = nn.LeakyReLU()

    def forward(self, x):
        ''' 
        Input
        shape (4, 100) = (d, num_samples)

        Output
        shape (20, 100) = (k*T, num_samples)
        '''
        x = x.permute(1, 0) # (100, 4)
        # x = torch.squeeze(x, dim=1)

        x = self.input_layer(x)
        # x = self.relu(x)
        # x = self.act1(x)
        x = self.act(x)

        x = self.hidden_layer1(x)
        # x = self.relu(x)
        # x = self.act2(x)
        x = self.act(x)

        x = self.hidden_layer2(x)
        # x = self.relu(x)
        # x = self.act1(x)
        x = self.act(x)

        x = self.output_layer(x)
        # x = self.act2(x)
        x = self.act(x)

        # x = torch.unsqueeze(x, dim=1)
        x = x.permute(1, 0)
        return x


In [None]:
controlNN = ControlNN()
control= controlNN(gen_init_state(num_samples))
print(control.shape)

Pass to GPU

In [None]:
A = A.to(device)
B = B.to(device)
Q_t = Q_t.to(device)
R_t = R_t.to(device)
# W = W.to(device)
Q_T = Q_T.to(device)
# x_0 = x_0.to(device)
# K_0 = K_0.to(device)
# w_matrix = w_matrix.to(device)
# state = state.to(device)
# w = w.to(device)
w_stack = w_stack.to(device)

controlNN = controlNN.to(device)


Q_t_block_diag = Q_t_block_diag.to(device)
R_t_block_diag = R_t_block_diag.to(device)

Optimizer

In [None]:
optim = torch.optim.Adam(controlNN.parameters(), lr=1e-4)

Training

In [None]:
epochs = 200000

In [None]:
x0 = gen_init_state(num_samples).to(device) # (4, 100)

In [None]:
# control = controlNN(x0) # (20, 100) (k*T, 100)
# control = controlNN # (20, 100) (k*T, 100)
# control.to(device)
controlNN.train()

# for _ in tqdm(range(epochs)):
for epoch in tqdm(range(epochs)):
    optim.zero_grad()

    control = controlNN(x0)
    x = x0
    state = torch.tensor([]).to(device)
    state = torch.cat((state, x), dim=0)

    for t in range(T-1):
        u = control[t*k : (t+1)*k, :] # (2, 100)
        x = torch.matmul(A, x) + torch.matmul(B, u) + w_stack # (4, 100)
        # x = A @ x + B @ u + w_stack # (4, 100)
        state = torch.cat((state, x), dim=0)
    # state (40, 100) (d*T, 100) / control (20, 100) (k*T, 100)
    u = control[(T-1)*k : T*k, :]
    # x_terminal = torch.matmul(A, x) + torch.matmul(B, u) + w_stack # Update terminal state 'x_T' 
    x_terminal = A @ x + B @ u + w_stack # Update terminal state 'x_T' 
    
    # loss_state = torch.matmul(torch.transpose(state, 0, 1), torch.matmul(Q_t_block_diag, state))
    loss_state = state.T @ Q_t_block_diag @ state
    loss_state_diag = loss_state.diagonal()
    # print(loss_state_diag.shape)
    avg_loss_state = loss_state_diag.mean() 
    # print(avg_loss_state.shape)

    # loss_control = torch.matmul(torch.transpose(control, 0, 1), torch.matmul(R_t_block_diag, control)) 
    loss_control = control.T @ R_t_block_diag @ control
    # print(loss_control.shape)
    loss_control_diag = loss_control.diagonal()
    avg_loss_control = loss_control_diag.mean() 

    # print('x_T', x_terminal.shape) # (4, 100)
    # print('Q_t', Q_t.shape) # (4, 4)
    # loss_terminal_state = torch.matmul(torch.transpose(x_terminal, 0, 1), torch.matmul(Q_t, x_terminal))
    loss_terminal_state = x_terminal.T @ Q_t @ x_terminal
    loss_terminal_state_diag = loss_terminal_state.diagonal()
    avg_loss_terminal_state = loss_terminal_state_diag.mean()

    loss = avg_loss_state + avg_loss_control + avg_loss_terminal_state
    if epoch % 5000 == 0:
        torch.save(controlNN.state_dict(), 'LQR.pt')
        # torch.save(controlNN.state_dict(), f'LQR_{str(controlNN.act2)}.pt')
        print(f'epoch: {epoch}, objective: {loss.item()}')
    # print(loss)
    # print('bboa?')
    loss.backward()
    # print('back!!!')
    optim.step()

In [None]:
u_pred = control

print('shape of control:', control.shape)
print('shape of state:', state.shape)

Real Solution of LQR

In [None]:
# def get_sol_P(A, B, Q_t, Q_T, R_t, device, num_steps=1000, tol=1e-6):
def get_sol_P():
    P_list = []
    P_next = Q_T # Terminal condition
    P_list.append(P_next)

    for _ in range(T): # P_T^*, P_{T-1}^*, ..., P_0^*
        P_t = Q_t + A.T @ P_next @ A 
        - A.T @ P_next @ B @ torch.linalg.inv(B.T @ P_next @ B + R_t) @ B.T @ P_next @ A
        P_list.append(P_t)
        P_next = P_t
    
    P_list.reverse() # P_0^*, P_1^*, ..., P_T^*
    
    return P_list

P_sol = get_sol_P()
# len(P_sol)
# P_sol[0]

In [None]:
def get_sol_K(P):
    P_sol_list = P
    P_sol_list.reverse() # P_T^*, P_{T-1}^*, ..., P_0^*
    K_list = []
    
    for i in range(T): # K_{T-1}^*, ..., K_0^*
        K_t = torch.linalg.inv(B.T @ P_sol_list[i] @ B + R_t) @ B.T @ P_sol_list[i] @ A
        K_list.append(K_t)

    K_list.reverse() # K_0^*, K_1^*, ..., K_{T-1}^*

    return K_list

K_sol = get_sol_K(P_sol)
# len(K_sol)
# print(K_sol[0].shape)

In [None]:
def get_sol_u(K, x_0):
    x_sol = torch.zeros_like(state).to(device)
    u_sol = torch.zeros_like(control).to(device)

    x_sol[0*d : 1*d, :] = x_0
    
    for i in range(T-1):
        K_t = K[i]
        # x_t = x[i*d : (i+1)*d, :] # [x_i^0, ..., x_i^{num_samples - 1}]
        x_t = x_sol[i*d : (i+1)*d, :]
        u_sol[i*k : (i+1)*k, :] = - K_t @ x_t # Update u_t

        u_t = u_sol[i*k : (i+1)*k, :]
        x_sol[(i+1)*d : (i+2)*d, :] = A @ x_t + B @ u_t + w_stack # Update x_t

    K_t = K[T-1]
    x_t = x_sol[(T-1)*d : T*d, :]
    u_sol[(T-1)*k : T*k, :] = - K_t @ x_t  # Update u_{T-1}
    # print('xsol', x_sol)
    return u_sol
    
u_sol = get_sol_u(K_sol, x0)
# print(u_sol)

In [None]:
# torch.norm(control-u_sol, p=2)/num_samples
# torch.norm(u_pred - u_sol, p='fro')/num_samples
# loss_tr = loss_fn(u_pred, u_sol)

# print(f'Test loss: {loss_test.item()}')

Test

In [None]:
controlNN.eval()

In [None]:
x0_test = gen_init_state(num_samples).to(device)

In [None]:
with torch.no_grad():
    u_pred_test = controlNN(x0_test)
# print(u_pred_test)


In [None]:
P_sol_test = get_sol_P()
K_sol_test = get_sol_K(P_sol_test)
u_sol_test = get_sol_u(K_sol_test, x0_test)

# print(u_sol_test)

In [None]:
loss_fn = torch.nn.MSELoss()

In [None]:
# torch.norm(u_pred_test - u_sol_test, p='fro')/num_samples
loss_test = loss_fn(u_pred_test, u_sol_test)

print(f'Test loss: {loss_test.item()}')