In [1]:

import math
import os
import random

import numpy as np
from numpy.typing import NDArray
from sklearn.decomposition import PCA
from tabulate import tabulate

from xrdpattern.pattern import PatternDB
from xrdpattern.pattern import XrdPattern
from matplotlib import pyplot as plt

from opxrd import OpXRD
from xrdpattern.xrd import LabelType

from holytools.devtools import Profiler
profiler = Profiler()


In [74]:
from IPython.display import Latex

class DatabaseAnalyser:
    def __init__(self, databases : list[PatternDB], output_dirpath : str):
        if len(databases) == 0:
            raise ValueError('No databases provided')
        self.databases : list[PatternDB] = databases
        self.joined_db : PatternDB = PatternDB.merge(databases)
        self.output_dirpath : str = output_dirpath
        os.makedirs(self.output_dirpath, exist_ok=True)

        random.seed(42)

    def run_all(self):
        print(f'Running analysis for {len(self.databases)} databases: {[db.name for db in self.databases]}')
        
        self.plot_in_single(limit_patterns = 10)
        self.plot_in_single(limit_patterns = 50)
        self.plot_in_single(limit_patterns = 100)
        # self.plot_fourier(max_freq=2)
        # self.plot_pca_scatter()
        self.plot_effective_components()
        
        self.plot_histogram()
        self.show_label_fractions()
        self.print_total_counts()

    def plot_in_single(self, limit_patterns : int):
        lower_alphabet = [chr(i) for i in range(97, 123)]
        explanation = [f'{letter}:{db.name}' for letter, db in zip(lower_alphabet, self.databases)]
        self.print_text(f'---> Combined pattern plot for databaes {explanation} | No. patterns = {limit_patterns}')
        
        lower_alphabet = [chr(i) for i in range(97, 123)]
        save_fpath = os.path.join(self.output_dirpath, f'ALL_pattern_multiplot.png')

        cols = 3
        rows = math.ceil(len(self.databases) / cols)
        num_plots = len(self.databases)
        fig = plt.figure(dpi=600, figsize=(cols * 3, rows * 3))
        axes = []
        for i in range(num_plots):
            if i != 0:
                ax = fig.add_subplot(rows, cols, i + 1, sharex=axes[0], sharey=axes[0])
            else:
                ax = fig.add_subplot(rows, cols, i + 1)
            axes.append(ax)

        for letter, ax, database in zip(lower_alphabet, axes, self.databases):
            patterns = database.patterns[:limit_patterns]
            data = [p.get_pattern_data() for p in patterns]

            for x, y in data:
                ax.plot(x, y, linewidth=0.25, alpha=0.50, linestyle='--')
            title = f'{letter})'

            if title:
                ax.set_title(title, loc='left')

        fig.supylabel('Standardized relative intensity (a.u.)')
        fig.supxlabel(r'$2\theta$ [$^\circ$]', ha='center')
        
        plt.tight_layout()
        plt.savefig(f'{save_fpath}')
        plt.show()

    def plot_fourier(self, max_freq=5):
        for db in self.databases:
            fig, ax = plt.subplots(figsize=(10, 4), dpi=300)
            patterns = db.patterns[:10]
            for p in patterns:
                x,y = p.get_pattern_data()
                xf, yf = self.compute_fourier_transform(x, y, max_freq)

                xf, yf = xf[100:], yf[100:]

                plt.plot(xf, yf, linewidth=0.75, linestyle='--', alpha=0.75)

            ax.set_title(f'{db.name} patterns Fourier transform ' + r'$F(k)=\int d(2\theta) I(2\theta) e^{-ik2\theta}$' + f' [No. patterns = {len(patterns)}]')
            ax.set_xlabel(r'k [deg$^{−1}$]')
            ax.set_ylabel('|F($k$)| (a.u.)')

            plt.savefig(os.path.join(self.output_dirpath, f'{db.name}_fourier.png'))
            plt.show()

    def plot_effective_components(self):
        # self.print_text(r'---> Average $\overline{\Delta}$ over fraction of max components per database; '
        #                 r'$\Delta = \frac{|| I(2\theta) - I(2\theta)_{PCA}||}{||I(2\theta)||}$')
        self.print_text(r'Cumulative explained variance ratio $v$ over components |  $v =  \frac{\sum_i \lambda_i}{\sum^n_{j=1} \lambda_j}$')
        # markers = ['o','s','^','v','D','p','*','+','x']

        num_entries = XrdPattern.std_num_entries()
        for db_num, db in enumerate(self.databases):
            max_components = min(len(db.patterns), XrdPattern.std_num_entries())
            standardized_intensities = [p.get_pattern_data()[1] for p in db.patterns]
            print(f'[Debug]: Performing PCA for {db.name} | No. patterns = {len(standardized_intensities)}')
            pca = PCA(n_components=max_components)
            pca.fit_transform(standardized_intensities)

            accuracies = []
            # components_list = np.linspace(0,1, num=20)
            components_list = range(300)
            for n_comp in components_list:
                # n_comp = int(frac * max_components)
                explained_variance = np.sum(pca.explained_variance_ratio_[:n_comp])
                accuracies.append(explained_variance)

            # plt.plot(components_list,accuracies, label=db.name, marker=markers[db_num])
            plt.plot(components_list,accuracies, label=db.name)
            
        plt.xlabel(f'No. components')
        plt.ylabel(f'Cumulative explained variance $V$')
        plt.xlim(0, num_entries//2)
        plt.ylim(0.6, 1)
        plt.legend(loc='lower right')
        plt.savefig(os.path.join(self.output_dirpath, f'ALL_effective_components.png'))

        plt.show()

    def plot_histogram(self):
        self.print_text(f'---> Histograms')
        self.joined_db.show_histograms(save_fpath=os.path.join(self.output_dirpath, 'ALL_histogram.png'), attach_colorbar=False)

    def show_label_fractions(self):
        self.print_text(f'---> Overview of label fractions per contribution')
        table_data = []
        for d in self.databases:
            label_counts = {l: 0 for l in LabelType}
            patterns = d.patterns
            for l in LabelType:
                for p in patterns:
                    if p.has_label(label_type=l):
                        label_counts[l] += 1
            db_percentages = [label_counts[l] / len(patterns) for l in LabelType]
            table_data.append(db_percentages)

        col_headers = [label.name for label in LabelType]
        row_headers = [db.name for db in self.databases]

        table = tabulate(table_data, headers=col_headers, showindex=row_headers, tablefmt='psql')
        print(table)

    def print_total_counts(self):
        self.print_text(f'---> Total pattern counts in opXRD')
        num_total = len(self.get_all_patterns())

        labeled_patterns = [p for p in self.get_all_patterns() if p.is_labeled()]
        num_labelel = len(labeled_patterns)
        print(f'Total number of patterns = {num_total}')
        print(f'Number of labeled patterns = {num_labelel}')


    @staticmethod
    def compute_fourier_transform(x,y, max_freq : float):
        N = len(y)
        T = (x[-1] - x[0]) / (N - 1)
        yf = np.fft.fft(y)
        xf = np.fft.fftfreq(N, T)[:N // 2]

        magnitude = 2.0 / N * np.abs(yf[:N // 2])
        valid_indices = xf <= max_freq

        xf = xf[valid_indices]
        yf = magnitude[valid_indices]
        return xf, yf

    @staticmethod
    def compute_mismatch(i1 : NDArray, i2 : NDArray) -> float:
        norm_original = np.linalg.norm(i1) / len(i1)
        delta_norm = np.linalg.norm(i1 - i2)/len(i1)
        mismatch = delta_norm / norm_original

        return mismatch

    # -----------------------
    # tools

    def get_all_patterns(self) -> list[XrdPattern]:
        return self.joined_db.patterns
    
    @staticmethod
    def print_text(msg : str):
        display(Latex(msg))

In [76]:
test_databases = OpXRD.load_project_list(root_dirpath='/home/daniel/aimat/data/opXRD/test')
opxrd_databases = OpXRD.load_project_list(root_dirpath='/home/daniel/aimat/data/opXRD/final')

- Loading databases from /home/daniel/aimat/data/opXRD/test
[20m[2024-12-19 14:58:50]: Loading patterns from local dirpath /home/daniel/aimat/data/opXRD/test/USC[0m


[38;2;0;255;0m100%[39m [38;2;0;255;0m(338 of 338)[39m |######################| Elapsed Time: 0:00:00 Time:  0:00:000000


[20m[2024-12-19 14:58:51]: Finished loading pattern database located at /home/daniel/aimat/data/opXRD/test/USC[0m
[20m[2024-12-19 14:58:51]: Successfully extracted 338 patterns from 338/338 xrd files[0m
[20m[2024-12-19 14:58:51]: Loading patterns from local dirpath /home/daniel/aimat/data/opXRD/test/CNRS[0m


[38;2;0;255;0m100%[39m [38;2;0;255;0m(210 of 210)[39m |######################| Elapsed Time: 0:00:01 Time:  0:00:010000


[20m[2024-12-19 14:58:52]: Finished loading pattern database located at /home/daniel/aimat/data/opXRD/test/CNRS[0m
[20m[2024-12-19 14:58:52]: Successfully extracted 210 patterns from 210/210 xrd files[0m
- Loading databases from /home/daniel/aimat/data/opXRD/final
[20m[2024-12-19 14:58:52]: Loading patterns from local dirpath /home/daniel/aimat/data/opXRD/final/EMPA[0m


[38;2;0;255;0m100%[39m [38;2;0;255;0m(770 of 770)[39m |######################| Elapsed Time: 0:00:01 Time:  0:00:010000


[20m[2024-12-19 14:58:53]: Finished loading pattern database located at /home/daniel/aimat/data/opXRD/final/EMPA[0m
[20m[2024-12-19 14:58:53]: Successfully extracted 770 patterns from 770/770 xrd files[0m
[20m[2024-12-19 14:58:53]: Loading patterns from local dirpath /home/daniel/aimat/data/opXRD/final/LBNL/UiO_compounds[0m


[38;2;0;255;0m100%[39m [38;2;0;255;0m(1348 of 1348)[39m |####################| Elapsed Time: 0:00:01 Time:  0:00:010000


[20m[2024-12-19 14:58:54]: Finished loading pattern database located at /home/daniel/aimat/data/opXRD/final/LBNL/UiO_compounds[0m
[20m[2024-12-19 14:58:54]: Successfully extracted 1348 patterns from 1348/1348 xrd files[0m
[20m[2024-12-19 14:58:56]: Loading patterns from local dirpath /home/daniel/aimat/data/opXRD/final/LBNL/perovskite_precursor_solutions[0m


[38;2;0;255;0m100%[39m [38;2;0;255;0m(68322 of 68322)[39m |##################| Elapsed Time: 0:01:23 Time:  0:01:230006


[20m[2024-12-19 15:00:20]: Finished loading pattern database located at /home/daniel/aimat/data/opXRD/final/LBNL/perovskite_precursor_solutions[0m
[20m[2024-12-19 15:00:20]: Successfully extracted 68322 patterns from 68322/68322 xrd files[0m
[20m[2024-12-19 15:00:20]: Loading patterns from local dirpath /home/daniel/aimat/data/opXRD/final/LBNL/MnSbO_annealing[0m


[38;2;0;255;0m100%[39m [38;2;0;255;0m(342 of 342)[39m |######################| Elapsed Time: 0:00:00 Time:  0:00:000000


[20m[2024-12-19 15:00:20]: Finished loading pattern database located at /home/daniel/aimat/data/opXRD/final/LBNL/MnSbO_annealing[0m
[20m[2024-12-19 15:00:20]: Successfully extracted 342 patterns from 342/342 xrd files[0m
[20m[2024-12-19 15:00:20]: Loading patterns from local dirpath /home/daniel/aimat/data/opXRD/final/USC[0m


[38;2;0;255;0m100%[39m [38;2;0;255;0m(338 of 338)[39m |######################| Elapsed Time: 0:00:00 Time:  0:00:000000


[20m[2024-12-19 15:00:21]: Finished loading pattern database located at /home/daniel/aimat/data/opXRD/final/USC[0m
[20m[2024-12-19 15:00:21]: Successfully extracted 338 patterns from 338/338 xrd files[0m
[20m[2024-12-19 15:00:22]: Loading patterns from local dirpath /home/daniel/aimat/data/opXRD/final/INT[0m


[38;2;0;255;0m100%[39m [38;2;0;255;0m(19796 of 19796)[39m |##################| Elapsed Time: 0:00:44 Time:  0:00:440003


[20m[2024-12-19 15:01:06]: Finished loading pattern database located at /home/daniel/aimat/data/opXRD/final/INT[0m
[20m[2024-12-19 15:01:06]: Successfully extracted 19796 patterns from 19796/19796 xrd files[0m
[20m[2024-12-19 15:01:06]: Loading patterns from local dirpath /home/daniel/aimat/data/opXRD/final/HKUST/in_house[0m


[38;2;0;255;0m100%[39m [38;2;0;255;0m(21 of 21)[39m |########################| Elapsed Time: 0:00:00 Time:  0:00:0000


[20m[2024-12-19 15:01:06]: Finished loading pattern database located at /home/daniel/aimat/data/opXRD/final/HKUST/in_house[0m
[20m[2024-12-19 15:01:06]: Successfully extracted 21 patterns from 21/21 xrd files[0m
[20m[2024-12-19 15:01:06]: Loading patterns from local dirpath /home/daniel/aimat/data/opXRD/final/HKUST/accumulated[0m


[38;2;0;255;0m100%[39m [38;2;0;255;0m(499 of 499)[39m |######################| Elapsed Time: 0:00:02 Time:  0:00:020000


[20m[2024-12-19 15:01:09]: Finished loading pattern database located at /home/daniel/aimat/data/opXRD/final/HKUST/accumulated[0m
[20m[2024-12-19 15:01:09]: Successfully extracted 499 patterns from 499/499 xrd files[0m
[20m[2024-12-19 15:01:09]: Loading patterns from local dirpath /home/daniel/aimat/data/opXRD/final/CNRS[0m


[38;2;0;255;0m100%[39m [38;2;0;255;0m(1052 of 1052)[39m |####################| Elapsed Time: 0:00:05 Time:  0:00:050000


[20m[2024-12-19 15:01:15]: Finished loading pattern database located at /home/daniel/aimat/data/opXRD/final/CNRS[0m
[20m[2024-12-19 15:01:15]: Successfully extracted 1052 patterns from 1052/1052 xrd files[0m
[20m[2024-12-19 15:01:15]: Loading patterns from local dirpath /home/daniel/aimat/data/opXRD/final/IKFT[0m


[38;2;0;255;0m100%[39m [38;2;0;255;0m(64 of 64)[39m |########################| Elapsed Time: 0:00:00 Time:  0:00:000:00


[20m[2024-12-19 15:01:15]: Finished loading pattern database located at /home/daniel/aimat/data/opXRD/final/IKFT[0m
[20m[2024-12-19 15:01:15]: Successfully extracted 64 patterns from 64/64 xrd files[0m


In [None]:
analyser = DatabaseAnalyser(databases=opxrd_databases, output_dirpath='/tmp/opxrd_analysis')
analyser.plot_effective_components()