In [None]:
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
import os
import sympy as sp
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import jax
jax.config.update("jax_enable_x64", True)

### THE ITERATION PARAMETERs
#physical parameters
delta = 0.012 #interface width
Lambda = jnp.sqrt(2)*delta/4
g_ab = [2e-3, 4e-3, 6e-3, 8e-3, 1e-2, 1.2e-2, 1.4e-2, 1.6e-2]
gamma_bc = 1e-2
gamma_ca = 1e-2  

image_num = 50
N = 256 #num of nodes
r = jnp.linspace(0, 1, N+1)
r_mid = (r[1:] + r[:-1]) / 2
dr = r_mid[1] - r_mid[0]


k = 2*jnp.pi*jnp.fft.fftfreq(N, d=dr)
kx, ky = jnp.meshgrid(k, k)
k_square = kx**2 + ky**2

In [None]:
sf = [0.873, 0.747, 0.624, 0.505, 0.391, 0.285, 0.188, 0.104] # theoretical shape factor

In [None]:
def cal_gradsquare(x):
    xh = jnp.fft.fft2(x)
    p = jnp.fft.ifft2(k_square*xh).real
    return x*p


def main(case):
    eb = []
    for gamma_ab in g_ab:
        if gamma_ab > 1e-2:
            Aa = (9*gamma_ab + 12*(gamma_ca - gamma_bc))/(4*jnp.sqrt(2)*Lambda)
            Ab = (9*gamma_ab - 12*(gamma_ca - gamma_bc))/(4*jnp.sqrt(2)*Lambda)
            Ac = 3*jnp.sqrt(2)*gamma_ca/Lambda - 4*Aa/3 - Ab/3 
        else:
            Ab = jnp.sqrt(2)*(9*gamma_bc + 12*(gamma_ab - gamma_ca))/(8*Lambda)
            Ac = jnp.sqrt(2)*(9*gamma_bc - 12*(gamma_ab - gamma_ca))/(8*Lambda)
            Aa = 3*jnp.sqrt(2)*gamma_ab/Lambda - 4*Ab/3 - Ac/3
        A = jnp.array([Aa, Ab, Ac])
        C = jnp.max(A + jnp.roll(A, 1))/6

        u = jnp.array(np.load(f"Nucleation_ZTS/TernaryMix/2d/Output/SymmetricCase/Data-InterfaceWidth-0.012/gamma_ab-{gamma_ab}/{case}/concentration.npy"))
        ua = u[0]
        ub = u[1]
        uc = u[2]

        bulk_term = Aa*jnp.power(ua, 2)*jnp.power(ua-1, 2) + Ab*jnp.power(ub, 2)*jnp.power(ub-1, 2) + Ac*jnp.power(uc, 2)*jnp.power(uc-1, 2) + C*(jnp.power(ua,2)*jnp.power(ub, 2) + jnp.power(ua,2)*jnp.power(uc,2) + jnp.power(uc,2)*jnp.power(ub,2) + jnp.power(1-ua,2)*jnp.power(1-ub,2)*jnp.power(1-uc,2))
        
        Wa = jnp.power(Lambda, 2)*(Aa + C)
        Wb = jnp.power(Lambda, 2)*(Ab + C)
        Wc = jnp.power(Lambda, 2)*(Ac + C)

        interface_term = Wa*cal_gradsquare(ua)/2 + Wb*cal_gradsquare(ub)/2 + Wc*cal_gradsquare(uc)/2

        G = jnp.mean(bulk_term + interface_term, axis=(1,2))
        critical_index = jnp.argmax(G)
        barrier = G[critical_index] - G[0]
        eb.append(float(barrier))

    return eb

eb_het = main(case="heterogeneous")
eb_hom = main(case="homogeneous")
print(f"Het:{eb_het}")
print(f"Hom:{eb_hom}")

In [None]:
from scipy.optimize import fsolve

g_sf_dict = dict(zip(g_ab, sf))
def cal_theo_eb_hom(g_ab):
    ## THE ITERATION PARAMETERs
    #physical parameters
    delta = 0.012 #interface width
    Lambda = jnp.sqrt(2)*delta/4
    N = 256

    gamma_ab = g_ab #interface energy between components A,B
    gamma_bc = 1e-2
    gamma_ca = 1e-2
    gamma = jnp.array([gamma_ab, gamma_bc, gamma_ca])

    if gamma_ab > 1e-2:
            Aa = (9*gamma_ab + 12*(gamma_ca - gamma_bc))/(4*jnp.sqrt(2)*Lambda)
            Ab = (9*gamma_ab - 12*(gamma_ca - gamma_bc))/(4*jnp.sqrt(2)*Lambda)
            Ac = 3*jnp.sqrt(2)*gamma_ca/Lambda - 4*Aa/3 - Ab/3 
    else:
        Ab = jnp.sqrt(2)*(9*gamma_bc + 12*(gamma_ab - gamma_ca))/(8*Lambda)
        Ac = jnp.sqrt(2)*(9*gamma_bc - 12*(gamma_ab - gamma_ca))/(8*Lambda)
        Aa = 3*jnp.sqrt(2)*gamma_ab/Lambda - 4*Ab/3 - Ac/3
    A = jnp.array([Aa, Ab, Ac])
    C = jnp.max(A + jnp.roll(A, 1))/6

    rc = 0.1 # the final radius of conponent C
    eps = 1e-14 

    def ChemicalPotential(i, ua, ub, uc):
        P = jnp.power(1-ua, 2)*jnp.power(1-ub, 2)*jnp.power(1-uc, 2)
        S = jnp.power(ua, 2)+jnp.power(ub, 2)+jnp.power(uc, 2)
        u = [ua, ub, uc]
        mu_i = 2*A[i]*u[i]*(1 - u[i])*(1 - 2*u[i]) + 2*C*(u[i]*(S - u[i]**2) - (1-u[i])*(P + eps)/(jnp.power(1-u[i], 2) + eps))
        return mu_i

    def energy(ua, ub, uc):
        bulk_term = Aa*jnp.power(ua, 2)*jnp.power(ua-1, 2) + Ab*jnp.power(ub, 2)*jnp.power(ub-1, 2) + Ac*jnp.power(uc, 2)*jnp.power(uc-1, 2) + C*(jnp.power(ua,2)*jnp.power(ub, 2) + jnp.power(ua,2)*jnp.power(uc,2) + jnp.power(uc,2)*jnp.power(ub,2) + jnp.power(1-ua,2)*jnp.power(1-ub,2)*jnp.power(1-uc,2))
        return bulk_term
    
    def get_mean_uc(uc0):
        assert uc0.shape == (N, N), "shape wrong"
        c = []
        for i in range(10, N, 10):
            c.append(np.mean(uc0[i, :]))
        return sum(c)/len(c)
    
    def get_mean_ua(ua0):
        assert ua0.shape == (N, N), "shape wrong"
        c = []
        for i in range(150, 240, 5):
            c.append(np.mean(ua0[i, :]))
        return sum(c)/len(c)

    def get_mean_ub(ub0):
        assert ub0.shape == (N, N), "shape wrong"
        c = []
        for i in range(150, 240, 5):
            c.append(np.mean(ub0[i, :]))
        return sum(c)/len(c)

    def critical_Delta_F():
        u = np.load(f"Nucleation_ZTS/TernaryMix/2d/Output/SymmetricCase/Data-InterfaceWidth-0.012/gamma_ab-{gamma_ab}/homogeneous/concentration.npy")
        ua0 = u[0, 0, :, :]
        ub0 = u[1, 0, :, :]
        uc0 = u[2, 0, :, :]
        mean_uc = get_mean_uc(uc0)
        mean_ua = get_mean_ua(ua0)
        mean_ub = get_mean_ub(ub0)
        print(f"case gamma_ab = {gamma_ab}, mean_ua = {mean_ua}, mean_ub = {mean_ub}, mean_uc = {mean_uc}, sum {mean_ua+ mean_ub + mean_uc}")

        ua0 = mean_ua
        ub0 = mean_ub
        uc0 = mean_uc

        ua1 = 0
        ub1 = 0
        uc1 = 1

        ua2 = 1
        ub2 = 0
        uc2 = 0

        f0 = energy(ua=ua0, ub=ub0, uc=uc0) 
        f1 = energy(ua=ua1, ub=ub1, uc=uc1)

        mu_a0 = ChemicalPotential(0, ua=ua0, ub=ub0, uc=uc0)
        mu_b0 = ChemicalPotential(1, ua=ua0, ub=ub0, uc=uc0)
        mu_c0 = ChemicalPotential(2, ua=ua0, ub=ub0, uc=uc0)
        
        df = f1 - f0
        k = (ua2 - ua1)*mu_a0 + (ub2 - ub1)*mu_b0 + (uc2 - uc1)*mu_c0
        def equation(x):
            return (df + k) * x - 4 * jnp.pi * k * x**3 + gamma_ca
        Rc = fsolve(equation, 0)
        
        DFc = (f1 - f0)*jnp.pi*Rc**2 + ((ua2 - ua1)*mu_a0 +  (ub2 - ub1)*mu_b0 + (uc2 - uc1)*mu_c0)*(jnp.pi*Rc**2)*(1 - 2*jnp.pi*Rc**2) + 2*jnp.pi*Rc*gamma_ca
        return Rc, DFc

    theo_r, theo_eb_hom = critical_Delta_F()


    return float(theo_r), float(theo_eb_hom)




theo_hom = []
theo_het = []
theo_cr_hom = []
theo_cr_het = []
for gamma_ab in g_ab:
    theo_r, theo_eb_hom = cal_theo_eb_hom(g_ab=gamma_ab)
    theo_eb_het = theo_eb_hom*g_sf_dict[gamma_ab]
    theo_hom.append(theo_eb_hom)
    theo_het.append(theo_eb_het)
    theo_cr_hom.append(theo_r)

print(f"Theo Het:{theo_het}")
print(f"Theo Hom:{theo_hom}")

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import matplotlib.pylab as pylabs
from matplotlib.ticker import ScalarFormatter
plt.rcParams['axes.prop_cycle'] = plt.cycler(color=plt.cm.Set1.colors)

myparams = {

   'axes.labelsize': '13',

   'xtick.labelsize': '11',

   'ytick.labelsize': '11',

   'lines.linewidth': 1,

   'legend.fontsize': '10',

   'font.family': 'Times New Roman',

   'figure.figsize': '9, 4'  
}
pylabs.rcParams.update(myparams)  

plt.figure(dpi=1000)

plt.subplot(1,2,1)

plt.plot(g_ab, theo_hom, 'o--',  markersize=4, label='Theoretical $E_{b}^{hom}$')
plt.plot(g_ab, eb_hom, 's-',  markersize=6, linewidth=0.7, label='Experimental $E_{b}^{hom}$')
plt.plot(g_ab, theo_het, 'x--',  markersize=4, label='Theoretical $E_{b}^{het}$')
plt.plot(g_ab, eb_het, 'v-',  markersize=6, linewidth=0.7, label='Experimental $E_{b}^{het}$')

plt.xlabel('Interface Energy $\\gamma_{ab}$')
plt.ylabel('Energy Barrier ($E_{b}^{hom / het}$)')
plt.title("Energy Barriers", fontsize=14)
plt.ticklabel_format(axis='y', style='sci', scilimits=(0,0))

plt.legend(loc='lower left', frameon=True)
plt.grid(True, linestyle='--', linewidth=0.5)

plt.yscale('log', base=10)
plt.xscale('log', base=10)
#plt.ylim(0, 0.8)

interface_energy_vector = np.array(g_ab)
numerical_shape_factor_vector = np.array(eb_het)/np.array(eb_hom)
theoretical_shape_factor_vector = np.array(sf)

plt.subplot(1,2,2)
plt.plot(interface_energy_vector, theoretical_shape_factor_vector,  "o--", markersize=4, label='Theoretical shape factor')
plt.plot(interface_energy_vector, numerical_shape_factor_vector, "s-", linewidth=0.7, markersize=6, label='Numerical shape factor')
plt.legend(fontsize=12)
plt.title("Shape factor", fontsize=14)
plt.xlabel('Interface Energy $\\gamma_{ab}$')
plt.ylabel('Shape Factor')
plt.grid(True, linestyle='--', linewidth=0.5)
#plt.ylim(0, 0.8)

plt.tight_layout()
plt.savefig("/home/ms/akrito/string-method-nucleation/correction/2d/sym/results.pdf", dpi=1000, bbox_inches='tight')


relative_err = np.abs(numerical_shape_factor_vector - theoretical_shape_factor_vector)/theoretical_shape_factor_vector
print(f"low:{np.min(relative_err)}\n high:{np.max(relative_err)}")
print(relative_err)