## 2D-from-1D MENT model: attempt to replicate Liwen's work

In [None]:
import math
import os
import sys

import matplotlib.pyplot as plt
import numpy as np
import proplot as pplt
import scipy.interpolate
from tqdm.notebook import tqdm

import mentflow as mf

In [None]:
pplt.rc["cmap.discrete"] = False
pplt.rc["cmap.sequential"] = "viridis"
pplt.rc["figure.facecolor"] = "white"
pplt.rc["grid"] = False

Generate measurement data.

In [None]:
def rotation_matrix(angle):
    _cos = np.cos(angle)
    _sin = np.sin(angle)
    return np.array([[_cos, _sin], [-_sin, _cos]])

In [None]:
data_name = "circles"
data_size = int(1.00e+06)
data_noise = None
xmax = 3.0
n_bins = 64
n_meas = 12

dist = mf.data.toy.gen_dist(data_name, noise=data_noise)
x0 = dist.sample(data_size)

angles = np.linspace(0.0, np.pi, n_meas, endpoint=False)
transfer_matrices = []
for angle in angles:
    matrix = rotation_matrix(angle)
    transfer_matrices.append(matrix)
    
measurements = []
bin_edges = np.linspace(-xmax, xmax, n_bins + 1)
for matrix in transfer_matrices:
    x_out = np.matmul(x0, matrix.T)
    hist, _ = np.histogram(x_out[:, 0], bins=bin_edges, density=True)
    measurements.append(hist)

In [None]:
n_iterations = 10

G = measurements
A = transfer_matrices

n_points = n_bins
n_constraints = n_meas
V = bin_edges[-1] - bin_edges[0]  # square grid side length
deltaV = V / n_points  # grid cell width

bin_coords = 0.5 * (bin_edges[:-1] + bin_edges[1:])

In [None]:
def interpolate_1d(coords=None, values=None, x=0.0):
    """Interpolate 1d array `values` at point x."""    
    fint = scipy.interpolate.interp1d(coords, values, kind="linear", fill_value=0.0, bounds_error=False)
    return fint(x)

In [None]:
coords = np.linspace(-3.0, 3.0, 10)
values = np.exp(-coords**2)

coords_int = np.linspace(coords[0], coords[-1], 100)
values_int = [interpolate_1d(coords, values, coords_int[i]) for i in range(len(coords_int))]

fig, ax = pplt.subplots(figsize=(3, 2))
ax.plot(coords_int, values_int, color="gray")
ax.plot(coords, values, lw=0, marker=".", color="black")
plt.show()

In [None]:
def rotate(u_i, T, v_i_coords=None, h_j_values=None, h_j_coords=None):
    """
    Given a point (u_i, v_i) in the ith coordinate system, find the point (u_j, v_j)
    in the jth coordinate system, then compute h_j(u_j). Do this for a set of 
    integration points v_i. (Take a vertical line in the u_i-v_i plane, apply T,
    compute h_j at each resulting u_j.)
    """
    values = np.zeros(len(v_i_coords))
    for index, v_i in enumerate(v_i_coords):
        u_j = np.dot(T[0], [u_i, v_i])
        values[index] = interpolate_1d(h_j_coords, h_j_values, u_j)        
    return values

In [None]:
# Define matrix connecting i, j coordinate systems.
T = np.zeros((n_constraints, n_constraints, 2, 2))
for i in range(n_constraints):
    for j in range(n_constraints):
        T[i][j] = np.matmul(A[i], np.linalg.inv(A[j]))

# Initialize the component (h) functions.
H = []
for i in range(n_constraints):
    h = np.ones(len(G[i]))
    h = h / np.sum(h)
    H.append(h)

# Start outer iterations.
for iteration in range(n_iterations):
    for i in range(n_constraints):
        print(f"iter={iteration:03.0f} proj={i:02.0f}")
        
        for j in range(n_points):
            
            # Compute the location on the measurement axis (u). This will be the same
            # for all projections since we use the same measurement grid on all
            # measurements.
            u_i = bin_coords[j]
            
            # Compute the product of the component (h) functions at point u_i.
            v_i_coords = bin_coords  # could change
            product = np.ones(n_points)
            for k in range(n_constraints):
                if i != k:
                    h_k_values = H[k]
                    h_k_coords = bin_coords
                    product *= rotate(u_i, T[k][i], v_i_coords, h_k_values, h_k_coords)

            # Integrate (sum) over v_i axis.
            integral = np.sum(product)

            # Update component (h) functions.
            if G[i][j] == 0.0:
                H[i][j] = 0.0
            else:
                if integral == 0.0:
                    print(f"A singular point occurred when calculating the h function")
                    print(f"on the {iteration + 1}th iteration and the {i + 1}th direction!")
                    print(f"Location of singularity：{u_i}")
                    raise ValueError
                    
                H[i][j] = G[i][j] / integral

            if H[i][j] > 1.00e+308:
                print("It exploded!!!")
                print("Take the point with the smallest difference from the constraints as the iteration result")
                print("H[i, j] = {}".format(H[i][j]))
                print("G[i, j] = {}".format(G[i][j]))
                print("integral = {}".format(integral))
                raise ValueError

                ### [What does this do?]
                # anchor = np.argmin(Differ)
                # for i in range(n_constraints):
                #     H[i] = H_iterative[anchor - 1][i]
        
        H[i] = H[i] / np.sum(H[i])

In [None]:
def evaluate_prob(x, y, prior=None):
    """Evaluate the probability distribution using the current h functions."""
    prob = 1.0
    for i in range(n_constraints):
        (u, v) = np.matmul(A[i], [x, y])

        values = H[i]
        coords = bin_coords
        prob *= interpolate_1d(coords, values, u)
    return prob

In [None]:
shape = (150, 150)
coords = [np.linspace(-3.0, 3.0, s) for s in shape]
prob = np.zeros(shape)
for i in range(len(coords[0])):
    for j in range(len(coords[1])):
        x = coords[0][i]
        y = coords[1][j]
        prob[i, j] = evaluate_prob(x, y)
prob = prob / np.sum(prob)

fig, ax = pplt.subplots()
ax.pcolormesh(coords[0], coords[1], prob.T)
plt.show()