From 1d249193039086144bd96222a053c49a9a40a6ee Mon Sep 17 00:00:00 2001 From: Thomas Baumann <39156931+brownbaerchen@users.noreply.github.com> Date: Wed, 24 Apr 2024 07:58:51 +0200 Subject: [PATCH 1/5] Added tutorial for PyTorch tensor datatype --- .github/workflows/ci_pipeline.yml | 2 +- docs/source/tutorial/doc_step_7_D.rst | 7 ++ etc/environment-pytorch.yml | 16 ++++ pySDC/playgrounds/ML_initial_guess/ml_heat.py | 10 ++- pySDC/playgrounds/ML_initial_guess/tensor.py | 38 +++++--- pySDC/tests/test_tutorials/test_step_7.py | 7 ++ pySDC/tutorial/step_7/D_pySDC_with_PyTorch.py | 89 +++++++++++++++++++ pySDC/tutorial/step_7/README.rst | 13 +++ pyproject.toml | 1 + 9 files changed, 168 insertions(+), 15 deletions(-) create mode 100644 docs/source/tutorial/doc_step_7_D.rst create mode 100644 etc/environment-pytorch.yml create mode 100644 pySDC/tutorial/step_7/D_pySDC_with_PyTorch.py diff --git a/.github/workflows/ci_pipeline.yml b/.github/workflows/ci_pipeline.yml index 6f645caa94..618de9648d 100644 --- a/.github/workflows/ci_pipeline.yml +++ b/.github/workflows/ci_pipeline.yml @@ -61,7 +61,7 @@ jobs: strategy: matrix: python: ['3.8', '3.9', '3.10'] - env: ['base', 'fenics', 'mpi4py', 'petsc'] + env: ['base', 'fenics', 'mpi4py', 'petsc', 'pytorch'] defaults: run: diff --git a/docs/source/tutorial/doc_step_7_D.rst b/docs/source/tutorial/doc_step_7_D.rst new file mode 100644 index 0000000000..1a9e6216a7 --- /dev/null +++ b/docs/source/tutorial/doc_step_7_D.rst @@ -0,0 +1,7 @@ +Full code: `pySDC/tutorial/step_7/D_pySDC_with_PyTorch.py `_ + +.. literalinclude:: ../../../pySDC/tutorial/step_7/D_pySDC_with_PyTorch.py + +Results: + +.. literalinclude:: ../../../data/step_7_D_out.txt diff --git a/etc/environment-pytorch.yml b/etc/environment-pytorch.yml new file mode 100644 index 0000000000..b80e21c5de --- /dev/null +++ b/etc/environment-pytorch.yml @@ -0,0 +1,16 @@ +name: pySDC +channels: + - conda-forge + - defaults +dependencies: + - numpy + - scipy>=0.17.1 + - sympy>=1.0 + - pytorch + - matplotlib>=3.0 + - pytest + - pytest-benchmark + - pytest-timeout + - pytest-order + - coverage[toml] + - sphinx diff --git a/pySDC/playgrounds/ML_initial_guess/ml_heat.py b/pySDC/playgrounds/ML_initial_guess/ml_heat.py index c3286869f0..76ef47a086 100644 --- a/pySDC/playgrounds/ML_initial_guess/ml_heat.py +++ b/pySDC/playgrounds/ML_initial_guess/ml_heat.py @@ -79,6 +79,7 @@ class HeatEquationModel(nn.Module): def __init__(self, problem, hidden_size=64): self.input_size = problem.nvars * 3 self.output_size = problem.nvars + self.problem = problem super().__init__() @@ -93,8 +94,8 @@ def __init__(self, problem, hidden_size=64): def forward(self, x, t, dt): # prepare individual tensors x = x.float() - _t = torch.ones_like(x) * t - _dt = torch.ones_like(x) * dt + _t = torch.ones(x.shape) * dt + _dt = torch.ones(x.shape) * dt # Concatenate t and dt with the input x _x = torch.cat((x, _t, _dt), dim=0) @@ -104,6 +105,11 @@ def forward(self, x, t, dt): _x = self.fc2(_x) return _x + def __call__(self, *args, **kwargs): + me = self.problem.u_init + me[:] = super().__call__(*args, **kwargs) + return me + def train_at_collocation_nodes(): """ diff --git a/pySDC/playgrounds/ML_initial_guess/tensor.py b/pySDC/playgrounds/ML_initial_guess/tensor.py index c28c321213..d59ef4b8d2 100644 --- a/pySDC/playgrounds/ML_initial_guess/tensor.py +++ b/pySDC/playgrounds/ML_initial_guess/tensor.py @@ -4,9 +4,6 @@ from pySDC.core.Errors import DataError try: - # TODO : mpi4py cannot be imported before dolfin when using fenics mesh - # see https://github.com/Parallel-in-Time/pySDC/pull/285#discussion_r1145850590 - # This should be dealt with at some point from mpi4py import MPI except ImportError: MPI = None @@ -26,7 +23,7 @@ class Tensor(torch.Tensor): @staticmethod def __new__(cls, init, val=0.0, *args, **kwargs): """ - Instantiates new datatype. This ensures that even when manipulating data, the result is still a mesh. + Instantiates new datatype. This ensures that even when manipulating data, the result is still a tensor. Args: init: either another mesh or a tuple containing the dimensions, the communicator and the dtype @@ -52,6 +49,31 @@ def __new__(cls, init, val=0.0, *args, **kwargs): raise NotImplementedError(type(init)) return obj + def __add__(self, *args, **kwargs): + res = super().__add__(*args, **kwargs) + res._comm = self.comm + return res + + def __sub__(self, *args, **kwargs): + res = super().__sub__(*args, **kwargs) + res._comm = self.comm + return res + + def __lmul__(self, *args, **kwargs): + res = super().__lmul__(*args, **kwargs) + res._comm = self.comm + return res + + def __rmul__(self, *args, **kwargs): + res = super().__rmul__(*args, **kwargs) + res._comm = self.comm + return res + + def __mul__(self, *args, **kwargs): + res = super().__mul__(*args, **kwargs) + res._comm = self.comm + return res + @property def comm(self): """ @@ -59,14 +81,6 @@ def comm(self): """ return self._comm - def __array_finalize__(self, obj): - """ - Finalizing the datatype. Without this, new datatypes do not 'inherit' the communicator. - """ - if obj is None: - return - self._comm = getattr(obj, '_comm', None) - def __abs__(self): """ Overloading the abs operator diff --git a/pySDC/tests/test_tutorials/test_step_7.py b/pySDC/tests/test_tutorials/test_step_7.py index 559831d1c4..cd2f4ec8b3 100644 --- a/pySDC/tests/test_tutorials/test_step_7.py +++ b/pySDC/tests/test_tutorials/test_step_7.py @@ -120,3 +120,10 @@ def test_C_2x2(): for line in p.stderr: print(line) assert p.returncode == 0, 'ERROR: did not get return code 0, got %s with %2i processes' % (p.returncode, num_procs) + + +@pytest.mark.pytorch +def test_D(): + from pySDC.tutorial.step_7.D_pySDC_with_PyTorch import train_at_collocation_nodes + + train_at_collocation_nodes() diff --git a/pySDC/tutorial/step_7/D_pySDC_with_PyTorch.py b/pySDC/tutorial/step_7/D_pySDC_with_PyTorch.py new file mode 100644 index 0000000000..83094d3e39 --- /dev/null +++ b/pySDC/tutorial/step_7/D_pySDC_with_PyTorch.py @@ -0,0 +1,89 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +from pySDC.playgrounds.ML_initial_guess.ml_heat import HeatEquationModel, Train_pySDC +from pySDC.playgrounds.ML_initial_guess.heat import Heat1DFDTensor + + +def train_at_collocation_nodes(): + """ + For the first proof of concept, we want to train the model specifically to the collocation nodes we use in SDC. + If successful, the initial guess would already be the exact solution and we would need no SDC iterations. + + What we find is that we can train the network to predict the solution to one very specific problem rather well. + See the error during training for what happens when we ask the network to solve for exactly what it just trained. + However, if we train for something else, i.e. solving to a different step size in this case, we can only use the + model to predict the solution of what it's been trained for last and it loses the ability to solve for previously + learned things. This is solely because we chose an overly simple model that is unsuitable to the task at hand and + is likely easily solved with a bit of patience. This is just a demonstration of the interface between pySDC and + PyTorch. If you want to do a project with this, feel free to take this as a starting point and do things that + actually do something! + + The output shows the training loss during training and, after each of three training sessions is complete, the error + of the prediction with the current state of the network. To demonstrate the forgetfulness, we finally print the + error of all learned predictions after training is complete. + """ + out = '' + errors_mid_training = [] + errors_post_training = [] + + # instantiate the pySDC problem and a model for PyTorch + problem = Heat1DFDTensor() + model = HeatEquationModel(problem) + + # setup neural network + lr = 0.001 + num_epochs = 250 + criterion = nn.MSELoss() + optimizer = optim.Adam(model.parameters(), lr=lr) + + # setup initial conditions + t = 0 + initial_condition = problem.u_exact(t) + + # train the model to predict the solution at certain collocation nodes + collocation_nodes = np.array([0.15505102572168285, 0.6449489742783183, 1]) * 1e-2 + for dt in collocation_nodes: + + # get target condition from implicit Euler step + target_condition = problem.solve_system(initial_condition, dt, initial_condition, t) + + # do the training + for epoch in range(num_epochs): + predicted_state = model(initial_condition, t, dt) + loss = criterion(predicted_state.float(), target_condition.float()) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + if (epoch + 1) % 50 == 0: + out += f'Training for {dt=:.2e}: Epoch [{epoch+1:4d}/{num_epochs:4d}], Loss: {loss.item():.4e}\n' + + # evaluate model to compute error + model_prediction = model(initial_condition, t, dt) + errors_mid_training += [abs(target_condition - model_prediction)] + out += f'Error of prediction at {dt:.2e} during training: {abs(target_condition-model_prediction):.2e}\n' + + # compare model and problem + for dt in collocation_nodes: + target_condition = problem.solve_system(initial_condition, dt, initial_condition, t) + model_prediction = model(initial_condition, t, dt) + errors_post_training += [abs(target_condition - model_prediction)] + out += f'Error of prediction at {dt:.2e} after training: {abs(target_condition-model_prediction):.2e}\n' + + print(out) + with open('data/step_7_D_out.txt', 'w') as file: + file.write(out) + + # test that the training went as expected + assert np.greater([1e-2, 1e-4, 1e-5], errors_mid_training).all(), 'Errors during training are larger than expected' + assert np.greater([1e0, 1e0, 1e-5], errors_post_training).all(), 'Errors after training are larger than expected' + + # save the model to use it throughout pySDC + torch.save(model.state_dict(), 'data/heat_equation_model.pth') + + +if __name__ == '__main__': + train_at_collocation_nodes() diff --git a/pySDC/tutorial/step_7/README.rst b/pySDC/tutorial/step_7/README.rst index c52b373f5a..d21f6cd7e0 100644 --- a/pySDC/tutorial/step_7/README.rst +++ b/pySDC/tutorial/step_7/README.rst @@ -51,3 +51,16 @@ Important things to note: - Below, we run the code 3 times: with 1 and 2 processors in space as well as 4 processors (2 in time and 2 in space). Do not expect scaling due to the CI environment. .. include:: doc_step_7_C.rst + + +Part D: pySDC and PyTorch +------------------------- + +PyTorch is a library for machine learning. The data structure is called tensor and allows to run on CPUs as well as GPUs in addition to access to various machine learning methods. +Since the potential for use in pySDC is very large, we have started on a datatype that allows to use PyTorch tensors throughout pySDC. + +This example trains a network to predict the results of implicit Euler solves for the heat equation. It is too simple to do anything useful, but demonstrates how to use tensors in pySDC and then apply the enormous PyTorch infrastructure. +This is work in progress in very early stages! The tensor datatype is the simplest possible implementation, rather than an efficient one. +If you want to work on this, your input is appreciated! + +.. include:: doc_step_7_D.rst diff --git a/pyproject.toml b/pyproject.toml index 37f2be3aad..49473e153f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,6 +59,7 @@ markers = [ 'cupy: tests for cupy on GPUs', 'libpressio: tests using the libpressio library', 'monodomain: tests the monodomain project, which requires previous compilation of c++ code', + 'pytorch: tests for PyTorch related things in pySDC' ] timeout = 300 From fd7ad4328d14aeee89fc2295c5efb91100880740 Mon Sep 17 00:00:00 2001 From: Thomas Baumann <39156931+brownbaerchen@users.noreply.github.com> Date: Wed, 24 Apr 2024 08:06:11 +0200 Subject: [PATCH 2/5] Added dill to environment --- etc/environment-pytorch.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/etc/environment-pytorch.yml b/etc/environment-pytorch.yml index b80e21c5de..514e129cd9 100644 --- a/etc/environment-pytorch.yml +++ b/etc/environment-pytorch.yml @@ -8,6 +8,7 @@ dependencies: - sympy>=1.0 - pytorch - matplotlib>=3.0 + - dill - pytest - pytest-benchmark - pytest-timeout From a81d2e25b360dbe265063b17423651455e73b359 Mon Sep 17 00:00:00 2001 From: Thomas Baumann <39156931+brownbaerchen@users.noreply.github.com> Date: Wed, 24 Apr 2024 08:32:04 +0200 Subject: [PATCH 3/5] Removed dill dependency --- etc/environment-pytorch.yml | 1 - .../test_TOMS/test_AllenCahn_contracting_circle.py | 4 ++-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/etc/environment-pytorch.yml b/etc/environment-pytorch.yml index 514e129cd9..b80e21c5de 100644 --- a/etc/environment-pytorch.yml +++ b/etc/environment-pytorch.yml @@ -8,7 +8,6 @@ dependencies: - sympy>=1.0 - pytorch - matplotlib>=3.0 - - dill - pytest - pytest-benchmark - pytest-timeout diff --git a/pySDC/tests/test_projects/test_TOMS/test_AllenCahn_contracting_circle.py b/pySDC/tests/test_projects/test_TOMS/test_AllenCahn_contracting_circle.py index d5dd28d9ce..57e14d38bc 100644 --- a/pySDC/tests/test_projects/test_TOMS/test_AllenCahn_contracting_circle.py +++ b/pySDC/tests/test_projects/test_TOMS/test_AllenCahn_contracting_circle.py @@ -1,6 +1,4 @@ import pytest -import dill -import os results = {} @@ -21,6 +19,8 @@ def test_AllenCahn_contracting_circle(variant, inexact): @pytest.mark.base @pytest.mark.order(2) def test_show_results(): + import dill + import os from pySDC.projects.TOMS.AllenCahn_contracting_circle import show_results # dump result From a194fd98c90234972bd74a6f7946f65281ad7e61 Mon Sep 17 00:00:00 2001 From: Thomas Baumann <39156931+brownbaerchen@users.noreply.github.com> Date: Wed, 24 Apr 2024 08:37:35 +0200 Subject: [PATCH 4/5] Added dill again --- etc/environment-pytorch.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/etc/environment-pytorch.yml b/etc/environment-pytorch.yml index b80e21c5de..514e129cd9 100644 --- a/etc/environment-pytorch.yml +++ b/etc/environment-pytorch.yml @@ -8,6 +8,7 @@ dependencies: - sympy>=1.0 - pytorch - matplotlib>=3.0 + - dill - pytest - pytest-benchmark - pytest-timeout From c31e10e2ed3d3889aad23a0c4120c24de89c50ca Mon Sep 17 00:00:00 2001 From: Thomas Baumann <39156931+brownbaerchen@users.noreply.github.com> Date: Mon, 6 May 2024 13:43:04 +0200 Subject: [PATCH 5/5] Removed unnecessary modules in PyTorch environment --- etc/environment-pytorch.yml | 6 ------ 1 file changed, 6 deletions(-) diff --git a/etc/environment-pytorch.yml b/etc/environment-pytorch.yml index 514e129cd9..5c5658c807 100644 --- a/etc/environment-pytorch.yml +++ b/etc/environment-pytorch.yml @@ -9,9 +9,3 @@ dependencies: - pytorch - matplotlib>=3.0 - dill - - pytest - - pytest-benchmark - - pytest-timeout - - pytest-order - - coverage[toml] - - sphinx