In [2]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt 
from mpl_toolkits import mplot3d
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.init as init

torch.manual_seed(24)

MASS = 1.0
HBAR = 1.0
AUTOEV = 27.2114
AUTONM = 0.05291772108

# qGrid: 0.0 to 5.0; NGRID = 1024; 
# rGrid: 0.0 to 0.02*TWOPI / (NGRID * dq)    --> 0.0 to 25.7108; NGRID = 1024; 
NQGRID = 2048
qGrid = torch.linspace(0.0, 10.0, NQGRID, dtype=torch.float64)
dq = (10.0-0.0)/NQGRID


In [3]:
# zb-CdSe at Gamma point (eV): 
Eref_zbCdSe_Gamma = torch.tensor([-20.1798943671376740, -20.1798943671376740, -6.3095256693774902, -6.3095256693774902, -6.3095256693768960, 
                     -6.3095256693768960, -6.3024330398428914, -6.3024330398428914, -4.4148875987097256, -4.4148875987097256, 
                     1.3233415397783173, 1.3233415397783173, 1.3233415397789470, 1.3233415397789470, 1.3235075638553759, 
                     1.3235075638553759], dtype=torch.float64)

# zb-CdSe for all k points
Eref_zbCdSe = torch.tensor(np.loadtxt("bandStruct_zbCdSe.par")[:, 1:])

In [4]:
def plotBandStruct(bandStruct_array, marker_array, label_array): 
    fig, axs = plt.subplots(1,1, figsize=(3,3))
    for bandStructIndex in range(len(bandStruct_array)): 
        numBands = len(bandStruct_array[bandStructIndex][0])
        numKpts = len(bandStruct_array[bandStructIndex])
        for i in range(numBands): 
            if i==0: 
                axs.plot(np.arange(numKpts), bandStruct_array[bandStructIndex][:, i].detach().numpy(), marker_array[bandStructIndex], label=label_array[bandStructIndex])
            else: 
                axs.plot(np.arange(numKpts), bandStruct_array[bandStructIndex][:, i].detach().numpy(), marker_array[bandStructIndex])
    axs.legend(frameon=False)
    plt.show()
    return

def pot_func(x, params): 
    pot = (params[0]*(x*x - params[1]) / (params[2] * np.exp(params[3]*x*x) - 1.0))
    return pot
    
def plotPP(q_array, vq_Cd_array, vq_Se_array, label_array, lineshape_array): 
    fig, axs = plt.subplots(1,1, figsize=(3,3))
    for i in range(len(vq_Cd_array)): 
        q = q_array[i].detach().numpy()
        vq_Cd = vq_Cd_array[i].detach().numpy()
        vq_Se = vq_Se_array[i].detach().numpy()
        axs.plot(q, vq_Cd, lineshape_array[i], label="Cd "+label_array[i])
        axs.plot(q, vq_Se, lineshape_array[i], label="Se "+label_array[i])
    axs.set(xlabel=r"$q$", ylabel=r"$v(q)$")
    axs.legend(frameon=False)
    plt.show()
    return

In [6]:
# Create Net model class
'''
class Net(nn.Module):
    # Constructor
    def __init__(self, Layers):
        super(Net, self).__init__()
        self.hidden = nn.ModuleList()
        self.batch_norm = nn.ModuleList()
        for input_size, output_size in zip(Layers, Layers[1:]):
            layer = nn.Linear(input_size, output_size, dtype=torch.float64)
            init.xavier_normal_(layer.weight)
            init.constant_(layer.bias, 0)
            self.hidden.append(layer)
            if output_size != Layers[-1]:
                self.batch_norm.append(nn.BatchNorm1d(output_size, dtype=torch.float64))
    # Prediction
    def forward(self, activation):
        L = len(self.hidden)
        for (l, linear_transform) in enumerate(self.hidden):
            if l < L - 1:
                activation = torch.relu(linear_transform(activation))
                activation = self.batch_norm[l](activation)
            else:
                activation = linear_transform(activation)
        return activation

Layers = [1, 20, 2]
PPmodel = Net(Layers)
'''

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.input_layer = nn.Linear(1, 20, dtype=torch.float64) 
        self.bn1 = nn.BatchNorm1d(20, dtype=torch.float64) 

        self.hidden_layer1 = nn.Linear(20, 20, dtype=torch.float64) 
        self.bn2 = nn.BatchNorm1d(20, dtype=torch.float64)  

        self.hidden_layer2 = nn.Linear(20, 20, dtype=torch.float64) 
        self.bn3 = nn.BatchNorm1d(20, dtype=torch.float64) 

        self.output_layer = nn.Linear(20, 2, dtype=torch.float64) 

        # Xavier initialization for weights
        init.xavier_normal_(self.input_layer.weight)
        init.xavier_normal_(self.hidden_layer1.weight)
        init.xavier_normal_(self.hidden_layer2.weight)
        init.xavier_normal_(self.output_layer.weight)

    def forward(self, x):
        x = torch.relu(self.bn1(self.input_layer(x)))  
        x = torch.relu(self.bn2(self.hidden_layer1(x)))
        x = torch.relu(self.bn3(self.hidden_layer2(x)))
        x = self.output_layer(x)
        return x

PPmodel = Net()


# print("list(model.parameters()):\n ", list(PPmodel.parameters()))
# print("\nmodel.state_dict():\n ", PPmodel.state_dict())
print(PPmodel(torch.tensor([[1.0], [2.0]], dtype=torch.float64)))
# print(PPmodel(torch.tensor([[1.0]], dtype=torch.float64)))

tensor([[-0.0331,  0.6281],
        [-1.2901, -0.0408]], dtype=torch.float64, grad_fn=<AddmmBackward0>)


ValueError: Expected more than 1 value per channel when training, got input size torch.Size([1, 20])

In [None]:
nSystem = 1

# read system
scale = 11.4485278
unitCellVector1 = torch.tensor([0.0, 0.5, 0.5], dtype=torch.float64) * scale
unitCellVector2 = torch.tensor([0.5, 0.0, 0.5], dtype=torch.float64) * scale
unitCellVector3 = torch.tensor([0.5, 0.5, 0.0], dtype=torch.float64) * scale
unitCellVectors = torch.cat((unitCellVector1.unsqueeze(0), unitCellVector2.unsqueeze(0), unitCellVector3.unsqueeze(0)), dim=0)
cellVolume = torch.dot(unitCellVector1, torch.cross(unitCellVector2, unitCellVector3))
# print(cellVolume)

prefactor = 2 * np.pi / cellVolume
gVector1 = prefactor * torch.cross(unitCellVectors[1], unitCellVectors[2])
gVector2 = prefactor * torch.cross(unitCellVectors[2], unitCellVectors[0])
gVector3 = prefactor * torch.cross(unitCellVectors[0], unitCellVectors[1])
gVectors = torch.cat((gVector1.unsqueeze(0), gVector2.unsqueeze(0), gVector3.unsqueeze(0)), dim=0)

nAtoms = 2
atomTypes = np.array(["Cd", "Se"])
atomPos = torch.tensor([[0.125, 0.125, 0.125],
                        [-0.125, -0.125, -0.125]], dtype=torch.float64)
atomPos = atomPos @ unitCellVectors
# print(atomPos)

# read kPoints
kpt_zbCdSe = torch.tensor(np.loadtxt("ZB_kpoints.par"))
kpt_zbCdSe = kpt_zbCdSe @ gVectors
nkpt = kpt_zbCdSe.shape[0]

Gamma = torch.tensor([0.0, 0.0, 0.0], dtype=torch.float64)

maxKE = 10

nBands = 16

In [None]:
def basis(maxKE, scale, unitCellVectors, cellVolume): 
    prefactor = 2 * np.pi / cellVolume
    
    gVector1 = prefactor * torch.cross(unitCellVectors[1], unitCellVectors[2])
    gVector2 = prefactor * torch.cross(unitCellVectors[2], unitCellVectors[0])
    gVector3 = prefactor * torch.cross(unitCellVectors[0], unitCellVectors[1])
    # print(gVector1, gVector2, gVector3)
    minGMag = min(torch.norm(gVector1), torch.norm(gVector2), torch.norm(gVector3))
    numMaxBasisVectors = int(np.sqrt(2*maxKE) / minGMag)
    # print(numMaxBasisVectors)

    k = torch.arange(-numMaxBasisVectors, numMaxBasisVectors+1, dtype=torch.float64).repeat((2*numMaxBasisVectors+1)**2)
    j = torch.arange(-numMaxBasisVectors, numMaxBasisVectors+1, dtype=torch.float64).repeat_interleave((2*numMaxBasisVectors+1)).repeat((2*numMaxBasisVectors+1))
    i = torch.arange(-numMaxBasisVectors, numMaxBasisVectors+1, dtype=torch.float64).repeat_interleave((2*numMaxBasisVectors+1)**2)
    allGrid = torch.vstack((i, j, k)).T
    transform = torch.vstack((gVector1, gVector2, gVector3)).T
    allBasisSet = allGrid @ transform
    # print(allBasisSet.shape[0])
    # print(allBasisSet)

    row_norms = torch.norm(allBasisSet, dim=1)
    condition = (HBAR*0.5*row_norms**2 / MASS < maxKE)
    indices = torch.where(condition)[0]
    basisSet = allBasisSet[indices]
    # print(basisSet.shape[0])
    # print(basisSet)
    
    sorting_indices = torch.argsort(basisSet[:, 2], stable=True)
    basisSet = basisSet[sorting_indices]
    sorting_indices = torch.argsort(basisSet[:, 1], stable=True)
    basisSet = basisSet[sorting_indices]
    sorting_indices = torch.argsort(basisSet[:, 0], stable=True)
    basisSet = basisSet[sorting_indices]
    row_norms = torch.norm(basisSet, dim=1)
    sorting_indices = torch.argsort(row_norms[:], stable=True)
    sorted_basisSet = basisSet[sorting_indices]

    '''
    sorting_indices1 = np.lexsort((basisSet[:, 2], basisSet[:, 1], basisSet[:, 0], row_norms))    
    sorted_basisSet1 = basisSet[sorting_indices1]
    print(torch.equal(sorted_basisSet1, sorted_basisSet))
    print(torch.allclose(sorted_basisSet1, sorted_basisSet, atol=1e-5))
    '''
    return sorted_basisSet

# construct hamiltonian at a certain k-point (kVector). Quicker through vectorization. 
def calcHamiltonianMatrix_NN(basisStates, kVector, nAtoms, cellVolume):
    n = basisStates.shape[0]
    HMatrix = torch.zeros((n, n), dtype=torch.complex128)

    # Kinetic energy
    for i in range(n): 
        HMatrix[i,i] += HBAR**2 / (2*MASS) * (torch.norm(basisStates[i] + kVector))**2
        
    # Local potential
    gDiff = torch.stack([basisStates] * (basisStates.shape[0]), dim=1) - basisStates.repeat(basisStates.shape[0], 1, 1)
    
    for k in range(nAtoms): 
        gDiffDotTau = torch.sum(gDiff * atomPos[k], axis=2)
        structFact = 1/cellVolume * (torch.cos(gDiffDotTau) + 1j*torch.sin(gDiffDotTau))

        thisAtomIndex = np.where(atomTypes[k]==PP_order)[0]
        if len(thisAtomIndex)!=1: 
            raise ValueError("Type of atoms in PP. ")
        thisAtomIndex = thisAtomIndex[0]
        
        atomFF = PPmodel(torch.norm(gDiff, dim=2).view(-1, 1))
        atomFF = atomFF[:, thisAtomIndex].view(n, n)
        # atomFF = pot_func(torch.norm(gDiff, dim=2), totalParams[thisAtomIndex])
        
        HMatrix += atomFF * structFact
    return HMatrix

def calcBandStruct(basisStates, nkpt, kpts_coord, nAtoms, cellVolume, nBands): 
    bandStruct = torch.zeros((nkpt, nBands), dtype=torch.float64)
    for kpt_index in range(nkpt): 
        HamiltonianMatrixAtKpt = calcHamiltonianMatrix_NN(basisStates, kpts_coord[kpt_index], nAtoms, cellVolume)

        # diagonalize the hamiltonian
        energies = torch.linalg.eigvalsh(HamiltonianMatrixAtKpt)
        
        energiesEV = energies * AUTOEV
        # 2-fold degeneracy due to spin
        final_energies = energiesEV.repeat_interleave(2)[:nBands]
    
        bandStruct[kpt_index] = final_energies

    return bandStruct

In [None]:
# Validating on the initialized NN model

# PPmodel(torch.tensor([[1.0], [2.0]], dtype=torch.float64))
NN_init = PPmodel(qGrid.view(-1, 1))

CdParams = torch.tensor([-31.4518, 1.3890, -0.0502, 1.6603, 0.0586], dtype=torch.float64)
SeParams = torch.tensor([8.4921, 4.3513, 1.3600, 0.3227, 0.1746], dtype=torch.float64)
SParams = torch.tensor([7.6697, 4.5192, 1.3456, 0.3035, 0.2087], dtype=torch.float64)
PP_order = np.array(["Cd", "Se", "S"])
totalParams = torch.cat((CdParams.unsqueeze(0), SeParams.unsqueeze(0), SParams.unsqueeze(0)), dim=0)

CdPP = pot_func(qGrid, CdParams)
SePP = pot_func(qGrid, SeParams)
# totalPP = torch.cat((CdPP.unsqueeze(0), SePP.unsqueeze(0)), dim=0)
plotPP([qGrid, qGrid], [CdPP, NN_init[:, 0]], [SePP, NN_init[:, 1]], ["ZungerForm", "NN_init"], ["-", ":"])


basisStates = basis(maxKE, scale, unitCellVectors, cellVolume)

NN_init_BandStruct = calcBandStruct(basisStates, nkpt, kpt_zbCdSe, nAtoms, cellVolume, nBands)
plotBandStruct([Eref_zbCdSe, NN_init_BandStruct], ["bo", "r-"], ["Reference zbCdSe", "NN_init"])

In [None]:
learning_rate = 2
optimizer = torch.optim.Adam(PPmodel.parameters(), lr = learning_rate)

basisStates = basis(maxKE, scale, unitCellVectors, cellVolume)

# torch.autograd.set_detect_anomaly(False)

def train_model(epochs):
    for epoch in range(epochs):
        # for name, param in PPmodel.named_parameters():
        #     print(f"Parameter: {name}, Size: {param.size()}")
        #     print(param.data)
        NN_BandStruct = calcBandStruct(basisStates, nkpt, kpt_zbCdSe, nAtoms, cellVolume, nBands)
        plotBandStruct([Eref_zbCdSe, NN_BandStruct], ["bo", "r-"], ["Reference zbCdSe", "NN_"+str(epoch)])
        loss = ((Eref_zbCdSe - NN_BandStruct) ** 2).mean()
        print(f"Loss: {loss}")
        optimizer.zero_grad()
        loss.backward()
        for name, param in PPmodel.named_parameters():
            if param.grad is not None:
                print(f"Parameter: {name}, Gradient Norm: {param.grad.norm().item()}")
        optimizer.step()

train_model(30)

NN_latest = PPmodel(qGrid.view(-1, 1))
plotPP([qGrid, qGrid], [CdPP, NN_latest[:, 0]], [SePP, NN_latest[:, 1]], ["ZungerForm", "NN_latest"], ["-", ":"])

In [None]:
train_model(50)

NN_latest = PPmodel(qGrid.view(-1, 1))
plotPP([qGrid, qGrid], [CdPP, NN_latest[:, 0]], [SePP, NN_latest[:, 1]], ["ZungerForm", "NN_latest"], ["-", ":"])