In [1]:
import torch

import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib as mpl
import matplotlib.pyplot as plt
import math

%matplotlib inline

In [62]:
class Block(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(Block, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x
    
class OddProjBlock(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(OddProjBlock, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim + input_dim, output_dim, bias = False)
        
    def forward(self, x):
        x = torch.cat([F.relu(self.fc1(x)) - F.relu(self.fc1(-x)), x], dim = 1)
        x = self.fc2(x)
        return x
    
class Symmetric(nn.Module):
    def __init__(self, input_dim, hidden_dim, symmetric_dim, output_dim):
        super(Symmetric, self).__init__()
        
        self.phi = Block(input_dim, hidden_dim, symmetric_dim)
        self.rho = Block(symmetric_dim, hidden_dim, output_dim)
    
    
    def forward(self, x):        
        batch_size, input_set_dim, input_dim = x.shape
        
        x = x.view(-1, input_dim)
        z = self.phi(x)
        z = z.view(batch_size, input_set_dim, -1)
        z = torch.mean(z, 1)
        return self.rho(z)

In [4]:
class SlaterDeterminant(nn.Module):
    def __init__(self, n, input_dim, hidden_dim):
        super(SlaterDeterminant, self).__init__()
        self.orbitals = Block(input_dim, hidden_dim, n)
        
        self.input_dim = input_dim
        self.n = n
    
    def forward(self, x):
        x = x.view(-1, self.input_dim)
        sd = self.orbitals(x)
        sd = sd.view(-1, n, n)
        return torch.det(sd)

In [5]:
class MultiSlaterDeterminant(nn.Module):
    def __init__(self, n, input_dim, hidden_dim, anti_dim):
        super(MultiSlaterDeterminant, self).__init__()
        self.orbitals = nn.ModuleList([Block(input_dim, hidden_dim, n) for _ in range(anti_dim)])
        
        self.input_dim = input_dim
        self.n = n
        
    def forward(self,x):        
        #x = x.view(-1, self.input_dim)
        #sds = [f(x).view(-1, self.n, self.n) for f in self.orbitals]
        sds = [f(x) for f in self.orbitals]
        sds = torch.stack(sds,1)
        sds = torch.det(sds)
        return torch.sum(sds, dim = 1)

In [6]:
class AntiNet(nn.Module):
    def __init__(self, n, input_dim, hidden_dim, anti_dim):
        super(AntiNet, self).__init__()
        self.orbitals = nn.ModuleList([Block(input_dim, hidden_dim, n) for _ in range(anti_dim)])
        self.g = OddProjBlock(anti_dim, hidden_dim, 1)
        
        self.input_dim = input_dim
        self.n = n
        
    def forward(self,x):
        sds = [f(x) for f in self.orbitals]
        sds = torch.stack(sds,1)
        sds = torch.det(sds)
        return torch.flatten(self.g(sds))


In [7]:
class DeepAntiNet(nn.Module):
    def __init__(self, n, input_dim, hidden_dim, anti_dim):
        super(DeepAntiNet, self).__init__()
        self.orbitals = nn.ModuleList([Block(input_dim, hidden_dim, n) for _ in range(anti_dim)])
        self.g1 = OddProjBlock(anti_dim, hidden_dim, hidden_dim)
        self.g2 = OddProjBlock(hidden_dim, hidden_dim, 1)
        
        self.input_dim = input_dim
        self.n = n
        
    def forward(self,x):
        sds = [f(x) for f in self.orbitals]
        sds = torch.stack(sds,1)
        sds = torch.det(sds)
        sds = self.g1(sds)
        return torch.flatten(self.g2(sds))


In [73]:
class MultiBackflow(nn.Module):
    def __init__(self, n, input_dim, hidden_dim, anti_dim):
        super(MultiBackflow, self).__init__()
        self.sym = Symmetric(input_dim, hidden_dim, hidden_dim, hidden_dim)
        self.push = Block(input_dim + hidden_dim, hidden_dim, input_dim)
        self.orbitals = nn.ModuleList([Block(input_dim, hidden_dim, n) for _ in range(anti_dim)])
        
        self.input_dim = input_dim
        self.n = n
        
    def forward(self,x):
        batch_dim, set_dim, input_dim = x.shape
        sym_feature = self.sym(x).unsqueeze(1).repeat(1, set_dim, 1)
        z = torch.cat([x, sym_feature], 2)
        z = self.push(z)
        
        sds = [f(z) for f in self.orbitals]
        sds = torch.stack(sds,1)
        sds = torch.det(sds)
        return torch.sum(sds, dim = 1)
    
class DeepMultiBackflow(nn.Module):
    def __init__(self, n, input_dim, hidden_dim, anti_dim):
        super(DeepMultiBackflow, self).__init__()
        self.sym = Symmetric(input_dim, hidden_dim, hidden_dim, hidden_dim)
        self.push = Block(input_dim + hidden_dim, hidden_dim, input_dim)
        self.orbitals = nn.ModuleList([Block(input_dim, hidden_dim, n) for _ in range(anti_dim)])
        self.g1 = OddProjBlock(anti_dim, hidden_dim, hidden_dim)
        self.g2 = OddProjBlock(hidden_dim, hidden_dim, 1)        
        
        self.input_dim = input_dim
        self.n = n
        
    def forward(self,x):
        batch_dim, set_dim, input_dim = x.shape
        sym_feature = self.sym(x).unsqueeze(1).repeat(1, set_dim, 1)
        z = torch.cat([x, sym_feature], 2)
        z = self.push(z)
        
        sds = [f(z) for f in self.orbitals]
        sds = torch.stack(sds,1)
        sds = torch.det(sds)
        sds = self.g1(sds)
        return torch.flatten(self.g2(sds))

In [74]:
#Validate batching

n = 5
d = 3
hidden_dim = 20

x = 10 * torch.normal(mean = 0, std = 1, size = (2, n, d))

x0 = x[:1]
x1 = x[1:]

SD = DeepMultiBackflow(n, d, hidden_dim, 4)
#SD = AntiNet(n, d, hidden_dim, 4)
print(SD(x))
print(SD(x0))
print(SD(x1))

tensor([-4.3524e-05,  7.5264e-07], grad_fn=<ReshapeAliasBackward0>)
tensor([-4.3524e-05], grad_fn=<ReshapeAliasBackward0>)
tensor([7.5263e-07], grad_fn=<ReshapeAliasBackward0>)


In [75]:
#Validate antisymmetry



x = 10 * torch.normal(mean = 0, std = 1, size = (n, d))
P = torch.eye(n)
P[0,0] = P[1,1] = 0
P[0,1] = P[1,0] = 1
x_ = torch.mm(P, x)
x = torch.unsqueeze(x, 0)
x_ = torch.unsqueeze(x_, 0)

SD = MultiBackflow(n, d, hidden_dim, 3)
y = SD(x)
y_ = SD(x_)

ANN = AntiNet(n, d, hidden_dim, 3)
y = ANN(x)
y_ = ANN(x_)
print(y)
print(y_)

tensor([16.6430], grad_fn=<ReshapeAliasBackward0>)
tensor([-16.6430], grad_fn=<ReshapeAliasBackward0>)


In [13]:
def train(model, x, y, iterations, lr=0.005):
    model.train()
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)

    losses = []
    for i in range(iterations):
        outputs = model(x)

        optimizer.zero_grad()
        loss = criterion(outputs, y)
        loss.backward()
                
        optimizer.step()

        losses.append(loss.item())
    
    model.eval()
    return losses

In [87]:
n = 5
d = 3
hidden_dim = 15
anti_dim = 5

iterations = 10000
samples = 4000

In [88]:
teacher = MultiSlaterDeterminant(n, d, hidden_dim, 200)
train_x = 5 * torch.normal(mean = 0, std = 1, size = (samples, n, d))
train_y = teacher(train_x).detach()

In [89]:
for _ in range(1):
    student = MultiSlaterDeterminant(n, d, hidden_dim, anti_dim)
    losses = train(student, train_x, train_y, iterations, lr = 0.0025)
    print(losses[::50])
    print(min(losses))

[5856.53662109375, 741.1084594726562, 530.5269165039062, 448.66168212890625, 392.57049560546875, 342.07476806640625, 302.68682861328125, 274.0189514160156, 255.272705078125, 240.46038818359375, 226.79737854003906, 215.61814880371094, 207.5760955810547, 201.0840606689453, 195.84963989257812, 190.97569274902344, 186.99057006835938, 185.81324768066406, 180.8656005859375, 178.109619140625, 175.42984008789062, 172.9220733642578, 170.07476806640625, 168.19583129882812, 166.38568115234375, 164.96507263183594, 163.480224609375, 161.69554138183594, 159.85218811035156, 158.37608337402344, 157.13900756835938, 155.96759033203125, 155.01658630371094, 153.89830017089844, 153.07733154296875, 152.18309020996094, 151.37498474121094, 150.84228515625, 149.6619415283203, 149.0463104248047, 148.6286163330078, 147.42213439941406, 146.4218292236328, 145.58621215820312, 147.1170654296875, 144.22564697265625, 143.80368041992188, 143.79898071289062, 143.66879272460938, 142.33067321777344, 142.01803588867188, 14

In [90]:
for _ in range(1):
    student = AntiNet(n, d, hidden_dim, anti_dim)
    losses = train(student, train_x, train_y, iterations, lr = 0.0025)
    print(losses[::50])
    print(min(losses))

[3512.325927734375, 911.1068115234375, 612.9573974609375, 386.7205810546875, 283.54998779296875, 244.0565643310547, 222.72642517089844, 206.0544891357422, 190.17880249023438, 177.4953155517578, 168.604736328125, 163.58834838867188, 159.72425842285156, 157.037841796875, 154.3712921142578, 152.3917694091797, 150.23475646972656, 148.1197509765625, 146.00425720214844, 143.3270721435547, 142.13088989257812, 139.1960906982422, 137.6366424560547, 135.72891235351562, 134.44398498535156, 133.7918701171875, 133.135986328125, 131.8527069091797, 131.41253662109375, 130.63121032714844, 129.99705505371094, 129.47955322265625, 129.00271606445312, 128.6268310546875, 127.66271209716797, 127.19969940185547, 126.20470428466797, 125.23307037353516, 124.30564880371094, 123.51802062988281, 122.73004913330078, 122.12548065185547, 121.77183532714844, 121.06544494628906, 120.6657943725586, 120.52877044677734, 120.57437133789062, 119.75520324707031, 119.6533203125, 119.30360412597656, 119.18016052246094, 118.71

In [91]:
for _ in range(1):
    student = DeepAntiNet(n, d, hidden_dim, anti_dim)
    losses = train(student, train_x, train_y, iterations, lr = 0.0025)
    print(losses[::50])
    print(min(losses))

[3515.30712890625, 815.9725952148438, 444.9671936035156, 336.3922424316406, 273.1974182128906, 246.13539123535156, 228.58990478515625, 215.8326873779297, 205.8565216064453, 197.8935089111328, 190.03326416015625, 183.65663146972656, 178.40951538085938, 173.64035034179688, 168.87057495117188, 164.59393310546875, 159.90093994140625, 156.3222198486328, 154.01736450195312, 152.18991088867188, 148.59304809570312, 145.6527557373047, 145.48806762695312, 142.29397583007812, 140.3756103515625, 139.18350219726562, 139.64926147460938, 136.72361755371094, 135.5692596435547, 134.34730529785156, 133.5185089111328, 133.3640594482422, 134.2925567626953, 133.21299743652344, 131.08200073242188, 130.5001983642578, 129.0056915283203, 128.78836059570312, 127.6392822265625, 126.78963470458984, 126.06910705566406, 125.72554779052734, 124.49893188476562, 124.94474792480469, 123.21211242675781, 122.40718841552734, 122.21017456054688, 121.43399047851562, 121.7424545288086, 120.92613220214844, 119.41935729980469,

In [92]:
for _ in range(1):
    student = MultiBackflow(n, d, hidden_dim, anti_dim)
    losses = train(student, train_x, train_y, 2*iterations, lr = 0.0025)
    print(losses[::50])
    print(min(losses))

[3511.876708984375, 1247.783447265625, 724.7862548828125, 442.59100341796875, 337.2234191894531, 280.14752197265625, 241.66787719726562, 215.65403747558594, 197.8311767578125, 184.82493591308594, 175.7493896484375, 166.92166137695312, 159.8018341064453, 153.41368103027344, 148.80221557617188, 144.73773193359375, 140.6356201171875, 137.72434997558594, 134.7498321533203, 131.93875122070312, 128.04904174804688, 125.74459075927734, 123.20875549316406, 121.30916595458984, 119.7239990234375, 118.46581268310547, 117.53044891357422, 116.1927719116211, 116.91909790039062, 113.79862213134766, 112.72754669189453, 112.24043273925781, 110.91645812988281, 111.03224182128906, 108.77552032470703, 109.06781005859375, 106.99089050292969, 105.93807983398438, 105.76836395263672, 105.075439453125, 103.90969848632812, 103.87273406982422, 104.49620056152344, 102.5136947631836, 102.17745971679688, 102.99374389648438, 101.02153015136719, 100.4216537475586, 99.1388931274414, 98.62106323242188, 99.34546661376953

In [93]:
for _ in range(1):
    student = DeepMultiBackflow(n, d, hidden_dim, anti_dim)
    losses = train(student, train_x, train_y, 2*iterations, lr = 0.0025)
    print(losses[::50])
    print(min(losses))

[3511.873046875, 1656.4500732421875, 655.002197265625, 429.4568176269531, 322.79608154296875, 270.04620361328125, 237.44671630859375, 215.8983612060547, 197.660888671875, 183.63722229003906, 172.4657440185547, 164.4317169189453, 155.42352294921875, 149.62098693847656, 146.12986755371094, 142.19808959960938, 141.0067596435547, 137.94581604003906, 136.9471893310547, 134.96791076660156, 135.78794860839844, 134.2774200439453, 130.5178680419922, 130.853271484375, 129.52540588378906, 127.67672729492188, 125.22744750976562, 124.9001235961914, 124.3214340209961, 125.58372497558594, 121.34986877441406, 122.76814270019531, 120.48391723632812, 119.18077850341797, 120.06259155273438, 118.76380157470703, 119.8555679321289, 116.71315002441406, 114.62187957763672, 115.18939208984375, 114.1116943359375, 114.04415893554688, 111.82836151123047, 112.62871551513672, 111.61647033691406, 111.0267562866211, 109.58422088623047, 110.46817779541016, 109.59723663330078, 108.55722045898438, 107.20478820800781, 11

In [None]:
a = np.array([6.588473796844482, 6.398560047149658, 7.056000232696533])
b = np.array([6.899078845977783, 5.879907608032227, 5.7301530838012695])
c = np.array([4.987086296081543, 4.876344203948975, 4.408130645751953])

x_pos = np.arange(3)
names = ["Default", "One Extra Layer", "Two Extra Layers"]
means = [np.mean(a), np.mean(b), np.mean(c)]
stds = [np.std(a), np.std(b), np.std(c)]


fig, ax = plt.subplots()
ax.bar(x_pos, means, yerr=stds, align='center', alpha=0.5, ecolor='black', capsize=10)
ax.set_ylabel('Mean Squared Error')
ax.set_xticks(x_pos)
ax.set_xticklabels(names)
ax.yaxis.grid(True)

# Save the figure and show
plt.tight_layout()
plt.savefig('bar_plot_with_error_bars.png')
plt.show()