In [1]:
import os
import re
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import matplotlib.cm as cm
import matplotlib.colorbar as colorbar
import matplotlib.patches as mpatches
import seaborn as sns

from tqdm import tqdm
from ase.io import read
from collections import Counter
from ase.visualize.plot import plot_atoms
from ase.neighborlist import NeighborList

import warnings
warnings.filterwarnings('ignore')

In [3]:
def extract_best_name_from_file(filepath):
    try:
        with open(filepath, 'r', encoding='utf-8', errors='ignore') as f:
            content = f.read()
    except Exception:
        return os.path.basename(filepath).split('.')[0]

    name_fields = [
        '_chemical_name_common',
        '_chemical_name_systematic',
        '_chemical_name_mineral',
        '_publ_section_title'
    ]

    for field in name_fields:
        match = re.search(rf"{field}\s+(?:'([^']+)'|\"([^\"]+)\"|;(.*?);)", content, re.DOTALL)
        if match:
            name = next(filter(None, match.groups())).strip().replace('\n', ' ')
            if field != '_publ_section_title':
                return name

    titl_match = re.search(r'TITL\s+(.+)', content)
    if titl_match:
        return titl_match.group(1).strip()

    cod_block = re.search(r'_cod_data_source_block\s+(\S+)', content)
    if cod_block:
        return cod_block.group(1).strip()

    return os.path.basename(filepath).split('.')[0]

In [4]:
def load_all_structures(cif_dir, n):
    files = sorted([f for f in os.listdir(cif_dir) if f.endswith('.cif')])[:n]
    structures = []
    titles = []
    for file in tqdm(files, desc="Loading CIFs"):
        filepath = os.path.join(cif_dir, file)
        atoms = read(filepath)
        title = extract_best_name_from_file(filepath)
        structures.append(atoms)
        titles.append(title)
    return structures, titles

In [5]:
def show_grid_plot(structures, titles, plot_func, supercell=None, main_title=None, x_label=None, y_label=None):
    n = len(structures)
    size = int(np.ceil(np.sqrt(n)))
    fig, axs = plt.subplots(size, size, figsize=(5 * size, 6 * size))
    axs = axs.flatten()

    for i in range(n):
        atoms = structures[i] * supercell if supercell else structures[i]
        plot_func(atoms, axs[i])
        axs[i].set_title(titles[i], fontsize=9, pad=2)
        axs[i].tick_params(axis='both', which='both', length=4, width=1)
        axs[i].spines['top'].set_visible(True)
        axs[i].spines['right'].set_visible(True)
        axs[i].spines['left'].set_visible(True)
        axs[i].spines['bottom'].set_visible(True)

    for j in range(n, len(axs)):
        axs[j].set_visible(False)

    if main_title:
        fig.suptitle(main_title, fontsize=18, y=1.02)

    plt.subplots_adjust(left=0.05, right=0.95, top=0.93, bottom=0.07, hspace=0.4, wspace=0.5)
    plt.show()

In [6]:
def plot_structure(atoms, ax):
    plot_atoms(atoms, ax=ax, radii=0.3, show_unit_cell=2)


def plot_supercell(atoms, ax):
    plot_atoms(atoms, ax=ax, radii=0.25, show_unit_cell=2)


def plot_topdown(atoms, ax):
    plot_atoms(atoms, ax=ax, radii=0.3, show_unit_cell=2, rotation=("90x,0y,0z"))

In [7]:
def plot_bond_length_hist(atoms, ax, title="Bond Length Distribution"):
    cutoffs = [1.5] * len(atoms)
    nl = NeighborList(cutoffs, self_interaction=False, bothways=True)
    nl.update(atoms)
    bond_lengths = []

    for i in range(len(atoms)):
        indices, offsets = nl.get_neighbors(i)
        for j, _ in zip(indices, offsets):
            dist = atoms.get_distance(i, j, mic=True)
            bond_lengths.append(dist)

    counts, bins, patches = ax.hist(bond_lengths, bins=20, color='mediumseagreen', edgecolor='black', label='Bond Lengths', alpha=0.85)
    ax.set_xlabel("Bond Length (Å)")
    ax.set_ylabel("Frequency")
    ax.set_title(title, fontsize=10)
    ax.grid(True, axis='y', linestyle='--', alpha=0.6)
    ax.legend()
    norm = plt.Normalize(vmin=min(counts), vmax=max(counts))
    sm = plt.cm.ScalarMappable(cmap='Greens', norm=norm)
    sm.set_array([])
    ax.tick_params(axis='both', which='both', length=4, width=1)


In [None]:
def plot_lattice_system_vs_anisotropy(df):
    lattice_map = {
        "Triclinic": range(1, 3),
        "Monoclinic": range(3, 16),
        "Orthorhombic": range(16, 75),
        "Tetragonal": range(75, 143),
        "Trigonal": range(143, 168),
        "Hexagonal": range(168, 195),
        "Cubic": range(195, 231)
    }

    def get_lattice_system(sg_number):
        for system, sg_range in lattice_map.items():
            if sg_number in sg_range:
                return system
        return "Unknown"

    if 'sg_encoded' in df.columns and 'cell_anisotropy' in df.columns:
        df_copy = df.copy()
        df_copy['lattice_system'] = df_copy['sg_encoded'].apply(get_lattice_system)

        plt.figure(figsize=(10, 6))
        sns.boxplot(data=df_copy, x='lattice_system', y='cell_anisotropy')
        plt.yscale('log')
        plt.title("Anisotropy by Lattice System (Log Scale)")
        plt.xlabel("Lattice System")
        plt.ylabel("Cell Anisotropy")
        plt.grid(True, axis='y')
        plt.tight_layout()
        plt.show()

In [None]:
def plot_element_hist(atoms, ax):
    symbols = atoms.get_chemical_symbols()
    symbol_counts = Counter(symbols)
    keys = list(symbol_counts.keys())
    values = list(symbol_counts.values())

    cmap = cm.get_cmap('tab10')  
    colors = [cmap(i % cmap.N) for i in range(len(keys))]

    bars = ax.bar(keys, values, color=colors, edgecolor='black')

    ax.set_xlabel("Element")
    ax.set_ylabel("Count")
    ax.grid(axis='y')
    ax.tick_params(axis='x', labelrotation=45)

    for i, (key, value) in enumerate(zip(keys, values)):
        ax.text(i, value + 0.2, f"{key} = {value}", ha='center', va='bottom', fontsize=9)

    label = ", ".join(f"{k}" for k in keys)
    ax.set_title(label, fontsize=10)

    patches = [mpatches.Patch(color=colors[i], label=keys[i]) for i in range(len(keys))]
    ax.legend(handles=patches, title='Elements', bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.)

def plot_unitcell_lengths(atoms, ax):
    lengths = atoms.cell.lengths()
    labels = ['a', 'b', 'c']

    cmap = cm.get_cmap('viridis')
    norm = plt.Normalize(min(lengths), max(lengths))

    for i, (label, value) in enumerate(zip(labels, lengths)):
        color = cmap(norm(value))
        rect = patches.Rectangle((i - 0.4, 0), 0.8, value, facecolor=color, edgecolor='black', linewidth=1.2)
        ax.add_patch(rect)
        ax.text(i, value + 0.1, f"{label} = {value:.2f}", ha='center', va='bottom', fontsize=9, color='black')

    ax.set_xlim(-0.5, len(labels) - 0.5)
    ax.set_ylim(0, max(lengths) + 1)
    ax.set_xticks(range(len(labels)))
    ax.set_xticklabels(labels)
    ax.set_xlabel("Axis")
    ax.set_ylabel("Length (Å)")
    ax.grid(axis='y')

    title = ", ".join(f"{label} = {lengths[i]:.2f}" for i, label in enumerate(labels))
    ax.set_title(title, fontsize=11)

    sm = cm.ScalarMappable(cmap=cmap, norm=norm)
    sm.set_array([])  
    cbar = plt.colorbar(sm, ax=ax, pad=0.02)
    cbar.set_label('Unit Cell Length (Å)', fontsize=10)

def plot_unitcell_angles(atoms, ax):
    angles = atoms.cell.angles()
    labels = ['α', 'β', 'γ']

    cmap = cm.get_cmap('plasma')
    norm = plt.Normalize(min(angles), max(angles))

    for i, (label, value) in enumerate(zip(labels, angles)):
        color = cmap(norm(value))
        rect = patches.Rectangle((i - 0.4, 0), 0.8, value, facecolor=color, edgecolor='black', linewidth=1.2)
        ax.add_patch(rect)
        ax.text(i, value + 0.8, f"{label} = {value:.2f}", ha='center', va='bottom', fontsize=9, color='black')

    ax.set_xlim(-0.5, len(labels) - 0.5)
    ax.set_ylim(0, max(angles) + 5)
    ax.set_xticks(range(len(labels)))
    ax.set_xticklabels(labels)
    ax.set_xlabel("Angle")
    ax.set_ylabel("Degrees")
    ax.grid(axis='y')

    title = ", ".join(f"{label} = {angles[i]:.2f}" for i, label in enumerate(labels))
    ax.set_title(title, fontsize=11)

    sm = cm.ScalarMappable(cmap=cmap, norm=norm)
    sm.set_array([])
    cbar = plt.colorbar(sm, ax=ax, pad=0.02)
    cbar.set_label('Angle (Degrees)', fontsize=10)

In [10]:
def plot_rdf(atoms, ax, r_max=10.0, bins=100):
    num_atoms = len(atoms)
    distances = []
    for i in range(num_atoms):
        for j in range(i + 1, num_atoms):
            d = atoms.get_distance(i, j, mic=True)
            if d <= r_max:
                distances.append(d)
    distances = np.array(distances)
    hist, edges = np.histogram(distances, bins=bins, range=(0, r_max), density=True)
    bin_centers = 0.5 * (edges[:-1] + edges[1:])
    ax.plot(bin_centers, hist, color='blue')
    ax.set_title("Radial Distribution Function", fontsize=12)
    ax.set_xlabel("Distance r (Å)")
    ax.set_ylabel("g(r)")
    ax.grid(True)

In [None]:
def visualize_all(cif_dir, n):
    structures, titles = load_all_structures(cif_dir, n)

    show_grid_plot(structures, titles, plot_structure, main_title="Crystal Structures")
    show_grid_plot(structures, titles, plot_supercell, supercell=(2, 2, 2), main_title="Supercells (2x2x2)")
    show_grid_plot(structures, titles, plot_topdown, main_title="Top-down View")
    
    show_grid_plot(structures, titles, plot_bond_length_hist,
                    main_title="Bond Length Distributions", x_label="Bond Length (Å)", y_label="Count")
    show_grid_plot(structures, titles, plot_element_hist,
                    main_title="Elemental Compositions", y_label="Count")
    show_grid_plot(structures, titles, plot_unitcell_lengths,
                    main_title="Unit Cell Lengths", y_label="Length (Å)")
    show_grid_plot(structures, titles, plot_unitcell_angles,
                   main_title="Unit Cell Angles", y_label="Degrees")
    show_grid_plot(structures, titles, plot_rdf,
                   main_title="Radial Distribution Functions", x_label="Distance r (Å)", y_label="g(r)")

In [12]:
visualize_all(r"data/cifs", 9)

Loading CIFs:   0%|          | 0/9 [00:00<?, ?it/s]

Loading CIFs: 100%|██████████| 9/9 [00:02<00:00,  3.22it/s]
