In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import tqdm
from scipy import integrate
import matplotlib.pyplot as plt
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

class DGM_Layer(nn.Module):
    
    def __init__(self, dim_x, dim_S, activation='Tanh'):
        super(DGM_Layer, self).__init__()
        
        if activation == 'ReLU':
            self.activation = nn.ReLU()
        elif activation == 'Tanh':
            self.activation = nn.Tanh()
        elif activation == 'Sigmoid':
            self.activation = nn.Sigmoid()
        elif activation == 'LogSigmoid':
            self.activation = nn.LogSigmoid()
        else:
            raise ValueError("Unknown activation function {}".format(activation))
            

        self.gate_Z = self.layer(dim_x+dim_S, dim_S)
        self.gate_G = self.layer(dim_x+dim_S, dim_S)
        self.gate_R = self.layer(dim_x+dim_S, dim_S)
        self.gate_H = self.layer(dim_x+dim_S, dim_S)
            
    def layer(self, nIn, nOut):
        l = nn.Sequential(nn.Linear(nIn, nOut), self.activation)
        return l
    
    def forward(self, x, S):
        x_S = torch.cat([x,S],1)
        Z = self.gate_Z(x_S)
        G = self.gate_G(x_S)
        R = self.gate_R(x_S)
        
        input_gate_H = torch.cat([x, S*R],1)
        H = self.gate_H(input_gate_H)
        
        output = ((1-G))*H + Z*S
        return output


class Net_DGM(nn.Module):

    def __init__(self, dim_x, dim_S, activation='Tanh'):
        super(Net_DGM, self).__init__()

        self.dim = dim_x
        if activation == 'ReLU':
            self.activation = nn.ReLU()
        elif activation == 'Tanh':
            self.activation = nn.Tanh()
        elif activation == 'Sigmoid':
            self.activation = nn.Sigmoid()
        elif activation == 'LogSigmoid':
            self.activation = nn.LogSigmoid()
        else:
            raise ValueError("Unknown activation function {}".format(activation))

        self.input_layer = nn.Sequential(nn.Linear(dim_x+1, dim_S), self.activation)

        self.DGM1 = DGM_Layer(dim_x=dim_x+1, dim_S=dim_S, activation=activation)
        self.DGM2 = DGM_Layer(dim_x=dim_x+1, dim_S=dim_S, activation=activation)
        self.DGM3 = DGM_Layer(dim_x=dim_x+1, dim_S=dim_S, activation=activation)

        self.output_layer = nn.Linear(dim_S, 1)

    def forward(self,t,x):
        tx = torch.cat([t,x], 1)
        S1 = self.input_layer(tx)
        S2 = self.DGM1(tx,S1)
        S3 = self.DGM2(tx,S2)
        S4 = self.DGM3(tx,S3)
        output = self.output_layer(S4)
        return output

class DGM_layer(nn.Module):
    def __init__(self, in_features, out_feature, residual = False):
        super(DGM_layer, self).__init__()
        self.residual = residual
        
        self.Z = nn.Linear(out_feature, out_feature); self.UZ = nn.Linear(in_features, out_feature, bias = False)
        self.G = nn.Linear(out_feature, out_feature); self.UG = nn.Linear(in_features, out_feature, bias = False)
        self.R = nn.Linear(out_feature, out_feature); self.UR = nn.Linear(in_features, out_feature, bias = False)
        self.H = nn.Linear(out_feature, out_feature); self.UH = nn.Linear(in_features, out_feature, bias = False)
        
    def forward(self, x, s):
        z = torch.tanh(self.UZ(x)+self.Z(s))
        g = torch.tanh(self.UG(x)+self.G(s))
        r = torch.tanh(self.UR(x)+self.R(s))
        h = torch.tanh(self.UH(x)+self.H(s))
        return (1-g)*h+z*s  
    
class DGM_Net(nn.Module):
    def __init__(self, in_dim, out_dim, n_layers, n_neurons, residual = False): 
        super(DGM_Net, self).__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.n_layers = n_layers
        self.n_neurons = n_neurons
        self.residual = residual
    
        self.input_layer = nn.Linear(in_dim, n_neurons)
        self.dgm_layers = nn.ModuleList([DGM_layer(self.in_dim, self.n_neurons, self.residual) for i in range(self.n_layers)])
        self.output_layer = nn.Linear(n_neurons, out_dim)
    
    def forward(self, x):
        s = torch.tanh(self.input_layer(x))
        for i, dgm_layer in enumerate(self.dgm_layers):
            s = dgm_layer(x, s)
        return self.output_layer(s)
    
def get_gradient(output, x):
    grad = torch.autograd.grad(output, x, grad_outputs=torch.ones_like(output), create_graph=True, retain_graph=True, only_inputs=True)[0]
    return grad

def get_laplacian(grad, x):
    hess_diag = []
    for d in range(x.shape[1]):
        v = grad[:,d].view(-1,1)
        grad2 = torch.autograd.grad(v,x,grad_outputs=torch.ones_like(v), only_inputs=True, create_graph=True, retain_graph=True)[0]
        hess_diag.append(grad2[:,d].view(-1,1))    
    hess_diag = torch.cat(hess_diag,1)
    laplacian = hess_diag.sum(1, keepdim=True)
    return laplacian


In [2]:
class Bellman_pde():
    '''
    Approximating the Bellman PDE on [0,T]*[x1_l,x1_r]*[x2_l,x2_r]
    '''
    def __init__(self, net, x_interval, y_interval, H, M, C, D, R, T, sigma, a):
        self.net = net 
        self.x1_l = x_interval[0].item() # torch tensor, dim = 3
        self.x1_r = x_interval[1].item()
        self.x2_l = y_interval[0].item()
        self.x2_r = y_interval[1].item()
        self.H = H # H, M, C, D, R: torch tensors, dim = 2*2
        self.M = M 
        self.C = C 
        self.D = D 
        self.R = R         
        self.T = T # integer
        self.sigma = sigma # sigma, a: torch tensors, dim = 1*2
        self.a = a 
        
    def sample(self, size):
    
        t_x = torch.cat((torch.rand([size, 1])*self.T, (self.x1_l - self.x1_r) * torch.rand([size, 1]) + self.x1_r, (self.x2_l - self.x2_r) * torch.rand([size, 1]) + self.x2_r), dim=1)
        
        x_boundary = torch.cat((torch.ones(size, 1)*self.T, (self.x1_l - self.x1_r) * torch.rand([size, 1]) + self.x1_r, (self.x2_l - self.x2_r) * torch.rand([size, 1]) + self.x2_r), dim=1)     
        return t_x, x_boundary
    
    def mat_ext(self, mat, size):
        if mat.shape == torch.Size([2, 2]):
            return mat.unsqueeze(0).repeat(size,1,1)
        elif mat.shape == torch.Size([1, 2]):
            return mat.t().unsqueeze(0).repeat(size,1,1)
        
    def get_hessian(self, grad_x, x):
        hessian = torch.zeros(len(x),2,2)
        dxx = torch.autograd.grad(grad_x[0][:,1], x, grad_outputs=torch.ones_like(grad_x[0][:,1]), allow_unused=True, retain_graph=True)[0][:,1]
        dxy = torch.autograd.grad(grad_x[0][:,1], x, grad_outputs=torch.ones_like(grad_x[0][:,1]), allow_unused=True, retain_graph=True)[0][:,2]
        dyx = torch.autograd.grad(grad_x[0][:,2], x, grad_outputs=torch.ones_like(grad_x[0][:,2]), allow_unused=True, retain_graph=True)[0][:,1]
        dyy = torch.autograd.grad(grad_x[0][:,2], x, grad_outputs=torch.ones_like(grad_x[0][:,2]), allow_unused=True, retain_graph=True)[0][:,2]
        hessian[:,0,0] = dxx 
        hessian[:,0,1] = dxy
        hessian[:,1,0] = dyx
        hessian[:,1,1] = dyy
        return hessian  
        
        
    def loss_func(self, size):
        loss = nn.MSELoss() # MSE 
        
        # Extend the input matrices
        H = self.mat_ext(self.H, size) # H, M, C, D, R: dim = batchsize*2*2
        M = self.mat_ext(self.M, size)
        C = self.mat_ext(self.C, size)
        D = self.mat_ext(self.D, size)
        R = self.mat_ext(self.R, size) # control: dim = batchsize*2*1          
        T = self.T
        a = self.a
        sig = self.sigma.t()

        x, x_boundary = self.sample(size=size)
        x = x.requires_grad_(True) # Track gradients during automatic differentiation

        # gradients
        grad = torch.autograd.grad(self.net(x), x, grad_outputs=torch.ones_like(self.net(x)), create_graph=True)
        
        du_dt = grad[0][:,0].reshape(-1, 1)  # derivative w.r.t. time, dim = batchsize*1
        
        du_dx = grad[0][:,1:] # derivative w.r.t. space, dim = batchsize*2 
                
        # Hessian matrix
        hessian = self.get_hessian(grad,x)
        
        # Error from the equation
        sig2_ext = self.mat_ext(torch.matmul(sig,sig.t()), size) # dim = batchsize*2*2
        prod = torch.bmm(sig2_ext,hessian) # sigma*sigma^T*2nd derivatives
        trace = torch.diagonal(prod, dim1=1, dim2=2).sum(dim=1).unsqueeze(0).t() # trace, dim = batchsize*1
        x_space = x[:,1:].unsqueeze(1).reshape(size,2,1) # extract (x1,x2)^T, dim = batchsize*2*1
        x_space_t = x_space.reshape(size,1,2) # dim = batchsize*1*2
        du_dx_ext_t = du_dx.unsqueeze(1) # dim=batchsize*1*2
        
        pde = du_dt+0.5*trace+torch.bmm(du_dx_ext_t,torch.bmm(H,x_space)).squeeze(1)\
                +torch.bmm(du_dx_ext_t,torch.bmm(M,self.a)).squeeze(1)\
                +torch.bmm(x_space_t,torch.bmm(C,x_space)).squeeze(1)\
                +torch.bmm(a.reshape(size,1,2),torch.bmm(D,a)).squeeze(1) # dim = batchsize*1
 
        pde_err = loss(pde, torch.zeros(size,1))
        
        # Error from the boundary condition
        x_bound = x_boundary[:,1:].unsqueeze(1).reshape(size,2,1) # extract (x1,x2)^T, dim = batchsize*2*1
        x_bound_t = x_bound.reshape(size,1,2) # dim = batchsize*1*2
        
        boundary_err = loss(self.net(x_boundary), torch.bmm(x_bound_t,torch.bmm(R,x_bound)).squeeze(1))
        
        return pde_err + boundary_err


class Train():
    def __init__(self, net, PDE, BATCH_SIZE):
        self.errors = []
        self.BATCH_SIZE = BATCH_SIZE
        self.net = net
        self.model = PDE

    def train(self, epoch, lr):
        optimizer = optim.Adam(self.net.parameters(), lr) # Import the parameters, lr: learning rate
        avg_loss = 0
        for e in range(epoch):
            optimizer.zero_grad()
            loss = self.model.loss_func(self.BATCH_SIZE)
            avg_loss = avg_loss + float(loss.item())
            loss.backward()
            optimizer.step()
            if (e+1) % 100 == 0:
                loss = avg_loss/100
                print("epoch {} - lr {} - loss: {}".format(e, lr, loss))
                avg_loss = 0

                error = self.model.loss_func(self.BATCH_SIZE)
                self.errors.append(error.detach())

    def get_errors(self):
        return self.errors

In [3]:
class FFN(nn.Module):

    def __init__(self, sizes, activation=nn.ReLU, output_activation=nn.Identity, batch_norm=False):
        super().__init__()
        
        layers = [nn.BatchNorm1d(sizes[0]),] if batch_norm else []
        for j in range(len(sizes)-1):
            layers.append(nn.Linear(sizes[j], sizes[j+1]))
            if batch_norm:
                layers.append(nn.BatchNorm1d(sizes[j+1], affine=True))
            if j<(len(sizes)-2):
                layers.append(activation())
            else:
                layers.append(output_activation())

        self.net = nn.Sequential(*layers)

    def freeze(self):
        for p in self.parameters():
            p.requires_grad=False

    def unfreeze(self):
        for p in self.parameters():
            p.requires_grad=True

    def forward(self, x):
        return self.net(x)

In [4]:
class LQR:
    def __init__(self, H, M, D, C, R, sigma, T):
        self.H = H.double()
        self.M = M.double()
        self.D = D.double()
        self.C = C.double()
        self.R = R.double()
        self.sigma = sigma.reshape(1,-1).t() # reshape the input sigma as a 2*1 matrix
        self.T = T


    def riccati_ode(self, t, Q):
        if type(self.C) == torch.Tensor:
             self.C = self.C.numpy()

        # Rewrite the imput 1*4 vector as a 2*2 matrix
    
        Q_matrix = Q.reshape((2,2))
        
        # RHS of the ode
        quadratic_term = -np.linalg.multi_dot([Q_matrix,self.M,np.linalg.inv(self.D),self.M,Q_matrix])
        linear_term = 2*np.dot(np.transpose(self.H),Q_matrix)
        constant_term = self.C
        
        # Riccati ode in the matrix form
        dQ_dt_matrix = linear_term + quadratic_term + constant_term
        
        # Rewrite the matrix ode as a 1*4 vector
        dQ_dt = dQ_dt_matrix.reshape(4,)
        
        return dQ_dt

    
    def riccati_solver(self, time_grid):
        if type(time_grid) == torch.Tensor:
            time_grid = time_grid.numpy()

        Q_0 = self.R.reshape(4,) # initial condition: Q(0)=S(T)=R

        # Solving S(r) on [t,T] is equivalent to solving Q(r)=S(T-r) on [0,T-t] 
        time_grid_Q = np.flip(self.T-time_grid) 
        interval = np.array([time_grid_Q[0], time_grid_Q[-1]]) 
        sol = integrate.solve_ivp(self.riccati_ode, interval, Q_0, t_eval=time_grid_Q)

        t_val = self.T - sol.t # do the time-reversal to get the solution S(t)

        return np.flip(t_val), np.flip(sol.y)
        
    def riccati_plot(self, time_grid):
        sol_t, sol_y = self.riccati_solver(time_grid)
        plt.plot(sol_t,sol_y[0],label='S[0,0]',color='blue')
        plt.plot(sol_t,sol_y[1],label='S[0,1]',color='red')
        plt.plot(sol_t,sol_y[2],label='S[1,0]',color='yellow')
        plt.plot(sol_t,sol_y[3],label='S[1,1]',color='purple')

        plt.xlabel('time')
        plt.ylabel('S(t)')
        plt.legend(['S[0,0]','S[0,1]','S[1,0]','S[1,1]'])
        plt.show()

    def value_function(self, t, x):
        n = 500 # Fix the number of steps to be 500
        val_func = torch.zeros((len(x),1), dtype=torch.float64) 
        x = x.double()

        for j in range(len(x)):
            initial_time = t[j].double().item() 
            step = (self.T-initial_time)/n # step = (T-t)/n
            time_grid = torch.arange(initial_time, self.T+step, step) # generate the time grid on [t,T]
            t_val, S_r = self.riccati_solver(time_grid)   
            S_t = torch.tensor([[S_r[0,0], S_r[1,0]], [S_r[2,0], S_r[3,0]]]) 
            S_t = S_t.double()

            # Assuming sigma is 2x1
            sig = torch.matmul(self.sigma, self.sigma.t()) # 2*2 matrix
            sig = sig.double()
            integral = 0
            for i in range(len(t_val)-1):
                S_i = torch.tensor([[S_r[0,i], S_r[1,i]], [S_r[2,i], S_r[3,i]]])
                S_i_1 = torch.tensor([[S_r[0,i+1], S_r[1,i+1]], [S_r[2,i+1], S_r[3,i+1]]])
                difference = S_i_1-S_i
                integral += torch.trace(torch.matmul(sig,difference))*(t_val[i+1] - t_val[i])

            x_j = x[j].reshape(1,-1).t()
            x_j_t = x_j.t()
            val_func[j] = torch.linalg.multi_dot([x_j_t,S_t,x_j]) + integral

        return val_func
        

    def optimal_control(self, t, x):
        n = 500
        a = torch.zeros(len(x), 2)
        x = x.double()

        for i in range(len(x)):
            init_time = t[i].double().item() 
            step = (self.T-init_time)/n # step = (T-t)/n
            time_grid = torch.arange(init_time, self.T+step, step) # generate the time grid on [t,T]
            S_r = self.riccati_solver(time_grid)[1]
            S_t = torch.tensor([[S_r[0,0], S_r[1,0]], [S_r[2,0], S_r[3,0]]]) 
            S_t = S_t.double()
            x_i = x[i].reshape(1,-1).t() 

            # The product is 2*1, need to flatten it first before appending the value to a_star
            a[i] = -torch.flatten(torch.linalg.multi_dot([self.D,self.M.t(),S_t,x_i])) 
            
        return a

In [41]:
n_epochs = 1000
batch_size = 2

# Set parameters
H = torch.tensor([[1.0,0],[0,1.0]])
M = torch.tensor([[1.0,0],[0,1.0]])
D = torch.tensor([[0.1,0],[0,0.1]])
C = torch.tensor([[0.1,0],[0,0.1]])
R = torch.tensor([[1.0,0],[0,1.0]])
sigma = torch.tensor([[0.05, 0.05]])

# Create data t and x
T = 1.0
x_range = torch.tensor([-3, 3])
y_range = torch.tensor([-3, 3])
t = np.random.uniform(0, T, size=batch_size)
x = np.random.uniform(-3, 3, size=(batch_size, 1, 2))
t0 = torch.from_numpy(np.array([t]).T).float()
x0 = torch.from_numpy(x.reshape(batch_size, 2)).float()
tx = torch.cat([t0,x0], dim=1)

# Convert numpy to torch tensor
t = torch.from_numpy(t)
x = torch.from_numpy(x)

# Determine the value function for the samples of t and x
lqr_equation = LQR(H, M, D, C, R, sigma, T)
opt_control = lqr_equation.optimal_control(t ,x).float()
value_func = lqr_equation.value_function(t, x,).float()
 
# Input for FFN neural network (control function)
dim = [3,100,100,2] 
# Input for Net_DGM  neural network (value function)
value_dim_input = 2 
value_dim_hidden = 100

# Input for DGM_Net neural network (PDE)
dim_input = 3
dim_output = 1
num_layers = 3
num_neurons = 50
learning_rate = 0.001

# Initialize the control model, loss function, and Adam optimizer
control_model = FFN(sizes=dim)
# control_loss_fn = nn.MSELoss()
control_optimizer = optim.Adam(control_model.parameters(), lr=learning_rate)

# Initailize the value function model, loss function, and Adam optimizer
value_model = Net_DGM(value_dim_input, value_dim_hidden)
value_optimizer = optim.Adam(value_model.parameters(), lr=learning_rate)

alpha_pred = control_model(tx)
print(alpha_pred)

alpha_pred = alpha_pred.unsqueeze(1).reshape(batch_size,2,1).clone().detach()

print(alpha_pred.type())
net = DGM_Net(dim_input, dim_output, num_layers, num_neurons)
Bellman = Bellman_pde(net, x_range, y_range, H, M, C, D, R, T, sigma, alpha_pred)
train = Train(net, Bellman, BATCH_SIZE=batch_size)
train.train(epoch=n_epochs, lr=learning_rate)

tensor([[-0.1508,  0.0705],
        [-0.1754,  0.1189]], grad_fn=<AddmmBackward0>)
torch.FloatTensor
epoch 99 - lr 0.001 - loss: 19.595554146766663
epoch 199 - lr 0.001 - loss: 4.105608345270157
epoch 299 - lr 0.001 - loss: 3.2331372701376675
epoch 399 - lr 0.001 - loss: 2.1250549018383027
epoch 499 - lr 0.001 - loss: 2.038394311517477
epoch 599 - lr 0.001 - loss: 1.4185540039278566
epoch 699 - lr 0.001 - loss: 1.2136523062363267
epoch 799 - lr 0.001 - loss: 0.7073227990511805
epoch 899 - lr 0.001 - loss: 0.6233518862724304
epoch 999 - lr 0.001 - loss: 1.308500952757895
