In [None]:
!pip install pymatgen megnet

import yaml
import json
import os
import pandas as pd
import numpy as np
import tensorflow as tf
from tqdm import tqdm

from pathlib import Path
import shutil

from pymatgen.core.composition import Composition
from pymatgen.core import Structure

from sklearn.model_selection import train_test_split
from megnet.models import MEGNetModel
from megnet.data.crystal import CrystalGraph, CrystalGraphWithBondTypes
from megnet.data.graph import GaussianDistance, StructureGraph
from megnet.data.molecule import MolecularGraph
from tensorflow.keras.models import load_model
from megnet.layers import _CUSTOM_OBJECTS
from megnet.models.base import GraphModel
import gc

import matplotlib.pyplot as plt 
import seaborn as sns
#from tensorflow import set_random_seed #bug loading this 


from google import colab
colab.drive.mount('/content/gdrive')

gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)

Collecting pymatgen
  Downloading pymatgen-2022.0.17.tar.gz (40.6 MB)
[K     |████████████████████████████████| 40.6 MB 1.7 MB/s 
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Installing backend dependencies ... [?25l[?25hdone
    Preparing wheel metadata ... [?25l[?25hdone
Collecting megnet
  Downloading megnet-1.3.0-py3-none-any.whl (114 kB)
[K     |████████████████████████████████| 114 kB 65.4 MB/s 
Collecting uncertainties>=3.1.4
  Downloading uncertainties-3.1.6-py2.py3-none-any.whl (98 kB)
[K     |████████████████████████████████| 98 kB 10.2 MB/s 
Collecting ruamel.yaml>=0.15.6
  Downloading ruamel.yaml-0.17.21-py3-none-any.whl (109 kB)
[K     |████████████████████████████████| 109 kB 72.4 MB/s 
Collecting spglib>=1.9.9.44
  Downloading spglib-1.16.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (292 kB)
[K     |████████████████████████████████| 292 kB 72.9 MB/s 
Collecting scipy>=1.5.0

In [None]:
def read_pymatgen_dict(file):
    with open(file, "r") as f:
        d = json.load(f)
    return Structure.from_dict(d)


def energy_within_threshold(prediction, target):
    # compute absolute error on energy per system.
    # then count the no. of systems where max energy error is < 0.02.
    e_thresh = 0.02
    error_energy = tf.math.abs(target - prediction)

    success = tf.math.count_nonzero(error_energy < e_thresh)
    total = tf.size(target)
    return success / tf.cast(total, tf.int64)


def idx_to_coords(i, j, k):
    if k != 1:
        z = 0.144826 if k == 0 else 0.355174
        return 0.08333333 + 0.125*i, 0.041666667 + 0.125*j, z
    else:
        return 0.041666667 + 0.125*i, 0.08333333 + 0.125*j, 0.25


################################# MODIF ICI ######################################################################
def prepare_dataset(dataset_path, split=None, fill_holes=True, remove_common_atoms=True, mode="train"):
    if mode == "train" and split is None:
        raise ValueError("`split` argument must not be None when mode='train'")
    dataset_path = Path(dataset_path)
    
    struct = {
        item.name.strip(".json"): read_pymatgen_dict(item)
        for item in (dataset_path / "structures").iterdir()
    }

    data = pd.DataFrame(columns=["structures"], index=struct.keys())
    if mode=="train":
        targets = pd.read_csv(dataset_path / "targets.csv", index_col=0)
        data = data.assign(structures=struct.values(), targets=targets)
    else:
        data = data.assign(structures=struct.values())

    d = {"Mo": 42, "W":74, "Se":34, "S":16}
    if fill_holes:
        new_structures = []
        for i, struct in enumerate(data.structures):
            abc = np.array([[m.c, m.a, m.b] for m in struct])
            species = np.array([d[str(m.specie)] for m in struct])
            mat3d = np.histogramdd(abc, bins=(3, 8, 8), weights=species)[0]
            ids_holes = np.where(mat3d == 0)
            coords_holes = [[ax[i] for ax in ids_holes] for i in range(ids_holes[0].size)]
            filled_struct = struct.copy()
            [filled_struct.append(1, idx_to_coords(*coords)) for coords in coords_holes]
            new_structures.append(filled_struct)
        data["structures"] = new_structures

        # print(data.structure[4][:-5])
    if remove_common_atoms:
        struct_to_remove = []
        Mo = Composition("Mo")
        S = Composition("S")
        new_structures = []
        for i, struct in enumerate(data.structures):
            struct.remove_species(Mo)
            struct.remove_species(S)
            if len(struct) > 0:
                new_structures.append(struct)
            else:
                struct_to_remove.append(i)
        print(data.shape)
        data.drop(data.iloc[struct_to_remove].index, inplace=True)
        print(data.shape)
        data["structures"] = new_structures
    if mode == "train":
        return data.loc[split["train"]], data.loc[split["test"]]
    else:
        return data

#######################################################################################################

def prepare_model(split,
    cutoff,
    nfeat_bond,
    gaussian_width,
    npass,
    nblocks,
    n1,
    n2,
    n3, 
    embedding_dim,
    dropout,
    lr,
    output_dir,
    seed,
    prefix=''
    ):
    '''
    nblocks: (int) number of MEGNetLayer blocks
    n1: (int) number of hidden units in layer 1 in MEGNetLayer
    n2: (int) number of hidden units in layer 2 in MEGNetLayer
    n3: (int) number of hidden units in layer 3 in MEGNetLayer
    embedding_dim: (int) number of embedding dimension
    npass: (int) number of recurrent steps in Set2Set layer
    '''

    model_config = {
      # graph params
      'cutoff':cutoff,
      'nfeat_bond':nfeat_bond,
      'gaussian_width':gaussian_width,
      # model params
      'npass':npass,
      'nblocks': nblocks,
      'n1': n1,
      'n2': n2,
      'n3': n3,
      'embedding_dim': embedding_dim,
      'dropout':dropout,
      'lr':lr,
      # training_params
      'seed':seed
    }

    np.random.seed(seed)
    #set_random_seed(seed)

    gaussian_centers = np.linspace(0, cutoff + 1, nfeat_bond)

    test_name = f"{split}_{prefix}"
    for val_name, value in model_config.items():
      test_name += f"{val_name}-{value}_"

    dirname = f"{output_dir}/{test_name}"
    if not os.path.isdir(dirname):
      os.mkdir(dirname)
    
    model = MEGNetModel(
        nblocks = nblocks,
        n1 = n1,
        n2 = n2,
        n3 = n3,
        embedding_dim = embedding_dim,
        graph_converter=CrystalGraph(cutoff=cutoff),
        centers=gaussian_centers,
        width=gaussian_width,
        loss=["MAE"],
        npass=npass,
        lr=lr,
        metrics=energy_within_threshold,
    )

    with open(f"{dirname}/config.json", 'w') as outfile:
      json.dump(model_config, outfile)

    return model, dirname


def get_minimal_dec(struct, empty_struct):
  minx = 1
  miny = 1
  posx=-1
  posy=-1
  list_x = []
  list_y = []
  for dec_x in np.arange(0,8,1):
    new_struct = struct_augment(
      struct.copy(),
      empty_struct.copy(),
      dec_x = dec_x,
      dec_y = 0,
      transpose = False,
      rotation = 0,
      sandwich_flip = False
      )
    x, y, z, species = zip(*[(s.a, s.b, s.c, s.specie) for s in new_struct])
    size = max(x)-min(x)
    list_x +=[size]
    if size < minx :
      posx = dec_x
      minx = size
  for dec_y in np.arange(0,8,1):  
    new_struct = struct_augment(
      struct.copy(),
      empty_struct.copy(),
      dec_x = 0,
      dec_y = dec_y,
      transpose = False,
      rotation = 0,
      sandwich_flip = False
      )
    x, y, z, species = zip(*[(s.a, s.b, s.c, s.specie) for s in new_struct])
    size = max(y)-min(y)
    list_y +=[size]
    if size < miny :
      posy=dec_y
      miny=size

  return posx, posy

In [None]:
with open("/content/gdrive/MyDrive/IDAO/IDAO_2022/splits.json", "r") as f:
    splits = json.load(f)

In [None]:
path = "/content/gdrive/MyDrive/IDAO/IDAO_2022"

shutil.unpack_archive(f"{path}/data/dichalcogenides_public.tar.gz", "data")
shutil.unpack_archive(f"{path}/data/dichalcogenides_private.tar.gz", "data_private")
targets = pd.read_csv('/content/data/dichalcogenides_public/targets.csv')
targets.rename(columns={"band_gap": "targets","_id":"id"}, inplace=True)
data_private = prepare_dataset("data_private/dichalcogenides_private", fill_holes=False, remove_common_atoms=False, mode='test')


Mo = Composition("Mo")
S = Composition("S")
W = Composition("W")
Se = Composition("Se")
struct = data_private.structures[0].copy()
struct.remove_species(Mo)
struct.remove_species(S)
struct.remove_species(W)
struct.remove_species(Se)
empty_struct = struct.copy()

In [None]:
# get minimal pos 

if False : 
  list_id = []
  list_dec_x = []
  list_dec_y = []
  for struct, id in tqdm(zip(data.structures, data.structures.index), position=0, total=data.shape[0]):
    x,y = get_minimal_dec(struct, empty_struct)
    list_id += [id]
    list_dec_x += [x]
    list_dec_y += [y]
  df_min_dec = pd.DataFrame({'id':list_id,'dec_x':list_dec_x,'dec_y':list_dec_y})
  df_min_dec.to_csv('/content/gdrive/MyDrive/IDAO/IDAO_2022/minimal_dec_public.csv',index=False)

In [None]:
df_min_dec_public = pd.read_csv('/content/gdrive/MyDrive/IDAO/IDAO_2022/minimal_dec_public.csv')
df_min_dec_private = pd.read_csv('/content/gdrive/MyDrive/IDAO/IDAO_2022/minimal_dec_private.csv')

# struct augmentation / reduction

In [None]:
def struct_augment(
  struct,
  empty_struct,
  dec_x = 0,
  dec_y = 0,
  transpose = False,
  rotation = 0,
  sandwich_flip = False,
  mirror_x = False,
  mirror_y = False
  ):
  
  d = {"Mo": 42, "W":74, "Se":34, "S":16}
  Mo = Composition("Mo")
  S = Composition("S")
  # generate the matrix 
  abc = np.array([[m.c, m.a, m.b] for m in struct])
  species = np.array([d[str(m.specie)] for m in struct])
  mat3d = np.histogramdd(abc, bins=(3, 8, 8), weights=species)[0]
  mat3d[np.where(mat3d == 0)]=1

  # transform the matrix 

  ## periodic translation
  if dec_x !=0 :
    mat3d = np.roll(mat3d, dec_x, axis=1)
  if dec_y != 0 :
    mat3d = np.roll(mat3d, dec_y, axis=2)

  ## transpose
  if transpose :
    mat3d = np.transpose(mat3d,(0,2,1))

  ## rotation
  if rotation != 0:
    mat3d = np.rot90(mat3d, rotation,(1,2))

  ## sandwich_flip
  if sandwich_flip:
    mat3d = np.rot90(mat3d, 2, (0,1))

  # mirrors_flips 
  if mirror_x:
    mat3d = np.flip(mat3d, 1)
  if mirror_y:
    mat3d = np.flip(mat3d, 2)

  # generate the struct
  new_struct = empty_struct.copy()
  for i in range(8):
    for j in range(8):
      for k in range(3):
        x,y,z = idx_to_coords(i, j, k)
        new_struct.append(int(mat3d[k,i,j]), (x,y,z))
  new_struct.remove_species(Mo)
  new_struct.remove_species(S)
  return new_struct 


def dataset_augment(data, df_min_dec, transposes, rotations, sandwiches, mirrors):
  # make an empty struct (we keep the orientation that way )
  Mo = Composition("Mo")
  S = Composition("S")
  W = Composition("W")
  Se = Composition("Se")
  struct = data.structures[0].copy()
  struct.remove_species(Mo)
  struct.remove_species(S)
  struct.remove_species(W)
  struct.remove_species(Se)
  empty_struct = struct.copy()

  list_struct = []
  list_id = []
  list_transpose = []
  list_rotation = []
  list_sandwich = []
  list_mirror_x = []
  list_mirror_y = []

  for struct, id in tqdm(zip(data.structures, data.structures.index), position=0, total=data.structures.shape[0]):
    dec_x, dec_y = df_min_dec[df_min_dec.id==id].values[0,1:]

    for transpose in transposes:
      for rotation in rotations:
        for sandwich in sandwiches:
          if (int(transpose)+int((rotation==1)or(rotation==3))) == 1:
            new_dec_x, new_dec_y = dec_y,dec_x
          else :
            new_dec_x, new_dec_y = dec_x, dec_y

          new_struct = struct_augment(
                struct.copy(),
                empty_struct.copy(),
                dec_x = new_dec_x,
                dec_y = new_dec_y,
                transpose = transpose,
                rotation = rotation,
                sandwich_flip = sandwich,
                mirror_x = False,
                mirror_y = False
                )
          list_struct += [new_struct]
          list_id += [id]
          list_transpose += [transpose]
          list_rotation += [rotation]
          list_sandwich += [sandwich]
          list_mirror_x += [False]
          list_mirror_y += [False]
    for mirror_x in mirrors:
      for mirror_y in mirrors:
        new_struct = struct_augment(
          struct.copy(),
          empty_struct.copy(),
          dec_x = dec_x,
          dec_y = dec_y,
          transpose = False,
          rotation = 0,
          sandwich_flip = False,
          mirror_x = mirror_x,
          mirror_y = mirror_y
        )
        list_struct += [new_struct]
        list_id += [id]
        list_transpose += [False]
        list_rotation += [0]
        list_sandwich += [False]
        list_mirror_x += [mirror_x]
        list_mirror_y += [mirror_y]

In [None]:
def get_results(preds,test, title='',save_path='',plot=True):
  ewt = energy_within_threshold(preds, test.targets.values)
  if plot : 
    plt.figure(figsize=(20,10))
    plt.subplot(1,2,1)
    plt.title(f'{title} ewt : {ewt:.3f}')
    plt.plot([0,2],[0.02,.02],'r', alpha=.5)
    plt.plot([0,2],[-0.02,-0.02],'r', alpha=.5)
    plt.plot(test.targets,preds-test.targets,'.')
    plt.xlabel('Targets')
    plt.ylabel('Error')
    plt.grid()
    plt.subplot(1,2,2)
    plt.title('Error distribution')
    plt.grid()
    sns.histplot(np.abs(preds-test.targets).values)
    if save_path == '':
      plt.show()
    else :
      plt.savefig(f"{save_path}/res.png")
  return ewt

In [None]:
for i_split in range(10):
    nb_epochs = 800
    lr_reduce_patience = 150
    lr_stop_patience = 500

    SPLIT = f"split_{i_split}"
    train, test = prepare_dataset("data/dichalcogenides_public", split=splits[SPLIT], fill_holes=False, remove_common_atoms=False, mode='train')
    train_aug = dataset_augment(
      data = train,
      df_min_dec = df_min_dec_public,
      transposes = [False],
      rotations = [0,1,2,3],
      sandwiches = [False,True],
      mirrors = [False]
      )
    print(train_aug.shape)
    train_aug = train_aug.merge(targets, on=['id'])
    print(train_aug.shape)

    test_aug = dataset_augment(
      data = test,
      df_min_dec = df_min_dec_public,
      transposes = [False],
      rotations = [0],
      sandwiches = [False],
      mirrors = [False]
      )
    print(test_aug.shape)
    test_aug = test_aug.merge(targets, on=['id'])
    print(test_aug.shape)
    
    reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(
        monitor='energy_within_threshold',
        factor=0.5, 
        verbose=1,
        patience=lr_reduce_patience,
        min_lr=0.00001
        )
    
    early_stopping = tf.keras.callbacks.EarlyStopping(
      monitor='energy_within_threshold', min_delta=0, patience=lr_stop_patience, verbose=1
      )
    
    model, dirname = prepare_model(split=SPLIT,
                                   cutoff=20,
                                   nfeat_bond=150,
                                   gaussian_width=.5,
                                   npass=3,
                                   nblocks=4,
                                   n1=64,
                                   n2=32,
                                   n3=32, 
                                   embedding_dim=16,
                                   dropout=.2,
                                   lr=5e-4,
                                   output_dir="/content/gdrive/MyDrive/IDAO_2022/new_callbacks",
                                   seed=42,
                                   prefix='rot_dwich'
                                   )
    print(dirname)
    model.train(train_aug.structures,
                train_aug.targets,
                validation_structures=test_aug.structures,
                validation_targets=test_aug.targets,
                epochs=nb_epochs,
                batch_size=128,
                dirname=dirname,
                patience=120,
                callbacks=[reduce_lr, early_stopping])
    preds = model.predict_structures(test_aug.structures)
    ewt = get_results(preds[:,0], test_aug, 'test', dirname, plot=True)
    del(model)
    gc.collect()