In [62]:
!pip install torch pandas scipy matplotlib



In [63]:
import numpy as np
from scipy.spatial.distance import cdist
import torch 
import torch.optim as optim
import pandas as pd 

# modified import statements to access TFN-torch implementation 
import sys
import os
sys.path.append(os.path.abspath('../../models/'))
import tensorfieldnetworks.layers as layers
import tensorfieldnetworks.utils as utils

In [64]:
# Cell 2: Load formation energy and atom positions data
# Load formation energy from the Excel file
file_path = 'formation_energy.xlsx'  # Adjust to your file path
df = pd.read_excel(file_path)
formation_energy_dict = dict(zip(df.iloc[:, 0], df.iloc[:, 1]))

# Read the LAMMPS lattice file
lattice_file_path = 'trigger/my_lattice_prep.data'  # Adjust to your file path

with open(lattice_file_path, 'r') as f:
    lines = f.readlines()

# Find the line where atom data starts
start_index = None
for i, line in enumerate(lines):
    if line.strip() == "Atoms # atomic":
        start_index = i + 1
        break

# Extract atom positions from the file starting at the correct line
atom_positions = {}
atom_ids = []
for line in lines[start_index:]:
    parts = line.split()
    if len(parts) >= 5:
        atom_id = int(parts[0])
        x_pos = float(parts[2])
        y_pos = float(parts[3])
        z_pos = float(parts[4])
        atom_ids.append(atom_id)
        atom_positions[atom_id] = (x_pos, y_pos, z_pos)

# Convert atom_positions to a numpy array for vectorized operations
positions_array = np.array([atom_positions[atom_id] for atom_id in atom_ids])

# Calculate pairwise distances and closest atoms
distances = cdist(positions_array, positions_array, metric='euclidean')
np.fill_diagonal(distances, np.inf)
closest_indices = np.argsort(distances, axis=1)[:, :15]
closest_positions_array = positions_array[closest_indices]

closest_atoms_dict = {atom_id: (positions_array[i], closest_positions_array[i]) 
                      for i, atom_id in enumerate(atom_ids)}

# Let's check the data
closest_atoms_dict[1]  # Check closest atoms for atom 1

(array([1.16737936e-02, 1.20945465e+01, 3.78777050e-08]),
 array([[2.65081110e+00, 1.20945465e+01, 3.95771430e-08],
        [1.76301663e+00, 1.20945465e+01, 2.32759666e+00],
        [1.16737936e-02, 8.06303090e+00, 3.78767310e-08],
        [1.17051422e+00, 1.20945465e+01, 4.65519329e+00],
        [2.65081110e+00, 1.00787887e+01, 3.49139497e+00],
        [2.65081110e+00, 8.06303090e+00, 3.96354430e-08],
        [4.29503732e+00, 1.20945465e+01, 2.32759666e+00],
        [1.76301663e+00, 8.06303090e+00, 2.32759666e+00],
        [5.14107876e+00, 1.20945465e+01, 3.44934950e-08],
        [3.53785194e+00, 1.20945465e+01, 4.65519329e+00],
        [1.17051422e+00, 8.06303090e+00, 4.65519329e+00],
        [1.17051423e+00, 6.04727309e+00, 1.16379835e+00],
        [4.29503732e+00, 8.06303090e+00, 2.32759666e+00],
        [5.98061892e+00, 1.00787887e+01, 1.16379835e+00],
        [5.14107876e+00, 8.06303090e+00, 3.44355280e-08]]))

In [65]:
# Cell 3: Prepare input data for the network
# Radial Basis Function parameters
rbf_low = 0.0
rbf_high = 3.5
rbf_count = 4
rbf_spacing = (rbf_high - rbf_low) / rbf_count
centers = torch.Tensor(np.linspace(rbf_low, rbf_high, rbf_count))

# Function to compute input tensors (distances, rbf, etc.)
def get_inputs(positions):
    rij = utils.difference_matrix(positions)
    dij = utils.distance_matrix(positions)
    gamma = 1. / rbf_spacing
    rbf = torch.exp(-gamma * (dij.unsqueeze(-1) - centers)**2)
    return rij, dij, rbf

# Create the dataset and labels
inputs = []
labels = []
for atom_id, (atom_pos, closest_atoms_pos) in closest_atoms_dict.items():
    # Combine the atom and its closest neighbors to create input features
    positions = np.vstack([atom_pos, closest_atoms_pos])
    
    # Get the rij, dij, and rbf inputs
    rij, dij, rbf = get_inputs(torch.Tensor(positions))
    inputs.append((rij, rbf))  # Store the input tensors
    labels.append(formation_energy_dict.get(atom_id, 0))  # Fetch formation energy label

# Convert inputs and labels into tensors
inputs_tensor = [(rij.unsqueeze(0), rbf.unsqueeze(0)) for rij, rbf in inputs]
labels_tensor = torch.Tensor(labels)

# Let's check the input tensors
inputs_tensor[0][1].shape, labels_tensor[0]


(torch.Size([1, 16, 16, 4]), tensor(1292.7557))

In [66]:
# Cell 4: Define EGNN model architecture
class Readout(torch.nn.Module):
    def __init__(self, input_dims, num_classes):
        super(Readout, self).__init__()
        self.lin = torch.nn.Linear(input_dims, num_classes)
        
    def forward(self, inputs):
        inputs = torch.mean(inputs.squeeze(), dim=0)
        inputs = self.lin(inputs).unsqueeze(0)
        return inputs

class EGNN(torch.nn.Module):
    def __init__(self, rbf_dim=rbf_count, num_classes=1):
        super(EGNN, self).__init__()
        self.layer_dims = [1, 16, 16, 4]
        self.num_layers = len(self.layer_dims) - 1
        self.rbf_dim = rbf_dim
        self.embed = layers.SelfInteractionLayer(input_dim=1, output_dim=1, bias=False)
        
        self.layers = []
        for layer, (layer_dim_out, layer_dim_in) in enumerate(zip(self.layer_dims[1:], self.layer_dims[:-1])):
            self.layers.append(layers.Convolution(rbf_dim, layer_dim_in))
            self.layers.append(layers.Concatenation())
            self.layers.append(layers.SelfInteraction(layer_dim_in, layer_dim_out))
            self.layers.append(layers.NonLinearity(layer_dim_out))
        
        self.layers = torch.nn.ModuleList(self.layers)
        self.readout = Readout(self.layer_dims[-1], num_classes)
        
    def forward(self, rbf, rij):
        embed = self.embed(torch.ones(1, 16, 1, 1).repeat([rbf.size(0), 1, 1, 1]))
        input_tensor_list = {0: [embed]}
        for il, layer in enumerate(self.layers[::4]):
            input_tensor_list = self.layers[4*il](input_tensor_list, rbf, rij)
            input_tensor_list = self.layers[4*il+1](input_tensor_list)
            input_tensor_list = self.layers[4*il+2](input_tensor_list)
            input_tensor_list = self.layers[4*il+3](input_tensor_list)
        return self.readout(input_tensor_list[0][0])

# Instantiate the model
model = EGNN(num_classes=1)


In [67]:
# Cell 5: Train the EGNN model
# Set up loss function and optimizer
criterion = torch.nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
epochs = 2000
for epoch in range(epochs):
    running_loss = 0.0
    for i, (inputs, label) in enumerate(zip(inputs_tensor, labels_tensor)):
        rij, rbf = inputs
        label = label.unsqueeze(0)  # Make label batch-like
        
        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(rbf, rij)
        
        # Compute loss
        loss = criterion(outputs.squeeze(), label)
        loss.backward()

        # Update parameters
        optimizer.step()

        # Accumulate loss
        running_loss += loss.item()

    print(f'Epoch [{epoch}/{epochs}], Loss: {running_loss/len(inputs_tensor):.4f}')

print('Finished Training')


  return F.mse_loss(input, target, reduction=self.reduction)


Epoch [0/2000], Loss: 718352.0992
Epoch [1/2000], Loss: 717370.1749
Epoch [2/2000], Loss: 716158.0837
Epoch [3/2000], Loss: 714766.5982
Epoch [4/2000], Loss: 713230.4239
Epoch [5/2000], Loss: 711571.2733
Epoch [6/2000], Loss: 709802.7549
Epoch [7/2000], Loss: 707933.5029
Epoch [8/2000], Loss: 705968.1047
Epoch [9/2000], Loss: 703905.5629
Epoch [10/2000], Loss: 701737.3808
Epoch [11/2000], Loss: 699457.2885
Epoch [12/2000], Loss: 697068.0893
Epoch [13/2000], Loss: 694574.7957
Epoch [14/2000], Loss: 691981.7629
Epoch [15/2000], Loss: 689292.9004
Epoch [16/2000], Loss: 686511.8813
Epoch [17/2000], Loss: 683642.4549
Epoch [18/2000], Loss: 680688.3106
Epoch [19/2000], Loss: 677653.2428
Epoch [20/2000], Loss: 674541.1383
Epoch [21/2000], Loss: 671355.8980
Epoch [22/2000], Loss: 668101.5552
Epoch [23/2000], Loss: 664782.2246
Epoch [24/2000], Loss: 661402.0324
Epoch [25/2000], Loss: 657965.0869
Epoch [26/2000], Loss: 654475.5970
Epoch [27/2000], Loss: 650937.8135
Epoch [28/2000], Loss: 647356.

  if key is 1:
  if key is 0 or key is 1:
  if key is 0 or key is 1:
  if key is 1:
  if key is 0 or key is 1:
  if key is 0 or key is 1:
  if key is 1:
  if key is 0 or key is 1:
  if key is 0 or key is 1:
  if key is 1:
  if key is 0 or key is 1:
  if key is 0 or key is 1:
  if key is 1:
  if key is 0 or key is 1:
  if key is 0 or key is 1:
  if key is 1:
  if key is 0 or key is 1:
  if key is 0 or key is 1:
  if key is 1:
  if key is 0 or key is 1:
  if key is 0 or key is 1:
  if key is 1:
  if key is 0 or key is 1:
  if key is 0 or key is 1:
  if key is 1:
  if key is 0 or key is 1:
  if key is 0 or key is 1:
  if key is 1:
  if key is 0 or key is 1:
  if key is 0 or key is 1:
  if key is 1:
  if key is 0 or key is 1:
  if key is 0 or key is 1:
  if key is 1:
  if key is 0 or key is 1:
  if key is 0 or key is 1:
  if key is 1:
  if key is 0 or key is 1:
  if key is 0 or key is 1:
  if key is 1:
  if key is 0 or key is 1:
  if key is 0 or key is 1:
  if key is 1:
  if key is 0 or ke

KeyboardInterrupt: 