In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.cm import ScalarMappable

import seaborn as sns
import matplotlib.pyplot as plt

from plot_setting import global_setting
global_setting()

# Bohachevsky Function

In [None]:
def f_old(x, y):
    f = x**2+2*y**2-0.3*torch.cos(3*torch.pi*x)-0.4*torch.cos(4*torch.pi*y)+0.7
    return -f

In [None]:
function_name = "Bohachevsky"
xlim=(-40, 40)
ylim=(-40, 40)
n_x = 1001
n_y = 1001

#Dense grid for bias correction
x = torch.linspace(xlim[0], xlim[1],10001)
y = torch.linspace(ylim[0], ylim[1],10001)

XX, YY = torch.meshgrid(x, y)
ZZ = f_old(x=XX, y=YY)
min_val = ZZ.min()

def f(x, y):
    fval = f_old(x, y)
    return fval - min_val

l_inf = ZZ.max() - min_val
print(f"Maximum_Value (L-inf): {l_inf}")

x = torch.linspace(xlim[0], xlim[1],n_x)
y = torch.linspace(ylim[0], ylim[1],n_y)

XX, YY = torch.meshgrid(x, y)
ZZ = f(x=XX, y=YY,)

In [None]:
from mpl_toolkits import mplot3d

fig = plt.figure(figsize =(6, 2.8))
ax = plt.axes(projection ='3d')
surf = ax.plot_surface(XX.numpy(), YY.numpy(), ZZ.numpy(), cmap="coolwarm", 
                       rstride=20, cstride=20, linewidth=0.25, edgecolor="k", antialiased=True)
fig.colorbar(surf, shrink=0.5, aspect=12, pad=0.1)
ax.set_xlabel(r"X")
ax.set_ylabel(r"Y")
ax.set_zlabel(r"F(x, y)")
plt.tight_layout()
plt.savefig(f"Figures/{function_name}_surface_plot.jpg", dpi=300, bbox_inches="tight")
plt.show()

# Computing Norms

In [None]:
def Lp_norm(residual, p=2):
    return ((residual**p).mean())**(1/p)

# dense set
N_s = 10000000
x_s = torch.zeros(N_s, 1).uniform_(xlim[0], xlim[1])
y_s = torch.zeros(N_s, 1).uniform_(ylim[0], ylim[1])

residuals = f(x = x_s, y = y_s)
l2_norm_val = Lp_norm(residuals, p=2)
l4_norm_val = Lp_norm(residuals, p=4)
l6_norm_val = Lp_norm(residuals, p=6)

In [None]:
print(f"Maximum_Value (L-inf): {l_inf}")
print(f"L2: {l2_norm_val}")
print(f"L4: {l4_norm_val}")
print(f"L6: {l6_norm_val}")

# Evolutionary Sampling

In [None]:
# initial population
N_s = 500
x_s = torch.zeros(N_s, 1).uniform_(xlim[0], xlim[1])
y_s = torch.zeros(N_s, 1).uniform_(ylim[0], ylim[1])

In [None]:
plot_iters = [1, 5, 10, 20, 100, 1000, 5000]

# epochs = max(plot_iters)+1
epochs = 20000
count = 0

resampled_pop_norm_epoch = []
retained_pop_norm_epoch = []
size_retained = []

residuals = f(x = x_s, y = y_s)
l2_norm_init_pop = Lp_norm(residuals, p=2)
resampled_pop_norm_epoch.append(l2_norm_init_pop)
retained_pop_norm_epoch.append(l2_norm_init_pop)

nplots = len(plot_iters)+1

fig, axes = plt.subplots(2, nplots//2, figsize=(nplots//2*4.2, 2 * 3.5))

ax = axes[0, 0]
ax.scatter(x_s, y_s, s=20, c='b', marker='o', alpha=0.7)
qcs = ax.contour(XX, YY, ZZ, levels=20, cmap='RdGy_r', alpha = 0.4)
# ax.set_aspect("equal")
plt.colorbar(
   ScalarMappable(norm=qcs.norm, cmap=qcs.cmap), ax=ax,
#    ticks=range(vmin, vmax+5, 5)
)
ax.set_title(f"Epoch: 0", fontsize=15)

for i in range(epochs):
    fitness = f(x = x_s, y = y_s)
    mask = fitness > fitness.mean()
    x_old = x_s[mask].unsqueeze(1)
    y_old = y_s[mask].unsqueeze(1)
    
    x_new = torch.zeros(N_s-len(x_old), 1).uniform_(xlim[0], xlim[1])
    y_new = torch.zeros(N_s-len(y_old), 1).uniform_(ylim[0], ylim[1])
    
    x_s = torch.concat([x_old, x_new], dim=0)
    y_s = torch.concat([y_old, y_new], dim=0)
    
    residuals_old = f(x = x_old, y = y_old)
    l2_norm_old = Lp_norm(residuals_old, p=2)
    retained_pop_norm_epoch.append(l2_norm_old)
    
    residuals_new = f(x = x_new, y = y_new)
    l2_norm_new = Lp_norm(residuals_new, p=2)
    resampled_pop_norm_epoch.append(l2_norm_new)
    size_retained.append(len(x_old))
    
    if count<len(plot_iters):
        if plot_iters[count]==i:
            ncol = int((count+1)%(nplots/2))
            nrow = int((count+1)//(nplots/2))

            ax = axes[nrow, ncol]
            ax.scatter(x_old, y_old, s=20, c='b', marker='o', alpha=0.7, label="Retained Population")
            ax.scatter(x_new, y_new, s=20, c='r', marker='^', alpha=0.7, label="Re-sampled Population")
            qcs =ax.contour(XX, YY, ZZ, levels=20, cmap='RdGy_r', alpha = 0.4)
    #         ax.set_aspect("equal")
            plt.colorbar(
               ScalarMappable(norm=qcs.norm, cmap=qcs.cmap), ax=ax,
            #    ticks=range(vmin, vmax+5, 5)
            )
            ax.set_title(f"Epoch: {plot_iters[count]}", fontsize=15)
            count += 1

# plt.subplots_adjust(wspace=0.2, hspace=0.2)
# axes[0,1].legend(fontsize=15, bbox_to_anchor=(1.2, 1.5), ncol=2)
plt.tight_layout()
plt.savefig(f"Figures/{function_name}_evosample.jpg", dpi=150, bbox_inches="tight")
plt.show()

In [None]:
epochs_vals = np.arange(1, epochs+2)

plt.figure(figsize=(4, 3))
plt.plot(epochs_vals, retained_pop_norm_epoch, lw=2.0, label="Retained Population")
plt.plot(epochs_vals, resampled_pop_norm_epoch, lw=2.0, label="Resampled Population")
plt.hlines(l2_norm_val, xmin=1, xmax=epochs+2, linestyle="dashed", label=r"$L^2$ Norm", zorder = 5, color = 'k')
plt.hlines(l4_norm_val, xmin=1, xmax=epochs+2, linestyle="dashed", label=r"$L^4$ Norm", zorder = 6, color = 'limegreen')
plt.hlines(l6_norm_val, xmin=1, xmax=epochs+2, linestyle="dashed", label=r"$L^6$ Norm", zorder = 7, color = 'darkviolet')
plt.hlines(l_inf, xmin=1, xmax=epochs+2, linestyle="dashed", label=r"$L^\infty$ Norm", zorder = 8, color='r')
plt.xscale("log")
plt.grid("on", alpha=0.2)
plt.xlabel("Epochs", fontsize=15)
plt.ylabel(r"$\mathcal{L}^2_r(\theta)$", fontsize=15)
plt.xticks(fontsize=15)
plt.yticks(fontsize=15)
plt.legend(loc="upper left", ncol=2, fontsize=9)
plt.ylim([1000, 7000])
plt.savefig(f"Figures/{function_name}_dynamic_Lp_norm.jpg", dpi=100, bbox_inches="tight")
plt.show()

In [None]:
plt.figure(figsize=(4, 3))
plt.plot(epochs_vals[:-1], np.array(size_retained)/N_s, lw=2.0)
plt.hlines(1.0, xmin=1, xmax=epochs+1, linestyle="dashed", label=r"$L^2$ Norm", zorder = 5)
plt.xscale("log")
plt.grid("on", alpha=0.2)
plt.xlabel("Epochs", fontsize=15)
plt.ylabel(r"$|\mathcal{P}^r|/|\mathcal{P}|$", fontsize=15)
plt.xticks(fontsize=15)
plt.yticks(fontsize=15)
# plt.legend(loc="upper left", ncol=2, fontsize=9)
plt.ylim([-0.1,1.1])
plt.savefig(f"Figures/{function_name}_retain_population_size.jpg", dpi=100, bbox_inches="tight")
plt.show()