In [23]:
import torch

import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib as mpl
import matplotlib.pyplot as plt
import math

%matplotlib inline

In [24]:
class Block(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(Block, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x
    
class OddProjBlock(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(OddProjBlock, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim + input_dim, output_dim, bias = False)
        
    def forward(self, x):
        x = torch.cat([F.relu(self.fc1(x)) - F.relu(self.fc1(-x)), x], dim = 1)
        x = self.fc2(x)
        return x
    
class Symmetric(nn.Module):
    def __init__(self, input_dim, hidden_dim, symmetric_dim, output_dim):
        super(Symmetric, self).__init__()
        
        self.phi = Block(input_dim, hidden_dim, symmetric_dim)
        self.rho = Block(symmetric_dim, hidden_dim, output_dim)
    
    
    def forward(self, x):        
        batch_size, input_set_dim, input_dim = x.shape
        
        x = x.view(-1, input_dim)
        z = self.phi(x)
        z = z.view(batch_size, input_set_dim, -1)
        z = torch.mean(z, 1)
        return self.rho(z)

In [25]:
class SlaterDeterminant(nn.Module):
    def __init__(self, n, input_dim, hidden_dim):
        super(SlaterDeterminant, self).__init__()
        self.orbitals = Block(input_dim, hidden_dim, n)
        
        self.input_dim = input_dim
        self.n = n
    
    def forward(self, x):
        x = x.view(-1, self.input_dim)
        sd = self.orbitals(x)
        sd = sd.view(-1, n, n)
        return torch.det(sd)

In [26]:
class MultiSlaterDeterminant(nn.Module):
    def __init__(self, n, input_dim, hidden_dim, anti_dim):
        super(MultiSlaterDeterminant, self).__init__()
        self.orbitals = nn.ModuleList([Block(input_dim, hidden_dim, n) for _ in range(anti_dim)])
        
        self.input_dim = input_dim
        self.n = n
        
    def forward(self,x):        
        #x = x.view(-1, self.input_dim)
        #sds = [f(x).view(-1, self.n, self.n) for f in self.orbitals]
        sds = [f(x) for f in self.orbitals]
        sds = torch.stack(sds,1)
        sds = torch.det(sds)
        return torch.sum(sds, dim = 1)

In [27]:
class AntiNet(nn.Module):
    def __init__(self, n, input_dim, hidden_dim, anti_dim):
        super(AntiNet, self).__init__()
        self.orbitals = nn.ModuleList([Block(input_dim, hidden_dim, n) for _ in range(anti_dim)])
        self.g = OddProjBlock(anti_dim, hidden_dim, 1)
        
        self.input_dim = input_dim
        self.n = n
        
    def forward(self,x):
        sds = [f(x) for f in self.orbitals]
        sds = torch.stack(sds,1)
        sds = torch.det(sds)
        return torch.flatten(self.g(sds))


In [28]:
class DeepAntiNet(nn.Module):
    def __init__(self, n, input_dim, hidden_dim, anti_dim):
        super(DeepAntiNet, self).__init__()
        self.orbitals = nn.ModuleList([Block(input_dim, hidden_dim, n) for _ in range(anti_dim)])
        self.g1 = OddProjBlock(anti_dim, hidden_dim, hidden_dim)
        self.g2 = OddProjBlock(hidden_dim, hidden_dim, 1)
        
        self.input_dim = input_dim
        self.n = n
        
    def forward(self,x):
        sds = [f(x) for f in self.orbitals]
        sds = torch.stack(sds,1)
        sds = torch.det(sds)
        sds = self.g1(sds)
        return torch.flatten(self.g2(sds))


In [29]:
class MultiBackflow(nn.Module):
    def __init__(self, n, input_dim, hidden_dim, anti_dim):
        super(MultiBackflow, self).__init__()
        self.sym = Symmetric(input_dim, hidden_dim, hidden_dim, hidden_dim)
        self.push = Block(input_dim + hidden_dim, hidden_dim, input_dim)
        self.orbitals = nn.ModuleList([Block(input_dim, hidden_dim, n) for _ in range(anti_dim)])
        
        self.input_dim = input_dim
        self.n = n
        
    def forward(self,x):
        batch_dim, set_dim, input_dim = x.shape
        sym_feature = self.sym(x).unsqueeze(1).repeat(1, set_dim, 1)
        z = torch.cat([x, sym_feature], 2)
        z = self.push(z)
        
        sds = [f(z) for f in self.orbitals]
        sds = torch.stack(sds,1)
        sds = torch.det(sds)
        return torch.sum(sds, dim = 1)
    
class DeepMultiBackflow(nn.Module):
    def __init__(self, n, input_dim, hidden_dim, anti_dim):
        super(DeepMultiBackflow, self).__init__()
        self.sym = Symmetric(input_dim, hidden_dim, hidden_dim, hidden_dim)
        self.push = Block(input_dim + hidden_dim, hidden_dim, input_dim)
        self.orbitals = nn.ModuleList([Block(input_dim, hidden_dim, n) for _ in range(anti_dim)])
        self.g1 = OddProjBlock(anti_dim, hidden_dim, hidden_dim)
        self.g2 = OddProjBlock(hidden_dim, hidden_dim, 1)        
        
        self.input_dim = input_dim
        self.n = n
        
    def forward(self,x):
        batch_dim, set_dim, input_dim = x.shape
        sym_feature = self.sym(x).unsqueeze(1).repeat(1, set_dim, 1)
        z = torch.cat([x, sym_feature], 2)
        z = self.push(z)
        
        sds = [f(z) for f in self.orbitals]
        sds = torch.stack(sds,1)
        sds = torch.det(sds)
        sds = self.g1(sds)
        return torch.flatten(self.g2(sds))

In [30]:
#Validate batching

n = 5
d = 3
hidden_dim = 20

x = 10 * torch.normal(mean = 0, std = 1, size = (2, n, d))

x0 = x[:1]
x1 = x[1:]

SD = DeepMultiBackflow(n, d, hidden_dim, 4)
#SD = AntiNet(n, d, hidden_dim, 4)
print(SD(x))
print(SD(x0))
print(SD(x1))

tensor([-1.0010e-05,  1.9118e-06], grad_fn=<ReshapeAliasBackward0>)
tensor([-1.0013e-05], grad_fn=<ReshapeAliasBackward0>)
tensor([1.9118e-06], grad_fn=<ReshapeAliasBackward0>)


In [31]:
#Validate antisymmetry



x = 10 * torch.normal(mean = 0, std = 1, size = (n, d))
P = torch.eye(n)
P[0,0] = P[1,1] = 0
P[0,1] = P[1,0] = 1
x_ = torch.mm(P, x)
x = torch.unsqueeze(x, 0)
x_ = torch.unsqueeze(x_, 0)

SD = DeepMultiBackflow(n, d, hidden_dim, 3)
y = SD(x)
y_ = SD(x_)

# ANN = AntiNet(n, d, hidden_dim, 3)
# y = ANN(x)
# y_ = ANN(x_)
print(y)
print(y_)

tensor([-6.8524e-06], grad_fn=<ReshapeAliasBackward0>)
tensor([6.8524e-06], grad_fn=<ReshapeAliasBackward0>)


In [32]:
def train(model, x, y, iterations, lr=0.005):
    model.train()
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)

    losses = []
    for i in range(iterations):
        outputs = model(x)

        optimizer.zero_grad()
        loss = criterion(outputs, y)
        loss.backward()
                
        optimizer.step()

        losses.append(loss.item())
    
    model.eval()
    return losses

In [49]:
n = 5
d = 3
hidden_dim = 20
anti_dim = 5

iterations = 20000
samples = 5000

In [50]:
teacher = MultiSlaterDeterminant(n, d, hidden_dim, 200)
train_x = 3 * torch.normal(mean = 0, std = 1, size = (samples, n, d))
train_y = teacher(train_x).detach()

In [51]:
# for _ in range(1):
#     student = MultiSlaterDeterminant(n, d, hidden_dim, anti_dim)
#     losses = train(student, train_x, train_y, iterations, lr = 0.0025)
#     print(losses[::50])
#     print(min(losses))

In [52]:
# for _ in range(1):
#     student = AntiNet(n, d, hidden_dim, anti_dim)
#     losses = train(student, train_x, train_y, iterations, lr = 0.0025)
#     print(losses[::50])
#     print(min(losses))

In [53]:
for _ in range(1):
    student = DeepAntiNet(n, d, hidden_dim, anti_dim)
    losses = train(student, train_x, train_y, iterations, lr = 0.0025)
    print(losses[::50])
    print(min(losses))

[74.98603820800781, 12.448062896728516, 6.464196681976318, 4.879818916320801, 4.030492782592773, 3.530247211456299, 3.208014488220215, 2.978376865386963, 2.790480613708496, 2.6464672088623047, 2.5403292179107666, 2.4579055309295654, 2.3777809143066406, 2.302720308303833, 2.2359743118286133, 2.1958868503570557, 2.140211343765259, 2.102665901184082, 2.096693754196167, 2.0359768867492676, 2.0037131309509277, 1.9792101383209229, 1.9611250162124634, 1.9563833475112915, 1.8965578079223633, 1.8754156827926636, 1.8572744131088257, 1.8880958557128906, 1.8251302242279053, 1.814673662185669, 1.8229738473892212, 1.7819308042526245, 1.8033159971237183, 1.819546103477478, 1.735672116279602, 1.7281690835952759, 1.7095849514007568, 1.6916066408157349, 1.6836519241333008, 1.6800060272216797, 1.6810171604156494, 1.66530179977417, 1.6315369606018066, 1.6138352155685425, 1.6022917032241821, 1.6097099781036377, 1.5965417623519897, 1.5813536643981934, 1.544289469718933, 1.5281130075454712, 1.564332246780395

In [54]:
for _ in range(1):
    student = MultiBackflow(n, d, hidden_dim, anti_dim)
    losses = train(student, train_x, train_y, iterations, lr = 0.0025)
    print(losses[::50])
    print(min(losses))

[75.35175323486328, 26.88316535949707, 10.985112190246582, 6.651232719421387, 5.025967121124268, 4.157585144042969, 3.60455584526062, 3.2522475719451904, 3.007519245147705, 2.7913570404052734, 2.617907762527466, 2.477581262588501, 2.371523141860962, 2.2269725799560547, 2.1166398525238037, 2.045684814453125, 1.938610315322876, 1.8895525932312012, 1.8510980606079102, 1.7752336263656616, 1.7310562133789062, 1.6829882860183716, 1.703217625617981, 1.6123077869415283, 1.5615037679672241, 1.5194311141967773, 1.4945552349090576, 1.484784722328186, 1.433695673942566, 1.4250519275665283, 1.4044640064239502, 1.3823468685150146, 1.358427882194519, 1.3484439849853516, 1.3306024074554443, 1.3096555471420288, 1.3183000087738037, 1.292242169380188, 1.278598427772522, 1.3161906003952026, 1.2624006271362305, 1.2468044757843018, 1.2346489429473877, 1.2465819120407104, 1.2361586093902588, 1.2222607135772705, 1.2185897827148438, 1.1802247762680054, 1.1862661838531494, 1.1698143482208252, 1.1598702669143677

In [None]:
for _ in range(1):
    student = DeepMultiBackflow(n, d, hidden_dim, anti_dim)
    losses = train(student, train_x, train_y, iterations, lr = 0.0025)
    print(losses[::50])
    print(min(losses))

In [None]:
a = np.array([6.588473796844482, 6.398560047149658, 7.056000232696533])
b = np.array([6.899078845977783, 5.879907608032227, 5.7301530838012695])
c = np.array([4.987086296081543, 4.876344203948975, 4.408130645751953])

x_pos = np.arange(3)
names = ["Default", "One Extra Layer", "Two Extra Layers"]
means = [np.mean(a), np.mean(b), np.mean(c)]
stds = [np.std(a), np.std(b), np.std(c)]


fig, ax = plt.subplots()
ax.bar(x_pos, means, yerr=stds, align='center', alpha=0.5, ecolor='black', capsize=10)
ax.set_ylabel('Mean Squared Error')
ax.set_xticks(x_pos)
ax.set_xticklabels(names)
ax.yaxis.grid(True)

# Save the figure and show
plt.tight_layout()
plt.savefig('bar_plot_with_error_bars.png')
plt.show()