Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci_pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ jobs:
strategy:
fail-fast: false
matrix:
env: ['base', 'fenics', 'mpi4py', 'petsc']
env: ['base', 'fenics', 'mpi4py', 'petsc', 'pytorch']
python: ['3.8', '3.9', '3.10', '3.11', '3.12']

defaults:
Expand Down
7 changes: 7 additions & 0 deletions docs/source/tutorial/doc_step_7_D.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Full code: `pySDC/tutorial/step_7/D_pySDC_with_PyTorch.py <https://github.com/Parallel-in-Time/pySDC/blob/master/pySDC/tutorial/step_7/D_pySDC_with_PyTorch.py>`_

.. literalinclude:: ../../../pySDC/tutorial/step_7/D_pySDC_with_PyTorch.py

Results:

.. literalinclude:: ../../../data/step_7_D_out.txt
11 changes: 11 additions & 0 deletions etc/environment-pytorch.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
name: pySDC
channels:
- conda-forge
- defaults
dependencies:
- numpy
- scipy>=0.17.1
- sympy>=1.0
- pytorch
- matplotlib>=3.0
- dill
10 changes: 8 additions & 2 deletions pySDC/playgrounds/ML_initial_guess/ml_heat.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ class HeatEquationModel(nn.Module):
def __init__(self, problem, hidden_size=64):
self.input_size = problem.nvars * 3
self.output_size = problem.nvars
self.problem = problem

super().__init__()

Expand All @@ -93,8 +94,8 @@ def __init__(self, problem, hidden_size=64):
def forward(self, x, t, dt):
# prepare individual tensors
x = x.float()
_t = torch.ones_like(x) * t
_dt = torch.ones_like(x) * dt
_t = torch.ones(x.shape) * dt
_dt = torch.ones(x.shape) * dt

# Concatenate t and dt with the input x
_x = torch.cat((x, _t, _dt), dim=0)
Expand All @@ -104,6 +105,11 @@ def forward(self, x, t, dt):
_x = self.fc2(_x)
return _x

def __call__(self, *args, **kwargs):
me = self.problem.u_init
me[:] = super().__call__(*args, **kwargs)
return me


def train_at_collocation_nodes():
"""
Expand Down
38 changes: 26 additions & 12 deletions pySDC/playgrounds/ML_initial_guess/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@
from pySDC.core.Errors import DataError

try:
# TODO : mpi4py cannot be imported before dolfin when using fenics mesh
# see https://github.com/Parallel-in-Time/pySDC/pull/285#discussion_r1145850590
# This should be dealt with at some point
from mpi4py import MPI
except ImportError:
MPI = None
Expand All @@ -26,7 +23,7 @@ class Tensor(torch.Tensor):
@staticmethod
def __new__(cls, init, val=0.0, *args, **kwargs):
"""
Instantiates new datatype. This ensures that even when manipulating data, the result is still a mesh.
Instantiates new datatype. This ensures that even when manipulating data, the result is still a tensor.

Args:
init: either another mesh or a tuple containing the dimensions, the communicator and the dtype
Expand All @@ -52,21 +49,38 @@ def __new__(cls, init, val=0.0, *args, **kwargs):
raise NotImplementedError(type(init))
return obj

def __add__(self, *args, **kwargs):
res = super().__add__(*args, **kwargs)
res._comm = self.comm
return res

def __sub__(self, *args, **kwargs):
res = super().__sub__(*args, **kwargs)
res._comm = self.comm
return res

def __lmul__(self, *args, **kwargs):
res = super().__lmul__(*args, **kwargs)
res._comm = self.comm
return res

def __rmul__(self, *args, **kwargs):
res = super().__rmul__(*args, **kwargs)
res._comm = self.comm
return res

def __mul__(self, *args, **kwargs):
res = super().__mul__(*args, **kwargs)
res._comm = self.comm
return res

@property
def comm(self):
"""
Getter for the communicator
"""
return self._comm

def __array_finalize__(self, obj):
"""
Finalizing the datatype. Without this, new datatypes do not 'inherit' the communicator.
"""
if obj is None:
return
self._comm = getattr(obj, '_comm', None)

def __abs__(self):
"""
Overloading the abs operator
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import pytest
import dill
import os

results = {}

Expand All @@ -21,6 +19,8 @@ def test_AllenCahn_contracting_circle(variant, inexact):
@pytest.mark.base
@pytest.mark.order(2)
def test_show_results():
import dill
import os
from pySDC.projects.TOMS.AllenCahn_contracting_circle import show_results

# dump result
Expand Down
7 changes: 7 additions & 0 deletions pySDC/tests/test_tutorials/test_step_7.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,3 +120,10 @@ def test_C_2x2():
for line in p.stderr:
print(line)
assert p.returncode == 0, 'ERROR: did not get return code 0, got %s with %2i processes' % (p.returncode, num_procs)


@pytest.mark.pytorch
def test_D():
from pySDC.tutorial.step_7.D_pySDC_with_PyTorch import train_at_collocation_nodes

train_at_collocation_nodes()
89 changes: 89 additions & 0 deletions pySDC/tutorial/step_7/D_pySDC_with_PyTorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from pySDC.playgrounds.ML_initial_guess.ml_heat import HeatEquationModel, Train_pySDC
from pySDC.playgrounds.ML_initial_guess.heat import Heat1DFDTensor


def train_at_collocation_nodes():
"""
For the first proof of concept, we want to train the model specifically to the collocation nodes we use in SDC.
If successful, the initial guess would already be the exact solution and we would need no SDC iterations.

What we find is that we can train the network to predict the solution to one very specific problem rather well.
See the error during training for what happens when we ask the network to solve for exactly what it just trained.
However, if we train for something else, i.e. solving to a different step size in this case, we can only use the
model to predict the solution of what it's been trained for last and it loses the ability to solve for previously
learned things. This is solely because we chose an overly simple model that is unsuitable to the task at hand and
is likely easily solved with a bit of patience. This is just a demonstration of the interface between pySDC and
PyTorch. If you want to do a project with this, feel free to take this as a starting point and do things that
actually do something!

The output shows the training loss during training and, after each of three training sessions is complete, the error
of the prediction with the current state of the network. To demonstrate the forgetfulness, we finally print the
error of all learned predictions after training is complete.
"""
out = ''
errors_mid_training = []
errors_post_training = []

# instantiate the pySDC problem and a model for PyTorch
problem = Heat1DFDTensor()
model = HeatEquationModel(problem)

# setup neural network
lr = 0.001
num_epochs = 250
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=lr)

# setup initial conditions
t = 0
initial_condition = problem.u_exact(t)

# train the model to predict the solution at certain collocation nodes
collocation_nodes = np.array([0.15505102572168285, 0.6449489742783183, 1]) * 1e-2
for dt in collocation_nodes:

# get target condition from implicit Euler step
target_condition = problem.solve_system(initial_condition, dt, initial_condition, t)

# do the training
for epoch in range(num_epochs):
predicted_state = model(initial_condition, t, dt)
loss = criterion(predicted_state.float(), target_condition.float())

optimizer.zero_grad()
loss.backward()
optimizer.step()

if (epoch + 1) % 50 == 0:
out += f'Training for {dt=:.2e}: Epoch [{epoch+1:4d}/{num_epochs:4d}], Loss: {loss.item():.4e}\n'

# evaluate model to compute error
model_prediction = model(initial_condition, t, dt)
errors_mid_training += [abs(target_condition - model_prediction)]
out += f'Error of prediction at {dt:.2e} during training: {abs(target_condition-model_prediction):.2e}\n'

# compare model and problem
for dt in collocation_nodes:
target_condition = problem.solve_system(initial_condition, dt, initial_condition, t)
model_prediction = model(initial_condition, t, dt)
errors_post_training += [abs(target_condition - model_prediction)]
out += f'Error of prediction at {dt:.2e} after training: {abs(target_condition-model_prediction):.2e}\n'

print(out)
with open('data/step_7_D_out.txt', 'w') as file:
file.write(out)

# test that the training went as expected
assert np.greater([1e-2, 1e-4, 1e-5], errors_mid_training).all(), 'Errors during training are larger than expected'
assert np.greater([1e0, 1e0, 1e-5], errors_post_training).all(), 'Errors after training are larger than expected'

# save the model to use it throughout pySDC
torch.save(model.state_dict(), 'data/heat_equation_model.pth')


if __name__ == '__main__':
train_at_collocation_nodes()
13 changes: 13 additions & 0 deletions pySDC/tutorial/step_7/README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,16 @@ Important things to note:
- Below, we run the code 3 times: with 1 and 2 processors in space as well as 4 processors (2 in time and 2 in space). Do not expect scaling due to the CI environment.

.. include:: doc_step_7_C.rst


Part D: pySDC and PyTorch
-------------------------

PyTorch is a library for machine learning. The data structure is called tensor and allows to run on CPUs as well as GPUs in addition to access to various machine learning methods.
Since the potential for use in pySDC is very large, we have started on a datatype that allows to use PyTorch tensors throughout pySDC.

This example trains a network to predict the results of implicit Euler solves for the heat equation. It is too simple to do anything useful, but demonstrates how to use tensors in pySDC and then apply the enormous PyTorch infrastructure.
This is work in progress in very early stages! The tensor datatype is the simplest possible implementation, rather than an efficient one.
If you want to work on this, your input is appreciated!

.. include:: doc_step_7_D.rst
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ markers = [
'cupy: tests for cupy on GPUs',
'libpressio: tests using the libpressio library',
'monodomain: tests the monodomain project, which requires previous compilation of c++ code',
'pytorch: tests for PyTorch related things in pySDC'
]
timeout = 300

Expand Down