In [7]:
#!/usr/bin/env python

import sys
# import os
import numpy as np
import pandas as pd
import mongo
import time
import pickle
import math
import torch
from torch.optim import Adam, SGD
# from torch.utils.data import Dataset, DataLoader
import multiprocess as mp
# import mongo
from cgcnn.data import StructureData, ListDataset, StructureDataTransformer
import tqdm
# from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import ShuffleSplit
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from skorch.callbacks import Checkpoint, LoadInitState 
from cgcnn.data import collate_pool, MergeDataset
from cgcnn.model import CrystalGraphConvNet
from skorch import NeuralNetRegressor
import skorch.callbacks.base
from skorch.dataset import CVSplit
from skorch.callbacks.lr_scheduler import WarmRestartLR, LRScheduler
from adamwr.adamw import AdamW
from adamwr.cosine_scheduler import CosineLRWithRestarts

def round_(n, decimals=0):
    '''
    Python can't round for jack. We use someone else's home-brew to do
    rounding. Credit goes to David Amos
    (<https://realpython.com/python-rounding/#rounding-half-up>).
    '''
    multiplier = 10 ** decimals
    return math.floor(n*multiplier + 0.5) / multiplier

def get_surface_from_doc(doc):
    '''
    Some of our functions parse by "surface", which we identify by mpid, Miller
    index, shift, and whether it's on the top or bottom of the slab. This
    helper function parses an aggregated/projected Mongo document for you and
    gives you back a tuple that contains these surface identifiers.

    Arg:
        doc     A Mongo document (dictionary) that contains the keys 'mpid',
                'miller', 'shift', and 'top'.
    Returns:
        surface A 4-tuple whose elements are the mpid, Miller index, shift, and
                a Boolean indicating whether the surface is on the top or
                bottom of the slab. Note that the Miller indices will be
                formatted as a string, and the shift will be rounded to 2
                decimal places.
    '''
    surface = (doc['mpid'], str(doc['miller']), round_(doc['shift'], 2), doc['top'])
    return surface


def get_docs_file(dataset, num_docs):
    start = time.time()
#    print('Shyam get_docs_file 1')
    docs_all = pickle.load(open(dataset, 'rb'))
#    print('Shyam get_docs_file 2')
    total_docs = len(docs_all)
    for doc in docs_all:
        doc['surface'] = get_surface_from_doc(doc)
#    print('Shyam get_docs_file 3')
    docs = docs_all[:num_docs]
#    print('Shyam get_docs_file 4')
    target_list = np.array([doc['energy'] for doc in docs]).reshape(-1,1)
#    target_file = open('./input/target_list.pkl','wb')
  #  pickle.dump(target_list, target_file)
#    print('Shyam get_docs_file 5')
    end = time.time()
    Docs_time = end - start
    print('Documents loaded')
    return docs, Docs_time, total_docs

## Shyam's function ##
def make_SDT_list(dataset):
    SDT = StructureDataTransformer(atom_init_loc='./input/atom_init.json',
                                  max_num_nbr=12,
                                   step=0.2,
                                  radius=1,
                                  use_tag=False,
                                  use_fixed_info=False,
                                  use_distance=True,
                                  train_geometry = 'final-adsorbate'
                                  )

    SDT_out = SDT.transform(dataset)

    structures = SDT_out[0]

    #Settings necessary to build the model (since they are size of vectors as inputs)
    orig_atom_fea_len = structures[0].shape[-1]
    nbr_fea_len = structures[1].shape[-1]
    SDT_out = SDT.transform(dataset)
#    print(len(SDT_out))
#    SDT_lists = list()
    with mp.Pool(4) as pool:
        SDT_list = list(tqdm.tqdm(pool.imap(lambda x: SDT_out[x],range(len(SDT_out)),chunksize=40),total=len(SDT_out)))

    with open('./input/SDT_list.pkl','wb') as file_out:
        pickle.dump(SDT_list, file_out)
        
    print('Shyam make SDT list pickle file created')


def get_SDT_list(dataset):
    start = time.time()
    SDT = StructureDataTransformer(atom_init_loc='./input/atom_init.json',
                                  max_num_nbr=12,
                                   step=0.2,
                                  radius=1,
                                  use_tag=False,
                                  use_fixed_info=False,
                                  use_distance=True,
                                  train_geometry = 'final-adsorbate'
                                  )

    SDT_out = SDT.transform(dataset)

    structures = SDT_out[0]

    #Settings necessary to build the model (since they are size of vectors as inputs)
    orig_atom_fea_len = structures[0].shape[-1]
    nbr_fea_len = structures[1].shape[-1]
    SDT_out = SDT.transform(dataset)

    with mp.Pool(4) as pool:
        SDT_list = list(tqdm.tqdm(pool.imap(lambda x: SDT_out[x],range(len(SDT_out)),chunksize=40),total=len(SDT_out)))
        
    end = time.time()
    SDT_time = end-start
    print('SDT list created')
    return SDT_time

def get_device():
    cuda = torch.cuda.is_available()
    if cuda:
        device = torch.device("cuda")
    else:
        device='cpu'

    return device

def shuffle(SDT_list, target_list):
    indices = np.arange(len(SDT_list))
#    indices = len(SDT_list)
#    indices = np.arange(SDT_list)
#    print('Shyam shuffle 1')
    print(indices)
    SDT_training, SDT_test, target_training, target_test, train_idx, test_idx = \
    train_test_split(SDT_list, target_list,indices, test_size=0.2, random_state=42)
    print('Shyam shuffle end')
    return SDT_training, SDT_test, target_training, target_test

def training(device, num_epochs, SDT_training, target_training):
    structures = SDT_training[0]
    print('Shyam %s'%type(structures))
    orig_atom_fea_len = structures[0].shape[-1]
    nbr_fea_len = structures[1].shape[-1]
    print('Shyam training 1')
#   train_test_splitter = ShuffleSplit(test_size=0.25, random_state=42)
    train_test_splitter = ShuffleSplit(n_splits= 1, test_size=0.25, train_size = 0.25)

    
    # warm restart scheduling from https://arxiv.org/pdf/1711.05101.pdf
#    LR_schedule = LRScheduler(CosineLRWithRestarts, batch_size=214, epoch_size=len(SDT_training), \
#                              restart_period=10, t_mult=1.2)    
    LR_schedule = LRScheduler(CosineLRWithRestarts, batch_size=10, epoch_size=len(SDT_training), \
                              restart_period=10, t_mult=1.2)
    
    #Make a checkpoint to save parameters every time there is a new best for validation lost
    cp = Checkpoint(monitor='valid_loss_best',fn_prefix='valid_best_')
    #Callback to load the checkpoint with the best validation loss at the end of training
    class train_end_load_best_valid_loss(skorch.callbacks.base.Callback):
        def on_train_end(self, net, X, y):
            net.load_params('valid_best_params.pt')
    load_best_valid_loss = train_end_load_best_valid_loss()
    
    print('Shyam training 2')
    
    # To extract intermediate features, set the forward takes only the first return value to calculate loss
    class MyNet(NeuralNetRegressor):
        def get_loss(self, y_pred, y_true, **kwargs):
            y_pred = y_pred[0] if isinstance(y_pred, tuple) else y_pred  # discard the 2nd output
            return super().get_loss(y_pred, y_true, **kwargs)
    print('Shyam training 3')    
    
    '''
    net = MyNet(
        CrystalGraphConvNet,
        module__orig_atom_fea_len = orig_atom_fea_len,
        module__nbr_fea_len = nbr_fea_len,
        batch_size=214,
        module__classification=False,
        lr=0.0056,
        max_epochs= num_epochs, 
        module__atom_fea_len=46,
        module__h_fea_len=83,
        module__n_conv=8, #8
        module__n_h=4,
        optimizer__weight_decay=1e-5,
        optimizer=AdamW,
        iterator_train__pin_memory=True,
        iterator_train__num_workers=0,
        iterator_train__collate_fn = collate_pool,
        iterator_valid__pin_memory=True,
        iterator_valid__num_workers=0,
        iterator_valid__collate_fn = collate_pool,
        device=device,
    #     criterion=torch.nn.MSELoss,
        criterion=torch.nn.L1Loss,
        dataset=MergeDataset,
        train_split = CVSplit(cv=train_test_splitter),
        callbacks=[cp, load_best_valid_loss, LR_schedule]
    )
    '''

#    data = MergeDataset(SDT_training, target_training)
    data = MergeDataset

    net = MyNet(
        CrystalGraphConvNet,
        module__orig_atom_fea_len = orig_atom_fea_len,
        module__nbr_fea_len = nbr_fea_len,
        batch_size=10,
        module__classification=False,
        lr=0.0056,
        max_epochs= num_epochs, 
        module__atom_fea_len=46,
        module__h_fea_len=83,
        module__n_conv=8, #8
        module__n_h=4,
        optimizer__weight_decay=1e-5,
        optimizer=AdamW,
        iterator_train__pin_memory=True,
        iterator_train__num_workers=0,
        iterator_train__collate_fn = collate_pool,
        iterator_valid__pin_memory=True,
        iterator_valid__num_workers=0,
        iterator_valid__collate_fn = collate_pool,
        device=device,
    #     criterion=torch.nn.MSELoss,
        criterion=torch.nn.L1Loss,
        dataset=data,
        train_split = CVSplit(cv=train_test_splitter),
        callbacks=[cp, load_best_valid_loss, LR_schedule]
    )    

    print('Shyam training 4')
    
    start = time.time()
    net.initialize()
    print('Shyam training 5 -- net done')

#    print('Shyam training 5')

    net.fit(SDT_training, target_training)
    print('Shyam training end')
    end = time.time()
    training_time = end-start
    return net, train_test_splitter, training_time


def prediction(dataset, SDT_list, target_list, num_docs, num_SDT, num_epochs, device):
    docs, Docs_time, total_docs = get_docs_file(dataset, num_docs)
    SDT_time = get_SDT_list(docs)
    print('SDT_list and target_list are loaded')

#    SDT_list = pickle.load(open(SDT_list, 'rb'))
#    target_list = pickle.load(open(target_list, 'rb'))

    SDT_list = SDT_list[:num_SDT]
    target_list = target_list[:num_SDT]
    
 #   print('Shyam prediction 2')

    docs, Docs_time, total_docs = get_docs_file(dataset, num_docs)
    SDT_time = get_SDT_list(docs)
    print('Shyam prediction 3')
    
    SDT_training, SDT_test, target_training, target_test = shuffle(SDT_list, target_list)
    net, train_test_splitter, training_time = training(device, num_epochs, SDT_training, target_training)    
    train_indices, valid_indices = next(train_test_splitter.split(SDT_training))
    print('Shyam prediction 4')
    train_error = mean_absolute_error(target_training[train_indices].reshape(-1), 
                                      net.predict(SDT_training)[train_indices].reshape(-1))
    val_error = mean_absolute_error(target_training[valid_indices].reshape(-1), 
                                      net.predict(SDT_training)[valid_indices].reshape(-1))
    test_error = mean_absolute_error(target_test.reshape(-1), 
                                      net.predict(SDT_test).reshape(-1))
    start = time.time()
    
    measure_pred_time = mean_absolute_error(target_list[:num_docs].reshape(-1), 
                                      net.predict(SDT_list[:num_docs]).reshape(-1))
    end = time.time()
    pred_time = end - start
    times = (Docs_time, SDT_time, training_time, pred_time)
    errors = (train_error, val_error, test_error)
    return errors, times, SDT_training, total_docs

def figure_of_merit(dataset, SDT_list, target_list, num_docs, num_SDT, num_epochs):
    device = get_device()
    print('Shyam figure of merit 1')
    errors, times, SDT_training, total_docs = prediction(dataset, SDT_list, target_list, num_docs, num_SDT, num_epochs, device)
    print('Shyam figure of merit 2')
    Docs_time, SDT_time, training_time, pred_time = times
    print('Shyam figure of merit 3')
    train_error, val_error, test_error = errors
    print('\n')   
    print('BENCHMARK RESULTS')
    print('Current device:', device)
    print('Time to load %d documents: %f seconds\n' %(total_docs, Docs_time))
    print('Time to convert %d documents into SDT list: %f seconds' %(num_docs, SDT_time))
    print('Time to train the model using %d training examples for %d epochs: %f seconds' %(len(SDT_training), num_epochs, training_time))
    print('Training error: %f ev'  %train_error)
    print('Validation error: %f ev'  %val_error)
    print('Test error: %f ev'  %test_error)
    print('Time to predict energy for %d documents: %f seconds' %(num_docs, pred_time))
    with open('result.md', 'w') as f:
        f.writelines('# Benchmark Test \n\n')
        f.writelines('## Measurement on Edison\n\n')
        f.writelines('```C\n') 
        f.writelines('sbatch run_benchmark.sh \n')
        f.writelines('```\n\n')
        f.writelines('## Benchmark Results \n')
        if device == 'cpu':
            f.writelines('Current device:' + device + '\n')
        else:
            f.writelines('Current device: ' + device.type + '\n')
        f.writelines('\nTime to load %d documents: %f seconds\n' %(total_docs, Docs_time))
        f.writelines('\nTime to convert %d documents into SDT list: %f seconds\n' %(num_docs, SDT_time))
        f.writelines('\nTime to train the model using %d training examples: %f seconds\n' %(len(SDT_training), training_time))
        f.writelines('\nTraining error: %f ev\n'  %train_error)
        f.writelines('\nValidation error: %f ev\n'  %val_error)
        f.writelines('\nTest error: %f ev\n'  %test_error)
        f.writelines('\nTime to predict energy for %d documents: %f seconds\n' %(num_docs, pred_time))
    return


In [8]:
num_docs=10
num_SDT=5

# Read dataset from file
file = open('./input/mat_10.pkl','rb')
datas = pickle.load(file)

# Write target energies to file
target_array = np.array([data['energy'] for data in datas]).reshape(-1, 1)
print(target_array)
with open('./input/target_list.pkl','wb') as target_file:
    pickle.dump(target_array, target_file)

# Write SDT file
make_SDT_list(datas)

#b = get_docs_file(dataset, num_docs)


dataset='./input/mat_10.pkl'
SDT_list='./input/SDT_list.pkl'
target_list='./input/target_list.pkl'

#num_SDT=20771
num_epochs=2
figure_of_merit(dataset, SDT_list, target_list, num_docs, num_SDT, num_epochs)

[[-0.49401431]
 [-0.21384117]
 [-0.01293161]
 [ 0.09524271]
 [-1.03625676]
 [ 0.0541775 ]
 [ 0.00765951]
 [ 0.36665157]
 [ 0.33331164]
 [-0.61652263]]


100%|██████████| 10/10 [00:33<00:00,  3.31s/it]


Shyam make SDT list pickle file created
Shyam figure of merit 1
Documents loaded


100%|██████████| 10/10 [00:48<00:00,  4.82s/it]


SDT list created
SDT_list and target_list are loaded
Documents loaded


100%|██████████| 10/10 [01:06<00:00,  6.63s/it]

SDT list created
Shyam prediction 3
[0 1 2 3 4]
Shyam shuffle end
Shyam <class 'str'>





AttributeError: 'str' object has no attribute 'shape'