In [1]:
from matplotlib.backends.backend_pdf import PdfPages
import matplotlib.pyplot as plt
import os
import numpy as np
import pandas as pd
from collections import defaultdict, Counter
from jarvis.db.figshare import data
from jarvis.core.atoms import Atoms
from jarvis.analysis.structure.spacegroup import Spacegroup3D
import matplotlib
import seaborn as sns
from sklearn.metrics import mean_absolute_error

%matplotlib inline

# matplotlib.use('Agg')

dft_3d = data('dft_3d')
df = pd.DataFrame(dft_3d)
prop = "mbj_bandgap"
base_directory = "/scratch/yll6162/CrossPropertyTL/pred/atomgpt_new'

Obtaining 3D dataset 76k ...
Reference:https://www.nature.com/articles/s41524-020-00440-1
Other versions:https://doi.org/10.6084/m9.figshare.6815699
Loading the zipfile...
Loading completed.


### Extract test ids

In [None]:
test_df = pd.read_csv(f"/data/yll6162/alignntl_dft_3d/tl_dataset/dataset_alignn_matbert-base-cased_robo_prop_{prop}_test.csv", index_col = 0)[[prop, 'ids']]
# train_df.to_csv(f"/scratch/yll6162/atomgpt/dataset_split/{prop}_train.csv")
# test_df.to_csv(f"/scratch/yll6162/atomgpt/dataset_split/{prop}_test.csv")
# val_df.to_csv(f"/scratch/yll6162/atomgpt/dataset_split/{prop}_val.csv")


filepaths = [
    f"{base_directory}/alignn_prop_{prop}_pred_otf.csv",
    f"{base_directory}/bert-base-uncased_chemnlp_prop_{prop}_pred_otf.csv",
    f"{base_directory}/bert-base-uncased_robo_prop_{prop}_pred_otf.csv",
    f"{base_directory}/alignn_bert-base-uncased_chemnlp_prop_{prop}_pred_otf.csv",
    f"{base_directory}/alignn_bert-base-uncased_robo_prop_{prop}_pred_otf.csv",
    f"{base_directory}/alignn_matbert-base-cased_chemnlp_prop_{prop}_pred_otf.csv",
    f"{base_directory}/alignn_matbert-base-cased_robo_prop_{prop}_pred_otf.csv",
]
for filepath in filepaths:
    if not os.path.exists(filepath):
        continue
    pred_df = pd.read_csv(filepath, index_col = 0)
    print(filepath)
    tolerance = 0.001
    equal_col1 = np.isclose(test_df[prop], pred_df["labels"], atol=tolerance)
    assert equal_col1.all()
    pred_df["ids_test"] = test_df["ids"].values
    pred_df.to_csv(filepath)

### Plot specific divisions: cystal system, chemical formula, etc

In [None]:
def get_top_accurate_predictions(csv_file, num_predictions=None, sel='top'): #adjust this number to change the number of materials (i.e. the top 100 most accurate predictions)
    df = pd.read_csv(csv_file)
    num_pred = int(0.1 * len(df))
    if num_predictions:
        num_pred = num_predictions
    df = df[df['labels'] != 0] #Leave this line if you want to look at semiconductors/insulators, comment out to include metals
    df['deviation'] = abs(df['labels'] - df['predictions'])
    if sel == 'top':
        return df.sort_values(by='deviation', ascending=True).head(num_pred)
    elif sel == 'bottom':
        return df.sort_values(by='deviation', ascending=False).head(num_pred)
    elif sel == 'all':
        return df
    else:
        raise ValueError('Invalid selection. Please select either "top" or "bottom" or "all".')

def to_grouped_bar(axis, annot_y_offset=0, annot=True):
    total_patches = len(axis.patches)
    for i,patch in enumerate(axis.patches):
        current_width = patch.get_width()
        k=0.4
        patch.set_width(current_width * k)
        patch.set_x(patch.get_x() + (current_width * (0.5-k)))
        if i // (total_patches/2) == 0:
            patch.set_x(patch.get_x() + (current_width * k))
        # else:
            
        height = patch.get_height()
        if annot==True:
            if height > 0:
                if isinstance(height, int):
                    label = f'{height}'
                elif isinstance(height, float):
                    label = f'{height:.3f}'  # Format float to 2 decimal places
                else:
                    label = f'{height}'  # Fallback in case of unexpected type
                axis.annotate(label, 
                        xy=(patch.get_x() + patch.get_width() / 2, height + annot_y_offset * (i // (total_patches/2))), 
                        xytext=(-2, 3),  # 5 points vertical offset
                        textcoords='offset points', 
                        ha='center', va='bottom')


plt.figure()
fig, axs = plt.subplots(4, 2, figsize=(20, 20))  #
fig.suptitle(f'{prop}:', fontsize=16)
bar_colormap = ["#FAD7AC","#6C8EBF"]
ax_right_1 = axs[2, 0].twinx()
ax_right_2 = axs[3, 0].twinx()
legend_name = {'alignn_robocystallographer_matbert_embed': "LLM+GNN embeddings", 'alignn_embed': "GNN embeddings"}
csv_files = ["'alignn_matbert-base-cased_robo_prop_mbj_bandgap_pred_otf.csv", 'alignn_prop_mbj_bandgap_pred_otf.csv']
for model_idx, model_dir in enumerate(csv_files):
# for model_idx, model_dir in enumerate(['robocystallographer_matbert_embed', 'alignn_embed']):
    model_path = os.path.join(base_directory, model_dir)
    if os.path.isdir(model_path):
        csv_file = next((os.path.join(model_path, f)
                        for f in os.listdir(model_path) if f.endswith('.csv')), None)
    else:
        csv_file = model_path
    if csv_file:
        top_predictions = get_top_accurate_predictions(csv_file, sel='top')
        all_predictions = get_top_accurate_predictions(csv_file, sel='all')
        elements_dict = {}
        for i in range(1, 100):
            elements_dict[i] = 0
        space_groups = []
        space_group_perf = defaultdict(list)
        space_group_perf_val = defaultdict(float)
        labels = []
        predictions = []
        densities = []
        packing_fractions = []
        prototypes = []
        prototypes_perf = defaultdict(list)
        prototypes_perf_val = defaultdict(float)
        wyckoffs = []
        wyckoffs_cell = []
        crystal_systems = []
        crystal_systems_perf = defaultdict(list)
        crystal_systems_perf_val = defaultdict(float)

        for _, row in top_predictions.iterrows():
            jid_with_extension = row['ids_test']
            jid = jid_with_extension.replace('.vasp', '')
            data = next(
                (item for item in dft_3d if item['jid'] == jid), None)

            if data is not None:
                atoms = Atoms.from_dict(data["atoms"])
                crys = pd.Categorical(df.crys)
                spg = Spacegroup3D(atoms)
                space_groups.append(spg.space_group_number)
                space_group_perf[spg.space_group_number].append(row['deviation'])
                density = atoms.density
                packing_fraction = atoms.packing_fraction
                prot = atoms.composition.prototype
                z = atoms.atomic_numbers
                # w = spg._dataset["wyckoffs"]
                # w2 = ''.join(set(w))
                # wyckoffs_cell.append(w2)
                # for j in w:
                #     wyckoffs.append(j)
                for e in z:
                    elements_dict[e] += 1
            labels.append(row['labels'])
            predictions.append(row['predictions'])
            crystal_systems.append(data['crys'])
            # crystal_systems_perf[data['crys']].append(row['deviation'])
            densities.append(density)
            prototypes.append(prot)
            # prototypes_perf[prot].append(row['deviation'])
            packing_fractions.append(packing_fraction)

        for _, row in all_predictions.iterrows():
            atoms = Atoms.from_dict(data["atoms"])
            jid_with_extension = row['ids_test']
            jid = jid_with_extension.replace('.vasp', '')
            data = next(
                (item for item in dft_3d if item['jid'] == jid), None)
            prot = atoms.composition.prototype
            crystal_systems_perf[data['crys']].append(row['deviation'])
            prototypes_perf[prot].append(row['deviation'])


        for key in space_group_perf:    
            space_group_perf_val[key] = np.mean(space_group_perf[key])

        # max_items = 10
        selected = ["A", "A2BC4", "A2BCD6", "AB", "AB2", "AB2C4", "ABC", "ABC2", "ABC3", "ABC4","A2B", "A2BC"]
        max_items = len(selected)
        counts = Counter(prototypes)
        selected_counts = {k: counts[k] for k in selected}
        # sorted_proto = np.array(sorted(counts.items(), reverse=True, key=lambda x: x[1])[0:max_items])
        sorted_proto = np.array(sorted(selected_counts.items(), key=lambda x: x[0]))
        # sorted_proto = np.array(sorted(counts.items(), reverse=True, key=lambda x: x[0])[0:max_items])
        # sorted_wyckoffs = np.array(
        #     sorted(Counter(wyckoffs).items(), reverse=True, key=lambda x: x[1])[0:max_items])
        
        for key in sorted(selected):
            prototypes_perf_val[key] = np.mean(prototypes_perf[key])
        crystal_systems_sorted = sorted(crystal_systems)
        for key in crystal_systems_sorted:
            crystal_systems_perf_val[key] = np.mean(crystal_systems_perf[key])
        # with PdfPages(f"{model_dir}_analysis.pdf") as pdf:


        axs[0, 0].scatter(labels, predictions, alpha=0.7, label=legend_name[model_dir])
        axs[0, 0].set_title('Labels vs. Predictions')
        axs[0, 0].set_xlabel('Labels')
        axs[0, 0].set_ylabel('Predictions')
        axs[0, 0].plot([min(labels), max(labels)], [
                        min(labels), max(labels)], 'r--')
        axs[0, 0].legend()

        # axs[0, 1].hist(space_groups, bins=30, alpha=0.7, label=legend_name[model_dir])
        bin_edges = [i for i in range(0, 231, 10)]
        # Calculate the centers of the bins
        bin_centers = [(bin_edges[i] + bin_edges[i+1]) / 2 for i in range(len(bin_edges) - 1)]

        sns.histplot(space_groups, bins = bin_edges, ax=axs[0, 1], kde=False, label=legend_name[model_dir])
        axs[0, 1].set_title('Distribution of Space Group Numbers')
        axs[0, 1].set_xlabel('Space Group Number')
        axs[0, 1].set_ylabel('Top 10% Prediction Frequency')
        axs[0, 1].set_xticks(bin_centers)
        axs[0, 1].legend(loc='upper left')
        bars =axs[0, 1].patches  # Get the list of bars in the plot
        bar_color = bars[0 + len(bin_edges) * model_idx].get_facecolor()  # This returns RGBA
        

        # Adjust the bar width and positions

        

        # axs[1, 0].bar(space_group_perf_val.keys(), space_group_perf_val.values())
        # axs[1, 0].set_title('Space Group Performance')
        # axs[1, 0].set_xlabel('Space Group Number')
        # axs[1, 0].set_ylabel('All Testset MAE')
        # axs[1, 0].set_ylim(0, 2.1)


        # axs[1, 0].hist(densities, bins=30, alpha=0.7, label=legend_name[model_dir])
        # axs[1, 0].set_title('Density Distribution')
        # axs[1, 0].set_xlabel('Density (g/cm³)')
        # axs[1, 0].set_ylabel('Top 10% Prediction Frequency')
        # axs[1, 0].legend()

        # axs[1, 1].bar(prototypes_perf_val.keys(), prototypes_perf_val.values())
        # axs[1, 1].set_title('Chemical Formula Performance')
        # axs[1, 1].set_xlabel('Prototype')
        # axs[1, 1].set_ylabel('All Testset MAE')
        # axs[1, 1].set_ylim(0, 1.1)

        # axs[1, 1].hist(packing_fractions, bins=5, alpha=0.7, label=legend_name[model_dir]) #you may have to adjust the bin size to better visualize
        # axs[1, 1].set_title('Packing Fraction Distribution')
        # axs[1, 1].set_xlabel('Packing Fraction')
        # axs[1, 1].set_ylabel('Top 10% Prediction Frequency')
        # axs[1, 1].set_xlim(0, 1)
        # axs[1, 1].legend()

        if model_idx == 0:
            axs[2, 0].grid(True, linewidth=0.3, linestyle='--')
        axs[2, 0].bar(np.arange(max_items), np.array(
            sorted_proto)[:, 1].astype(int), label=legend_name[model_dir], color=bar_color)
        axs[2, 0].set_title('Chemical Formula Distribution')
        axs[2, 0].set_xticks(np.arange(max_items))
        axs[2, 0].set_xticklabels(
            np.array(sorted_proto)[:, 0], rotation=0)
        axs[2, 0].set_xlabel('Chemical Formula')
        axs[2, 0].set_ylabel(f'Top 10% Prediction Frequency')
        # axs[2, 0].set_ylim(0, 30)
        # axs[2, 0].set_ylim(0, 50)
        axs[2, 0].legend()            
        for bar in axs[2, 0].patches:
            bar.set_edgecolor('black')  # Set the color of the border
            bar.set_linewidth(1.5)  # Set the width of the border
        marker_color = list(bar_color)
        marker_color[-1] = 1
        marker_color = tuple(marker_color)
        ax_right_1.plot(prototypes_perf_val.keys(), prototypes_perf_val.values(), label=f'{legend_name[model_dir]} MAE', color=marker_color, marker='s' if model_idx == 0 else '^', \
            markersize=8, markeredgecolor='black', markeredgewidth=1)
        ax_right_1.set_ylabel('All Testset MAE')
        ax_right_1.set_ylim(0, 1.1)
        if model_idx > 0:
            handles1, labels1 = axs[2, 0].get_legend_handles_labels()
            handles2, labels2 = ax_right_1.get_legend_handles_labels()
            axs[2, 0].legend(handles1 + handles2, labels1 + labels2, loc='upper left')

        
        axs[3, 1].bar(crystal_systems_perf_val.keys(), crystal_systems_perf_val.values(), label=legend_name[model_dir], color=bar_color)
        # data = {
        #     'Category': list(crystal_systems_perf_val.keys()),
        #     'Value': list(crystal_systems_perf_val.values())
        # }

        # Create the seaborn bar plot
        # sns.barplot(x='Category', y='Value', data=data, ax=axs[3, 1], label=legend_name[model_dir], color='blue' if model_idx == 0 else 'orange')
        # sns.barplot(x='Category', y='Value', data=data, ax=axs[3, 1], label=legend_name[model_dir], color=bar_colormap[model_idx])
        axs[3, 1].set_title('Crystal System Performance')
        axs[3, 1].set_xlabel('Crystal System')
        axs[3, 1].set_ylabel('All Testset MAE')
        axs[3, 1].legend(loc='upper right')
        axs[3, 1].set_ylim(0.2, 0.6)
        for bar in axs[3, 1].patches:
            bar.set_edgecolor('black')  # Set the color of the border
            bar.set_linewidth(1.5)  # Set the width of the border

        # axs[3, 1].set_ylim(0, 0.6)

        # axs[3, 1].bar(np.arange(max_items), np.array(
        #     np.array(sorted_wyckoffs)[:, 1], dtype='int'))
        # axs[3, 1].set_title('Wyckoff Site Distribution')
        # plt.xticks(np.arange(max_items),d
        #             np.array(sorted_wyckoffs)[:, 0])
        # axs[3, 1].set_xlabel('Wyckoff Site')
        # axs[3, 1].set_ylabel('Top 10% Prediction Frequency')
        
        # axs[3, 0].hist(crystal_systems_sorted,
        #                 bins=np.arange(0, 8, 1), width=.5, label=legend_name[model_dir])
        if model_idx == 0:
            axs[3, 0].grid(True, linewidth=0.3, linestyle='--')
        sns.histplot(crystal_systems_sorted, bins=np.arange(0, 8, 1), ax=axs[3, 0], kde=False, label=f'{legend_name[model_dir]} top pred freq')
        axs[3, 0].set_title('Crystal System Distribution')
        axs[3, 0].set_xlabel('Crystal System')
        axs[3, 0].set_ylabel('Top 10% Prediction Frequency')
        axs[3, 0].set_ylim(0, 55)
        

        
        # ax_right_2.set_title('Crystal System Performance')
        # ax_right_2.set_xlabel('Crystal System')
        ax_right_2.set_ylabel('All Testset MAE')
        # ax_right_2.legend(loc='upper right')
        ax_right_2.set_ylim(0.3, 0.8)
        # print(bar_color)

        ax_right_2.plot(crystal_systems_perf_val.keys(), crystal_systems_perf_val.values(), label=f'{legend_name[model_dir]} MAE', color=marker_color, marker='s' if model_idx == 0 else '^', \
            markersize=10, markeredgecolor='black', markeredgewidth=1)
        if model_idx > 0:
            handles1, labels1 = axs[3, 0].get_legend_handles_labels()
            handles2, labels2 = ax_right_2.get_legend_handles_labels()
            axs[3, 0].legend(handles1 + handles2, labels1 + labels2, loc='upper left')

        
        # for bar in axs[3, 1].patches:
        #     bar.set_edgecolor('black')  # Set the color of the border
        #     bar.set_linewidth(1.5)  # Set the width of the border


        # axs[2, 1].bar(list(elements_dict.keys()),
        #                 list(elements_dict.values()), label=legend_name[model_dir])
        # axs[2, 1].set_title('Atomic Number Distribution')
        # axs[2, 1].set_xlabel('Atomic Number')
        # axs[2, 1].set_ylabel('Top 10% Prediction Frequency')
        # axs[2, 1].legend()
        if model_idx == 0:
            axs[2, 1].grid(True, linewidth=0.3, linestyle='--')
        axs[2, 1].bar(prototypes_perf_val.keys(), prototypes_perf_val.values(), label=f'{legend_name[model_dir]} top pred freq', color=bar_color)
        axs[2, 1].set_title('Chemical Formula Performance')
        axs[2, 1].set_xlabel('Prototype')
        axs[2, 1].set_ylabel('All Testset MAE')
        axs[2, 1].set_ylim(0, 1.1)
        for bar in axs[2, 1].patches:
            bar.set_edgecolor('black')  # Set the color of the border
            bar.set_linewidth(1.5)  # Set the width of the border



            # axs[1, 1].set_ylim(0, 1.1)

to_grouped_bar(axs[0, 1])
to_grouped_bar(axs[2, 1], annot_y_offset=0.02)
to_grouped_bar(axs[2, 0])
to_grouped_bar(axs[3, 0], annot = True)

to_grouped_bar(axs[3, 1])
plt.tight_layout()
plt.show()

# pdf.savefig(fig)
# plt.savefig(f"{prop}_{model_dir}.png")

plt.close()