Skip to content

Commit

Permalink
assort. fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
fiona-naughton committed Aug 2, 2016
1 parent 67ffe1f commit 41f322d
Showing 1 changed file with 112 additions and 97 deletions.
209 changes: 112 additions & 97 deletions package/MDAnalysis/wham.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,28 +5,6 @@
import datreant.core as dtr



def check_bundle_metadata(bundle, expected):
""" check each simulation in bundle has the expected metadata """
common_metadata = bundle.categories.keys()
for meta in expected:
if meta not in common_metadata:
raise ValueError("Not all simulations contain metadata {}."
"(Common metadata: {})".format(meta, common_metadata))
## TODO - also check if all the values are of the expected type?


def check_bundle_auxiliaries(bundle, expected):
""" check each simulation in bundle has expected auxiliary """
# when auxiliaries added to mdsynthesis, this might become more direct
for aux in expected:
for sim in bundle:
if aux not in sim.universe.trajectory.aux_list:
raise ValueError("Simulation {} does not contain auxiliary data"
" {}".format(sim.name, aux))
## TODO - also check it's got the right length/type?


## TODO - NAMING. Currently named in line with docs for Grossfield wham, but
## some of these aren't very clear/nice so likely to change...
# TODO - should probably split this up
Expand All @@ -38,10 +16,10 @@ def wham(bundle, spring='spring', loc_win_min='loc_win_min',
periodicity='', num_bins=200, tol=1e-6, numpad=0,
run_bootstrap=True, num_MC_trials=200,
start_time=None, end_time=None,
energy_units='kcal', keep_file=False):
energy_units='kcal', keep_files=False):
""" Wrapper for the Grossfield implementation of WHAM.
<link documentation>
[link documentation]
Each simulation must have the appropriate metadata (spring constant,
restrained value, temperature) and auxiliary data (timeseries data as
Expand All @@ -52,13 +30,8 @@ def wham(bundle, spring='spring', loc_win_min='loc_win_min',
be provided. If any names differ from the default values they must be
specified.
!! --> [[Currently doesn't add energy, so will only work when all simulations
use the same temperature]]
Various wham paramaters, etc etc [TBA]
Various wham paramaters can be specified [] + additional options [TBC...]
UNITS - ***
Paramerers
----------
Expand All @@ -69,7 +42,7 @@ def wham(bundle, spring='spring', loc_win_min='loc_win_min',
assuming biasing potential has form 1/2 k(x-x0)^2. [[<--format]]
[[Some simulation packages don't use the 1/2 - so have to pass 2*the
restraint const instead?]].
Must match the units used for energy + the reaction coordinate. [examples].
Must match the units used for energy + the reaction coordinate. [examples?].
loc_win_min : str, optional
Name of metadata field storing the reaction coordinate value that is
the minimum of the biasing potential in each window, ie x0 above.
Expand All @@ -83,14 +56,12 @@ def wham(bundle, spring='spring', loc_win_min='loc_win_min',
Name of the auxiliary containing the force/reaction coordinate value
throughout simulation.
timeseries_type : str, optional
What value is recorded in timeseries_data:
coord - value of the reaction coordinate x
force - value of the restraining force F = -kx
delta - difference between reaction coordiante and restrained
value x-x0
What value is recorded in timeseries_data; for available options see
``calc_reaction_cood``.
energy : str, optional
Name of the auxiliary containing potential energy. [At the same steps as
timeseries_data above]. [I assume units must match energy_units?]. Only
Name of the auxiliary containing potential energy. [Currently must be
at exactly the same same steps as timeseries above (ie same dt and
initial time)]. [I assume units must match energy_units?]. Only
required if simulations performed at different temperatures.
calc_temperature : float, optional
Temperature at which to perform wham calculation (in Kelvin). If not
Expand All @@ -102,7 +73,7 @@ def wham(bundle, spring='spring', loc_win_min='loc_win_min',
coordinate.
perodicity : str, optional
Periodicity of system. Default ('') indicates a nonperiodic reaction
coordiante. [[Currently passed straight on to wham, will likely change]]
coordiante. [[Currently passed straight on to wham, should change]]
num_bins : int, optional
Number of bins to use in histogram (= number of points in final profile).
tol : float, optional
Expand All @@ -126,13 +97,11 @@ def wham(bundle, spring='spring', loc_win_min='loc_win_min',
if not isinstance(bundle, dtr.Bundle):
TypeError('{} is not a bundle'.format(bundle))

if timeseries_type not in ['force', 'coord', 'delta']:
# TODO - check this later so don't have the possible options >1 place?
raise ValueError("Timeseries type should be 'force', 'coord', or "
"'distance' (see docs...)")
# TODO - temp catch timeseries_type is valid; will make this nicer
calc_reaction_coord(1, timeseries_type, 1 ,1)

if hist_min is not None and hist_max is not None:
if float(hist_min) >= float(hist_max): ## check float?
if float(hist_min) >= float(hist_max): ## check floatable?
raise ValueError('hist_min {} is greater than hist_max {}'.format(
hist_min, hist_max))

Expand All @@ -153,6 +122,7 @@ def wham(bundle, spring='spring', loc_win_min='loc_win_min',
if correl_time:
metadata.append(correl_time)

check_bundle_metadata(bundle, metadata)
# check if all simulations are at the same temperature...
temps = bundle.categories[temperature]
if all(t == temps[0] for t in temps):
Expand All @@ -173,8 +143,6 @@ def wham(bundle, spring='spring', loc_win_min='loc_win_min',
'performed at different temperatures')
auxiliaries.append(energy)

# check all the simulations have the metadata+auxiliary names specified
check_bundle_metadata(bundle, metadata)
check_bundle_auxiliaries(bundle, auxiliaries)

# TODO - check values for other options are valid (num_bins, numpad,
Expand All @@ -185,81 +153,69 @@ def wham(bundle, spring='spring', loc_win_min='loc_win_min',
# TODO - if keeping files, allow to specifiy file names/directory
# will the default file names always be valid? Particularly if sticking
# Sim name in it... (could just use index)
timeseriesfile_root = 'timeseries_{sim}.dat'
timeseriesfile_root = 'timeseries_{}.dat'
metadatafile = 'metadatafile.dat'
outfile = 'outfile.dat' # called freefile in Grossfield wham docs

#### WRITE INPUT FILES
## TODO - this is currently somewhat nasty, and currently only works
## for adding the force/coord data, not energy as well; will have
## TODO - getting the aux values is currently a bit nasty, will have
## a fiddle with auxiliary stuff to hopefully make this nicer...
with open(metadatafile, 'w') as meta_file:
global_min_val = None
global_max_val = None
passed_sims=[] #keep track of which simulations we actually feed through
# to wham (so we don't try run with none...)
for sim in bundle:
timeseries_file = timeseriesfile_root.format(sim=sim.name)
timeseries_file = timeseriesfile_root.format(sim.name)
k = sim.categories[spring]
x0 = sim.categories[loc_win_min]

# figure out the time range.
aux = sim.universe.trajectory._auxs[timeseries_data]
# figure out the time range. Assuming energy will be recorded at same points.
data = sim.universe.trajectory._auxs[timeseries_data]
step = 0
if start_time is None:
start_step = 0
else:
while aux.step_to_time(step) < start_time:
while data.step_to_time(step) < start_time:
step = step+1
if step == len(aux):
if step == len(data):
break
start_step = step
if end_time is None:
end_step = len(aux)
end_step = len(data)
else:
while aux.step_to_time(step) < end_time:
while data.step_to_time(step) < end_time:
step = step+1
if step == len(aux):
if step == len(data):
break
end_step = step
if start_step == len(aux) and end_time is not None:
if start_step == len(data) and end_time is not None:
warnings.warn('Simulation {} will be skipped (no data before '
'end_time ({} ps)'.format(sim.name, end_time))
elif end_step == 0 and start_time is not None:
warnings.warn('Simulation {} will be skipped (no data after '
'start_time ({} ps)'.format(sim.name, start_time))
else:
if end_step == len(aux):
end_step = end_step-1 #temp because errors with len(aux) - need to fix
if end_step == len(data):
end_step = end_step-1 # TODO temp because errors with len(aux) - need to fix
# write the timeseries file...
max_val = None
min_val = None
# assuming for now that we have energy/'timeseries data' at
# the exact same set of points.
with open(timeseries_file, 'w') as data_file:
for auxstep in aux[start_step:end_step]:
## IF IT'S MULTITEMP WE ALSO NEED AN ENERGY AUX - HAVE TO
## COORDINATE THE TWO SOMEHOW...
if timeseries_type == 'coord':
x = auxstep.data[0]
elif timeseries_type == 'force':
# F = -kx; does the missing 1/2 factor matter here?
x = -auxstep.data[0]/k
elif timeseries_type == 'delta':
# delta_x = x-x0
x = auxstep.data[0] + x0
data_file.write(str(auxstep.time)+' '+str(x)+'\n')

min_val = (x if min_val is None else x if x<min_val
else min_val)
max_val = (x if max_val is None else x if x>max_val
else max_val)


global_min_val = (min_val if global_min_val is None
else min_val if min_val < global_min_val
else min_val)
global_max_val = (max_val if global_max_val is None
else max_val if max_val > global_max_val
else max_val)
if multi_temp:
ener = sim.universe.trajectory._auxs[energy]
ener_data = [i.data[0] for i in ener[start_step:end_step]]
for i, as_data in enumerate(data[start_step:end_step]):
x = calc_reaction_coord(as_data.data[0],
timeseries_type, k, x0)
min_val = update_min(x, min_val)
max_val = update_max(x, max_val)
line = [as_data.time, x]
if multi_temp:
line.append(ener_data[i])
data_file.write(list_to_string(line)+'\n')

if hist_max is not None and min_val > float(hist_max):
warnings.warn('Simulation {} will be skipped (minimum value {} '
Expand All @@ -281,9 +237,12 @@ def wham(bundle, spring='spring', loc_win_min='loc_win_min',
metafile_info = [timeseries_file, x0, k, correl]
if multi_temp:
metafile_info.append(sim.categories[temperature])
meta_file.write(' '.join([str(i) for i in metafile_info])+'\n')
meta_file.write(list_to_string(metafile_info)+'\n')
passed_sims.append(sim.name)

global_min_val = update_min(min_val, global_min_val)
global_max_val = update_max(max_val, global_max_val)

if len(passed_sims) == 0:
raise ValueError('Aborting (all simulations skipped). Try increasing '
'time or reaction coordinate range.')
Expand All @@ -301,28 +260,84 @@ def wham(bundle, spring='spring', loc_win_min='loc_win_min',
randSeed = 1 # TODO - how to deal with random seed - should make an argument?
if run_bootstrap:
wham_args=wham_args+[num_MC_trials, randSeed]
os.system(wham_command+' '+' '.join([str(i) for i in wham_args]))
## TODO - switch to subprocess; check exit code - catch any errors?
os.system(wham_command+' '+list_to_string(wham_args))
## TODO - switch to subprocess; [+ catch any errors etc]


#### PARSE OUTPUT FILE
outfile = np.genfromtxt(outfile)
outfiledata = np.genfromtxt(outfile)
if not run_bootstrap:
profile = outfile[:,:2]
profile = outfiledata[:,:2]
if run_bootstrap:
profile = outfile[:,:3]
profile = outfiledata[:,:3]

#### CLEANUP FILES
#if not keep_files:
# # TODO - best way to remove files? (all in a temp directory?)
# os.remove(metadatafile)
# os.remove(outfile)
# for sim in passed_sims:
# os.remove(timeseriesfile_root.format(sim)) # or use a wildcard
if not keep_files:
# TODO - best way to remove files? (all in a temp directory?)
os.remove(metadatafile)
os.remove(outfile)
for sim in passed_sims:
os.remove(timeseriesfile_root.format(sim)) # or use a wildcard

####
return profile
# TODO - in the outfile we also get the probability + it's error in [:,3]
# and [:,4]; and the 'F-values' for each simulation (that we didn't skip);
# option to get prob instead of PMF? option to return Fvalues as well (tuple?)




def calc_reaction_coord(value, value_type, k, x0):
""" Calculate value of reaction coordinate corresponding to *value*.
Calculate the reaction coordinate at a particular time point from *value*,
depending on *value_type*. Currently allowed types are:
- coord: reaction coordinate value
- force: value of the restraining force; harmonic potential is assumed
so x = -F/k + x0
- delta: difference in reaction coord value and minimum of restraining
potential; x = delta_x + x0
[...]
"""
if value_type == 'coord':
x = value
elif value_type == 'force':
# F = -k delta_x; does the missing 1/2 factor matter here?
x = -value/k + x0
elif value_type == 'delta':
# delta_x = x-x0
x = value + x0
else:
raise ValueError('{} is not a valid timeseries data type'.format(value_type))
return x

def check_bundle_metadata(bundle, expected):
""" check each simulation in bundle has the expected metadata """
common_metadata = bundle.categories.keys()
for meta in expected:
if meta not in common_metadata:
raise ValueError("Not all simulations contain metadata {}."
"(Common metadata: {})".format(meta, common_metadata))
## TODO - also check if all the values are of the expected type?


def check_bundle_auxiliaries(bundle, expected):
""" check each simulation in bundle has expected auxiliary """
# when auxiliaries added to mdsynthesis, this might become more direct
for aux in expected:
for sim in bundle:
if aux not in sim.universe.trajectory.aux_list:
raise ValueError("Simulation {} does not contain auxiliary data"
" {}".format(sim.name, aux))
## TODO - also check it's got the right length/type?

def list_to_string(lst):
return ' '.join([str(i) for i in lst])

def update_min(new, curr):
return new if curr is None else new if new < curr else curr

def update_max(new, curr):
return new if curr is None else new if new > curr else curr

0 comments on commit 41f322d

Please sign in to comment.