In [None]:
import os
import json
import numpy as np
import torch
import torch.nn as nn
from torch.optim import Adam
from pymatgen.core.structure import Structure
import matplotlib.pyplot as plt
import csv
import random
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score


In [None]:
config = {
    "data_dir": "./structure",         
    "epochs": 200,                  
    "batch_size": 32,             
    "learning_rate": 0.001,      
    "radius": 8.0,                       
    "show_progress": True          
}

# random state
torch.manual_seed(52)
np.random.seed(52)


In [None]:
class GaussianDistance:
    def __init__(self, dmin, dmax, step, var=None):
        self.filter = np.arange(dmin, dmax + step, step)
        self.var = var if var else step

    def expand(self, distances):
        return np.exp(-(distances[..., np.newaxis] - self.filter)**2 / self.var**2)
    

In [None]:
class POSCARDataset(Dataset):
    def __init__(self, data_dir, max_nbr=12, radius=8.0):
        self.data_dir = data_dir
        self.max_nbr = max_nbr
        self.radius = radius
        self.gdf = GaussianDistance(dmin=0, dmax=radius, step=0.2)
        print(f"Number of Gaussian filters: {len(self.gdf.filter)}")      
 
        with open(os.path.join(data_dir, 'atom_init-enhanced.json')) as f:
            elem_embed = json.load(f)
        self.elem_dict = {int(k): np.array(v, dtype=np.float32) for k, v in elem_embed.items()}
        

        with open(os.path.join(data_dir, 'id_prop.csv')) as f:
            self.materials = [line.strip().split(',') for line in f]


    def __len__(self):
        return len(self.materials)

    @functools.lru_cache(maxsize=None)
    def __getitem__(self, idx):
        material_id, target = self.materials[idx]
        struct_path = os.path.join(self.data_dir, material_id)
        
        crystal = Structure.from_file(struct_path)
        
        atom_features = [self.elem_dict[site.specie.number] for site in crystal]
        
        all_nbrs = crystal.get_all_neighbors(self.radius, include_index=True)
        nbr_features = []
        nbr_indices = []
        
        for nbr_list in all_nbrs:
            sorted_nbrs = sorted(nbr_list, key=lambda x: x[1])[:self.max_nbr]
            if len(sorted_nbrs) < self.max_nbr:             
                pad_num = self.max_nbr - len(sorted_nbrs)
                distances = [nbr[1] for nbr in sorted_nbrs] + [self.radius+1]*pad_num
                indices = [nbr[2] for nbr in sorted_nbrs] + [0]*pad_num
            else:
                distances = [nbr[1] for nbr in sorted_nbrs]
                indices = [nbr[2] for nbr in sorted_nbrs]
                
            nbr_features.append(distances)
            nbr_indices.append(indices)
        
        return (
            torch.FloatTensor(atom_features),
            torch.FloatTensor(self.gdf.expand(np.array(nbr_features))),
            torch.LongTensor(nbr_indices)
        ), torch.FloatTensor([float(target)]), material_id

dataset = POSCARDataset(config["data_dir"], radius=config["radius"])
print(f"Dataset loading completed，{len(dataset)}samples")


In [None]:
def collate_pool(batch):
    atom_fea_list, nbr_fea_list, nbr_idx_list = [], [], []
    batch_target = []
    crystal_atom_idx = []
    batch_ids = []
    base_idx = 0

    for (atom_fea, nbr_fea, nbr_idx), target, cif_id in batch:
        n_i = atom_fea.shape[0]  
        atom_fea_list.append(atom_fea)
        nbr_fea_list.append(nbr_fea)
        nbr_idx_list.append(nbr_idx + base_idx)
        crystal_atom_idx.append(torch.arange(base_idx, base_idx + n_i))
        batch_target.append(target)
        batch_ids.append(cif_id)
        base_idx += n_i

    return (torch.cat(atom_fea_list, dim=0),
            torch.cat(nbr_fea_list, dim=0),
            torch.cat(nbr_idx_list, dim=0),
            crystal_atom_idx), \
           torch.stack(batch_target, dim=0), \
           batch_ids

def split_dataset(dataset, ratios=(0.8, 0.1, 0.1)):
    indices = np.random.permutation(len(dataset))
    train_end = int(ratios[0]*len(dataset))
    val_end = train_end + int(ratios[1]*len(dataset))
    
    return {
        "train": DataLoader(dataset, batch_size=config["batch_size"],
                          sampler=SubsetRandomSampler(indices[:train_end]),
                          collate_fn=collate_pool),
        "val": DataLoader(dataset, batch_size=config["batch_size"],
                        sampler=SubsetRandomSampler(indices[train_end:val_end]),
                        collate_fn=collate_pool),
        "test": DataLoader(dataset, batch_size=config["batch_size"],
                         sampler=SubsetRandomSampler(indices[val_end:]),
                         collate_fn=collate_pool)
    }

loaders = split_dataset(dataset)
print("Data loader successfully created.")
print(f"Train set batch: {len(loaders['train'])} | Validation set batch: {len(loaders['val'])} | Test set batch: {len(loaders['test'])}")


In [None]:
# Take a sample of data
sample_batch = next(iter(loaders['train']))
(atom_fea, nbr_fea, nbr_idx, crystal_idx), target, cif_id = sample_batch


In [None]:
import torch.nn.functional as F

class ChannelAttention(nn.Module):

    def __init__(self, channels, reduction=8):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool1d(1)
        self.max_pool = nn.AdaptiveMaxPool1d(1)
        
        self.mlp = nn.Sequential(
            nn.Linear(channels, channels // reduction),
            nn.ReLU(inplace=True),
            nn.Linear(channels // reduction, channels),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        b, c = x.size()
        
        avg_out = self.avg_pool(x.unsqueeze(2)).view(b, c)
        avg_out = self.mlp(avg_out).view(b, c, 1)
        
        max_out = self.max_pool(x.unsqueeze(2)).view(b, c)
        max_out = self.mlp(max_out).view(b, c, 1)
        
        out = avg_out + max_out
        return x.unsqueeze(2) * out.expand_as(x.unsqueeze(2))

class ConvLayer(nn.Module):
    """
    Convolutional Layer with Channel Attention
    """
    def __init__(self, atom_fea_len, nbr_fea_len, reduction=8):
        super().__init__()
        self.atom_fea_len = atom_fea_len
        self.nbr_fea_len = nbr_fea_len
        
        self.fc_full = nn.Linear(2*self.atom_fea_len+self.nbr_fea_len, 2*self.atom_fea_len)
        self.sigmoid = nn.Sigmoid()
        self.softplus1 = nn.Softplus()
        self.bn1 = nn.BatchNorm1d(2*self.atom_fea_len)
        self.bn2 = nn.BatchNorm1d(self.atom_fea_len)
        self.softplus2 = nn.Softplus()
        
        self.ca = ChannelAttention(atom_fea_len, reduction)

    def forward(self, atom_in_fea, nbr_fea, nbr_fea_idx):
        N, M = nbr_fea_idx.shape
        atom_nbr_fea = atom_in_fea[nbr_fea_idx, :]
        total_nbr_fea = torch.cat([
            atom_in_fea.unsqueeze(1).expand(N, M, self.atom_fea_len),
            atom_nbr_fea, 
            nbr_fea
        ], dim=2)        
        total_gated_fea = self.fc_full(total_nbr_fea)
        total_gated_fea = self.bn1(total_gated_fea.view(-1, 2*self.atom_fea_len)).view(N, M, 2*self.atom_fea_len)        
        nbr_filter, nbr_core = total_gated_fea.chunk(2, dim=2)
        nbr_filter = self.sigmoid(nbr_filter)
        nbr_core = self.softplus1(nbr_core)        
        nbr_sumed = torch.sum(nbr_filter * nbr_core, dim=1)
        nbr_sumed = self.bn2(nbr_sumed)
        out = atom_in_fea + nbr_sumed
        out = self.ca(out).squeeze(2)        
        return self.softplus2(out)

class AC_CGCNN(nn.Module):
    """
    CGCNN with Channel Attention
    """
    def __init__(self, orig_dim, nbr_dim, hidden_dim=64, h_fea_len=64, 
                 n_h=2, n_conv=3, reduction=8):
        super().__init__()
        self.atom_embed = nn.Linear(orig_dim, hidden_dim)
        self.conv_layers = nn.ModuleList([
            ConvLayer(
                atom_fea_len=hidden_dim, 
                nbr_fea_len=nbr_dim,
                reduction=reduction
            ) for _ in range(n_conv)
        ])
        self.conv_to_fc = nn.Linear(hidden_dim, h_fea_len)
        self.conv_to_fc_softplus = nn.Softplus()
        self.fcs = nn.ModuleList()
        self.softpluses = nn.ModuleList()
        if n_h > 1:
            for _ in range(n_h-1):
                self.fcs.append(nn.Linear(h_fea_len, h_fea_len))
                self.softpluses.append(nn.Softplus())        
        self.fc_out = nn.Linear(h_fea_len, 1)

    def forward(self, atom_fea, nbr_fea, nbr_idx, crystal_idx):
        x = self.atom_embed(atom_fea)
        for conv in self.conv_layers:
            x = conv(x, nbr_fea, nbr_idx)        
        crystal_feat = self.pooling(x, crystal_idx)        
        crystal_feat = self.conv_to_fc_softplus(self.conv_to_fc(crystal_feat))        
        for fc, softplus in zip(self.fcs, self.softpluses):
            crystal_feat = softplus(fc(crystal_feat))        
        return self.fc_out(crystal_feat)
    
    def pooling(self, atom_fea, crystal_idx):
        """
        Pooling the atom features to crystal features
        
        Parameters:
        atom_fea (Tensor): Atom feature vectors (N, atom_fea_len)
        crystal_idx (list): List of indices for each crystal
        
        Returns:
        Tensor: Crystal feature vectors (batch_size, atom_fea_len)
        """
        return torch.cat([
            torch.mean(atom_fea[idx], dim=0, keepdim=True)
            for idx in crystal_idx
        ], dim=0)
    

In [None]:
model = AC_CGCNN(
    orig_dim=atom_fea.shape[-1],
    nbr_dim=nbr_fea.shape[-1],
    hidden_dim=64,
    n_conv=4
)


In [None]:
import json
import matplotlib.pyplot as plt
import os
import torch
from torch.optim import Adam
import torch.nn as nn

def train_model():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    
    optimizer = Adam(model.parameters(), lr=config["learning_rate"])
    criterion = nn.MSELoss()
    
    history = {
        "train": [],
        "val": [],
        "config": config  
    }
    
    for epoch in range(1, config["epochs"]+1):
        model.train()
        train_loss = 0.0
        for batch in loaders["train"]:
            (atom_fea, nbr_fea, nbr_idx, crystal_idx), target, _ = batch
            atom_fea = atom_fea.to(device)
            nbr_fea = nbr_fea.to(device)
            nbr_idx = nbr_idx.to(device)
            
            optimizer.zero_grad()
            pred = model(atom_fea, nbr_fea, nbr_idx, crystal_idx)
            loss = criterion(pred, target.to(device))
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * len(target)

        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for batch in loaders["val"]:
                (atom_fea, nbr_fea, nbr_idx, crystal_idx), target, _ = batch
                pred = model(
                    atom_fea.to(device),
                    nbr_fea.to(device),
                    nbr_idx.to(device),
                    crystal_idx
                )
                val_loss += criterion(pred, target.to(device)).item() * len(target)
        
        avg_train = train_loss / len(loaders["train"].sampler)
        avg_val = val_loss / len(loaders["val"].sampler)
        history["train"].append(avg_train)
        history["val"].append(avg_val)
        
        if config["show_progress"]:
            print(f"Epoch {epoch:03d}/{config['epochs']} | "
                  f"Train Loss: {avg_train:.4f} | Val Loss: {avg_val:.4f}")
    visualize_history(history)
    save_history(history, "training_history.json")
    
    return model, history

def save_history(history, filename):

    dir_path = os.path.dirname(filename)    
    if dir_path and not os.path.exists(dir_path):
        os.makedirs(dir_path, exist_ok=True)
    serializable_history = {
        "train": history["train"],
        "val": history["val"],
        "config": history["config"]
    }
    
    with open(filename, 'w') as f:
        json.dump(serializable_history, f, indent=4)
    
    print(f"Training history has been saved to {filename}")


# Training
trained_model, training_history = train_model()

# Saving historical records
save_history(training_history, "training_history.json")



In [None]:
def evaluate_model(model, loader, device):
    model.eval()
    all_targets = []
    all_preds = []
    
    with torch.no_grad():
        for batch in loader:
            (atom_fea, nbr_fea, nbr_idx, crystal_idx), target, _ = batch
            pred = model(
                atom_fea.to(device),
                nbr_fea.to(device),
                nbr_idx.to(device),
                crystal_idx
            )
            all_targets.append(target.numpy())
            all_preds.append(pred.cpu().numpy())
    
    targets = np.concatenate(all_targets, axis=0)
    preds = np.concatenate(all_preds, axis=0)
    
    return {
        "mse": mean_squared_error(targets, preds),
        "mae": mean_absolute_error(targets, preds),
        "r2": r2_score(targets, preds),
        "targets": targets,
        "preds": preds
    }

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_metrics = evaluate_model(trained_model, loaders["train"], device)
val_metrics = evaluate_model(trained_model, loaders["val"], device)  
test_metrics = evaluate_model(trained_model, loaders["test"], device)

print(f"{'Metric':<10} | {'Training Set':>12} | {'Test Set':>10}")
print("-" * 40)
for metric in ["mse", "mae", "r2"]:
    print(f"{metric.upper():<10} | {train_metrics[metric]:>12.4f} | {test_metrics[metric]:>10.4f}")
    

In [None]:
torch.save(trained_model.state_dict(), 'pretrained model.pth')
