In [5]:
!pip install torch_geometric
!pip install torch_scatter
!pip install torch_sparse
!pip install pymatgen



In [6]:
import torch
import torch.nn as nn
from torch_geometric.nn import MessagePassing, global_mean_pool


class MegnetModule(MessagePassing):
    def __init__(self, edge_input_shape, node_input_shape, state_input_shape):
        super().__init__(aggr="mean")
        self.phi_e = nn.Sequential(
            nn.Linear(128, 64),
            nn.Softplus(),
            nn.Linear(64, 64),
            nn.Softplus(),
            nn.Linear(64, 32)
        )

        self.phi_u = nn.Sequential(
            nn.Linear(96, 64),
            nn.Softplus(),
            nn.Linear(64, 64),
            nn.Softplus(),
            nn.Linear(64, 32)
        )

        self.phi_v = nn.Sequential(
            nn.Linear(96, 64),
            nn.Softplus(),
            nn.Linear(64, 64),
            nn.Softplus(),
            nn.Linear(64, 32)
        )

        self.preprocess_e = nn.Sequential(
            nn.Linear(edge_input_shape, 64),
            nn.Softplus(),
            nn.Linear(64, 32)
        )

        self.preprocess_v = nn.Sequential(
            nn.Linear(node_input_shape, 64),
            nn.Softplus(),
            nn.Linear(64, 32)
        )

        self.preprocess_u = nn.Sequential(
            nn.Linear(state_input_shape, 64),
            nn.Softplus(),
            nn.Linear(64, 32)
        )

    def forward(self, x, edge_index, edge_attr, state, batch, bond_batch):
        x = self.preprocess_v(x)
        edge_attr = self.preprocess_e(edge_attr)
        state = self.preprocess_u(state)

        x_skip = x
        edge_attr_skip = edge_attr
        state_skip = state

        edge_attr = self.edge_updater(
            edge_index=edge_index, x=x, edge_attr=edge_attr, state=state, bond_batch=bond_batch
        )
        x = self.propagate(
            edge_index=edge_index, x=x, edge_attr=edge_attr, state=state, batch=batch
        )
        u_v = global_mean_pool(x, batch)
        u_e = global_mean_pool(edge_attr, bond_batch)
        state = self.phi_u(torch.cat((u_e, u_v, state), 1))
        return x + x_skip, edge_attr + edge_attr_skip, state + state_skip

    def message(self, edge_attr):
        return edge_attr

    def update(self, inputs, x, state, batch):
        return self.phi_v(torch.cat((inputs, x, state[batch, :]), 1))

    def edge_update(self, x_i, x_j, edge_attr, state, bond_batch):
        return self.phi_e(torch.cat((x_i, x_j, edge_attr, state[bond_batch, :]), 1))


In [7]:
import torch
import torch.nn as nn
from torch_geometric.nn import Set2Set


class MEGNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.emb = nn.Embedding(95, 16)
        self.m1 = MegnetModule(100, 16, 2)
        self.m2 = MegnetModule(32, 32, 32)
        self.m3 = MegnetModule(32, 32, 32)
        self.se = Set2Set(32, 1)
        self.sv = Set2Set(32, 1)
        self.hiddens = nn.Sequential(
            nn.Linear(160, 32),
            nn.Softplus(),
            nn.Linear(32, 16),
            nn.Softplus(),
            nn.Linear(16, 1)
        )

    def forward(self, x, edge_index, edge_attr, state, batch, bond_batch):
        x = self.emb(x).squeeze()
        x, edge_attr, state = self.m1(x, edge_index, edge_attr, state, batch, bond_batch)
        x, edge_attr, state = self.m2(x, edge_index, edge_attr, state, batch, bond_batch)
        x, edge_attr, state = self.m3(x, edge_index, edge_attr, state, batch, bond_batch)
        x = self.sv(x, batch)
        edge_attr = self.se(edge_attr, bond_batch)
        tmp = torch.cat((x, edge_attr, state), 1)
        out = self.hiddens(tmp)
        return out


In [8]:
import numpy as np
from copy import copy
from pymatgen.io.cif import CifParser
import torch
import random


class Scaler:
    def __init__(self):
        self.mean = 0
        self.std = 1.0

    def fit(self, dataset, feature_name='y'):
        data = np.array([getattr(dataset.get(i), feature_name).data.numpy() for i in range(len(dataset))])
        self.mean = np.mean(data)
        self.std = np.std(data)

    def transform(self, data):
        data_copy = copy(data)
        return (data_copy - self.mean) / (self.std if abs(self.std) > 1e-7 else 1.)

    def inverse_transform(self, data):
        data_copy = copy(data)
        std = self.std if abs(self.std) > 1e-7 else 1.0
        return data_copy * std + self.mean


class String2StructConverter:
    def __init__(self, struct_target_name):
        self.target_name = struct_target_name

    def convert(self, elem):
        struct = CifParser.from_string(elem['structure']).get_structures()[0]
        struct.y = elem[self.target_name]
        return struct


def set_random_seed(seed):
    torch.backends.cudnn.deterministic = True
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)


In [9]:
from pymatgen.core import Structure, Lattice
from pymatgen.optimization.neighbors import find_points_in_spheres
import numpy as np
from torch_geometric.data import Data
import torch


class SimpleCrystalConverter:
    def __init__(
            self,
            atom_converter=None,
            bond_converter=None,
            cutoff=5.0
    ):
        self.cutoff = cutoff
        self.atom_converter = atom_converter if atom_converter else DummyConverter()
        self.bond_converter = bond_converter if bond_converter else DummyConverter()

    def convert(self, d):
        lattice_matrix = np.ascontiguousarray(np.array(d.lattice.matrix), dtype=float)
        pbc = np.array([1, 1, 1], dtype=int)
        cart_coords = np.ascontiguousarray(np.array(d.cart_coords), dtype=float)

        center_indices, neighbor_indices, _, distances = find_points_in_spheres(
            cart_coords, cart_coords, r=self.cutoff, pbc=pbc, lattice=lattice_matrix, tol=1e-8
        )
        exclude_self = center_indices != neighbor_indices

        edge_index = torch.Tensor(np.stack((center_indices[exclude_self], neighbor_indices[exclude_self]))).long()
        if torch.numel(edge_index) == 0:
            return None

        x = torch.Tensor(self.atom_converter.convert(np.array([i.specie.Z for i in d]))).long()
        edge_attr = torch.Tensor(self.bond_converter.convert(distances[exclude_self]))
        state = getattr(d, "state", None) or [[0.0, 0.0]]
        y = d.y if hasattr(d, "y") else 0
        bond_batch = torch.Tensor([0 for _ in range(edge_index.shape[1])]).long()

        return Data(
            x=x, edge_index=edge_index, edge_attr=edge_attr, state=torch.Tensor(state), y=y, bond_batch=bond_batch
        )

    def __call__(self, d):
        return self.convert(d)


class DummyConverter:
    def convert(self, d):
        return d.reshape((-1, 1))


class GaussianDistanceConverter:
    def __init__(self, centers=np.linspace(0, 5, 100), sigma=0.5):
        self.centers = centers
        self.sigma = sigma

    def convert(self, d):
        return np.exp(
            -((d.reshape((-1, 1)) - self.centers.reshape((1, -1))) / self.sigma) ** 2
        )


In [11]:
import torch
from torch_geometric.data import InMemoryDataset
from monty.serialization import loadfn
import os.path as osp
from tqdm import tqdm


class MPDataset(InMemoryDataset):
    def __init__(self, root, transform=None, pre_transform=None):
        super().__init__(root, transform, pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        return ["mp.2018.6.1.json"]

    @property
    def processed_file_names(self):
        return ["data.pt"]

    def process(self):
        raw_data = loadfn(osp.join(self.raw_dir, "mp.2018.6.1.json"))

        converter = String2StructConverter('formation_energy_per_atom')
        structures_list = [converter.convert(s) for s in tqdm(raw_data)]

        if self.pre_transform is not None:
            data_list = [self.pre_transform(data) for data in structures_list]
            data_list = [data for data in data_list if data]
        else:
            raise "you should give struct2graph converter"

        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])


In [12]:
import torch
import torch.nn as nn
from torch_geometric.nn import Set2Set
import math


class MEGNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder_layer = nn.TransformerEncoderLayer(d_model=1,
                                                        nhead=7,
                                                        dropout=0,
                                                        batch_first=True)
        self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=3)
        self.seq_len = 987
        self.feature_size = 988
        self.W_pos = nn.Parameter(self._positional_encoding("sincos", True, self.seq_len, self.feature_size, normalize=True))
        self.dropout = nn.Dropout()
        self.emb = nn.Embedding(95, 16)
        self.m1 = MegnetModule(100, 16, 2)
        self.m2 = MegnetModule(32, 32, 32)
        self.m3 = MegnetModule(32, 32, 32)
        self.se = Set2Set(32, 1)
        self.sv = Set2Set(32, 1)
        self.hiddens = nn.Sequential(
            nn.Linear(160, 32),
            nn.Softplus(),
            nn.Linear(32, 16),
            nn.Softplus(),
            nn.Linear(16, 1)
        )
    def _positional_encoding(self, pe, learn_pe, q_len=1, d_model=1, normalize=True):
        if pe == None:
            W_pos = torch.zeros((q_len, d_model)) # pe = None and learn_pe = False can be used to measure impact of pe
            learn_pe = False
        elif pe == 'zeros':
            W_pos = torch.empty((q_len, d_model))
            nn.init.uniform_(W_pos, -0.02, 0.02)
        elif pe == 'normal' or pe == 'gauss':
            W_pos = torch.zeros((q_len, d_model))
            torch.nn.init.normal_(W_pos, mean=0.0, std=0.1)
        elif pe == 'uniform':
            W_pos = torch.zeros((q_len, d_model))
            nn.init.uniform_(W_pos, a=0.0, b=0.1)
        elif pe == 'lin2d': W_pos = Coord2dPosEncoding(q_len, d_model, exponential=False, normalize=normalize)
        elif pe == 'exp2d': W_pos = Coord2dPosEncoding(q_len, d_model, exponential=True, normalize=normalize)
        elif pe == 'sincos': W_pos = PositionalEncoding(q_len, d_model, normalize=normalize)
        return nn.Parameter(W_pos, requires_grad=learn_pe)

    def forward(self, x, edge_index, edge_attr, state, batch, bond_batch):
        self.seq_len = x.shape[1]
        x = self.dropout(x + self.W_pos)

        output = self.transformer_encoder(x, self.mask)
        output = self.decoder(output)
        x = self.emb(output).squeeze()
        x, edge_attr, state = self.m1(x, edge_index, edge_attr, state, batch, bond_batch)
        x, edge_attr, state = self.m2(x, edge_index, edge_attr, state, batch, bond_batch)
        x, edge_attr, state = self.m3(x, edge_index, edge_attr, state, batch, bond_batch)
        x = self.sv(x, batch)
        edge_attr = self.se(edge_attr, bond_batch)
        tmp = torch.cat((x, edge_attr, state), 1)
        out = self.hiddens(tmp)
        return out

In [13]:
class MEGNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.emb = nn.Embedding(95, 16)
        self.m1 = MegnetModule(100, 16, 2)
        self.m2 = MegnetModule(32, 32, 32)
        self.m3 = MegnetModule(32, 32, 32)
        self.se = Set2Set(32, 1)
        self.sv = Set2Set(32, 1)
        self.hiddens = nn.Sequential(
            nn.Linear(160, 32),
            nn.Softplus(),
            nn.Linear(32, 16),
            nn.Softplus(),
            nn.Linear(16, 1)
        )
    def forward(self, x, edge_index, edge_attr, state, batch, bond_batch):
        x = self.emb(x).squeeze()
        x, edge_attr, state = self.m1(x, edge_index, edge_attr, state, batch, bond_batch)
        x, edge_attr, state = self.m2(x, edge_index, edge_attr, state, batch, bond_batch)
        x, edge_attr, state = self.m3(x, edge_index, edge_attr, state, batch, bond_batch)
        x = self.sv(x, batch)
        edge_attr = self.se(edge_attr, bond_batch)
        tmp = torch.cat((x, edge_attr, state), 1)
        out = self.hiddens(tmp)
        return out

In [14]:
data = loadfn("bulk_moduli.json")
structures = data["structures"]
target = torch.log10_(torch.Tensor(data["bulk_moduli"]))

converter = SimpleCrystalConverter(bond_converter=GaussianDistanceConverter())

structures_converted = [converter.convert(s) for s in structures]
for s in structures_converted:
    s.bond_batch = torch.Tensor([0 for _ in range(s.edge_index.shape[1])]).long()

batch = Batch.from_data_list(structures_converted)

device = 'cuda' if torch.cuda.is_available() else 'cpu'

M = MEGNet()
opt = torch.optim.Adam(M.parameters(), lr=1e-3)
batch = batch
target = target

M.train(True)
for i in range(20):
    preds = M(batch.x, batch.edge_index, batch.edge_attr, batch.state, batch.batch, batch.bond_batch).squeeze()

    loss = F.mse_loss(preds, target)
    loss.backward()
    opt.step()
    opt.zero_grad()
    print("Loss", loss.data.numpy())

Loss 4.588823
Loss 4.157692
Loss 3.7317421
Loss 3.2950416
Loss 2.8360705
Loss 2.3471415
Loss 1.8260355
Loss 1.2821169
Loss 0.748089
Loss 0.30355796
Loss 0.1188554
Loss 0.42694682
Loss 0.7378708
Loss 0.664472
Loss 0.43101892
Loss 0.23518999
Loss 0.13668957
Loss 0.1186971
Loss 0.1455581
Loss 0.18769422
