Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Universal Solvers #190

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions neurodiffeq/conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,11 @@ def parameterize(self, output_tensor, t):
:rtype: `torch.Tensor`
"""
if self.u_0_prime is None:
if isinstance(self.u_0, list):
parameterized = torch.zeros_like(output_tensor)
for i in range(len(self.u_0)):
parameterized[:, i] = (self.u_0[i] + (1 - torch.exp(-t + self.t_0)) * output_tensor[:, i].view(-1, 1))[:, 0]
return parameterized
return self.u_0 + (1 - torch.exp(-t + self.t_0)) * output_tensor
else:
return self.u_0 + (t - self.t_0) * self.u_0_prime + ((1 - torch.exp(-t + self.t_0)) ** 2) * output_tensor
Expand Down
312 changes: 307 additions & 5 deletions neurodiffeq/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from .generators import Generator2D
from .generators import GeneratorND
from .function_basis import RealSphericalHarmonics
from .conditions import BaseCondition
from .conditions import BaseCondition, NoCondition
from .neurodiffeq import safe_diff as diff
from .losses import _losses

Expand Down Expand Up @@ -113,7 +113,7 @@ class BaseSolver(ABC, PretrainedSolver):
def __init__(self, diff_eqs, conditions,
nets=None, train_generator=None, valid_generator=None, analytic_solutions=None,
optimizer=None, loss_fn=None, n_batches_train=1, n_batches_valid=4,
metrics=None, n_input_units=None, n_output_units=None,
metrics=None, n_input_units=None, n_output_units=None, system_parameters=None,
# deprecated arguments are listed below
shuffle=None, batch_size=None):
# deprecate argument `shuffle`
Expand All @@ -130,6 +130,9 @@ def __init__(self, diff_eqs, conditions,
)

self.diff_eqs = diff_eqs
self.system_parameters = {}
if system_parameters is not None:
self.system_parameters = system_parameters
self.conditions = conditions
self.n_funcs = len(conditions)
if nets is None:
Expand Down Expand Up @@ -376,7 +379,7 @@ def closure(zero_grad=True):
for name in self.metrics_fn:
value = self.metrics_fn[name](*funcs, *batch).item()
metric_values[name] += value
residuals = self.diff_eqs(*funcs, *batch)
residuals = self.diff_eqs(*funcs, *batch, **self.system_parameters)
residuals = torch.cat(residuals, dim=1)
try:
loss = self.loss_fn(residuals, funcs, batch) + self.additional_loss(residuals, funcs, batch)
Expand Down Expand Up @@ -1105,7 +1108,7 @@ class Solver1D(BaseSolver):

def __init__(self, ode_system, conditions, t_min=None, t_max=None,
nets=None, train_generator=None, valid_generator=None, analytic_solutions=None, optimizer=None,
loss_fn=None, n_batches_train=1, n_batches_valid=4, metrics=None, n_output_units=1,
loss_fn=None, n_batches_train=1, n_batches_valid=4, metrics=None, n_output_units=1, system_parameters=None,
# deprecated arguments are listed below
batch_size=None, shuffle=None):

Expand Down Expand Up @@ -1136,6 +1139,7 @@ def __init__(self, ode_system, conditions, t_min=None, t_max=None,
metrics=metrics,
n_input_units=1,
n_output_units=n_output_units,
system_parameters=system_parameters,
shuffle=shuffle,
batch_size=batch_size,
)
Expand Down Expand Up @@ -1164,11 +1168,12 @@ def get_solution(self, copy=True, best=True):
:rtype: BaseSolution
"""
nets = self.best_nets if best else self.nets
print(nets)
conditions = self.conditions
if copy:
nets = deepcopy(nets)
conditions = deepcopy(conditions)

print(nets)
return Solution1D(nets, conditions)

def _get_internal_variables(self):
Expand Down Expand Up @@ -1590,3 +1595,300 @@ def _get_internal_variables(self):
'xy_max': self.xy_max,
})
return available_variables

class _SingleSolver1D(GenericSolver):

class Head(nn.Module):
def __init__(self, u_0, base, n_input, n_output=1):
super().__init__()
self.u_0 = u_0
self.base = base
self.last_layer = nn.Linear(n_input, n_output)

def forward(self, x):
x = self.base(x)
x = self.last_layer(x)
return x

def __init__(self, bases, HeadClass, initial_conditions, n_last_layer_head, diff_eqs,
system_parameters=[{}],
optimizer=torch.optim.Adam, optimizer_args=None, optimizer_kwargs={"lr":1e-3},
train_generator=None, valid_generator=None, n_batches_train=1, n_batches_valid=4,
loss_fn=None, metrics=None, is_system=False):

if train_generator is None or valid_generator is None:
raise Exception(f"Train and Valid Generator cannot be None")

self.num = len(initial_conditions)
self.bases = bases
if HeadClass is None:
if is_system:
self.head = [self.Head(initial_conditions[i], self.bases[i], n_last_layer_head) for i in range(self.num)]
else:
self.head = [self.Head(torch.Tensor(initial_conditions).view(1, -1), self.bases, n_last_layer_head, len(initial_conditions))]
else:
if is_system:
self.head = [HeadClass(initial_conditions[i], self.bases[i], n_last_layer_head) for i in range(self.num)]
else:
self.head = [HeadClass(torch.Tensor(initial_conditions).view(1, -1), self.bases, n_last_layer_head, len(initial_conditions))]

self.optimizer_args = optimizer_args or ()
self.optimizer_kwargs = optimizer_kwargs or {}

if isinstance(optimizer, torch.optim.Optimizer):
self.optimizer = optimizer
elif issubclass(optimizer, torch.optim.Optimizer):
params = chain.from_iterable(n.parameters() for n in self.head)
self.optimizer = optimizer(params, *self.optimizer_args, **self.optimizer_kwargs)
else:
raise TypeError(f"Unknown optimizer instance/type {self.optimizer}")

super().__init__(
diff_eqs=diff_eqs,
conditions=[NoCondition()]*self.num,
train_generator=train_generator,
valid_generator=valid_generator,
nets=self.head,
system_parameters=system_parameters,
optimizer=self.optimizer,
n_batches_train=n_batches_train,
n_batches_valid=n_batches_valid,
loss_fn=loss_fn,
metrics=metrics
)

def additional_loss(self, residuals, funcs, coords):

loss = 0
for i in range(len(self.nets)):
out = self.nets[i](torch.zeros((1,1)))
loss += ((self.nets[i].u_0 - out)**2).mean()
return loss


class UniversalSolver1D(ABC):
r"""A solver class for solving a family of ODEs (for different initial conditions and parameters)

:param ode_system:
The ODE system to solve, which maps a torch.Tensor to a tuple of ODE residuals,
both the input and output must have shape (n_samples, 1).
:type ode_system: callable
"""

class Base(nn.Module):
def __init__(self):
super().__init__()
self.linear_1 = nn.Linear(1, 10)
self.linear_2 = nn.Linear(10, 10)
self.linear_3 = nn.Linear(10, 10)

def forward(self, x):
x = self.linear_1(x)
x = torch.tanh(x)
x = self.linear_2(x)
x = torch.tanh(x)
x = self.linear_3(x)
x = torch.tanh(x)
return x

def __init__(self, diff_eqs, is_system = True):

self.diff_eqs = diff_eqs
self.is_system = is_system

self.t_min = None
self.t_max = None
self.train_generator = None
self.valid_generator = None

def build(self,u_0s=None,
system_parameters=[{}],
BaseClass=Base,
HeadClass=None,
n_last_layer_head=10,
build_source=False,
optimizer=torch.optim.Adam,
optimizer_args=None, optimizer_kwargs={"lr":1e-3},
t_min=None,
t_max=None,
train_generator=None,
valid_generator=None,
n_batches_train=1,
n_batches_valid=4,
loss_fn=None,
metrics=None):

r"""
:param system_parameters:
List of dictionaries of parameters for which the solver will be trained
:type system_parameters: list[dict]
:param BaseClass:
Neural network class for base networks
:type nets: torch.nn.Module
:param n_last_layer_head:
Number of neurons in the last layer for each network
:type n_last_layer_head: int
:param build_source:
Boolean value for training the base networks or freezing their weights
:type build_source: bool
:param optimizer:
Optimizer to be used for training.
Defaults to a ``torch.optim.Adam`` instance that trains on all parameters of ``nets``.
:type optimizer: ``torch.nn.optim.Optimizer``, optional
:param t_min:
Lower bound of input (start time).
Ignored if ``train_generator`` and ``valid_generator`` are both set.
:type t_min: float, optional
:param t_max:
Upper bound of input (start time).
Ignored if ``train_generator`` and ``valid_generator`` are both set.
:type t_max: float, optional
:param train_generator:
Generator for sampling training points,
which must provide a ``.get_examples()`` method and a ``.size`` field.
``train_generator`` must be specified if ``t_min`` and ``t_max`` are not set.
:type train_generator: `neurodiffeq.generators.BaseGenerator`, optional
:param valid_generator:
Generator for sampling validation points,
which must provide a ``.get_examples()`` method and a ``.size`` field.
``valid_generator`` must be specified if ``t_min`` and ``t_max`` are not set.
:type valid_generator: `neurodiffeq.generators.BaseGenerator`, optional
:param n_batches_train:
Number of batches to train in every epoch, where batch-size equals ``train_generator.size``.
Defaults to 1.
:type n_batches_train: int, optional
:param n_batches_valid:
Number of batches to validate in every epoch, where batch-size equals ``valid_generator.size``.
Defaults to 4.
:type n_batches_valid: int, optional
:param loss_fn:
The loss function used for training.

- If a str, must be present in the keys of `neurodiffeq.losses._losses`.
- If a `torch.nn.modules.loss._Loss` instance, just pass the instance.
- If any other callable, it must map
A) a residual tensor (shape `(n_points, n_equations)`),
B) a function values tuple (length `n_funcs`, each element a tensor of shape `(n_points, 1)`), and
C) a coordinate values tuple (length `n_coords`, each element a tensor of shape `(n_coords, 1)`
to a tensor of empty shape (i.e. a scalar). The returned tensor must be connected to the computational graph,
so that backpropagation can be performed.

:type loss_fn:
str or `torch.nn.moduesl.loss._Loss` or callable
:param metrics:
Additional metrics to be logged (besides loss). ``metrics`` should be a dict where

- Keys are metric names (e.g. 'analytic_mse');
- Values are functions (callables) that computes the metric value.
These functions must accept the same input as the differential equation ``ode_system``.

:type metrics: dict[str, callable], optional
"""

self.u_0s = u_0s
self.system_parameters = system_parameters
self.n_last_layer_head = n_last_layer_head

if t_min is not None:
self.t_min = t_min
if t_max is not None:
self.t_max = t_max

if self.t_min is not None and self.t_max is not None:
self.train_generator = Generator1D(32, t_min=self.t_min, t_max=self.t_max, method='equally-spaced-noisy')
self.valid_generator = Generator1D(32, t_min=self.t_min, t_max=self.t_max, method='equally-spaced-noisy')

if train_generator is not None:
self.train_generator = train_generator
if valid_generator is not None:
self.valid_generator = valid_generator

if self.u_0s is None:
raise Exception("ICs must be specified")
if self.train_generator is None or self.valid_generator is None:
raise Exception(f"Train and valid generators cannot be None. Either provide `t_min` and `t_max` \
or provide the generators as arguments")

self.optimizer = optimizer
self.optimizer_args = optimizer_args or ()
self.optimizer_kwargs = optimizer_kwargs or {}

if build_source:
if self.is_system:
self.bases = [BaseClass() for _ in range(len(u_0s[0]))]
else:
self.bases = BaseClass()

self.solvers_base = [_SingleSolver1D(
bases=self.bases,
HeadClass=HeadClass,
initial_conditions=self.u_0s[i],
n_last_layer_head=n_last_layer_head,
diff_eqs=self.diff_eqs,
train_generator=self.train_generator,
valid_generator=self.valid_generator,
system_parameters=self.system_parameters[p],
optimizer=optimizer,optimizer_args=optimizer_args, optimizer_kwargs=optimizer_kwargs,
n_batches_train=n_batches_train,
n_batches_valid=n_batches_valid,
loss_fn=loss_fn,
metrics=metrics,
is_system=self.is_system
) for i in range(len(u_0s)) for p in range(len(self.system_parameters))]
else:
self.solvers_head = [_SingleSolver1D(
bases=self.bases,
HeadClass=HeadClass,
initial_conditions=self.u_0s[i],
n_last_layer_head=self.n_last_layer_head,
diff_eqs=self.diff_eqs,
train_generator=self.train_generator,
valid_generator=self.valid_generator,
system_parameters=self.system_parameters[p],
optimizer=optimizer,optimizer_args=optimizer_args, optimizer_kwargs=optimizer_kwargs,
n_batches_train=n_batches_train,
n_batches_valid=n_batches_valid,
loss_fn=loss_fn,
metrics=metrics,
is_system=self.is_system
) for i in range(len(self.u_0s)) for p in range(len(self.system_parameters))]


def fit(self, epochs=10, freeze_source=True):
r"""
:param epochs:
Number of epochs for training
:type epochs: int
:param freeze_source:
Boolean value indicating whether to freeze the base networks or not
:type freeze_source: bool
"""

if not freeze_source:
for i in range(len(self.solvers_base)):
self.solvers_base[i].fit(max_epochs=epochs)
else:
if self.is_system:
for net in self.bases:
for param in net.parameters():
param.requires_grad = False
else:
for param in self.bases.parameters():
param.requires_grad = False
for i in range(len(self.solvers_head)):
self.solvers_head[i].fit(max_epochs=epochs)


def get_solution(self, base=False):
r"""
:param base:
Boolean value indicating whether to get solutions for those conditions for which the base
was trained or solutions for those conditions for which only the last layer was trained
:type base: bool
:rtype: list[BaseSolution]
"""

if base:
return [self.solvers_base[i].get_solution() for i in range(len(self.solvers_base))]
else:
return [self.solvers_head[i].get_solution() for i in range(len(self.solvers_head))]