# Map a function with adaptive sampling and segmentation

This is a more advanced approach to learning a function with adaptive sampling.


The learning domain is subdivided into a number of regions (4 in the current case) and the function is sampled adaptively in each region.
The regions are then combined into a single learner.
Each region is sampled until it reaches a predefined loss goal. The combined learner is finally sampled until it reaches the same loss goal.

The learners at each step are saved and can be reloaded to continue the learning process.

In [None]:
import os
from copy import deepcopy

import adaptive
import matplotlib.pyplot as plt
from backend import adaptive_tools
from backend.processclass2 import Experiment1D_Dimensionless

%matplotlib notebook
adaptive.notebook_extension()


## 1. Define the function to be learned

Our function is the relation between the `r_max_n` and the input parameters `p_o` and `tau_r` that is obtained by solving the RDE equation of the Continuum model.
Here we use the `Experiment1D_Dimensionless` class that represents the dimensionless version of the RDE equation.

First, we need to set up the RDE solver with necessary parameters:

In [None]:
pr = Experiment1D_Dimensionless()
pr.beam_type = 'super_gauss'
pr.f0 = 1e6
pr.fwhm = 500
pr.step = 2
pr.order = 1

Then, we can define a simple function that can recieve the input parameters, set them into the RDE solver and return result:

In [None]:
def rde_r_max(xy):
    global pr
    _ = Experiment1D_Dimensionless()
    pr = deepcopy(pr)
    x, y = xy
    pr.p_o = x
    pr.tau_r = y
    pr.solve_steady_state()
    return pr.r_max_n

## 2. Define the segmentation of the function domain

Next, we define the domain of the function to be learned and divide it into regions.
Each region is defined by left-bottom and right-top corners.

The domain subdivision is based on the preliminary coarse sampling of the function.
In the current case, the domain is divided into 4 regions:

   * bottom-left segment contains a corner region with an abrupt edge,
   * top-left segment contains a thin long region with an abrupt edge,
   * bottom-right segment contains a gradual slope
   * top-right segment contains a rather flat region

##### Zoom in to see the regions better

In [None]:
# Define the bounds for each segment
segments = [
    ((0.23, 0.3), (1000, 10000)),
    ((0.4, 20), (1.0001, 1000)),
    ((0.28, 0.4), (1.0001, 1000)),
    ((0.3, 20), (1000, 10000))
]

In [None]:
prelim_fname = r'examples/r_max_interp_1.0.int'
preliminary_map = adaptive_tools.learner_load_full(prelim_fname)

In [None]:
# Create a plot
plt.figure(figsize=(10, 8))

# Define colors for each segment
colors = ['red', 'green', 'blue', 'orange']

# Plot each segment
for i, ((x_min, x_max), (y_min, y_max)) in enumerate(segments):
    plt.fill_betweenx([y_min, y_max], x_min, x_max, color=colors[i], alpha=0.3, label=f'Segment {i+1}')

# # Plot each segment with black lines
# for (x_min, x_max), (y_min, y_max) in segments:
#     plt.plot([x_min, x_max], [y_min, y_min], 'k-')  # Bottom line
#     plt.plot([x_min, x_max], [y_max, y_max], 'k-')  # Top line
#     plt.plot([x_min, x_min], [y_min, y_max], 'k-')  # Left line
#     plt.plot([x_max, x_max], [y_min, y_max], 'k-')  # Right line

X, Y, Z = preliminary_map.interpolated_on_grid(5000)
Z = Z.T
extent = [X.min(), X.max(), Y.min(), Y.max()]

# Overlay the imshow plot
plt.imshow(Z, extent=extent, origin='lower', aspect='auto', cmap='viridis')

# Add labels and a legend
plt.xlabel('$x$')
plt.ylabel('$y$')
# plt.semilogx(True)
# plt.semilogy(True)
plt.title('Segmentation of the learned function domain')
plt.legend()

# Display the plot
plt.show()

#### Plotting function to visualize the learning process

In [None]:
def plot(learner, npoints=300, tri_alpha=0.2, width=300, height=300, xlim=None, ylim=None):
    plot = learner.plot(npoints, tri_alpha=tri_alpha)
    if xlim is not None:
        plot.opts(xlim=xlim)
    if ylim is not None:
        plot.opts(ylim=ylim)
    plot.opts(width=width, height=height)
    return plot

## 3. Learn the function

### 3.1 Set up file names to save the learners
Associate each segment with a file name to save

In [None]:
segments

In [None]:
fname1 = 'r_max_interp_1.0_vertical.int'
fname2 = 'r_max_interp_1.0_horizontal.int'
fname3 = 'r_max_interp_1.0_corner.int'
fname4 = 'r_max_interp_1.0_rest.int'
# Final combined domain
fname5 = 'maps_temp/r_max_interp_1.0_full.int'

### 3.2 Learn the function in each segment separately

#### Segment 1

In [None]:
# Check if file exists
if os.path.isfile(fname1):
    learner1 = adaptive_tools.learner_load_full(fname1)
    learner1.function = rde_r_max
else:
    learner1 = adaptive.Learner2D(rde_r_max, bounds=segments[0])
runner1 = adaptive.Runner(learner1, loss_goal=0.001, ntasks=2)
runner1.live_info()
runner1.start_periodic_saving(save_kwargs=dict(fname=fname1), interval=60)

In [None]:
# runner1.live_plot(plotter=plot, update_interval=0.5)

In [None]:
learner1.to_numpy()

In [None]:
plot(learner1, npoints=1500, tri_alpha=0.2, width=800, height=800)

#### Segment 2

In [None]:
if os.path.isfile(fname2):
    learner2 = adaptive_tools.learner_load_full(fname2)
    learner2.function = rde_r_max
else:
    learner2 = adaptive.Learner2D(rde_r_max, bounds=segments[1])
runner2 = adaptive.Runner(learner2, loss_goal=0.001, ntasks=2)
runner2.live_info()
runner2.start_periodic_saving(save_kwargs=dict(fname=fname2), interval=60)

In [None]:
# runner2.live_plot(plotter=plot, update_interval=0.5)

In [None]:
plot(learner2, npoints=1500, tri_alpha=0.2, width=800, height=800)

#### Segment 3

In [None]:
if os.path.isfile(fname3):
    learner3 = adaptive_tools.learner_load_full(fname3)
    learner3.function = rde_r_max
else:
    learner3 = adaptive.Learner2D(rde_r_max, bounds=segments[2])
runner3 = adaptive.Runner(learner3, loss_goal=0.001, ntasks=2)
runner3.live_info()
runner3.start_periodic_saving(save_kwargs=dict(fname=fname3), interval=60)

In [None]:
# runner3.live_plot(plotter=plot, update_interval=0.5)

In [None]:
plot(learner3, npoints=1500, tri_alpha=0.2, width=800, height=800)

#### Segment 4

In [None]:
if os.path.isfile(fname4):
    learner4 = adaptive_tools.learner_load_full(fname4)
    learner4.function = rde_r_max
else:
    learner4 = adaptive.Learner2D(rde_r_max, bounds=segments[3])
runner4 = adaptive.Runner(learner4, loss_goal=0.001, ntasks=2, log=True)
runner4.live_info()
runner4.start_periodic_saving(save_kwargs=dict(fname=fname4), interval=60)

In [None]:
runner4.live_plot(plotter=plot, update_interval=0.5)

In [None]:
plot(learner4, npoints=500, tri_alpha=0.2, width=800, height=800)

### 3.3 Combine the learners into a single domain

In [None]:
learner12 = adaptive_tools.combine_learners(learner1, learner2)
learner123 = adaptive_tools.combine_learners(learner12, learner3)
learner1234 = adaptive_tools.combine_learners(learner123, learner4)

As you can see, the segments combined nicely and correctly constitute the learned domain:

In [None]:
plot(learner1234, npoints=2000, tri_alpha=0.2, width=800, height=800)

In [None]:
adaptive_tools.plot_learner(learner1234, n_points=7000, dpi=100)

### 3.4 Learn the function in the combined domain

Some regions were deliberately exluded to focus the learners on the features of the function.
In this step the learning is finalized by sampling the function in the whole domain.

In [None]:
whole_domain = ((0, 20), (1.0001, 10000))

In [None]:
learner_full = adaptive_tools.learner_rebound(learner1234, bounds=whole_domain)
learner_full.function = rde_r_max
df = learner_full.to_dataframe()
learner_full = adaptive.Learner2D(rde_r_max, bounds=whole_domain)
learner_full.load_dataframe(df)


In [None]:
learner_full.to_numpy().shape

In [None]:
learner4.data

In [None]:
if os.path.isfile(fname5):
    learner_full = adaptive_tools.learner_load_full(fname5)
    learner_full.function = rde_r_max
learner_full.function = rde_r_max
runner5 = adaptive.Runner(learner_full, loss_goal=0.001, ntasks=4, log=True)
runner5.live_info()
runner5.start_periodic_saving(save_kwargs=dict(fname=fname5), interval=60)

In [None]:
plot(learner_full, 1000, tri_alpha=0.2, width=800, height=800)

In [None]:
plot(learner_full, 1000, tri_alpha=0.2, width=800, height=800)

In [None]:
adaptive_tools.plot_learner(learner_full, n_points=9000, dpi=100)