Skip to content

Commit

Permalink
Merge pull request #136 from ReactionMechanismGenerator/scaling_factors
Browse files Browse the repository at this point in the history
Added a ZPE/freq scaling factor script to utils
  • Loading branch information
alongd committed Jun 23, 2019
2 parents b3b0cf8 + 722acfa commit 2658051
Show file tree
Hide file tree
Showing 35 changed files with 2,950 additions and 9,955 deletions.
2 changes: 2 additions & 0 deletions arc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
import arc_exceptions
import main
from main import ARC
import common
import parser
import plotter
import processor
import scheduler
import settings
import utils

import species
import job
254 changes: 254 additions & 0 deletions arc/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,254 @@
#!/usr/bin/env python
# encoding: utf-8

"""
ARC's common module
This module should not import any other module that has logging to avoid circular imports.
"""

from __future__ import (absolute_import, division, print_function, unicode_literals)
import logging
import warnings
import os
import sys
import time
import datetime
import shutil
import subprocess
import yaml

from arc.settings import arc_path, servers
from arc.arc_exceptions import InputError, SettingsError

##################################################################

logger = logging.getLogger('arc')

VERSION = '1.0.0'


def read_file(path):
"""
Read the ARC YAML input file and return the parameters in a dictionary.
Args:
(str, unicode): The input file path.
Returns:
dict: The input dictionary read from the file.
"""
if not os.path.isfile(path):
raise InputError('Could not find the input file {0}'.format(path))
with open(path, 'r') as f:
input_dict = yaml.load(stream=f, Loader=yaml.FullLoader)
return input_dict


def time_lapse(t0):
"""
A helper function returning the elapsed time since t0.
Args:
t0 (time.pyi): The initial time the count starts from.
Returns:
str, unicode: A "D HH:MM:SS" formatted time difference between now and t0.
"""
t = time.time() - t0
m, s = divmod(t, 60)
h, m = divmod(m, 60)
d, h = divmod(h, 24)
if d > 0:
d = str(d) + ' days, '
else:
d = ''
return '{0}{1:02.0f}:{2:02.0f}:{3:02.0f}'.format(d, h, m, s)


def check_ess_settings(ess_settings=None):
"""
A helper function to convert servers in the ess_settings dict to lists
Assists in troubleshooting job and trying a different server
Also check ESS and servers.
Args:
ess_settings (dict, optional): ARC's ESS settings dictionary.
Returns:
dict: An updated ARC ESS dictionary.
"""
if ess_settings is None or not ess_settings:
return dict()
settings = dict()
for software, server_list in ess_settings.items():
if isinstance(server_list, (str, unicode)):
settings[software] = [server_list]
elif isinstance(server_list, list):
for server in server_list:
if not isinstance(server, (str, unicode)):
raise SettingsError('Server name could only be a string. '
'Got {0} which is {1}'.format(server, type(server)))
settings[software.lower()] = server_list
else:
raise SettingsError('Servers in the ess_settings dictionary could either be a string or a list of '
'strings. Got: {0} which is a {1}'.format(server_list, type(server_list)))
# run checks:
for ess, server_list in settings.items():
if ess.lower() not in ['gaussian', 'qchem', 'molpro', 'onedmin', 'orca']:
raise SettingsError('Recognized ESS software are Gaussian, QChem, Molpro, Orca or OneDMin. '
'Got: {0}'.format(ess))
for server in server_list:
if not isinstance(server, bool) and server.lower() not in servers.keys():
server_names = [name for name in servers.keys()]
raise SettingsError('Recognized servers are {0}. Got: {1}'.format(server_names, server))
logger.info('\nUsing the following ESS settings:\n{0}\n'.format(settings))
return settings


def initialize_log(log_file, project, project_directory=None, verbose=logging.INFO):
"""
Set up a logger for ARC.
Args:
log_file (str, unicode): The log file name.
project (str, unicode): A name for the project.
project_directory (str, unicode, optional): The path to the project directory.
verbose (int, optional): Specify the amount of log text seen.
"""
# backup and delete an existing log file if needed
if project_directory is not None and os.path.isfile(log_file):
if not os.path.isdir(os.path.join(project_directory, 'log_and_restart_archive')):
os.mkdir(os.path.join(project_directory, 'log_and_restart_archive'))
local_time = datetime.datetime.now().strftime("%H%M%S_%b%d_%Y")
log_backup_name = 'arc.old.' + local_time + '.log'
shutil.copy(log_file, os.path.join(project_directory, 'log_and_restart_archive', log_backup_name))
os.remove(log_file)

logger.setLevel(verbose)
logger.propagate = False

# Use custom level names for cleaner log output
logging.addLevelName(logging.CRITICAL, 'Critical: ')
logging.addLevelName(logging.ERROR, 'Error: ')
logging.addLevelName(logging.WARNING, 'Warning: ')
logging.addLevelName(logging.INFO, '')
logging.addLevelName(logging.DEBUG, '')
logging.addLevelName(0, '')

# Create formatter and add to handlers
formatter = logging.Formatter('%(levelname)s%(message)s')

# Remove old handlers before adding ours
while logger.handlers:
logger.removeHandler(logger.handlers[0])

# Create console handler; send everything to stdout rather than stderr
ch = logging.StreamHandler(sys.stdout)
ch.setLevel(verbose)
ch.setFormatter(formatter)
logger.addHandler(ch)

# Create file handler
fh = logging.FileHandler(filename=log_file)
fh.setLevel(verbose)
fh.setFormatter(formatter)
logger.addHandler(fh)
log_header(project=project)

# ignore Paramiko and cclib warnings:
warnings.filterwarnings(action='ignore', module='.*paramiko.*')
warnings.filterwarnings(action='ignore', module='.*cclib.*')
logging.captureWarnings(capture=False)


def get_logger():
"""
Get the ARC logger (avoid having multiple entries of the logger).
"""
return logger


def log_header(project, level=logging.INFO):
"""
Output a header containing identifying information about ARC to the log.
Args:
project (str, unicode): The ARC project name to be logged in the header.
level: The desired logging level.
"""
logger.log(level, 'ARC execution initiated on {0}'.format(time.asctime()))
logger.log(level, '')
logger.log(level, '###############################################################')
logger.log(level, '# #')
logger.log(level, '# Automatic Rate Calculator #')
logger.log(level, '# ARC #')
logger.log(level, '# #')
logger.log(level, '# Version: {0}{1} #'.format(
VERSION, ' ' * (10 - len(VERSION))))
logger.log(level, '# #')
logger.log(level, '###############################################################')
logger.log(level, '')

# Extract HEAD git commit from ARC
head, date = get_git_commit()
branch_name = get_git_branch()
if head != '' and date != '':
logger.log(level, 'The current git HEAD for ARC is:')
logger.log(level, ' {0}\n {1}'.format(head, date))
if branch_name and branch_name != 'master':
logger.log(level, ' (running on the {0} branch)\n'.format(branch_name))
else:
logger.log(level, '\n')
logger.info('Starting project {0}'.format(project))


def get_git_commit():
"""
Get the recent git commit to be logged.
Note:
Returns empty strings if hash and date cannot be determined.
Returns:
str, unicode: The git HEAD commit hash.
str, unicode: The git HEAD commit date.
"""
if os.path.exists(os.path.join(arc_path, '.git')):
try:
return subprocess.check_output(['git', 'log', '--format=%H%n%cd', '-1'], cwd=arc_path).splitlines()
except (subprocess.CalledProcessError, OSError):
return '', ''
else:
return '', ''


def get_git_branch():
"""
Get the git branch to be logged.
Returns:
str, unicode: The git branch name.
"""
if os.path.exists(os.path.join(arc_path, '.git')):
try:
branch_list = subprocess.check_output(['git', 'branch'], cwd=arc_path).splitlines()
except (subprocess.CalledProcessError, OSError):
return ''
for branch_name in branch_list:
if '*' in branch_name:
return branch_name[2:]
else:
return ''


def log_footer(execution_time, level=logging.INFO):
"""
Output a footer for the log.
Args:
execution_time (str, unicode): The overall execution time for ARC.
level: The desired logging level.
"""
logger.log(level, '')
logger.log(level, 'Total execution time: {0}'.format(execution_time))
logger.log(level, 'ARC execution terminated on {0}'.format(time.asctime()))
88 changes: 88 additions & 0 deletions arc/commonTest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
This module contains unit tests for ARC's common module
"""

from __future__ import (absolute_import, division, print_function, unicode_literals)
import unittest
import os
import time

from arc.common import read_file, get_git_commit, time_lapse, check_ess_settings
from arc.settings import arc_path, servers
from arc.arc_exceptions import InputError, SettingsError

################################################################################


class TestARC(unittest.TestCase):
"""
Contains unit tests for ARC's common module
"""

def test_read_file(self):
"""Test the read_file() function"""
restart_path = os.path.join(arc_path, 'arc', 'testing', 'restart(H,H2O2,N2H3,CH3CO2).yml')
input_dict = read_file(restart_path)
self.assertIsInstance(input_dict, dict)
self.assertTrue('reactions' in input_dict)
self.assertTrue('freq_level' in input_dict)
self.assertTrue('use_bac' in input_dict)
self.assertTrue('ts_guess_level' in input_dict)
self.assertTrue('running_jobs' in input_dict)

with self.assertRaises(InputError):
read_file('nopath')

def test_get_git_commit(self):
"""Test the get_git_commit() function"""
git_commit = get_git_commit()
# output format: ['fafdb957049917ede565cebc58b29899f597fb5a', 'Fri Mar 29 11:09:50 2019 -0400']
self.assertEqual(len(git_commit[0]), 40)
self.assertEqual(len(git_commit[1].split()), 6)

def test_time_lapse(self):
"""Test the time_lapse() function"""
t0 = time.time()
time.sleep(2)
lap = time_lapse(t0)
self.assertEqual(lap, '00:00:02')

def test_check_ess_settings(self):
"""Test the check_ess_settings function"""
server_names = servers.keys()
ess_settings1 = {'gaussian': [server_names[0]], 'molpro': [server_names[1], server_names[0]],
'qchem': [server_names[0]]}
ess_settings2 = {'gaussian': server_names[0], 'molpro': server_names[1], 'qchem': server_names[0]}
ess_settings3 = {'gaussian': server_names[0], 'molpro': [server_names[1], server_names[0]],
'qchem': server_names[0]}
ess_settings4 = {'gaussian': server_names[0], 'molpro': server_names[1], 'qchem': server_names[0]}
ess_settings5 = {'gaussian': 'local', 'molpro': server_names[1], 'qchem': server_names[0]}

ess_settings1 = check_ess_settings(ess_settings1)
ess_settings2 = check_ess_settings(ess_settings2)
ess_settings3 = check_ess_settings(ess_settings3)
ess_settings4 = check_ess_settings(ess_settings4)
ess_settings5 = check_ess_settings(ess_settings5)

ess_list = [ess_settings1, ess_settings2, ess_settings3, ess_settings4, ess_settings5]

for ess in ess_list:
for soft, server_list in ess.items():
self.assertTrue(soft in ['gaussian', 'molpro', 'qchem'])
self.assertIsInstance(server_list, list)

with self.assertRaises(SettingsError):
ess_settings6 = {'nosoft': ['server1']}
check_ess_settings(ess_settings6)
with self.assertRaises(SettingsError):
ess_settings7 = {'gaussian': ['noserver']}
check_ess_settings(ess_settings7)

################################################################################


if __name__ == '__main__':
unittest.main(testRunner=unittest.TextTestRunner(verbosity=2))

0 comments on commit 2658051

Please sign in to comment.