In [1]:
%env CUDA_DEVICE_ORDER=PCI_BUS_ID
%env CUDA_VISIBLE_DEVICES=3

env: CUDA_DEVICE_ORDER=PCI_BUS_ID
env: CUDA_VISIBLE_DEVICES=3


In [2]:
import os
import pymatgen
import matplotlib.pyplot as plt
import pandas as pd

In [3]:
structures = pd.read_pickle("datasets/structures_defects.pickle.gzip")

In [4]:
target_material = "MoS2"

In [5]:
# TODO(inner perfectionist) eval is unsecure
defects = pd.read_csv(
  "datasets/dichalcogenides_innopolis_202105/descriptors.csv", index_col="_id",
  converters={"cell": eval, "defects": eval})

In [6]:
selected_structures = structures.loc[(defects.loc[structures.descriptor_id].base == target_material).values]

In [7]:
from megnet.models import MEGNetModel
from megnet.data.graph import GaussianDistance
from megnet.data.crystal import CrystalGraph
from megnet.utils.preprocessing import StandardScaler
from megnet.callbacks import ModelCheckpointMAE
from pymatgen.core import Lattice, Structure, Molecule

import tensorflow as tf
import numpy as np

In [8]:
import wandb
from wandb.keras import WandbCallback

In [9]:
from sklearn.model_selection import train_test_split

In [10]:
# TODO(kazeevn) pass structure as the global state
train, test = train_test_split(selected_structures, test_size=0.25, random_state=42)

In [None]:
wandb.init(project='ai4material_design', entity='kazeev')

[34m[1mwandb[0m: Currently logged in as: [33mkazeev[0m (use `wandb login --relogin` to force relogin)


In [None]:
config = wandb.config
config.target = "homo"

In [None]:
nfeat_edge = 100
gc = CrystalGraph(bond_converter=GaussianDistance(np.linspace(0, 5, nfeat_edge), 0.5), cutoff=15)
model = MEGNetModel(nfeat_edge=nfeat_edge, nfeat_global=2, graph_converter=gc, npass=1)

In [None]:
scaler = StandardScaler.from_training_data(train.defect_representation,
                                           train[config.target], is_intensive=True)
model.target_scaler = scaler

In [None]:
model.train(train.defect_representation, train[config.target],
            validation_structures=test.defect_representation,
            validation_targets=test[config.target],
            callbacks=[WandbCallback()],
            epochs=1000, verbose=1, patience=1000)

In [None]:
train['predicted'] = model.predict_structures(train.defect_representation)
plt.scatter(train.homo, train.predicted)

In [None]:
test['predicted'] = model.predict_structures(test.defect_representation)
plt.scatter(test.homo, test.predicted)