In [1]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import Phi

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [5]:
class DeepLandau():
    def __init__(self, x, gamma, dt):
        
        self.dt = torch.tensor(dt).float().to(device) # JKO time step
        self.gamma = torch.tensor(gamma).float().to(device) # 

        self.x = x.to(device) # velocity + time
        self.nex = x.shape[0] # number of particles
        self.d = x.shape[1]-1 # dimension

        # neural network
        self.net = Phi.Phi(nTh=3, m=20, d=self.d).to(device)

        # optimizer Adam
        self.optimizer = torch.optim.Adam(self.net.parameters(), lr=1e-3)


    def Landau_loss(self, x):
        loss = 0
        u, grad_u = self.net.Grad_Hess(x, justGrad=False)

        batch_size = x.shape[0]
        v = x[:,:-1]

        for i in range(batch_size):
            v_diff = v[i,:]-v[:,:] # 2d-matrix
            u_diff = u[i,:]-u # 2d-matrix
            norm = torch.norm(v_diff, dim=1, keepdim=True)
            proj = self.proj(v_diff, norm) # 3d-tensor

            # transportation loss
            temp = (norm**self.gamma) * torch.squeeze(torch.matmul(u_diff[:,None,:], torch.matmul(proj, u_diff[:,:,None])), dim=1)
            transport_cost = torch.sum(temp) / (2*self.nex)

            # entropy loss
            temp1 = torch.sum(torch.sum(proj * grad_u[i,:,:], dim=1), dim=1)
            temp2 = torch.sum(norm**2 * v_diff * u_diff, dim=1)
            logdet = torch.dot(torch.squeeze(norm, dim=1)**self.gamma, (temp1 - temp2)) / (self.nex)

            loss = loss + (transport_cost - 2*self.dt * logdet) / batch_size
        return loss


    def proj(self, v_diff, norm):
        v_temp = v_diff[:,None,:]
        proj = (norm**2)[:,:,None] * torch.eye(self.d).to(device) - v_temp.permute(0,2,1) * v_temp
        return proj


    # def train(self, batch_size, epoch):
    #     for epoch_idx in range(epoch):
    #         print("Epoch_idx: ", epoch_idx)
    #         loader = self.batch(self.x, batch_size)

    #         for batch_idx, batch_x in enumerate(loader):
    #             iter = 0
    #             print(batch_idx)
    #             while iter <= 1000:
    #                 self.optimizer.zero_grad()
    #                 loss = self.Landau_loss(batch_x[0])
    #                 if iter % 50 == 0:
    #                     print('Batch_idx %d, Iter %d, Loss: %.5e' % (batch_idx, iter, loss.item()))
    #                 loss.backward()
    #                 self.optimizer.step()
    #                 iter += 1

    def train(self):
        iter = 0
        v = self.x[:,:-1]
        while iter <= 1000:
            u = self.net.Grad_Hess(self.x, justGrad=True)
            loss_score = torch.norm(u + v)**2 / self.nex
            print('Iter %d, Score Loss: %.5e' % (iter, loss_score.item()))


            self.optimizer.zero_grad()
            loss = self.Landau_loss(self.x)
            print('Iter %d, Loss: %.5e' % (iter, loss.item()))
            loss.backward()
            self.optimizer.step()

            iter += 1

    
    def batch(self, x, batch_size):
        dataset = torch.utils.data.TensorDataset(x)
        loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True)
        return loader


    def compute_det(self, v):
        logdet = torch.zeros(self.nex)
        u, grad_u = self.net.Grad_Hess(v, justGrad=False)

        for i in range(self.nex):
            v_diff = v[i,:]-v 
            u_diff = u[i,:]-u 
            norm = torch.norm(v_diff, dim=1, keepdim=True)
            proj = self.proj(v_diff, norm) 
            
            temp1 = torch.sum(torch.sum(proj * grad_u[i,:,:], dim=1), dim=1)
            temp2 = torch.sum(norm**2 * v_diff * u_diff, dim=1)
            logdet[i] = torch.dot(torch.squeeze(norm, dim=1)**self.gamma, (temp1 - temp2)) / self.nex
        return torch.exp(logdet)


    def compute_v(self, v):
        v_new = torch.zeros((self.N, self.d))
        u = self.net.Grad_Hess(v, justGrad=True)

        for i in range(self.nex):
            v_diff = v[i,:]-v
            norm = torch.norm(v_diff, dim=1, keepdim=True)
            proj_i = self.proj(v_diff, norm)

            temp = (norm**self.gamma) * torch.squeeze(torch.matmul((u[i,:]-u)[:,None,:], proj_i), dim=1)
            v_new[i,:] = v[i,:] - torch.sum(temp, dim=0) / self.nex
        return v_new


    def compute_f(self,v,f):
        det = self.compute_det(self, v)
        f_new = f / det
        return f_new

    
    def train_score(self):
        iter = 0
        v = self.x[:,:-1]
        
        while iter <= 10000:
            self.optimizer.zero_grad()
            u = self.net.Grad_Hess(self.x, justGrad=True)
            loss = torch.norm(u + v)**2 / self.nex
            loss.backward()

            if iter % 100 == 0:
                print('Iter %d, Loss: %.5e' % (iter, loss.item()))
            
            self.optimizer.step()
            iter += 1

In [6]:
nex = 100
d = 2

x = torch.zeros(nex,d+1)
v = torch.randn(nex,d)
x[:,:d] = v

model = DeepLandau(x, gamma=0, dt=1e-3)
model.train_score()

Iter 0, Loss: 8.35033e+00
Iter 100, Loss: 2.93457e-02
Iter 200, Loss: 1.90268e-02
Iter 300, Loss: 1.21458e-02
Iter 400, Loss: 8.00944e-03
Iter 500, Loss: 5.64439e-03
Iter 600, Loss: 4.23359e-03
Iter 700, Loss: 3.30305e-03
Iter 800, Loss: 2.62168e-03
Iter 900, Loss: 2.09120e-03
Iter 1000, Loss: 1.66914e-03
Iter 1100, Loss: 1.33187e-03
Iter 1200, Loss: 1.06284e-03
Iter 1300, Loss: 8.49368e-04
Iter 1400, Loss: 6.81364e-04
Iter 1500, Loss: 5.50592e-04
Iter 1600, Loss: 4.50135e-04
Iter 1700, Loss: 3.74054e-04
Iter 1800, Loss: 3.17192e-04
Iter 1900, Loss: 2.75073e-04
Iter 2000, Loss: 2.43911e-04
Iter 2100, Loss: 2.20610e-04
Iter 2200, Loss: 2.02770e-04
Iter 2300, Loss: 1.88625e-04
Iter 2400, Loss: 1.76950e-04
Iter 2500, Loss: 1.66933e-04
Iter 2600, Loss: 1.58065e-04
Iter 2700, Loss: 1.50030e-04
Iter 2800, Loss: 1.42635e-04
Iter 2900, Loss: 1.35755e-04
Iter 3000, Loss: 1.29308e-04
Iter 3100, Loss: 1.23234e-04
Iter 3200, Loss: 1.17487e-04
Iter 3300, Loss: 1.12034e-04
Iter 3400, Loss: 1.06848e-

In [7]:
model.train()

Iter 0, Score Loss: 3.91562e-06
Iter 0, Loss: -4.70257e-02
Iter 1, Score Loss: 3.40750e-03
Iter 1, Loss: -4.88087e-02
Iter 2, Score Loss: 2.26070e-02
Iter 2, Loss: -5.01600e-02
Iter 3, Score Loss: 8.76243e-02
Iter 3, Loss: -2.16747e-02
Iter 4, Score Loss: 9.22067e-02
Iter 4, Loss: -4.82626e-02
Iter 5, Score Loss: 1.27343e-01
Iter 5, Loss: -3.44796e-02
Iter 6, Score Loss: 1.89890e-01
Iter 6, Loss: -6.08837e-02
Iter 7, Score Loss: 2.88950e-01
Iter 7, Loss: -4.16734e-02
Iter 8, Score Loss: 3.52405e-01
Iter 8, Loss: -6.02099e-02
Iter 9, Score Loss: 4.04511e-01
Iter 9, Loss: -6.12354e-02
Iter 10, Score Loss: 4.91038e-01
Iter 10, Loss: -5.48132e-02
Iter 11, Score Loss: 6.28527e-01
Iter 11, Loss: -7.14050e-02
Iter 12, Score Loss: 8.10497e-01
Iter 12, Loss: -6.61798e-02
Iter 13, Score Loss: 9.71954e-01
Iter 13, Loss: -6.79173e-02
Iter 14, Score Loss: 1.10237e+00
Iter 14, Loss: -7.92622e-02
Iter 15, Score Loss: 1.26485e+00
Iter 15, Loss: -7.37631e-02
Iter 16, Score Loss: 1.49070e+00
Iter 16, Lo

KeyboardInterrupt: 