In [1]:
%run "model.py"

In [2]:
def train_dense(batch_size, epoch_len, echo, seed, learn_rate, hamiltonian, hidden_features, nonlinearity, bias, max_dist):
    
    torch.manual_seed(seed)
    model = HolographicPixelGCN(
                EnergyModel(
                    hamiltonian, # Ising critical point
                    SymmetricGroup(2), 
                    Lattice(4, 2, max_dist)), 
                hidden_features, nonlinearity, bias)
    optimizer = optim.Adam(model.parameters(), lr = learn_rate)
    
    train_loss = 0.
    free_energy = 0.
    tot_var = 0.
    
    for epoch in range(epoch_len):
        
        x = model.sample(batch_size)
        log_prob = model.log_prob(x)
        energy = model.energy(x)
        free = energy + log_prob.detach()
        meanfree = free.mean()
        loss = (log_prob * (free - meanfree)).sum()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        free_energy += meanfree.item()
        with torch.no_grad():
            var = (free - meanfree) ** 2
            var = var.mean()
        tot_var += var
        if (epoch+1)%echo == 0:
            print('{:5} loss: {:8.4f}, free energy: {:8.4f}, Variance: {:8.4f}'.format(epoch+1, train_loss/echo, free_energy/echo, tot_var/echo))
            train_loss = 0.
            free_energy = 0.
            tot_var = 0.
    
    with torch.no_grad():
        F = []
        for i in range(echo):
            F.append(model.free_energy(model.sample(batch_size)).mean())
    F = torch.tensor(F)
    print(F)
    print(F.mean(), " ", F.std())

In [3]:
H = lambda J: -J*(TwoBody(torch.tensor([1.,-1.]), (1,0)) 
                      + TwoBody(torch.tensor([1.,-1.]), (0,1)))
batch_size = 100
epoch_len = 2000
echo = 100
seed = 0
learn_rate = 0.01
hamiltonian = H(0.440686793)
hidden_features = [4,4]
nonlinearity = 'Tanh'
bias = False
max_dist = None

train_dense(batch_size, epoch_len, echo, seed, learn_rate, hamiltonian, hidden_features, nonlinearity, bias, max_dist)

  100 loss:   1.6659, free energy: -12.8410, Variance:   3.4485
  200 loss:   4.4070, free energy: -13.0259, Variance:   3.0128
  300 loss:   2.7476, free energy: -13.0695, Variance:   2.9648
  400 loss:   1.4274, free energy: -13.3414, Variance:   2.7925
  500 loss:   2.8977, free energy: -13.3349, Variance:   2.7162
  600 loss:   2.8323, free energy: -13.3029, Variance:   2.7188
  700 loss:   4.1935, free energy: -13.3644, Variance:   2.6363
  800 loss:   1.7819, free energy: -13.3364, Variance:   2.6291
  900 loss:   2.7482, free energy: -13.4219, Variance:   2.5757
 1000 loss:   1.1555, free energy: -13.3038, Variance:   2.5400
 1100 loss:   3.9713, free energy: -13.3160, Variance:   2.5455
 1200 loss:   2.0384, free energy: -13.3680, Variance:   2.5481
 1300 loss:   2.1540, free energy: -13.4215, Variance:   2.5502
 1400 loss:  -0.7605, free energy: -13.3845, Variance:   2.5150
 1500 loss:   3.7681, free energy: -13.3875, Variance:   2.4830
 1600 loss:   4.8029, free energy: -13.4