# Solutions to Elliptic and Parabolic Problems via Finite Difference Based Unsupervised Small Linear Convolutional Neural Networks

Code for parabolic problems presented in [*Solutions to Elliptic and Parabolic Problems via Finite Difference Based Unsupervised Small Linear Convolutional Neural Networks*](https://arxiv.org/abs/2311.00259)

In [None]:
import random
import imageio
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import tensorflow as tf
from tensorflow.keras.utils import Progbar
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from IPython.display import Image

random.seed(42)
np.random.seed(42)
tf.random.set_seed(42)
tf.keras.utils.set_random_seed(42)

# # Uncomment to set memory growth for GPU
# def set_memory_growth():
#     # Get GPUs
#     gpus = tf.config.list_physical_devices('GPU')

#     # For tensorflow 2.x.x allow memory growth on GPU
#     for gpu in gpus:
#         tf.config.experimental.set_memory_growth(gpu, True)
        
        
# set_memory_growth()

## Get problem data

In [None]:
# Define grid size and time step
N = 128
h = 1/(N - 1)
time_step = 0.1

# Define grid
x = np.linspace(0, 1, N)
y = np.linspace(0, 1, N)
[X, Y] = np.meshgrid(x, y)

# Trigonometric functions (we used n = 1 and n = 4)
n = 1
u = lambda t: tf.constant((np.cos(t) * np.sin(n * np.pi * X) * np.sin(n * np.pi * Y)).reshape(1, N, N, 1), dtype=tf.float32)
f = lambda t: tf.constant((-np.sin(n*np.pi*X)*np.sin(n*np.pi*Y)*(-2*np.cos(t)*n**2*np.pi**2 + np.sin(t))).reshape(1, N, N, 1), dtype=tf.float32)

# Gaussian function
# u = lambda t: tf.constant((np.exp(-50*((2*X - 1)**2 + (2*Y - 1)**2)) * np.cos(t)).reshape(1, N, N, 1), dtype=tf.float32)
# f = lambda t: tf.constant((-np.exp(-50*(2*X - 1)**2 - 50*(2*Y - 1)**2)*(160000*np.cos(t)*X**2 - 160000*np.cos(t)*X + 160000*np.cos(t)*Y**2 - 160000*np.cos(t)*Y + 79200*np.cos(t) + np.sin(t))).reshape(1, N, N, 1), dtype=tf.float32)

# # Initial condition
u0 = u(0)

## Define time dependent loss function

In [None]:
# Define loss function
class TimeDependentLoss(tf.keras.losses.Loss):
    def __init__(self, N, step_size, f, **kwargs):
        super(TimeDependentLoss, self).__init__(**kwargs)
        self.N = N
        self.h = 1./(N - 1.)
        self.step_size = step_size
        
        # Tune this parameter
        self.alpha = np.square(self.h) * 4
        
        # Get source term
        self.f = f

        # Set up kernels
        # Laplacian kernel
        k_laplacian = np.array([[0, 1, 0], [1, -4, 1], [0, 1, 0]]) / np.square(self.h)
        k_laplacian = tf.constant(k_laplacian, dtype=tf.float32)
        self.k_laplacian = tf.reshape(k_laplacian, [3, 3, 1, 1])

    def call(self, current_previous, t):

        # Unpack current and previous predictions
        u_current, u_previous = current_previous
        
        # Get value of f at time t
        f_current = self.f(t)[:, 1:-1, 1:-1, :]
        
        # Loss on interior
        u_current_interior = u_current[:, 1:-1, 1:-1, :]
        u_previous_interior = u_previous[:, 1:-1, 1:-1, :]

        # Estimate right hand side (i.e., laplacian(u)) for current step
        rhs = tf.nn.convolution(u_current, self.k_laplacian, strides=1)

        interior = tf.reduce_mean(tf.square(u_current_interior - u_previous_interior - self.step_size*(rhs + f_current)))

        # Loss on boundary
        # Get boundary values for left, right, bottom, and top
        left_boundary = tf.square(tf.reshape(u_current[:, :, 0, :], [self.N]))
        right_boundary = tf.square(tf.reshape(u_current[:, :, -1, :], [self.N]))
        bottom_boundary = tf.square(tf.reshape(u_current[:, 0, :, :], [self.N]))
        top_boundary = tf.square(tf.reshape(u_current[:, -1, :, :], [self.N]))

        # # Define boundary loss for left, right, bottom, and top boundaries
        boundary = tf.concat([left_boundary,
                              right_boundary,
                              bottom_boundary,
                              top_boundary], axis = -1)
        boundary = tf.reduce_mean(boundary)

        # Compute final loss
        loss = self.alpha*interior + (1 - self.alpha)*boundary

        return loss

## Build U-Net

In [None]:
def get_norm(name):
    if "batch" in name:
        return tf.keras.layers.BatchNormalization(axis=-1, center=True, scale=True)
    elif "identity" in name:
        return tf.identity
    else:
        raise ValueError("Invalid normalization layer")


def get_regularizer(name):
    if "l2" in name:
        return tf.keras.regularizers.L2(1e-7)
    elif "none" in name:
        return None
    else:
        raise ValueError("Invalid regularization layer")


def get_activation(name, **kwargs):
    if name == "relu":
        return tf.keras.layers.Activation("relu")
    elif name == "tanh":
        return tf.keras.layers.Activation("tanh")
    elif name == "swish":
        return tf.keras.layers.Activation("swish")
    elif name == "identity":
        return tf.identity
    else:
        raise ValueError("Invalid activation layer")


class ConvDownsample(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super().__init__()
        self.pad = tf.keras.layers.ZeroPadding2D(padding=(1, 1))
        self.conv = tf.keras.layers.Conv2D(filters=kwargs["filters"],
                                           kernel_size=3,
                                           strides=2,
                                           kernel_regularizer=get_regularizer(kwargs["regularizer"]))
        self.norm = get_norm(kwargs["norm"])
        self.activation = get_activation(kwargs["activation"], **kwargs)

    def call(self, x):
        x = self.pad(x)
        x = self.conv(x)
        x = self.norm(x)
        x = self.activation(x)
        return x


def get_downsample(name, **kwargs):
    if name == 'maxpool':
        return tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(2, 2))
    elif name == 'avgpool':
        return tf.keras.layers.AveragePooling2D(pool_size=(2, 2), strides=(2, 2))
    elif name == 'conv':
        return ConvDownsample(**kwargs)
    else:
        raise ValueError("Invalid downsampling layer!")

class ConvLayer(tf.keras.layers.Layer):
    def __init__(self, filters, **kwargs):
        super().__init__()
        self.conv = tf.keras.layers.Conv2D(filters=filters,
                                           kernel_size=5,
                                           padding="same",
                                           kernel_regularizer=get_regularizer(kwargs["regularizer"]))
        self.norm = get_norm(kwargs["norm"])
        self.activation = get_activation(kwargs["activation"], **kwargs)

    def call(self, x):
        x = self.conv(x)
        x = self.norm(x)
        x = self.activation(x)
        return x


class EncoderBlock(tf.keras.layers.Layer):
    def __init__(self, filters, block, **kwargs):
        super().__init__()
        self.block = block(filters, **kwargs)
        self.down = get_downsample(kwargs["down_type"], filters=filters, **kwargs)

    def call(self, x):
        skip = self.block(x)
        x = self.down(skip)
        return skip, x


class Bottleneck(tf.keras.layers.Layer):
    def __init__(self, filters, block, **kwargs):
        super().__init__()
        self.block = block(filters, **kwargs)

    def call(self, x):
        x = self.block(x)
        return x


class DecoderBlock(tf.keras.layers.Layer):
    def __init__(self, filters, block, **kwargs):
        super().__init__()
        # self.trans_conv = tf.keras.layers.Conv2DTranspose(filters=filters,
        #                                                   kernel_size=2,
        #                                                   strides=2)
        self.upsample = tf.keras.layers.UpSampling2D(size=(2, 2))
        self.block = block(filters, **kwargs)

    def call(self, skip, x):
        up = self.upsample(x)
        concat = tf.keras.layers.concatenate([skip, up])
        out = self.block(concat)
        return out


class ConvLSTMBlock(tf.keras.layers.Layer):
    def __init__(self, filters, return_sequences, **kwargs):
        super().__init__()
        self.convlstm = tf.keras.layers.ConvLSTM2D(filters=filters,
                                                   kernel_size=5,
                                                   padding='same',
                                                   return_sequences=return_sequences)
        self.norm = get_norm(kwargs["norm"])
        self.activation = get_activation(kwargs["activation"], **kwargs)

    def call(self, x):
        x = self.convlstm(x)
        x = self.norm(x)
        x = self.activation(x)
        return x


class BaseModel(tf.keras.Model):

    def __init__(self,
                 block,
                 n_classes,
                 init_filters,
                 depth,
                 pocket,
                 **kwargs):
        super(BaseModel, self).__init__()

        # User defined inputs
        self.n_classes = n_classes
        self.init_filters = init_filters
        self.depth = depth
        self.pocket = pocket

        # If pocket network, do not double feature maps after downsampling
        self.mul_on_downsample = 2
        if self.pocket:
            self.mul_on_downsample = 1
            
        # self.merge_block = block(self.init_filters, **kwargs)

        self.encoder = list()
        for i in range(self.depth):
            filters = self.init_filters * self.mul_on_downsample ** i
            self.encoder.append(EncoderBlock(filters, block, **kwargs))

        filters = self.init_filters * self.mul_on_downsample ** self.depth
        self.bottleneck = Bottleneck(filters, block, **kwargs)

        self.decoder = list()
        for i in range(self.depth - 1, -1, -1):
            filters = self.init_filters * self.mul_on_downsample ** i
            self.decoder.append(DecoderBlock(filters, block, **kwargs))

        self.out = tf.keras.layers.Conv2D(n_classes, kernel_size=1)

    def call(self, x):
        skips = list()
        for encoder_block in self.encoder:
            skip, x = encoder_block(x)
            skips.append(skip)

        x = self.bottleneck(x)

        skips.reverse()
        for skip, decoder_block in zip(skips, self.decoder):
            x = decoder_block(skip, x)

        x = self.out(x)
        return x

    
conv_kwargs = {"regularizer": "l2",
               "norm": "identity",
               "activation": "identity",
               "alpha": 0.01,
               "down_type": "maxpool"}


class UNetBlock(tf.keras.layers.Layer):
    def __init__(self, filters, **kwargs):
        super().__init__()
        self.conv1 = ConvLayer(filters, **kwargs)
        self.conv2 = ConvLayer(filters, **kwargs)

    def call(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        return x



class UNet(tf.keras.Model):

    def __init__(self,
                 n_classes,
                 init_filters,
                 depth,
                 pocket):
        super(UNet, self).__init__()

        self.base_model = BaseModel(UNetBlock,
                                    n_classes,
                                    init_filters,
                                    depth,
                                    pocket,
                                    **conv_kwargs)

    @tf.function
    def call(self, x, **kwargs):
        return self.base_model(x, **kwargs)

## Put it all together

* Instantiate loss function and network
* Define training step
* Begin unsupervised training loop

In [None]:
loss_fn = TimeDependentLoss(N, time_step, f)

u_previous = u0
_n_steps = 250
model = UNet(1, 32, 3, True)

initial_learning_rate = 0.0001
lr_schedule = tf.keras.optimizers.schedules.CosineDecayRestarts(
    initial_learning_rate=initial_learning_rate,
    first_decay_steps=_n_steps,
    t_mul=1.0,
    m_mul=1.0,
    alpha=0.0
)
optimizer = tf.keras.optimizers.Adam(lr_schedule)
optimizer.global_clipnorm = 0.001

@tf.function
def train_step(u_previous, f, t):
    with tf.GradientTape() as tape:
        p = model(u_previous)
        loss = loss_fn((p, u_previous), t)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    return loss, p

sol = list()
sol.append(u_previous.numpy().reshape(N, N))

for t in range(10):
    best = np.Inf
    if t == 0:
        n_steps = 4*_n_steps
    else:
        n_steps = _n_steps

    prog_bar = Progbar(n_steps, stateful_metrics=["time_ste", "loss"])
    for step in range(n_steps):
        loss, u_candidate = train_step(u_previous, f, time_step*(1 + t))
        prog_bar.add(1, values=[("time_step", int(t + 1)), ("loss", loss)])

        if loss < best:
            best = loss
            u_next = u_candidate

    u_previous = u_next
    sol.append(u_next.numpy().reshape(N, N))

## Compute error and plot solution

In [None]:
step = 5
print("{}".format(np.format_float_scientific(h*np.sqrt(np.sum(np.square(u(time_step * step).numpy().reshape(N, N) - sol[step]))), 4)))
print("{}".format(np.format_float_scientific(np.max(np.abs(u(time_step * step).numpy().reshape(N, N) - sol[step])), 4)))

In [None]:
fig, axs = plt.subplots(nrows=2, ncols=5, figsize=(17, 8), sharey=True, sharex=True)
fig.subplots_adjust(wspace=0.1, hspace=0.1)
current_time = 0.0
for i, ax in enumerate(axs.flatten()):
    if i < 10:
        # im = ax.imshow(u(time_step * i).numpy().reshape(N, N), cmap="jet", vmin=0, vmax=1, extent=[0, 1, 0, 1])
        im = ax.imshow(sol[i], cmap="jet", vmin=0, vmax=1, extent=[0, 1, 0, 1])
        # im = ax.imshow(np.abs(u(time_step * i).numpy().reshape(N, N) - sol[i]), cmap="jet", extent=[0, 1, 0, 1])
       
        
        ax.set_title(f"t = {np.round(current_time, 4)}", fontsize=20)
        
        xticks = ax.get_xticks()
        ax.set_xticks([xticks[0], xticks[-1]])
        ax.set_xticklabels([int(xticks[0]), int(xticks[-1])])
        
        yticks = ax.get_yticks()
        ax.set_yticks([yticks[0], yticks[-1]])
        ax.set_yticklabels([int(yticks[0]), int(yticks[-1])])
        
    # ax.axis('off')
    current_time += time_step

# fig.subplots_adjust(right=0.8)
cbar_ax = fig.add_axes([0.91, 0.15, 0.05, 0.7])
ticklabs = cbar_ax.get_yticklabels()
cbar_ax.set_yticklabels(ticklabs, fontsize=20)
fig.colorbar(im, cax=cbar_ax)

# fig.text(0.08, 0.5, r'True solution at time $t$, $u^t$', va='center', rotation='vertical', fontsize=20)
fig.text(0.08, 0.5, r'Prediction at time $t$, $p^t$', va='center', rotation='vertical', fontsize=20)
# fig.text(0.08, 0.5, r'Difference, $|u^t - p^t|$', va='center', rotation='vertical', fontsize=20)

# plt.tight_layout()
plt.show()

In [None]:
def create_gif(array_list, filename='animation.gif'):
    fig, ax = plt.subplots()
    im = ax.imshow(array_list[0], cmap="jet")

    def update(num):
        im.set_data(array_list[num])
        return im

    anim = animation.FuncAnimation(fig, update, frames=len(array_list), repeat=True)
    anim.save(filename, writer='pillow')

In [None]:
create_gif(sol, filename='animation.gif')

In [None]:
Image(filename='animation.gif')

In [None]:
def create_surface_gif(array_list, filename='animation.gif', cmap='jet'):
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    ax.set_zlim(-0.1, 1.1)

    N = array_list[0].shape[0]
    X = np.linspace(0, 1, N)
    Y = np.linspace(0, 1, N)
    X, Y = np.meshgrid(X, Y)
    surf = ax.plot_surface(X, Y, array_list[0], cmap=cmap, rstride=1, cstride=1, linewidth=0, antialiased=False)

    def update(num):
        ax.clear()
        surf = ax.plot_surface(X, Y, array_list[num], cmap=cmap, rstride=1, cstride=1, linewidth=0, antialiased=False)
        ax.set_zlim(-1.1, 1.1)
        return surf,

    anim = animation.FuncAnimation(fig, update, frames=len(array_list), interval=100, blit=True, repeat=True)
    anim.save(filename, writer='pillow', fps=3)

In [None]:
create_surface_gif(sol, filename='animation_3d.gif')

In [None]:
Image(filename='animation_3d.gif')