In [1]:
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import math

import jax
import jax.numpy as jnp
import flax
from flax import linen as nn
from flax.training import train_state
from jax.nn import initializers
import optax

%matplotlib inline

In [26]:
class Block(nn.Module):
    hidden_dim : int
    output_dim : int

    def setup(self):
        self.linear1 = nn.Dense(features=self.hidden_dim, kernel_init = initializers.kaiming_uniform(), 
                                )
        self.linear2 = nn.Dense(features=self.output_dim, kernel_init = initializers.kaiming_uniform(), 
                                )
    def __call__(self, x):
        x = self.linear1(x)
        x = nn.relu(x)
        x = self.linear2(x)
        return x
    
class OddProjBlock(nn.Module):
    hidden_dim : int
    output_dim : int
    
    def setup(self):
        self.fc1 = nn.Dense(features=self.hidden_dim, kernel_init = initializers.kaiming_uniform(), 
                                )
        self.fc2 = nn.Dense(features = self.output_dim, kernel_init = initializers.kaiming_uniform(),  use_bias = False)
    
    def __call__(self, x):
        x = jnp.concatenate([nn.relu(self.fc1(x)) - nn.relu(self.fc1(-x)), x], axis = 1)
        x = self.fc2(x)
        return x    

In [27]:
class SlaterDeterminant(nn.Module):
    hidden_dim : int
    n : int

    def setup(self):
        self.orbitals = Block(hidden_dim=self.hidden_dim, output_dim=self.n)

    def __call__(self, x):
        x = self.orbitals(x)
        x = jax.scipy.linalg.det(x)
        return x

In [28]:
class MultiSlaterDeterminant(nn.Module):
    hidden_dim : int
    anti_dim : int
    n : int

    def setup(self):
        self.orbitals = [Block(hidden_dim=self.hidden_dim, output_dim=self.n) for _ in range(self.anti_dim)]

    def __call__(self, x):
        sds = [f(x) for f in self.orbitals]
        sds = jnp.stack(sds, axis = 1)
        x = jax.scipy.linalg.det(sds)
        return jnp.sum(x, axis = 1)

In [29]:
class AntiNet(nn.Module):
    hidden_dim : int
    anti_dim : int
    n : int

    def setup(self):
        self.orbitals = [Block(hidden_dim=self.hidden_dim, output_dim=self.n) for _ in range(self.anti_dim)]
        self.g = OddProjBlock(hidden_dim=self.hidden_dim, output_dim=1)

    def __call__(self, x):
        sds = [f(x) for f in self.orbitals]
        sds = jnp.stack(sds, axis = 1)
        x = jax.scipy.linalg.det(sds)
        x = self.g(x)
        return jnp.ravel(x)

In [30]:
class DeepAntiNet(nn.Module):
    hidden_dim : int
    anti_dim : int
    n : int

    def setup(self):
        self.orbitals = [Block(hidden_dim=self.hidden_dim, output_dim=self.n) for _ in range(self.anti_dim)]
        self.g1 = OddProjBlock(hidden_dim=self.hidden_dim, output_dim=self.hidden_dim)
        self.g2 = OddProjBlock(hidden_dim=self.hidden_dim, output_dim=1)


    def __call__(self, x):
        sds = [f(x) for f in self.orbitals]
        sds = jnp.stack(sds, axis = 1)
        x = jax.scipy.linalg.det(sds)
        x = self.g1(x)
        x = self.g2(x)
        return jnp.ravel(x)

In [31]:
#Validate Batching
n = 5
d = 3
anti_dim = 4
hidden_dim = 20
rng = jax.random.PRNGKey(42)

# SD = SlaterDeterminant(hidden_dim = hidden_dim, n = n)
# SD = MultiSlaterDeterminant(hidden_dim = hidden_dim, n = n, anti_dim = anti_dim)
SD = AntiNet(hidden_dim = hidden_dim, n = n, anti_dim = anti_dim)


rng, inp_rng, init_rng = jax.random.split(rng, 3)
inp = jax.random.normal(inp_rng, (2, n, d))
params = SD.init(init_rng, inp)



print(SD.apply(params, inp))
print(SD.apply(params, inp[:1]))
print(SD.apply(params, inp[1:]))

[-1.0705532 49.222343 ]
[-1.0705531]
[49.22234]


In [32]:
#Validate antisymmetry

#SD = SlaterDeterminant(hidden_dim = hidden_dim, n = n)
#SD = MultiSlaterDeterminant(hidden_dim = hidden_dim, n = n, anti_dim = anti_dim)
SD = DeepAntiNet(hidden_dim = hidden_dim, n = n, anti_dim = anti_dim)


rng, inp_rng, init_rng = jax.random.split(rng, 3)
inp = jax.random.normal(inp_rng, (n, d))

P = jnp.eye(n)
P = P.at[0,0].set(0)
P = P.at[0,1].set(1)
P = P.at[1,0].set(1)
P = P.at[1,1].set(0)

inp_ = jnp.dot(P, inp)
inp = jnp.expand_dims(inp, 0)
inp_ = jnp.expand_dims(inp_, 0)

print(inp)
print(inp_)

print(inp.shape)
params = SD.init(init_rng, inp)



y = SD.apply(params, inp)
y_ = SD.apply(params, inp_)

print(y,y_)

[[[-0.57649416  0.24882974  2.0914953 ]
  [ 0.7997544   0.12270966  2.237492  ]
  [-0.48119685  0.66588587 -1.2084068 ]
  [ 0.17537078  0.08046822  1.1909232 ]
  [-0.8474577  -2.2422376  -0.07337539]]]
[[[ 0.7997544   0.12270966  2.237492  ]
  [-0.57649416  0.24882974  2.0914953 ]
  [-0.48119685  0.66588587 -1.2084068 ]
  [ 0.17537078  0.08046822  1.1909232 ]
  [-0.8474577  -2.2422376  -0.07337539]]]
(1, 5, 3)
[-12.200166] [12.200166]


In [59]:
def calculate_loss(state, params, batch):
    x, y = batch
    outputs = state.apply_fn(params, x)
    loss = 2.0 * optax.l2_loss(outputs, y).mean()
    return loss

@jax.jit
def train_step(state, batch):
    grad_fn = jax.value_and_grad(calculate_loss,
                                 argnums=1  # Parameters are second argument of the function
                                )
    loss, grads = grad_fn(state, state.params, batch)
    state = state.apply_gradients(grads=grads)
    return state, loss

def train(model, params, x, y, iterations, lr=0.005):
    optimizer = optax.adam(learning_rate=lr)
    batch = (x,y)
    
    state = train_state.TrainState.create(apply_fn=model.apply,
                                            params=params,
                                            tx=optimizer)

    losses = []
    for i in range(iterations):
                
        state, loss = train_step(state, batch)
        losses.append(loss)
    
    return losses, state

In [60]:
##########################################################################################

In [83]:
n = 5
d = 3
hidden_dim = 25
anti_dim = 10

iterations = 10000
samples = 4000

In [84]:
teacher = MultiSlaterDeterminant(hidden_dim = hidden_dim, n = n, anti_dim = 200)
rng, inp_rng, init_rng = jax.random.split(rng, 3)
inp = jax.random.normal(inp_rng, (2, n, d))
teacher_params = teacher.init(init_rng, inp)

#train_x = 5 * jax.random.normal(inp_rng, (2000, n, d))  ### This lead to too large y values??
train_x = jax.random.normal(inp_rng, (samples, n, d))
train_y = teacher.apply(teacher_params, train_x)

In [85]:
# student = MultiSlaterDeterminant(hidden_dim = hidden_dim, n = n, anti_dim = anti_dim)
# rng, inp_rng, init_rng = jax.random.split(rng, 3)
# inp = jax.random.normal(inp_rng, (2, n, d))
# student_params = teacher.init(init_rng, inp)

# losses, state = train(student, student_params, train_x, train_y, iterations, lr = 0.005)
# print(losses[::50])
# print(min(losses))

In [86]:
# test_y = student.apply(state.params, train_x)
# print(test_y)
# print(train_y)

In [87]:
########################################################################################

In [89]:
for _ in range(3):
    student = MultiSlaterDeterminant(hidden_dim = hidden_dim, n = n, anti_dim = anti_dim)
    rng, inp_rng, init_rng = jax.random.split(rng, 3)
    inp = jax.random.normal(inp_rng, (2, n, d))
    student_params = student.init(init_rng, inp)

    losses, state = train(student, student_params, train_x, train_y, iterations, lr = 0.0025)
    #print(losses[::50])
    print(min(losses))

64.72303
80.0083
70.748505


In [90]:
for _ in range(3):
    student = AntiNet(hidden_dim = hidden_dim, n = n, anti_dim = anti_dim)
    rng, inp_rng, init_rng = jax.random.split(rng, 3)
    inp = jax.random.normal(inp_rng, (2, n, d))
    student_params = student.init(init_rng, inp)

    losses, state = train(student, student_params, train_x, train_y, iterations, lr = 0.0025)
    #print(losses[::50])
    print(min(losses))

68.18513
64.681435
63.388844


In [91]:
for _ in range(3):
    student = DeepAntiNet(hidden_dim = hidden_dim, n = n, anti_dim = anti_dim)
    rng, inp_rng, init_rng = jax.random.split(rng, 3)
    inp = jax.random.normal(inp_rng, (2, n, d))
    student_params = student.init(init_rng, inp)

    losses, state = train(student, student_params, train_x, train_y, iterations, lr = 0.0025)
    #print(losses[::50])
    print(min(losses))

57.11024
57.373543
60.08668


In [None]:
a = np.array([6.588473796844482, 6.398560047149658, 7.056000232696533])
b = np.array([6.899078845977783, 5.879907608032227, 5.7301530838012695])
c = np.array([4.987086296081543, 4.876344203948975, 4.408130645751953])

x_pos = np.arange(3)
names = ["Default", "One Extra Layer", "Two Extra Layers"]
means = [np.mean(a), np.mean(b), np.mean(c)]
stds = [np.std(a), np.std(b), np.std(c)]


fig, ax = plt.subplots()
ax.bar(x_pos, means, yerr=stds, align='center', alpha=0.5, ecolor='black', capsize=10)
ax.set_ylabel('Mean Squared Error')
ax.set_xticks(x_pos)
ax.set_xticklabels(names)
ax.yaxis.grid(True)

# Save the figure and show
plt.tight_layout()
plt.savefig('bar_plot_with_error_bars.png')
plt.show()