In [36]:
import torch
import torch.nn.functional as F
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import GCNConv, global_mean_pool
import pandas as pd
import ast
from ase.data import atomic_numbers  # Import ASE's atomic numbers dictionary

# Set target elements (optional filtering)
TARGET_ELEMENTS = None  # Set to None to include all elements, or {"Ni", "Ta", "Mn", "Sb"} for specific elements

# Load and filter data
def load_data(filepath):
    df = pd.read_csv(filepath)
    return df

# Preprocess to create graph data objects
def create_material_graph(structure_data, formation_energy):
    atoms = structure_data['data']['atoms']
    node_features = []
    positions = []

    # Create node features for each atom
    for atom in atoms:
        if TARGET_ELEMENTS is None or atom['element'] in TARGET_ELEMENTS:
            atomic_num = get_atomic_number(atom['element'])
            pos = [atom['x'], atom['y'], atom['z']]
            dist = sum([x**2 for x in pos]) ** 0.5
            node_features.append([atomic_num] + pos + [dist])
            positions.append(pos)
    
    # Create edges based on distance threshold
    edge_index = []
    threshold = 2e-10  # You may lower this threshold if needed
    for i in range(len(positions)):
        for j in range(i + 1, len(positions)):
            if sum([(positions[i][k] - positions[j][k]) ** 2 for k in range(3)]) ** 0.5 < threshold:
                edge_index.append([i, j])
                edge_index.append([j, i])

    # Check if edge_index is empty, add self-loops if needed
    if not edge_index:
        edge_index = [[i, i] for i in range(len(positions))]  # Add self-loops

    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
    x = torch.tensor(node_features, dtype=torch.float)
    y = torch.tensor([formation_energy], dtype=torch.float)

    return Data(x=x, edge_index=edge_index, y=y)

# Function to get atomic numbers for elements using ASE
def get_atomic_number(element_symbol):
    return atomic_numbers[element_symbol]

# Prepare dataset
def prepare_dataset(filepath):
    dataset = []
    data = load_data(filepath)
    for _, row in data.iterrows():
        try:
            structure_data = ast.literal_eval(row['structure'])[0]
            formation_energy = eval(row['formation_energy'])['value']
            graph = create_material_graph(structure_data, formation_energy)
            dataset.append(graph)            
        except: 
            # print(row)
            pass

    return dataset

# Define model
class MaterialGraph(torch.nn.Module):
    def __init__(self):
        super(MaterialGraph, self).__init__()
        self.conv1 = GCNConv(in_channels=5, out_channels=64)
        self.conv2 = GCNConv(in_channels=64, out_channels=64)
        self.fc1 = torch.nn.Linear(64, 32)
        self.fc2 = torch.nn.Linear(32, 1)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        x = global_mean_pool(x, data.batch)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Training and testing functions
def train(loader, model, optimizer):
    model.train()
    total_loss = 0
    for data in loader:
        optimizer.zero_grad()
        output = model(data)
        loss = F.mse_loss(output, data.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)

def test(loader, model):
    model.eval()
    total_loss = 0
    for data in loader:
        output = model(data)
        loss = F.mse_loss(output, data.y)
        total_loss += loss.item()
    return total_loss / len(loader)



In [37]:
# Main script
filepath = '../Data/1_MatDX/MatDX_EF.csv'  # Replace with actual path

In [38]:
dataset = prepare_dataset(filepath)

In [39]:
train_loader = DataLoader(dataset, batch_size=2, shuffle=True)



In [40]:
model = MaterialGraph()

In [41]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [42]:
# Training loop
epochs = 100
for epoch in range(1, epochs + 1):
    train_loss = train(train_loader, model, optimizer)
    if epoch % 10 == 0:
        test_loss = test(train_loader, model)
        print(f'Epoch {epoch}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}')

print("Training completed!")

  loss = F.mse_loss(output, data.y)
  loss = F.mse_loss(output, data.y)
  loss = F.mse_loss(output, data.y)
  loss = F.mse_loss(output, data.y)


Epoch 10, Train Loss: 40.4484, Test Loss: 40.4347
Epoch 20, Train Loss: 40.4403, Test Loss: 40.4345
Epoch 30, Train Loss: 40.4416, Test Loss: 40.4291
Epoch 40, Train Loss: 40.4403, Test Loss: 40.4310
Epoch 50, Train Loss: 40.4395, Test Loss: 40.4315


KeyboardInterrupt: 