In [None]:
import numpy as np
import matplotlib.pyplot as plt
import random
import os
import torch 

seed = 2023
random.seed(seed)
torch.manual_seed(seed)
np.random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
import dolfin as df
import time
from utils import *
from utils_training import *
import prepare_data
from utils_compare_methods import *
import seaborn as sns
from mpl_toolkits.axes_grid1 import make_axes_locatable
from prepare_data import rotate, outside_ball
import pandas as pd
import gc

sns.set_theme()
sns.set_context("paper")
colors = sns.color_palette("mako").as_hex()
my_cmap = sns.color_palette("viridis", as_cmap=True)

%load_ext autoreload
%autoreload 2


In [None]:
def border_torch(domains):
    """
    Determine the pixels on the boundary of multiple domains using PyTorch.

    Parameters:
        domains (tensor): Tensor defining the domains (0 outside, 1 inside).

    Returns:
        tensor: Tensor with 1 at boundary pixels, 0 otherwise.
    """
    pad_domains = torch.nn.functional.pad(
        domains, pad=(1, 1, 1, 1), mode="constant", value=0
    )

    diff = (
        (pad_domains[:, 1:-1, 1:-1] != pad_domains[:, :-2, 1:-1])
        | (pad_domains[:, 1:-1, 1:-1] != pad_domains[:, 2:, 1:-1])
        | (pad_domains[:, 1:-1, 1:-1] != pad_domains[:, 1:-1, :-2])
        | (pad_domains[:, 1:-1, 1:-1] != pad_domains[:, 1:-1, 2:])
        | (pad_domains[:, 1:-1, 1:-1] != pad_domains[:, :-2, :-2])
        | (pad_domains[:, 1:-1, 1:-1] != pad_domains[:, :-2, 2:])
        | (pad_domains[:, 1:-1, 1:-1] != pad_domains[:, 2:, :-2])
        | (pad_domains[:, 1:-1, 1:-1] != pad_domains[:, 2:, 2:])
    )

    res = diff.int()
    domain_tmp = torch.where(domains == 0, torch.tensor(10), torch.tensor(1))
    res -= domain_tmp
    return (res == 0).int()


def compute_boundaries(phi, level):
    domain = (phi <= 3e-16).to(phi.device)
    if level == 0:
        return domain
    else:
        boundary = border_torch(domain).to(phi.device)
        domain_1 = ((domain.int() + boundary.int()) == 1).to(phi.device)
        if level == 1:
            return domain, domain_1
        else:
            border_1 = border_torch(domain_1).to(phi.device)
            domain_2 = ((domain_1.int() + border_1.int()) == 1).to(phi.device)
            return domain, domain_1, domain_2

In [None]:
F, phi, G, params = create_FG_numpy(10, 64)

domain, domain_1, domain_2 = compute_boundaries(torch.tensor(phi), 2)

In [None]:
plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
plt.imshow(phi[4, :, :], origin="lower", cmap="viridis")
plt.colorbar(shrink=0.75)
plt.grid(False)
plt.title(r"$\phi$", fontsize=15)
plt.subplot(1, 2, 2)
plt.imshow(
    domain[4, :, :].int() + domain_1[4, :, :].int() + domain_2[4, :, :].int(),
    origin="lower",
    cmap="viridis",
)
plt.colorbar(shrink=0.75)
plt.grid(False)
plt.title(r"$\mathcal{S}_0$, $\mathcal{S}_1$ and $\mathcal{S}_2$", fontsize=15)
plt.tight_layout()
plt.savefig("./images/masks.pdf")
plt.show()