In [1]:
from __future__ import division
import os, sys, time, random
import math
import scipy
from scipy import constants
import torch
from torch import nn, optim
from torch import autograd
from torch.autograd import grad
import autograd.numpy as np
from torch.utils.data import Dataset, DataLoader
from torch.autograd.variable import Variable
from torchvision import transforms, datasets
import matplotlib.pyplot as plt
from torch.nn import functional as F
from scipy.constants import pi

In [3]:
class Potential(nn.Module):
    def __init__(self):
        super(Potential,self).__init__()
        self.hidden0 = nn.Sequential(
            nn.Linear(2,128),
            nn.Tanh()
        )
#         self.hidden1 = nn.Sequential(
#             nn.Linear(32,128),
#             nn.Tanh()
#         )
        self.hidden1 = nn.Sequential(
            nn.Linear(128,128),
            nn.Tanh()
        )
        self.out = nn.Sequential(
            nn.Linear(128,1),
            nn.Sigmoid()
        )
        
    def forward(self, x):
        x = self.hidden0(x)
        x = x + self.hidden1(x)
      #  x = x + self.hidden2(x)
        x = self.out(x)
        return x

In [4]:
def hermite(n,x):
    if n==0:
        return 1
    elif n==1:
        return 2*x
    else:
        return 2*x*hermite(n-1,x)-2*(n-1)*hermite(n-2,x) #recursion

def harmonic(m,h,w,n,x):
    #Normalization:
    norm=((m*w)/(math.pi*h))**(1/4)
    term1=(math.factorial(n))*(2**n)
    term2=(hermite(n,x)/math.sqrt(term1))
    expterms=(-1.0*m*w*x*x)/(2*h)
    #print(norm*term2,expterms,x)
    evalh=norm*term2*torch.exp(expterms)
    
    #print(norm,term1,term2,evalh)
    return evalh 
def init_wave_function(x,y): 
    return harmonic(1,1,1,2,x)*harmonic(1,1,1,2,y)


In [5]:
potential = Potential()
optimizer = torch.optim.Adam(potential.parameters(), lr = .001)



def conservation_energy(batch):
    batch.requires_grad_(True)
    x_coord = batch[:,0] 
    x_coord.requires_grad_(True)
    y_coord = batch[:,1] 
    y_coord.requires_grad_(True)
    output = init_wave_function(x_coord,y_coord)
    output.requires_grad_(True)
    potential_energy = potential(batch).squeeze()
  #  print(potential_energy.shape)
    potential_energy.requires_grad_(True)
    #potential_energy = .5*(x_coord**2 + y_coord**2).squeeze()
   # print(potential_energy)
    dHdx = grad(output, x_coord, grad_outputs = torch.ones_like(x_coord), 
                    create_graph=True, retain_graph=True, 
                   only_inputs=True,
                   allow_unused=True
                  )[0]
    d2Hdx2 = grad(dHdx, x_coord, grad_outputs = torch.ones_like(x_coord), 
                    create_graph=True, retain_graph=True, 
                   only_inputs=True,
                   allow_unused=True
                  )[0]
    dHdy = grad(output, y_coord, grad_outputs = torch.ones_like(y_coord), 
                    create_graph=True, retain_graph=True, 
                   only_inputs=True,
                   allow_unused=True
                  )[0]
    d2Hdy2 = grad(dHdy, y_coord, grad_outputs = torch.ones_like(y_coord), 
                    create_graph=True, retain_graph=True, 
                   only_inputs=True,
                   allow_unused=True
                  )[0]
    
    
    kinetic_energy = d2Hdx2 + d2Hdy2
   # print(kinetic_energy.shape)
    conserve_energy = kinetic_energy/(2*output) - potential_energy
    
    return conserve_energy
      

In [8]:
h = .01

def taylor_approx_x(batch): 
    batch.requires_grad_(True)
    x_coord = batch[:,0] 
    x_coord.requires_grad_(True)
    x_coord1 = x_coord + h
    x_coord2 = x_coord - h
    x1_coord1 = torch.unsqueeze(x_coord1,1)
    x2_coord2 = torch.unsqueeze(x_coord2,1)
    
    y_coord = batch[:,1] 
    y_coord.requires_grad_(True)
    y_coord = torch.unsqueeze(y_coord,1)
    batch_forward = torch.cat([x1_coord1,y_coord],1)
    batch_back = torch.cat([x2_coord2,y_coord],1)
    
    partial_x = (conservation_energy(batch_forward) - conservation_energy(batch_back))/(2*h)
    return partial_x


def taylor_approx_y(batch): 
    batch.requires_grad_(True)
    x_coord = batch[:,0] 
    x_coord.requires_grad_(True)
    x_coord = torch.unsqueeze(x_coord,1)
    
   # x1_coord = torch.unsqueeze(x1_coord,1)
    y_coord = batch[:,1] 
    y_coord.requires_grad_(True)
    y1 = y_coord + h
    y2 = y_coord - h
    y1_coord = torch.unsqueeze(y1,1)
    y2_coord = torch.unsqueeze(y2,1)
    batch_forward = torch.cat([x_coord,y1_coord],1)
    batch_back = torch.cat([x_coord,y2_coord],1)
    
    partial_y = (conservation_energy(batch_forward) - conservation_energy(batch_back))/(2*h)
    return partial_y
    



In [18]:
data = torch.rand(5000,2)
dataset = MyDataset(data)
loader = DataLoader(dataset, batch_size = 32, shuffle = True)

In [None]:
num_epochs = 2000
loss = []
#x = torch.tensor([0.0,0.0])
for epoch in range(num_epochs):
    for n_batch, batch in enumerate(loader):
        n_data = Variable(batch, requires_grad=True)

        optimizer.zero_grad()
   
       
        error = (taylor_approx_x(n_data)**2 + taylor_approx_y(n_data)**2).mean()
        
        
 
        error.backward(retain_graph=True)
    
       
        optimizer.step()
    loss.append(error)
    

In [6]:
x = torch.rand(100,2)
p = potential(x)

In [7]:
p1 = conservation_energy(x)

In [11]:
#RMSE between ground and learned energies
torch.mean((p1+5)**2)

tensor(1.3127e-05, grad_fn=<MeanBackward0>)

In [13]:
x_coord = x[:,0]
y_coord = x[:,1]
ground = .5*(x_coord**2 + y_coord**2)

In [16]:
#RMSE between ground and learned potentials
torch.mean((ground - potential(x).squeeze())**2)

tensor(1.3131e-05, grad_fn=<MeanBackward0>)

In [19]:
np.sqrt(1.3127e-05)

0.003623120202256613

The rest of the notebook can be ignored as it is used to generate the 2d energy plot.

In [None]:
x_coord = sample_x(4000,2)
learned_energy1 = -conserve_energy(x_coord).detach().numpy()
learned_energy1[3000],x_coord.detach().numpy()[3000]

In [21]:
pip install plotly

Collecting plotly
  Downloading plotly-4.8.1-py2.py3-none-any.whl (11.5 MB)
[K     |████████████████████████████████| 11.5 MB 3.8 MB/s eta 0:00:01
Collecting retrying>=1.3.3
  Downloading retrying-1.3.3.tar.gz (10 kB)
Building wheels for collected packages: retrying
  Building wheel for retrying (setup.py) ... [?25ldone
[?25h  Created wheel for retrying: filename=retrying-1.3.3-py3-none-any.whl size=11430 sha256=3a920103c2786fb7ef615811ccd13459f7f8979e121845361d6c09780d7b06b5
  Stored in directory: /Users/arijitsehanobish/Library/Caches/pip/wheels/f9/8d/8d/f6af3f7f9eea3553bc2fe6d53e4b287dad18b06a861ac56ddf
Successfully built retrying
Installing collected packages: retrying, plotly
Successfully installed plotly-4.8.1 retrying-1.3.3
Note: you may need to restart the kernel to use updated packages.


In [47]:
plot(fig, filename='./2dsystem.html')

'./2dsystem.html'

In [66]:
import numpy as np
from itertools import product
import plotly.graph_objs as go
import plotly.offline as py
py.init_notebook_mode(connected=True)
# Gen data
import plotly.graph_objects as go
from plotly.offline import plot
# x=X
# y=Y
Z=(torch.ones(50,50)*5).tolist()
trace1 = go.Surface(
    contours = {
        "x": {"show": True, "start": 0, "end": 1, "size": 0.1, "color":"white"},
        "y": {"show": True, "start": 0, "end": 1, "size": 0.1, "color":"white"},
    },
x = [0.0000, 0.0204, 0.0408, 0.0612, 0.0816, 0.1020, 0.1224, 0.1429, 0.1633,
         0.1837, 0.2041, 0.2245, 0.2449, 0.2653, 0.2857, 0.3061, 0.3265, 0.3469,
         0.3673, 0.3878, 0.4082, 0.4286, 0.4490, 0.4694, 0.4898, 0.5102, 0.5306,
         0.5510, 0.5714, 0.5918, 0.6122, 0.6327, 0.6531, 0.6735, 0.6939, 0.7143,
         0.7347, 0.7551, 0.7755, 0.7959, 0.8163, 0.8367, 0.8571, 0.8776, 0.8980,
         0.9184, 0.9388, 0.9592, 0.9796, 1.0000],
y = [0.0000, 0.0204, 0.0408, 0.0612, 0.0816, 0.1020, 0.1224, 0.1429, 0.1633,
         0.1837, 0.2041, 0.2245, 0.2449, 0.2653, 0.2857, 0.3061, 0.3265, 0.3469,
         0.3673, 0.3878, 0.4082, 0.4286, 0.4490, 0.4694, 0.4898, 0.5102, 0.5306,
         0.5510, 0.5714, 0.5918, 0.6122, 0.6327, 0.6531, 0.6735, 0.6939, 0.7143,
         0.7347, 0.7551, 0.7755, 0.7959, 0.8163, 0.8367, 0.8571, 0.8776, 0.8980,
         0.9184, 0.9388, 0.9592, 0.9796, 1.0000],
z = Z,colorscale='YlGnBu',showscale=False)
trace2 = go.Scatter3d(
        x = x_coord.tolist(),
        y = y_coord.tolist(),
        z = (-p1).tolist(),
        mode="markers",
        marker=dict(
            opacity=.99
        )
    )
traces=[trace1,trace2]
# Plot
fig = go.Figure(data=traces)
plot(fig, filename='./2d_system.html')



'./2d_system.html'