In [364]:
import numpy as np
import escnn.gspaces as gspaces
import escnn.nn as enn
import escnn
import torch
import torch.nn as nn
import torch.optim as optim

import random
import math
import time
import tqdm
import matplotlib.pyplot as plt
from scipy import sparse
import seaborn as sns
from sympy.combinatorics import Permutation, PermutationGroup

import netket as nk
import os
os.environ["JAX_PLATFORM_NAME"] = "cpu"
import json

PATH = os.getcwd()

In [365]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
torch.set_default_device(device)

if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))
    print('Memory Usage:')
    print('Allocated:', round(torch.cuda.memory_allocated(0),1), 'B')
    print('Reserved: ', round(torch.cuda.memory_reserved(0),1), 'B')
    print()
    torch.set_default_device('cuda')

# test
T = torch.randn(1, 4).to(device)
print(T)
print('Allocated:', round(torch.cuda.memory_allocated(0),1), 'B')

Using device: cpu
tensor([[-0.5981,  0.0977,  0.2979, -1.0878]])
Allocated: 0 B


### Generate Table of All Possible States Indexed in Lexigraphical Order

In [366]:
def generateStateTable(n,n0,N): # assumes ab > n0 > 0 and N = n choose n0
    states = []
    state = np.concatenate((np.ones(n-n0),-1*np.ones(n0))).astype(int)
    for i in range(0,N):
        states.append(np.copy(state))
        j = 0
        flag = True
        flip_count = 0
        up_count = 0
        while (flag): 
            if (j == n-1):
                for m in range(0,n):
                    if (state[m] != 1 and state[m+1] == 1):
                        flip_count += 1
                if (flip_count == 1):
                    break
            if (state[j] != 1):
                j += 1  
            elif (state[j+1] != 1):
                for m in range(0,j):
                    if (state[m] != 1 and state[m+1] == 1):
                        flip_count += 1
                    if (state[m] == 1):
                        up_count += 1
                if (flip_count == 1):
                    state[j],state[j+1] = state[j+1],state[j]
                    for k in range(1,up_count+1):
                        state[j-k],state[k-1] = state[k-1],state[j-k]
                else:
                    state[j],state[j+1] = state[j+1],state[j]
                flag = False
            else:
                j += 1
    return torch.tensor(states, dtype=torch.float, requires_grad=True)

### Generate Adjacency Matrices

In [367]:
class Node:
    def __init__(self, position, xadj, yadj):
        self.position = position
        self.xadj = xadj
        self.yadj = yadj
    
    def __str__(self):
        return f"{self.position} {self.xadj} {self.yadj}"

In [368]:
def squareAdjacencyList(a,b): # constructs a periodic adjacency graph with width a and height b
    nodes = []
    for j in range(0,b):
        for i in range(0,a):
            xadj = [[(i-1) % a,j],[(i+1) % a,j]]
            yadj = [[i,(j-1) % b],[i,(j+1) % b]]
            nodes.append(Node([[i,j]],xadj,yadj))

    return nodes

### First Neighbors Adjacency Matrix

In [369]:
def firstneighbors(a,b): 
    nodes = squareAdjacencyList(a,b)
    N = a*b
    J = [[0 for col in range(N)] for row in range(N)]

    for i in range(0,N-1):
        for j in range(i+1,N):
            flag = False
            for xptr in nodes[i].xadj:
                if xptr in nodes[j].position:
                    flag = True
            for yptr in nodes[i].yadj:
                if yptr in nodes[j].position:
                    flag = True
            if flag:
                J[i][j] = 1
                J[j][i] = 1
        
    return torch.tensor(J, dtype=torch.float, requires_grad=True)

### Second Neighbors (Euclidean) Adjacency Matrix

In [370]:
def secondneighbors(a,b):
    nodes = squareAdjacencyList(a,b)
    N = a*b
    J = [[0 for col in range(N)] for row in range(N)]

    for i in range(0,N-1):
        for j in range(i+1,N):
            flag = False

            for xptr in nodes[i].xadj:
                try:
                    for k in range(0,N):
                        if xptr in nodes[k].position:
                            intermediate = k
                    for yptr in nodes[intermediate].yadj:
                        if yptr in nodes[j].position:
                            flag = True
                except UnboundLocalError:
                    pass

            if flag:
                J[i][j] = 1
                J[j][i] = 1

    return torch.tensor(J, dtype=torch.float, requires_grad=True)

### Model Parameters

In [371]:
a = 4 # x-range of supercell
b = 4 # y-range of supercell
N1 = firstneighbors(a,b)
N2 = secondneighbors(a,b)
n = a*b # number of sites in lattice

n0 = n // 2 # number of down spins in the string (taken as floor(n/2))
N = int(math.factorial(n)/(math.factorial(n0)*math.factorial(n-n0))) # number of states

stateTable = generateStateTable(n,n0,N)

In [372]:
def stateValBin(state):
    sum = 0
    for i in range(len(state)):
        if state[i] == 1:
            sum += 2**i
    return sum

def searchState(state):
    high = N-1
    low = 0
    while True:
        mid = math.floor((high+low)/2)
        if np.array_equal(state,stateTable[mid]):
            return mid
        elif stateValBin(state) > stateValBin(stateTable[mid]):
            low = mid+1
        else:
            high = mid-1

### Define G-CNN Model

In [373]:
# Define the group and input/output types
gspace = gspaces.rot2dOnR2(N=4)  # Example for rotation group of order 4 (D4)
input_type = enn.FieldType(gspace, [gspace.trivial_repr])  # Scalar fields
output_type = enn.FieldType(gspace, [gspace.trivial_repr])  # Single scalar output

In [374]:
class PeriodicConvLayer(enn.EquivariantModule):
    def __init__(self, in_type, out_type, kernel_size):
        super(PeriodicConvLayer, self).__init__()
        
        self.kernel_size = kernel_size
        self.conv = enn.R2Conv(in_type, out_type, kernel_size, padding=0)
        
    def periodic_padding(self, x, padding):
        return torch.nn.functional.pad(x, padding, mode='circular')
        
    def forward(self, x):
        padding = self.kernel_size // 2
        x = self.periodic_padding(x.tensor, (padding, padding, padding, padding))
        x = enn.GeometricTensor(x, self.conv.in_type)
        return self.conv(x)
    
    def evaluate_output_shape(self, input_shape):
        padding = self.kernel_size // 2
        return self.conv.evaluate_output_shape(input_shape[:-2] + (input_shape[-2] + 2 * padding, input_shape[-1] + 2 * padding))

In [375]:
# Define the network
class GCNN(nn.Module):
    def __init__(self):
        super(GCNN, self).__init__()
        
        self.block1 = PeriodicConvLayer(input_type, enn.FieldType(gspace, 8*[gspace.regular_repr]), kernel_size=4)
        self.block2 = PeriodicConvLayer(enn.FieldType(gspace, 8*[gspace.regular_repr]), enn.FieldType(gspace, 16*[gspace.regular_repr]), kernel_size=4)
        self.block3 = PeriodicConvLayer(enn.FieldType(gspace, 16*[gspace.regular_repr]), output_type, kernel_size=4)
        self.pool = enn.PointwiseAvgPoolAntialiased(output_type, sigma=0.66, stride=1)
        
        self.linear = nn.Linear(49, 1)
        
    def forward(self, x):
        x = enn.GeometricTensor(x, input_type)
        x = self.block1(x)
        x = nn.functional.relu(x.tensor)
        x = enn.GeometricTensor(x, self.block2.conv.in_type)
        x = self.block2(x)
        x = nn.functional.relu(x.tensor)
        x = enn.GeometricTensor(x, self.block3.conv.in_type)
        x = self.block3(x)
        x = self.pool(x)
        x = torch.flatten(x.tensor, 1)
        x = self.linear(x)
        return x

### Build Model

In [376]:
model = GCNN().to(device)

### Energy Estimation Sampling

In [377]:
def findAdjStates(initState): ### takes some basis state and returns a list of all other states that have a nonzero Hamiltonian with it
    otherStates = []
    
    for i in range(n-1):
        for j in range(i+1,n):
            if (N1[i][j] != 0 or N2[i][j] !=0) and initState[i]*initState[j] == -1:
                tempState = initState.clone()
                tempState[i],tempState[j] = initState[j],initState[i]
                otherStates.append(tempState)                

    return otherStates
    
def computeExplicitHamEntry(state1,state2,J1,J2,diagonal):
    sum = 0.0
    if diagonal:
        for k in range(0,n-1):
            for l in range(k+1,n):
                if (N1[k][l] != 0):
                    sum += J1*state1[k]*state2[l]
                if (N2[k][l] != 0):
                    sum += J2*state1[k]*state2[l]          
    else:
        tempState = state1 * state2
        if (torch.count_nonzero(tempState == -1) == 2):
            indices = np.where(tempState == -1)
            e,f = indices[0][0],indices[0][1]
            if (N1[e][f] != 0):
                sum += 2*J1
            if (N2[e][f] != 0):
                sum += 2*J2
    
    return sum

def locEnergy(model,initState,coeff,J1,J2):
    sum = 0.0
    others = findAdjStates(initState)
    for x in others:
        c_x = model(torch.tensor(x.reshape(1,1,a,b), dtype=torch.float, requires_grad=True)).item()
        H_x = computeExplicitHamEntry(initState,x,J1,J2,False)
        sum += c_x*H_x
    sum /= coeff
    sum += computeExplicitHamEntry(initState,initState,J1,J2,True)
    return sum

def metropolis(model,initState,coeff):
    proposedState = torch.tensor([1]*(n-n0)+[-1]*n0, dtype=torch.float, requires_grad=True)
    proposedState[torch.randperm(proposedState.size(0))]
    newCoeff = model(torch.tensor(proposedState.reshape(1,1,a,b), dtype=torch.float, requires_grad=True)).item()
    acceptanceProb = min(1,(newCoeff/coeff)**2)
    bernTrial = torch.bernoulli(torch.tensor([acceptanceProb],dtype=torch.float, requires_grad=True)).item()
    if bernTrial == 1:
        return proposedState
    else:
        return initState

def sampleEnergy(model,batchSize,J1,J2): ### takes in basis state index
    ignore = 10 # number of MCMC steps to ignore (in order to reduce correlation)
    state = torch.tensor([1]*(n-n0)+[-1]*n0, dtype=torch.float, requires_grad=True)
    state = state[torch.randperm(state.size(0))]
    for i in range(ignore):
        c = model(torch.tensor(state.reshape(1,1,a,b), dtype=torch.float, requires_grad=True)).item()
        state = metropolis(model,state,c)

    sum = 0.0
    for step in range(batchSize):
        c = model(torch.tensor(state.reshape(1,1,a,b), dtype=torch.float, requires_grad=True)).item()
        sum += locEnergy(model,state,c,J1,J2)
        state = metropolis(model,state,c)
    
    return sum / batchSize

### Train Model

In [378]:
def trainModel(model, J1, J2, batchSize, numEpochs, lr):
    optimizer = optim.Adam(model.parameters(), lr=lr)

    for epoch in tqdm.tqdm(range(numEpochs)):
        model.train()
        optimizer.zero_grad()

        loss = sampleEnergy(model,batchSize,J1,J2)

        loss.backward()
        optimizer.step()

In [379]:
J1 = 1
J2 = 0
batchSize = 50
numEpochs = 100
lr = 0.001

trainModel(model,J1,J2,batchSize,numEpochs,lr)

  return func(*args, **kwargs)
100%|██████████| 100/100 [30:27<00:00, 18.27s/it]  


### Save Model

In [381]:
torch.save(model.state_dict(), PATH + '/4x4model.pth')

: 

### Load Model

In [None]:
model = GCNN().to(device)
model.load_state_dict(torch.load(PATH + '/4x4model.pth'))