In [10]:
import numpy as np
import torch
import torch.nn as nn
import time 
import logging
import torch.optim as optim
import os
from scipy.stats import multivariate_normal as normal
import torch.nn.functional as F
from torch.nn import Parameter

In [11]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device".format(device))

data_type=torch.float32
MOMENTUM = 0.99
EPSILON = 1e-6

Using cpu device


In [12]:
class Config(object):
    n_layer = 4
    batch_size = 1024
    valid_size = 1024
    
    dim=100; 
    Ntime=10; 
    delta=0.1/Ntime
    sqrt_deltaT=np.sqrt(0.1/Ntime); 

    logging_frequency = 100
    verbose = True
    y_init_range = [0, 1]
    
    num_hiddens = [dim,256,256,dim]
    
def get_config(name):
    try:
        return globals()[name]
    except KeyError:
        raise KeyError("config not defined.")

In [13]:
cfg=get_config('Config')

In [14]:
"""
This function defines one block for the nn 
"""
class Dense(nn.Module): 
    def __init__(self,cin, cout, batch_norm=False, activate=True): 
        super(Dense,self).__init__()
        self.cin=cin; 
        self.cout=cout; 
        self.activate=activate; 
        
        self.linear=nn.Linear(self.cin,self.cout) #The linear layer
        #BatchNorm1d: it requires the input to be a correct size
        if batch_norm: 
            self.bn=nn.BatchNorm1d(cout,eps=EPSILON,momentum=MOMENTUM)
        else: 
            self.bn=None
       # nn.init.normal_(self.linear.weight,std=5.0/np.sqrt(cin+cout))
        # This is the He initialization
        
    def forward(self,x): 
        x=self.linear(x)
        if self.bn is not None:
            x=self.bn(x)
        if self.activate:
            x=torch.tanh(x)
        return x 

In [15]:
"""
Constructing the Policy control

The control should take as input (X')=(t,X) as the input, and so the input dimension 
should be 1+dim
"""
class controlNN(nn.Module):
    def __init__(self, config):
        super(controlNN,self).__init__()
        self.config=config
        
        self.bn=nn.BatchNorm1d(config.num_hiddens[0],eps=EPSILON,momentum=MOMENTUM)
        # range(1,5): 1,2,3,4
        self.layers=[Dense(config.num_hiddens[i-1],config.num_hiddens[i]) for i in range(1, len(config.num_hiddens)-1)]
        self.layers+=[Dense(config.num_hiddens[-2], config.num_hiddens[-1],activate=False)]
        self.layers=nn.Sequential(*self.layers)
    
    def forward(self,x):
        x=self.bn(x)
        x=self.layers(x)
        return x    

In [16]:
class NeuralNet(nn.Module):
    def __init__(self,config):
        super(NeuralNet, self).__init__()
        self.config=config
        self.batch_size=self.config.batch_size
        self.dim=self.config.dim
        self.Ntime=self.config.Ntime
        self.delta=self.config.delta
        self.sqrt_deltaT =self.config.sqrt_deltaT
        
        ## We make the change here 
        self.mList=nn.ModuleList([controlNN(config) for _ in range(self.Ntime)])    # controlNN(self.config)
        
        self.time_stamp=torch.ones([self.batch_size,1, self.Ntime+1])*torch.arange(0,self.Ntime+1)*self.delta
        
        # x should have the size (batch_size,_)
    def forwardX(self,x): 
        xMat=torch.zeros([self.batch_size, self.dim, self.Ntime+1])
       # xcatMat=torch.zeros([self.batch_size, self.dim+1, self.Ntime+1])
        
        wMat=torch.FloatTensor(normal.rvs(size=[self.batch_size,self.dim,self.Ntime])*self.sqrt_deltaT)
        wMat=torch.reshape(wMat,(self.batch_size,self.dim,self.Ntime)) # Reshaping is needed when dim==1 
        
        xinit=torch.clone(x)
        xinit=torch.reshape(xinit,[self.batch_size, self.dim])
        
        xMat[:,:,0]=xinit 
        
       # xcat=torch.cat((self.time_stamp[:,:,0], xMat[:,:,0]),1)
        for i in range(0, self.Ntime):
          #  xcat=torch.cat((self.time_stamp[:,:,i], xMat[:,:,i]),1)
        #    xcatMat[:,:,i]= xcat

            control_temp=self.mList[i](xMat[:,:,i]); 

            xMat[:,:,i+1]=xMat[:,:,i]+(-0.25*xMat[:,:,i]+control_temp)*self.delta \
            + (0.2*xMat[:,:,i]+control_temp)*wMat[:,:,i]  
        
        return xMat, wMat   
    
    def backwardYZ(self,xMat,wMat): 
        yMat=torch.zeros([self.batch_size, self.dim, self.Ntime+1]); 
        zMat=torch.zeros([self.batch_size, self.dim, self.Ntime]);
        
       # temp_sum=-torch.sum(xMat[:,:,-1],dim=1,keepdim=True); 
        yMat[:,:,-1]=-xMat[:,:,-1]    #temp_sum.repeat(1,self.dim)
        
        for i in range(self.Ntime-1, -1, -1):
            zMat[:,:,i]= wMat[:,:,i]*yMat[:,:,i+1]/self.delta
            # we will have to do the differentiation and everything
            
            X=xMat[:,:,i];
          #  xcat=torch.cat((self.time_stamp[:,:,i],X),1)   
            # y*b+σ*z+f
            
            ctrl=self.mList[i](X)
            
            hami=torch.sum(yMat[:,:,i+1].detach()*ctrl  
                           +zMat[:,:,i].detach()*ctrl - ctrl*ctrl,
                           dim=1, keepdim=True)
            
            hami_x=torch.autograd.grad(outputs=[hami], inputs=[X], grad_outputs=torch.ones_like(hami), allow_unused=True,
                                 retain_graph=True, create_graph=True)[0]
            
            yMat[:,:,i]= yMat[:,:,i+1] + (-0.5*xMat[:,:,i]-0.25*yMat[:,:,i+1]+ 0.2*zMat[:,:,i]+hami_x)* self.delta 
            
        return yMat, zMat
    
    """
    Here is the issue: even though we can get something that is relatively close. 
    The maximization in Hamiltonian requries that one does the the update for each time, and 
    so for each time block, the controls are independent. 
    However, here we choose to update all the control all together by using the parameters 
    in the net control. 
    
    This should be prevented. 
    
    """
        
    def Hamcompute(self,xMat,yMat,zMat):
        ham=0.0
        for i in range(0,self.Ntime):
         #   xcat=torch.cat((self.time_stamp[:,:,i],xMat[:,:,i]),1); 
            X=xMat[:,:,i]
            temp=self.mList[i](X.detach()); 
    
            ham+=torch.mean(torch.sum(yMat[:,:,i].detach()* temp +zMat[:,:,i].detach()*temp-temp*temp, dim=1, keepdim=True))  
        return -ham
    
    def computeLoss(self,xMat):
        loss=0.0 
        for i in range(0,self.Ntime):
            #xcat=torch.cat((self.time_stamp[:,:,i],xMat[:,:,i]),1)
            X=xMat[:,:,i]
            tempctrl=self.mList[i](X)
            
            loss+=torch.sum(0.25*torch.square(X)*self.delta + torch.square(tempctrl),dim=1,keepdim=True)  
        
        #0.5*torch.square(xMat[:,:,-1])
        
        sum_temp=torch.sum(xMat[:,:,-1], dim=1, keepdim=True)
        
        return torch.mean(loss+0.5*torch.square(sum_temp))  

In [19]:
def train(cfg):
    
    model=NeuralNet(cfg);
    x0=torch.ones([cfg.batch_size,cfg.dim])
    
    epoch=1000; 
    
    optimizer=optim.Adam(model.parameters(),lr=1e-4)
    
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[500], gamma=0.2) 
    
    
    for i in range(0, epoch):
        optimizer.zero_grad()
        
        xmat,wmat=model.forwardX(x0)
        ymat,zmat=model.backwardYZ(xmat,wmat)
        # compute the Hamiltonian
        ham=model.Hamcompute(xmat,ymat,zmat)
        
        ham.backward()
       # torch.nn.utils.clip_grad_norm_(model.parameters(),6)
        
        optimizer.step()
        
        if i % 10 ==0:           
            print("Iter:", i, torch.mean(ymat[:,:,0]).item(), model.computeLoss(xmat).item())

In [None]:
train(cfg)

Iter: 0 -1.004112720489502 4778.1923828125
Iter: 10 -1.0025595426559448 4767.36279296875
Iter: 20 -1.0002959966659546 4750.10546875
Iter: 30 -1.0003409385681152 4746.36962890625
Iter: 40 -0.9981225728988647 4732.51904296875
Iter: 50 -0.9966267943382263 4721.23095703125
Iter: 60 -0.9936010837554932 4706.5830078125
Iter: 70 -0.993013858795166 4692.7685546875
Iter: 80 -0.9916049242019653 4684.3193359375
Iter: 90 -0.9886389374732971 4673.2021484375
Iter: 100 -0.9861000776290894 4658.08544921875
Iter: 110 -0.9836843609809875 4643.08349609375
Iter: 120 -0.9817636609077454 4634.248046875
Iter: 130 -0.9785168170928955 4621.12158203125
Iter: 140 -0.9766543507575989 4609.80859375
Iter: 150 -0.9736326336860657 4601.67724609375
Iter: 160 -0.9715568423271179 4590.04736328125
Iter: 170 -0.9695888757705688 4581.1533203125
Iter: 180 -0.9677788019180298 4576.05615234375
Iter: 190 -0.965873122215271 4571.2783203125
Iter: 200 -0.9647007584571838 4567.2333984375
Iter: 210 -0.9631655812263489 4560.84960937