In [1]:
import torch

import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib as mpl
import matplotlib.pyplot as plt
import math

%matplotlib inline

In [2]:
class Block(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(Block, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x
    
class OddProjBlock(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(OddProjBlock, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim + input_dim, output_dim, bias = False)
        
    def forward(self, x):
        x = torch.cat([F.relu(self.fc1(x)) - F.relu(self.fc1(-x)), x], dim = 1)
        x = self.fc2(x)
        return x
    
class Symmetric(nn.Module):
    def __init__(self, input_dim, hidden_dim, symmetric_dim, output_dim):
        super(Symmetric, self).__init__()
        
        self.phi = Block(input_dim, hidden_dim, symmetric_dim)
        self.rho = Block(symmetric_dim, hidden_dim, output_dim)
    
    
    def forward(self, x):        
        batch_size, input_set_dim, input_dim = x.shape
        
        x = x.view(-1, input_dim)
        z = self.phi(x)
        z = z.view(batch_size, input_set_dim, -1)
        z = torch.mean(z, 1)
        return self.rho(z)

In [3]:
class SlaterDeterminant(nn.Module):
    def __init__(self, n, input_dim, hidden_dim):
        super(SlaterDeterminant, self).__init__()
        self.orbitals = Block(input_dim, hidden_dim, n)
        
        self.input_dim = input_dim
        self.n = n
    
    def forward(self, x):
        x = x.view(-1, self.input_dim)
        sd = self.orbitals(x)
        sd = sd.view(-1, n, n)
        return torch.det(sd)

In [4]:
class MultiSlaterDeterminant(nn.Module):
    def __init__(self, n, input_dim, hidden_dim, anti_dim):
        super(MultiSlaterDeterminant, self).__init__()
        self.orbitals = nn.ModuleList([Block(input_dim, hidden_dim, n) for _ in range(anti_dim)])
        
        self.input_dim = input_dim
        self.n = n
        
    def forward(self,x):        
        #x = x.view(-1, self.input_dim)
        #sds = [f(x).view(-1, self.n, self.n) for f in self.orbitals]
        sds = [f(x) for f in self.orbitals]
        sds = torch.stack(sds,1)
        sds = torch.det(sds)
        return torch.sum(sds, dim = 1)

In [5]:
class AntiNet(nn.Module):
    def __init__(self, n, input_dim, hidden_dim, anti_dim):
        super(AntiNet, self).__init__()
        self.orbitals = nn.ModuleList([Block(input_dim, hidden_dim, n) for _ in range(anti_dim)])
        self.g = OddProjBlock(anti_dim, hidden_dim, 1)
        
        self.input_dim = input_dim
        self.n = n
        
    def forward(self,x):
        sds = [f(x) for f in self.orbitals]
        sds = torch.stack(sds,1)
        sds = torch.det(sds)
        return torch.flatten(self.g(sds))


In [6]:
class DeepAntiNet(nn.Module):
    def __init__(self, n, input_dim, hidden_dim, anti_dim):
        super(DeepAntiNet, self).__init__()
        self.orbitals = nn.ModuleList([Block(input_dim, hidden_dim, n) for _ in range(anti_dim)])
        self.g1 = OddProjBlock(anti_dim, hidden_dim, hidden_dim)
        self.g2 = OddProjBlock(hidden_dim, hidden_dim, 1)
        
        self.input_dim = input_dim
        self.n = n
        
    def forward(self,x):
        sds = [f(x) for f in self.orbitals]
        sds = torch.stack(sds,1)
        sds = torch.det(sds)
        sds = self.g1(sds)
        return torch.flatten(self.g2(sds))


In [None]:
class MultiBackflow(nn.Module):
    def __init__(self, n, input_dim, hidden_dim, anti_dim):
        super(MultiBackflow, self).__init__()
        self.sym = Symmetric(input_dim, hidden_dim, hidden_dim, hidden_dim)
        self.push = Block(input_dim + hidden_dim, hidden_dim, input_dim)
        self.orbitals = nn.ModuleList([Block(input_dim, hidden_dim, n) for _ in range(anti_dim)])
        
        self.input_dim = input_dim
        self.n = n
        
    def forward(self,x):
        batch_dim, set_dim, input_dim = x.shape
        sym_feature = self.sym(x).unsqueeze(1).repeat(1, set_dim, 1)
        z = torch.cat([x, sym_feature], 2)
        z = self.push(z)
        
        sds = [f(z) for f in self.orbitals]
        sds = torch.stack(sds,1)
        sds = torch.det(sds)
        return torch.sum(sds, dim = 1)
    
class DeepMultiBackflow(nn.Module):
    def __init__(self, n, input_dim, hidden_dim, anti_dim):
        super(DeepMultiBackflow, self).__init__()
        self.sym = Symmetric(input_dim, hidden_dim, hidden_dim, hidden_dim)
        self.push = Block(input_dim + hidden_dim, hidden_dim, input_dim)
        self.orbitals = nn.ModuleList([Block(input_dim, hidden_dim, n) for _ in range(anti_dim)])
        self.g1 = OddProjBlock(anti_dim, hidden_dim, hidden_dim)
        self.g2 = OddProjBlock(hidden_dim, hidden_dim, 1)        
        
        self.input_dim = input_dim
        self.n = n
        
    def forward(self,x):
        batch_dim, set_dim, input_dim = x.shape
        sym_feature = self.sym(x).unsqueeze(1).repeat(1, set_dim, 1)
        z = torch.cat([x, sym_feature], 2)
        z = self.push(z)
        
        sds = [f(z) for f in self.orbitals]
        sds = torch.stack(sds,1)
        sds = torch.det(sds)
        sds = self.g1(sds)
        return torch.flatten(self.g2(sds))

In [None]:
#Validate batching

n = 5
d = 3
hidden_dim = 20

x = 10 * torch.normal(mean = 0, std = 1, size = (2, n, d))

x0 = x[:1]
x1 = x[1:]

SD = DeepMultiBackflow(n, d, hidden_dim, 4)
#SD = AntiNet(n, d, hidden_dim, 4)
print(SD(x))
print(SD(x0))
print(SD(x1))

In [None]:
#Validate antisymmetry



x = 10 * torch.normal(mean = 0, std = 1, size = (n, d))
P = torch.eye(n)
P[0,0] = P[1,1] = 0
P[0,1] = P[1,0] = 1
x_ = torch.mm(P, x)
x = torch.unsqueeze(x, 0)
x_ = torch.unsqueeze(x_, 0)

SD = MultiBackflow(n, d, hidden_dim, 3)
y = SD(x)
y_ = SD(x_)

ANN = AntiNet(n, d, hidden_dim, 3)
y = ANN(x)
y_ = ANN(x_)
print(y)
print(y_)

In [None]:
def train(model, x, y, iterations, lr=0.005):
    model.train()
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)

    losses = []
    for i in range(iterations):
        outputs = model(x)

        optimizer.zero_grad()
        loss = criterion(outputs, y)
        loss.backward()
                
        optimizer.step()

        losses.append(loss.item())
    
    model.eval()
    return losses

In [None]:
n = 5
d = 3
hidden_dim = 15
anti_dim = 5

iterations = 10000
samples = 4000

In [None]:
teacher = MultiSlaterDeterminant(n, d, hidden_dim, 200)
train_x = 5 * torch.normal(mean = 0, std = 1, size = (samples, n, d))
train_y = teacher(train_x).detach()

In [None]:
for _ in range(1):
    student = MultiSlaterDeterminant(n, d, hidden_dim, anti_dim)
    losses = train(student, train_x, train_y, iterations, lr = 0.0025)
    print(losses[::50])
    print(min(losses))

In [None]:
for _ in range(1):
    student = AntiNet(n, d, hidden_dim, anti_dim)
    losses = train(student, train_x, train_y, iterations, lr = 0.0025)
    print(losses[::50])
    print(min(losses))

In [None]:
for _ in range(1):
    student = DeepAntiNet(n, d, hidden_dim, anti_dim)
    losses = train(student, train_x, train_y, iterations, lr = 0.0025)
    print(losses[::50])
    print(min(losses))

In [None]:
for _ in range(1):
    student = MultiBackflow(n, d, hidden_dim, anti_dim)
    losses = train(student, train_x, train_y, 2*iterations, lr = 0.0025)
    print(losses[::50])
    print(min(losses))

In [None]:
for _ in range(1):
    student = DeepMultiBackflow(n, d, hidden_dim, anti_dim)
    losses = train(student, train_x, train_y, 2*iterations, lr = 0.0025)
    print(losses[::50])
    print(min(losses))

In [None]:
a = np.array([6.588473796844482, 6.398560047149658, 7.056000232696533])
b = np.array([6.899078845977783, 5.879907608032227, 5.7301530838012695])
c = np.array([4.987086296081543, 4.876344203948975, 4.408130645751953])

x_pos = np.arange(3)
names = ["Default", "One Extra Layer", "Two Extra Layers"]
means = [np.mean(a), np.mean(b), np.mean(c)]
stds = [np.std(a), np.std(b), np.std(c)]


fig, ax = plt.subplots()
ax.bar(x_pos, means, yerr=stds, align='center', alpha=0.5, ecolor='black', capsize=10)
ax.set_ylabel('Mean Squared Error')
ax.set_xticks(x_pos)
ax.set_xticklabels(names)
ax.yaxis.grid(True)

# Save the figure and show
plt.tight_layout()
plt.savefig('bar_plot_with_error_bars.png')
plt.show()