Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 50 additions & 38 deletions pyhdx/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def check_bounds(fit_result):


def run_optimizer(inputs, output_data, optimizer_klass, optimizer_kwargs, model, criterion, regularizer,
epochs=EPOCHS, patience=PATIENCE, stop_loss=STOP_LOSS, tqdm=True):
epochs=EPOCHS, patience=PATIENCE, stop_loss=STOP_LOSS, callbacks=None, tqdm=True):
"""

Runs optimization/fitting of PyTorch model.
Expand All @@ -311,6 +311,10 @@ def run_optimizer(inputs, output_data, optimizer_klass, optimizer_kwargs, model,
Number of epochs with less progress than `stop_loss` before terminating optimization
stop_loss : :obj:`float`
Threshold of optimization value below which no progress is made
callbacks: :obj:`list` or `None`
List of callback functions
tqdm : :obj:`bool`
Toggle tqdm progress bar

Returns
-------
Expand All @@ -323,57 +327,61 @@ def run_optimizer(inputs, output_data, optimizer_klass, optimizer_kwargs, model,
np.random.seed(43)
torch.manual_seed(43)

mse_loss_list = [np.inf]
total_loss_list = [np.inf]
callbacks = callbacks or []
losses_list = [[np.inf]]

def closure():
output = model(*inputs)
loss = criterion(output, output_data)
mse_loss_list.append(loss.detach())
reg_loss = regularizer(model.deltaG)
total_loss = loss + reg_loss
total_loss_list.append(total_loss.detach()) # total_loss.item()
total_loss.backward()
return total_loss
losses_list.append([loss.item()]) # store mse loss
reg_loss_tuple = regularizer(model.deltaG)
for r in reg_loss_tuple:
loss += r

losses_list[-1] += [r.item() for r in reg_loss_tuple] # store reg losses

loss.backward()
return loss

stop = 0
iter = trange(epochs) if tqdm else range(epochs)
for epoch in iter:
optimizer_obj.zero_grad()
loss = optimizer_obj.step(closure)

diff = total_loss_list[-2] - total_loss_list[-1]
for cb in callbacks:
cb(epoch, model, optimizer_obj)

diff = sum(losses_list[-2]) - sum(losses_list[-1])
if diff < stop_loss:
stop += 1
if stop > patience:
break
else:
stop = 0

#par = model.deltaG.detach().numpy()
return np.array(mse_loss_list[1:]), np.array(total_loss_list[1:]), model
return np.array(losses_list[1:]), model


def regularizer_1d(r1, param):
reg_loss = r1 * torch.mean(torch.abs(param[:-1] - param[1:]))
return reg_loss * REGULARIZATION_SCALING
return reg_loss * REGULARIZATION_SCALING,


def regularizer_2d_mean(r1, r2, param):
#todo allow regularization wrt reference rather than mean
#param shape: Ns x Nr x 1
d_ax1 = torch.abs(param[:, :-1, :] - param[:, 1:, :])
d_ax2 = torch.abs(param - torch.mean(param, axis=0))
reg_loss = r1 * torch.mean(d_ax1) + r2 * torch.mean(d_ax2)
return reg_loss * REGULARIZATION_SCALING

return r1 * torch.mean(d_ax1) * REGULARIZATION_SCALING, r2 * torch.mean(d_ax2) * REGULARIZATION_SCALING


def regularizer_2d_reference(r1, r2, param):
#todo allow regularization wrt reference rather than mean
d_ax1 = torch.abs(param[:, :-1, :] - param[:, 1:, :])
d_ax2 = torch.abs(param - param[0])[1:]
reg_loss = r1 * torch.mean(d_ax1) + r2 * torch.mean(d_ax2)
return reg_loss * REGULARIZATION_SCALING

return r1 * torch.mean(d_ax1) * REGULARIZATION_SCALING, r2 * torch.mean(d_ax2) * REGULARIZATION_SCALING


def regularizer_2d_aligned(r1, r2, indices, param):
Expand All @@ -382,26 +390,23 @@ def regularizer_2d_aligned(r1, r2, indices, param):
d_ax1 = torch.abs(param[:, :-1, :] - param[:, 1:, :])
d_ax2 = torch.abs(param[0][i0] - param[1][i1])

reg_loss = r1 * torch.mean(d_ax1) + r2 * torch.mean(d_ax2)
return reg_loss * REGULARIZATION_SCALING
return r1 * torch.mean(d_ax1) * REGULARIZATION_SCALING, r2 * torch.mean(d_ax2) * REGULARIZATION_SCALING


def _loss_df(total_loss, mse_loss):
loss_dict = {
'total_loss': total_loss,
'mse_loss': mse_loss}
loss_dict['reg_loss'] = loss_dict['total_loss'] - loss_dict['mse_loss']
loss_dict['reg_percentage'] = loss_dict['reg_loss'] / loss_dict['total_loss'] * 100
def _loss_df(losses_array):
"""transforms losses array to losses dataframe
first column in losses array is mse loss, rest are regularzation losses
"""

loss_df = pd.DataFrame(loss_dict)
loss_df = pd.DataFrame(losses_array, columns=['mse_loss'] + [f'reg_{i + 1}' for i in range(losses_array.shape[1] - 1)])
loss_df.index.name = 'epoch'
loss_df.index += 1

return loss_df


def fit_gibbs_global(hdxm, initial_guess, r1=R1, epochs=EPOCHS, patience=PATIENCE, stop_loss=STOP_LOSS,
optimizer='SGD', **optimizer_kwargs):
optimizer='SGD', callbacks=None, **optimizer_kwargs):
"""
Fit Gibbs free energies globally to all D-uptake data in the supplied hdxm

Expand All @@ -415,6 +420,8 @@ def fit_gibbs_global(hdxm, initial_guess, r1=R1, epochs=EPOCHS, patience=PATIENC
patience
stop_loss
optimizer : :obj:`str`
callbacks: :obj:`list` or None
List of callback objects. call signature is cb(epoch, model, optimizer)
optimizer_kwargs

Returns
Expand Down Expand Up @@ -450,18 +457,18 @@ def fit_gibbs_global(hdxm, initial_guess, r1=R1, epochs=EPOCHS, patience=PATIENC
reg_func = partial(regularizer_1d, r1)

# returned_model is the same object as model
mse_loss, total_loss, returned_model = run_optimizer(inputs, output_data, optimizer_klass, optimizer_kwargs,
losses_array, returned_model = run_optimizer(inputs, output_data, optimizer_klass, optimizer_kwargs,
model, criterion, reg_func, epochs=epochs,
patience=patience, stop_loss=stop_loss)
losses = _loss_df(total_loss, mse_loss)
patience=patience, stop_loss=stop_loss, callbacks=callbacks)
losses = _loss_df(losses_array)
fit_kwargs.update(optimizer_kwargs)
result = TorchSingleFitResult(hdxm, model, losses=losses, **fit_kwargs)

return result


def fit_gibbs_global_batch(hdx_set, initial_guess, r1=R1, r2=R2, r2_reference=False, epochs=EPOCHS, patience=PATIENCE, stop_loss=STOP_LOSS,
optimizer='SGD', **optimizer_kwargs):
optimizer='SGD', callbacks=None, **optimizer_kwargs):
"""
Batch fit gibbs free energies to multiple HDX measurements

Expand All @@ -477,6 +484,8 @@ def fit_gibbs_global_batch(hdx_set, initial_guess, r1=R1, r2=R2, r2_reference=Fa
patience
stop_loss
optimizer
callbacks: :obj:`list` or None
List of callback objects. call signature is cb(epoch, model, optimizer)
optimizer_kwargs

Returns
Expand All @@ -485,7 +494,7 @@ def fit_gibbs_global_batch(hdx_set, initial_guess, r1=R1, r2=R2, r2_reference=Fa
"""
# todo still some repeated code with fit_gibbs single

fit_keys = ['r1', 'r2', 'r2_reference', 'epochs', 'patience', 'stop_loss', 'optimizer']
fit_keys = ['r1', 'r2', 'r2_reference', 'epochs', 'patience', 'stop_loss', 'optimizer', 'callbacks']
locals_dict = locals()
fit_kwargs = {k: locals_dict[k] for k in fit_keys}

Expand All @@ -497,8 +506,8 @@ def fit_gibbs_global_batch(hdx_set, initial_guess, r1=R1, r2=R2, r2_reference=Fa
return _batch_fit(hdx_set, initial_guess, reg_func, fit_kwargs, optimizer_kwargs)


def fit_gibbs_global_batch_aligned(hdx_set, initial_guess, r1=R1, r2=R2, epochs=EPOCHS, patience=PATIENCE, stop_loss=STOP_LOSS,
optimizer='SGD', **optimizer_kwargs):
def fit_gibbs_global_batch_aligned(hdx_set, initial_guess, r1=R1, r2=R2, epochs=EPOCHS, patience=PATIENCE,
stop_loss=STOP_LOSS, optimizer='SGD', callbacks=None, **optimizer_kwargs):
"""
Batch fit gibbs free energies to two HDX measurements. The supplied HDXMeasurementSet must have alignment information
(supplied by HDXMeasurementSet.add_alignment)
Expand All @@ -513,6 +522,8 @@ def fit_gibbs_global_batch_aligned(hdx_set, initial_guess, r1=R1, r2=R2, epochs=
epochs
patience
stop_loss
callbacks: :obj:`list` or None
List of callback objects. call signature is cb(epoch, model, optimizer)
optimizer
optimizer_kwargs

Expand All @@ -528,7 +539,7 @@ def fit_gibbs_global_batch_aligned(hdx_set, initial_guess, r1=R1, r2=R2, epochs=
indices = [torch.tensor(i, dtype=torch.long) for i in hdx_set.aligned_indices]
reg_func = partial(regularizer_2d_aligned, r1, r2, indices)

fit_keys = ['r1', 'r2', 'epochs', 'patience', 'stop_loss', 'optimizer']
fit_keys = ['r1', 'r2', 'epochs', 'patience', 'stop_loss', 'optimizer', 'callbacks']
locals_dict = locals()
fit_kwargs = {k: locals_dict[k] for k in fit_keys}

Expand All @@ -554,9 +565,10 @@ def _batch_fit(hdx_set, initial_guess, reg_func, fit_kwargs, optimizer_kwargs):
optimizer_klass = getattr(torch.optim, fit_kwargs['optimizer'])

loop_kwargs = {k: fit_kwargs[k] for k in ['epochs', 'patience', 'stop_loss']}
mse_loss, total_loss, returned_model = run_optimizer(inputs, output_data, optimizer_klass, optimizer_kwargs,
loop_kwargs['callbacks'] = fit_kwargs.pop('callbacks')
losses_array, returned_model = run_optimizer(inputs, output_data, optimizer_klass, optimizer_kwargs,
model, criterion, reg_func, **loop_kwargs)
losses = _loss_df(total_loss, mse_loss)
losses = _loss_df(losses_array)
fit_kwargs.update(optimizer_kwargs)
result = TorchBatchFitResult(hdx_set, model, losses=losses, **fit_kwargs)

Expand Down
54 changes: 47 additions & 7 deletions pyhdx/fitting_torch.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import SGD
import torch as t
from scipy import constants
from copy import deepcopy

import numpy as np
import pandas as pd
from pyhdx.models import Protein
import torch as t
import torch.nn as nn
from scipy import constants

from pyhdx.fileIO import dataframe_to_file
from pyhdx.models import Protein


class DeltaGFit(nn.Module):
Expand Down Expand Up @@ -100,7 +101,7 @@ def mse_loss(self):
@property
def total_loss(self):
"""obj:`float`: Total loss value of the Lagrangian"""
total_loss = self.losses['total_loss'].iloc[-1]
total_loss = self.losses.iloc[-1].sum()
return float(total_loss)

@property
Expand Down Expand Up @@ -224,3 +225,42 @@ def __call__(self, timepoints):

output = self.model(*inputs)
return output.detach().numpy()


class Callback(object):

def __call__(self, epoch, model, optimizer):
pass


class CheckPoint(Callback):

def __init__(self, epoch_step=1000):
self.epoch_step = epoch_step
self.model_history = {}

def __call__(self, epoch, model, optimizer):
if epoch % self.epoch_step == 0:
self.model_history[epoch] = deepcopy(model.state_dict())

def to_dataframe(self, names=None, field='deltaG'):
"""convert history of `field` into dataframe.
names must be given for batch fits with length equal to number of states

"""
entry = next(iter(self.model_history.values()))
g = entry[field]
if g.ndim == 3:
num_states = entry[field].shape[0] # G shape is Ns x Nr x 1
if not len(names) == num_states:
raise ValueError(f"Number of names provided must be equal to number of states ({num_states})")

dfs = []
for i in range(num_states):
df = pd.DataFrame({k: v[field].numpy()[i].squeeze() for k, v in self.model_history.items()})
dfs.append(df)
full_df = pd.concat(dfs, keys=names, axis=1)
else:
full_df = pd.DataFrame({k: v[field].numpy().squeeze() for k, v in self.model_history.items()})

return full_df
9 changes: 7 additions & 2 deletions pyhdx/web/apps.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from pyhdx import VERSION_STRING
from pyhdx.web.base import STATIC_DIR
from pyhdx.web.sources import DataFrameSource
from pyhdx.web.transforms import RescaleTransform, RemoveValueTransform, ApplyCmapTransform, PeptideLayoutTransform, ResetIndexTransform
from pyhdx.web.transforms import RescaleTransform, RemoveValueTransform, ApplyCmapTransform, PeptideLayoutTransform, ResetIndexTransform, \
AccumulateRegularizersTransform
from pyhdx.web.opts import CmapOpts
from pyhdx.web.filters import UniqueValuesFilter, MultiIndexSelectFilter
import logging
Expand Down Expand Up @@ -286,11 +287,15 @@ def main_app(client='default'):
filters=[filters['losses_fit_id'], filters['losses_state_name']]
)
view_list.append(losses)

accumulate_reg_trnsform = AccumulateRegularizersTransform(name='accumulate_regularizers')
trs_list.append(accumulate_reg_trnsform)

opts = {'color': 'r', **opts}
reg_losses = hvPlotAppView(
source=source, name='reg_losses', x='index', y='reg_loss', kind='line',
table='losses', streaming=True, responsive=True, opts=opts, label='reg',
transforms=[reset_index_transform_loss],
transforms=[reset_index_transform_loss, accumulate_reg_trnsform],
filters=[filters['losses_fit_id'], filters['losses_state_name']]
)
view_list.append(reg_losses)
Expand Down
17 changes: 17 additions & 0 deletions pyhdx/web/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,23 @@ def apply(self, table):
return table


class AccumulateRegularizersTransform(Transform):
"""
Very niche and temporary transform to accumulate reg losses to one column
"""

transform_type = 'accumulate_regularizers'

def apply(self, table):
# first two columns are index and mse_loss?
reg_total = table.iloc[:, 2:].sum(axis=1)
reg_total.name = 'reg_loss'

result = pd.concat([table.iloc[:, :2], reg_total], axis=1)

return result


class ResetIndexTransform(Transform):

level = param.ClassSelector(class_=(int, list, str), doc="""
Expand Down
2 changes: 1 addition & 1 deletion pyhdx/web/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
class hvPlotAppView(hvPlotView):

def get_data(self):

# get data filter using pandas query syntax?
try:
data = super().get_data()
except (KeyError, ValueError) as e:
Expand Down
Loading