# Impoert and setting

In [None]:
import gzip
import pickle

import numpy as np
import matplotlib.pyplot as plt
from matplotlib import set_loglevel
from pymatgen.entries.computed_entries import ComputedEntry
from pymatgen.analysis.phase_diagram import PhaseDiagram, PDPlotter

In [None]:
rcParams_dict = {
    # ---------- figure
    'figure.figsize': (8, 6),
    'figure.dpi': 120,
    'figure.facecolor': 'white',
    # ---------- axes
    'axes.grid': True,
    'axes.linewidth': 1.5,
    # ---------- ticks
    'xtick.direction': 'in',
    'ytick.direction': 'in',
    'xtick.major.width': 1.0,
    'ytick.major.width': 1.0,
    'xtick.major.size': 8.0,
    'ytick.major.size': 8.0,
    # ---------- lines
    'lines.linewidth': 1.5,
    'lines.markersize': 8,
    # ---------- grid
    'grid.linestyle': ':',
    # ---------- font
    'font.family': ['Times New Roman', 'Liberation Serif'],
    'mathtext.fontset': 'cm',
    #'mathtext.fontset': 'stix',
    'font.size': 16,
    'axes.labelsize': 20,
    'legend.fontsize': 20,
    'svg.fonttype': 'path',  # Embed characters as paths
    #'svg.fonttype': 'none',  # Assume fonts are installed on the machine
    'pdf.fonttype': 42,  # embed fonts in PDF using type42 (True type)
}

set_loglevel('error')
plt.rcParams.update(rcParams_dict)

# Load data

In [None]:
def load_data(filename):
    if filename.endswith('.gz'):
        with gzip.open(filename, 'rb') as f:
            return pickle.load(f)
    else:
        with open(filename, 'rb') as f:
            return pickle.load(f)

In [None]:
rslt_data = load_data('./pkl_data/rslt_data.pkl')

In [None]:
rslt_data

In [None]:
# ---------- current generation
gen = rslt_data['Gen'].max()
c_rslt = rslt_data[rslt_data['Gen'] == gen]
cgen_ids = c_rslt.index.values    # current IDs [array]
cgen_ids

In [None]:
e_all = rslt_data['E_eV_atom'].to_dict()
nat_data = load_data('./pkl_data/nat_data.pkl')
rin = load_data('./pkl_data/input_data.pkl')

In [None]:
atype = rin.atype
end_point = rin.end_point
emax_ea = rin.emax_ea
emin_ea = rin.emin_ea
show_max = rin.show_max
label_stable = rin.label_stable
vmax = rin.vmax

print(atype)
print(end_point)
print(emax_ea)
print(emin_ea)
print(show_max)
print(label_stable)
print(vmax)

# ---------- manually set
# atype = ('Cu', 'Sn', 'S')
# end_point = (0.0, 0.0, 0.0)
# emax_ea = 20
# emin_ea = -20
# show_max = 0.05
# label_stable = True
# vmax = 0.05

# Phase diagram

## Calculate phase diagram and hull distance

In [None]:
# ---------- calculate data
entries = {}
for cid, e in e_all.items():
    # ------ np.nan
    if np.isnan(e):
        continue
    # ------ emax_ea
    if emax_ea is not None:
        if e > emax_ea:
            print(f'Eliminate ID {cid} from convex hull: {e} > emax_ea')
            continue
    # ------ emin_ea
    if emin_ea is not None:
        if e < emin_ea:
            print(f'Eliminate ID {cid} from convex hull: {e} < emin_ea')
            continue
    # ------ entry
    composition = "".join(f"{element}{nat_i}" for element, nat_i in zip(atype, nat_data[cid]))
    entries[cid] = ComputedEntry(composition, e*sum(nat_data[cid]), entry_id=cid)

# ---------- end points
end_entry_values = [ComputedEntry(element, end_e) for element, end_e in zip(atype, end_point)]

# ---------- PhaseDiagram and hull distance
pd = PhaseDiagram(end_entry_values + list(entries.values()))
hdist = {cid: pd.get_e_above_hull(entries[cid]) for cid in entries}

In [None]:
hdist

## Plotly

In [None]:
plotter = PDPlotter(pd, show_unstable=0.05)
plotter.show()

## Ternery system, 3D plot, Plotly

In [None]:
plotter = PDPlotter(pd, ternary_style='3d')
plotter.show()

## Binary system, matplotlib
This is only for binary systems.

In [None]:
fig, ax = plt.subplots(1, 1)
plotter_mpl = PDPlotter(pd, show_unstable=0.0, backend='matplotlib', linewidth=1.5, markerfacecolor='darkslateblue', markersize=12)
plotter_mpl.get_plot(label_stable=label_stable, label_unstable=False, ax=ax)

ax.set_axisbelow(True)
# ---------- hline
ax.axhline(y=0, xmin=0, xmax=1, color='black', linestyle='--', zorder=1)


# ---------- label for only binary system
# default fontweight is 'bold' in PDPlotter, so set 'normal'
ax.set_xlabel('Composition', fontsize=24, fontweight='normal')
ax.set_ylabel('Formation energy (eV/atom)', fontsize=24, fontweight='normal')

# ---------- texts
for text in ax.texts:
    text.set_fontsize(14)
    text.set_fontweight('normal')    # bold --> normal

# ---------- scatter: unstable entries
scat_x = []
scat_y= []
scat_c = []
lines, stable_entries, unstable_entries = plotter_mpl.pd_plot_data
for entry, coord in unstable_entries.items():
    if entry.entry_id is not None:
        scat_x.append(coord[0])
        scat_y.append(coord[1])
        scat_c.append(hdist[entry.entry_id])
mappable = ax.scatter(scat_x, scat_y, s=50, c=scat_c, vmin=0, vmax=vmax, cmap='Oranges_r', marker='D', edgecolors='black', zorder=2)
cbar = fig.colorbar(mappable, ax=ax, shrink=0.8, pad=0.05)
cbar.ax.tick_params(labelsize=14)
cbar.set_label('Hull distance (eV/atom)', size=20, rotation=270, labelpad=30)

# ---------- mark the current generation
stable_compos = {entry.entry_id: compos for compos, entry in stable_entries.items()}
unstable_compos = {entry.entry_id: compos for entry, compos in unstable_entries.items()}
for cid in cgen_ids:
    if cid in stable_compos:
        mx, my = stable_compos[cid][0], stable_compos[cid][1]
        ax.plot(mx, my, '+', markeredgecolor='white')
    elif cid in unstable_compos:
        mx, my = unstable_compos[cid][0], unstable_compos[cid][1]
        ax.plot(mx, my, '+', markersize=10, markeredgewidth=0.5,  markeredgecolor='navy')

# ---------- ylim
stable_y = list(stable_entries.keys())
ymin = min(stable_y, key=lambda x: x[1])[1] -0.01
ax.set_ylim(ymin, show_max)

In [None]:
# ---------- save figure
#fig.savefig('hull_distance.png', bbox_inches='tight')    # PNG
#fig.savefig('hull_distance.png', bbox_inches='tight', dpi=300)    # high dpi PNG
#fig.savefig('hull_distance.svg', bbox_inches='tight')    # SVG
#fig.savefig('hull_distance.pdf', bbox_inches='tight')    # PDF

## Ternary system, matplotlib

This is only for ternary systems.

In [None]:
fig, ax = plt.subplots(1, 1)
plotter_mpl = PDPlotter(pd, show_unstable=0.0, backend='matplotlib', linewidth=1.5, markerfacecolor='darkslateblue', markersize=10)
plotter_mpl.get_plot(label_stable=label_stable, label_unstable=False, ax=ax)

# ---------- texts
for text in ax.texts:
    text.set_fontsize(14)
    text.set_fontweight('normal')    # bold --> normal

# ---------- scatter: unstable entries
scat_x = []
scat_y= []
scat_c = []
lines, stable_entries, unstable_entries = plotter_mpl.pd_plot_data
for entry, coord in unstable_entries.items():
    if entry.entry_id is not None:
        if hdist[entry.entry_id] <= show_max:
            scat_x.append(coord[0])
            scat_y.append(coord[1])
            scat_c.append(hdist[entry.entry_id])
mappable = ax.scatter(scat_x, scat_y, s=30, c=scat_c, vmin=0, vmax=vmax, cmap='Oranges_r', marker='D', edgecolors='black', zorder=3)
cbar = fig.colorbar(mappable, ax=ax, shrink=0.6, pad=-0.1)
cbar.ax.tick_params(labelsize=14)
cbar.set_label('Hull distance (eV/atom)', size=20, rotation=270, labelpad=30)

# ---------- mark the current generation
stable_compos = {entry.entry_id: compos for compos, entry in stable_entries.items()}
unstable_compos = {entry.entry_id: compos for entry, compos in unstable_entries.items()}
for cid in cgen_ids:
    if cid in hdist and hdist[cid] <= show_max:
        if cid in stable_compos:
            mx, my = stable_compos[cid][0], stable_compos[cid][1]
            ax.plot(mx, my, '+', markeredgecolor='white', zorder=2)
        elif cid in unstable_compos:
            mx, my = unstable_compos[cid][0], unstable_compos[cid][1]
            ax.plot(mx, my, '+', markersize=6, markeredgewidth=0.5,  markeredgecolor='navy', zorder=4)

In [None]:
# ---------- save figure
#fig.savefig('hull_distance.png', bbox_inches='tight')    # PNG
#fig.savefig('hull_distance.png', bbox_inches='tight', dpi=300)    # high dpi PNG
#fig.savefig('hull_distance.svg', bbox_inches='tight')    # SVG
#fig.savefig('hull_distance.pdf', bbox_inches='tight')    # PDF