In [13]:
import warnings
from pathlib import Path

import molli
import numpy as np
from scipy.stats._axis_nan_policy import SmallSampleWarning

molli.visual.configure(bgcolor="#1F1F1F")

warnings.simplefilter("ignore", SmallSampleWarning)


class Catalyst:
    """Class to represent a catalyst and featurize it.

    Attributes:
        name (str): Name of the catalyst.
        xyz_file (Path): Path to the xyz file of the catalyst.
        catalyst (molli.Molecule): Loaded catalyst molecule.
        steric_descriptors (dict): Calculated steric descriptors.
        electronic_descriptors (dict): Calculated electronic descriptors.
        descriptors (dict): Combined steric and electronic descriptors.
    """

    def __init__(self, name: str, xyz_file: Path, calculate_descriptors: bool = True):
        """Initialize the Catalyst object.

        Args:
            name (str): Name of the catalyst.
            xyz_file (Path): Path to the xyz file of the catalyst.
            calculate_descriptors (bool): Whether to calculate descriptors. Defaults to True.
        """
        self.name = name
        self.xyz_file = xyz_file

        self.catalyst = molli.load(xyz_file, parser="openbabel")
        self.catalyst.name = name

        # Initialize dictionaries to store calculated descriptors
        self.steric_descriptors: dict[str, float] = {}
        self.electronic_descriptors: dict[str, float] = {}

        # Set up required attributes
        self._identify_ligands()
        self._setup_morfeus()
        self._setup_xtb()
        self._identify_ipso_carbon()

        if calculate_descriptors:
            # Calculate steric descriptors
            self.buried_volume()
            self.buried_volume_ipso()
            self.quadrant_buried_volume_ligand()
            self.solid_angle()
            self.dispersion()
            self.pyramidalization()
            self.solvent_accessible_surface_area()
            self.sterimol()
            self.distances()

            # Calculate electronic descriptors
            self.homo_lumo()
            self.charges()
            self.fukui()
            self.global_electrophilicity()
            self.global_nucleophilicity()
            self.tolman_electronic_parameter()

        # Concatenate
        self.descriptors: dict[str, float] = (
            {"ligand_1_name": self.name} | self.steric_descriptors | self.electronic_descriptors
        )

    def __str__(self):
        """Returns the catalyst name."""
        return f"Catalyst: {self.name}"

    def __repr__(self):
        """Returns the catalyst name."""
        return self.__str__()

    def _identify_ligands(self):
        """Identify the ligands of the catalyst.

        Raises:
            AssertionError: If expected ligands are not found.
        """
        # Determine the index of Pd
        self.metal = self.catalyst.get_atom(molli.Element.Pd)
        self.metal_index = self.catalyst.get_atom_index(self.metal)

        # Find all ligating atoms to Pd
        ligands = [x for x in self.catalyst.connected_atoms(self.metal)]
        self.ligand_indices = [self.catalyst.get_atom_index(x) for x in ligands]
        assert len(ligands) == 4, "Catalyst should have 4 ligands."

        for ligand in ligands:
            if ligand.element == molli.Element.C:
                self.aryl_carbon = ligand
                self.aryl_carbon_index = self.catalyst.get_atom_index(ligand)
            if ligand.element == molli.Element.O:
                self.carboxylic_oxygen = ligand
                self.carboxylic_oxygen_index = self.catalyst.get_atom_index(ligand)
            if ligand.element == molli.Element.N:
                self.amine_nitrogen = ligand
                self.amine_nitrogen_index = self.catalyst.get_atom_index(ligand)
            if ligand.element == molli.Element.P:
                self.phosphine = ligand
                self.phosphine_index = self.catalyst.get_atom_index(ligand)

        assert self.aryl_carbon, "Aryl carbon not found."
        assert self.carboxylic_oxygen, "Carboxylic oxygen not found."
        assert self.amine_nitrogen, "Amine nitrogen not found."
        assert self.phosphine, "Phosphine not found."

        # Identify amine protons
        self.amine_protons = [
            x for x in self.catalyst.connected_atoms(self.amine_nitrogen) if x.element == molli.Element.H
        ]
        self.amine_proton_indices = [self.catalyst.get_atom_index(x) for x in self.amine_protons]
        assert len(self.amine_protons) == 2, "Primary amine should have 2 protons."

        # Identify the amine carbon
        self.amine_carbon = [
            x for x in self.catalyst.connected_atoms(self.amine_nitrogen) if x.element == molli.Element.C
        ][0]
        self.amine_carbon_index = self.catalyst.get_atom_index(self.amine_carbon)
        assert self.amine_carbon, "Amine carbon not found."

        # Identify groups on phosphine
        self.phosphine_groups = [
            x for x in self.catalyst.connected_atoms(self.phosphine) if x.element != molli.Element.Pd
        ]
        self.phosphine_group_indices = [self.catalyst.get_atom_index(x) for x in self.phosphine_groups]
        assert len(self.phosphine_groups) == 3, "Phosphine should have 2 groups."

        # Identify the carbonyl carbon
        carb_carbon = [x for x in self.catalyst.connected_atoms(self.carboxylic_oxygen) if x.element == molli.Element.C]
        self.carbonyl_carbon = carb_carbon[0]
        self.carbonyl_carbon_index = self.catalyst.get_atom_index(self.carbonyl_carbon)

        # Identify the carbonyl oxygen
        carb_oxygens = [x for x in self.catalyst.connected_atoms(carb_carbon[0]) if x.element == molli.Element.O]
        self.carbonyl_oxygen = [x for x in carb_oxygens if x != self.carboxylic_oxygen][0]
        self.carbonyl_oxygen_index = self.catalyst.get_atom_index(self.carbonyl_oxygen)
        assert self.carbonyl_oxygen, "Carbonyl oxygen not found."

    def _setup_morfeus(self):
        """Read xyz file for catalyst and setup Morfeus xyz elements and coordinates.

        Some calculations require the ligand with and without the metal, so we also
        extract these from the full catalyst.
        """
        import tempfile

        from morfeus import read_xyz

        # Full catalyst
        self.morfeus_elements, self.morfeus_coords = read_xyz(self.xyz_file)

        # Ligand with metal
        pd_ligand_mol = molli.Molecule(self.catalyst)
        pd_ligand_atoms = pd_ligand_mol.atoms[:76]

        for atom in pd_ligand_atoms:
            if atom.element != molli.Element.Pd:
                pd_ligand_mol.del_atom(atom)

        ligand_metal = pd_ligand_mol.get_atom(molli.Element.Pd)
        self.ligand_metal_index = pd_ligand_mol.get_atom_index(ligand_metal)

        with tempfile.NamedTemporaryFile(mode="w", delete=False) as f:
            pd_ligand_mol.dump_xyz(f)
            temp_file_name = f.name

        self.Pd_ligand_elements, self.Pd_ligand_coords = read_xyz(temp_file_name)

        # Ligand without metal
        self.ligand_mol = molli.Molecule(self.catalyst)
        ligand_atoms = self.ligand_mol.atoms[:76]

        for atom in ligand_atoms:
            self.ligand_mol.del_atom(atom)

        with tempfile.NamedTemporaryFile(mode="w", delete=False) as f:
            self.ligand_mol.dump_xyz(f)
            temp_file_name = f.name

        phosphine = self.ligand_mol.get_atom(molli.Element.P)
        self.ligand_phosphine_index = self.ligand_mol.get_atom_index(phosphine)
        self.ligand_phosphine_groups = [
            self.ligand_mol.get_atom_index(x)
            for x in self.ligand_mol.connected_atoms(phosphine)
            if x.element != molli.Element.P
        ]

        self.ligand_elements, self.ligand_coords = read_xyz(temp_file_name)

    @staticmethod
    def _in_between(text_file: str | Path, start_keyword: str, end_keyword: str) -> list:
        """Returns text between two strings in a file.

        Takes a list of text and returns the list containing entries between start_keyword
        and end_keyword (inclusive).

        Args:
            text_file (str | Path): Path to the text file to parse.
            start_keyword (str): String to start text saving.
            end_keyword (str): String to end text saving.

        Returns:
            list: List of each line from the parsed text.
        """
        output = []
        parsing = False
        with open(text_file) as f:
            for line in f:
                if start_keyword in line:
                    parsing = True
                if parsing:
                    output.append(line)
                if end_keyword in line:
                    parsing = False
        return output

    def _setup_xtb(self):
        """Import output files from xTB calculations."""
        import json

        # Read the PTB json file
        with open(self.xyz_file.parent / "ptb.xtbout.json") as file:
            self.ptb_json = json.load(file)

    def _get_fukui_values(self, atomic_index: int) -> tuple[float, float, float]:
        """Extracts Fukui function values from the vfukui output file.

        Args:
            atomic_index (int): Index of the atom for which to extract the Fukui function values.
                This is 0-indexed.

        Raises:
            ValueError: If the Fukui function block is not found in the vfukui.out file.

        Returns:
            tuple(float, float, float): Tuple of f(+), f(-), and f(0) values.
        """
        FUKUI_START = "#        f(+)     f(-)     f(0)"
        FUKUI_END = "Property Printout"

        lines = self._in_between(self.xyz_file.parent / "vfukui.out", FUKUI_START, FUKUI_END)

        if not lines:
            raise ValueError("Fukui function block not found in vfukui.out")

        cleaned = []
        for line in lines[1:-2]:
            cleaned.append(line.strip().split())

        f_plus = round(float(cleaned[atomic_index][1]), 5)
        f_minus = round(float(cleaned[atomic_index][2]), 5)
        f_zero = round(float(cleaned[atomic_index][3]), 5)

        return f_plus, f_minus, f_zero

    def _identify_ipso_carbon(self):
        """Identify the ipso carbon of buchwald-type phosphine ligands."""
        phosphine = self.ligand_mol.get_atom(molli.Element.P)
        groups = self.ligand_mol.connected_atoms(phosphine)

        ipso_carbon = []
        for atoms in groups:
            beta_groups = []
            for atom in self.ligand_mol.connected_atoms(atoms):
                if atom.atype == molli.AtomType.Aromatic:
                    connected_atoms = [x for x in self.ligand_mol.connected_atoms(atom)]
                    if any(atom.element != molli.Element.C for atom in connected_atoms):
                        continue
                    else:
                        beta_groups.append(atom)

            for atom in beta_groups:
                for atom_2 in self.ligand_mol.connected_atoms(atom):
                    connected_atoms = [x for x in self.ligand_mol.connected_atoms(atom_2)]
                    if any(atom_3.element != molli.Element.C for atom_3 in connected_atoms):
                        continue
                    elif any(atom_3.atype != molli.AtomType.Aromatic for atom_3 in connected_atoms):
                        continue
                    else:
                        ipso_carbon.append(atom_2)

        if len(ipso_carbon) == 0:
            self.ipso_carbon = None
            self.ipso_carbon_index = None
        if len(ipso_carbon) > 1:
            ipso_carbon_0_dist = self.ligand_mol.distance(phosphine, ipso_carbon[0])
            ipso_carbon_1_dist = self.ligand_mol.distance(phosphine, ipso_carbon[1])

            if ipso_carbon_0_dist < ipso_carbon_1_dist:
                self.ipso_carbon = ipso_carbon[0]
                self.ipso_carbon_index = self.ligand_mol.get_atom_index(ipso_carbon[0])
            else:
                self.ipso_carbon = ipso_carbon[1]
                self.ipso_carbon_index = self.ligand_mol.get_atom_index(ipso_carbon[1])

        if len(ipso_carbon) == 1:
            self.ipso_carbon = ipso_carbon[0]
            self.ipso_carbon_index = self.ligand_mol.get_atom_index(ipso_carbon[0])

    ##############################
    # Steric Descriptors
    ##############################

    def buried_volume(self):
        """Calculate the buried volume of the ligand.

        Updates:
            steric_descriptors (dict): Adds 'buried_volume_{radius}A' from 2 to 5 Å in 0.5 Å increments.
        """
        from morfeus import BuriedVolume

        radii = [2.0, 3.5, 5.0]

        for radius in radii:
            bv = BuriedVolume(
                self.Pd_ligand_elements,
                self.Pd_ligand_coords,
                self.ligand_metal_index + 1,
                radius=radius,
            )

            self.steric_descriptors[f"buried_volume_{radius}A"] = bv.buried_volume

    def buried_volume_ipso(self):
        """Calculate the buried volume of the ipso carbon.

        Updates:
            steric_descriptors (dict): Adds 'buried_volume_ipso_{radius}A' from 2 to 5 Å in 0.5 Å increments.
        """
        from morfeus import BuriedVolume

        radii = [2.5, 3.5, 4.5]

        for radius in radii:
            if self.ipso_carbon:
                bv = BuriedVolume(
                    self.ligand_elements,
                    self.ligand_coords,
                    self.ipso_carbon_index + 1,
                    radius=radius,
                )

                self.steric_descriptors[f"buried_volume_ipso_{radius}A"] = bv.buried_volume
            else:
                self.steric_descriptors[f"buried_volume_ipso_{radius}A"] = np.NaN

    def quadrant_buried_volume_ligand(self):
        """Calculate the quadrant buried volume of the ligand.

        This follows the same procedure as done for the Kraken database.
        """
        from morfeus import BuriedVolume

        quadrant_buried_volumes = []
        max_buried_volume_deltas = []

        quadrant_total_volumes = []
        max_total_volume_deltas = []

        for index in self.ligand_phosphine_groups:
            qbv = BuriedVolume(
                self.ligand_elements,
                self.ligand_coords,
                self.ligand_phosphine_index + 1,
                xz_plane_atoms=[self.ligand_phosphine_index + 1, index + 1],
            )

            qbv.octant_analysis()

            buried_volumes = qbv.quadrants["buried_volume"].values()
            quadrant_buried_volumes += [*buried_volumes]
            max_buried_volume_deltas.append(max(buried_volumes) - min(buried_volumes))

            qtv = BuriedVolume(
                self.ligand_elements,
                self.ligand_coords,
                self.ligand_phosphine_index + 1,
                xz_plane_atoms=[self.ligand_phosphine_index + 1, index + 1],
                radius=10,
            )

            qtv.octant_analysis()

            total_volumes = qtv.quadrants["buried_volume"].values()
            quadrant_total_volumes += [*total_volumes]
            max_total_volume_deltas.append(max(total_volumes) - min(total_volumes))

        self.steric_descriptors["quadrant_buried_volume_ligand_max"] = max(quadrant_buried_volumes)
        self.steric_descriptors["quadrant_buried_volume_ligand_min"] = min(quadrant_buried_volumes)
        self.steric_descriptors["quadrant_buried_volume_ligand_range_max"] = max(max_buried_volume_deltas)

        self.steric_descriptors["quadrant_total_volume_ligand_max"] = max(quadrant_total_volumes)
        self.steric_descriptors["quadrant_total_volume_ligand_min"] = min(quadrant_total_volumes)
        self.steric_descriptors["quadrant_total_volume_ligand_range_max"] = max(max_total_volume_deltas)

    def solid_angle(self):
        """Calculate the solid angle of the ligand.

        Updates:
            steric_descriptors (dict): Adds 'solid_angle' and 'cone_angle'.
        """
        from morfeus import SolidAngle

        solid_angle = SolidAngle(
            self.Pd_ligand_elements,
            self.Pd_ligand_coords,
            self.ligand_metal_index + 1,
        )

        self.steric_descriptors["solid_angle"] = solid_angle.solid_angle
        self.steric_descriptors["cone_angle"] = solid_angle.cone_angle

    def dispersion(self):
        """Calculate the dispersion (P_int) of the ligand.

        Updates:
            steric_descriptors (dict): Adds 'P_int_ligand'.
        """
        from morfeus import Dispersion

        disp = Dispersion(self.ligand_elements, self.ligand_coords)

        self.steric_descriptors["P_int_ligand"] = disp.p_int
        self.steric_descriptors["P_int_phosphine"] = disp.atom_p_int[self.ligand_phosphine_index + 1]

        if self.ipso_carbon:
            self.steric_descriptors["P_int_ipso_carbon"] = disp.atom_p_int[self.ipso_carbon_index + 1]
        else:
            self.steric_descriptors["P_int_ipso_carbon"] = np.NaN

    def pyramidalization(self):
        """Calculate the pyramidalization of the ligand.

        Updates:
            steric_descriptors (dict): Adds 'pyramidalization_P' and 'pyramidalization_alpha'.
        """
        from morfeus import Pyramidalization

        pyramidalization = Pyramidalization(
            self.ligand_coords,
            self.ligand_phosphine_index + 1,
            elements=self.ligand_elements,
            method="connectivity",
        )

        self.steric_descriptors["pyramidalization_P"] = pyramidalization.P
        self.steric_descriptors["pyramidalization_alpha"] = pyramidalization.alpha

    def solvent_accessible_surface_area(self):
        """Calculate the solvent accessible surface area of the catalyst.

        Updates:
            steric_descriptors (dict): Adds 'sasa_ligand_area' and 'sasa_ligand_volume'.
        """
        from morfeus import SASA

        sasa = SASA(self.ligand_elements, self.ligand_coords)

        self.steric_descriptors["sasa_ligand_area"] = sasa.area
        self.steric_descriptors["sasa_ligand_volume"] = sasa.volume
        self.steric_descriptors["sasa_phosphine_area"] = sasa.atom_areas[self.ligand_phosphine_index + 1]

    def sterimol(self):
        """Calculate Sterimol and buried Sterimol values for the phosphine ligand.

        Updates:
            steric_descriptors (dict): Adds 'sterimol_L', 'sterimol_B1_value',
                                      'sterimol_B5_value', 'buried_sterimol_L',
                                      'buried_sterimol_B1_value', and 'buried_sterimol_B5_value'.
        """
        from morfeus import Sterimol

        sterimol = Sterimol(
            self.morfeus_elements,
            self.morfeus_coords,
            self.metal_index + 1,
            self.phosphine_index + 1,
        )

        self.steric_descriptors["sterimol_ligand_L"] = sterimol.L_value
        self.steric_descriptors["sterimol_ligand_B1"] = sterimol.B_1_value
        self.steric_descriptors["sterimol_ligand_B5"] = sterimol.B_5_value

        sterimol.bury(method="delete")

        self.steric_descriptors["buried_sterimol_ligand_L"] = sterimol.L_value
        self.steric_descriptors["buried_sterimol_ligand_B1"] = sterimol.B_1_value
        self.steric_descriptors["buried_sterimol_ligand_B5"] = sterimol.B_5_value

    def distances(self):
        """Calculate various atomic distances.

        Updates:
            steric_descriptors (dict): Adds various distance attributes.
        """
        self.steric_descriptors["distance_Pd_P"] = self.catalyst.distance(self.metal, self.phosphine)
        self.steric_descriptors["distance_Pd_N"] = self.catalyst.distance(self.metal, self.amine_nitrogen)
        self.steric_descriptors["distance_Pd_C"] = self.catalyst.distance(self.metal, self.aryl_carbon)
        self.steric_descriptors["distance_Pd_O"] = self.catalyst.distance(self.metal, self.carboxylic_oxygen)

        # Distances for the amine ligand
        self.steric_descriptors["distance_avg_amine_N_H"] = round(
            sum([self.catalyst.distance(self.amine_nitrogen, x) for x in self.amine_protons]) / 2, 5
        )
        self.steric_descriptors["distance_amine_N_C"] = self.catalyst.distance(self.amine_nitrogen, self.amine_carbon)

        # Distances for carboxylate
        self.steric_descriptors["distance_carb_C_OH"] = self.catalyst.distance(
            self.carbonyl_carbon, self.carboxylic_oxygen
        )
        self.steric_descriptors["distance_carb_C_O"] = self.catalyst.distance(
            self.carbonyl_carbon, self.carbonyl_oxygen
        )
        self.steric_descriptors["distance_carb_OH_O"] = self.catalyst.distance(
            self.carboxylic_oxygen, self.carbonyl_oxygen
        )

    ##############################
    # Electronic Descriptors
    ##############################

    def homo_lumo(self):
        """Get the HOMO-LUMO gap from xTB-PTB.

        Updates:
            electronic_descriptors (dict): Adds 'homo_lumo'.
        """
        self.electronic_descriptors["homo_lumo"] = self.ptb_json["HOMO-LUMO gap / eV"]

    def charges(self):
        """Get the partial charges of the ligands and selected substituents.

        Updates:
            electronic_descriptors (dict): Adds various partial charge attributes.
        """
        charges = self.ptb_json["partial charges"]

        self.electronic_descriptors["partial_charge_metal"] = charges[self.metal_index]
        self.electronic_descriptors["partial_charge_phosphine"] = charges[self.phosphine_index]
        self.electronic_descriptors["partial_charge_amine_nitrogen"] = charges[self.amine_nitrogen_index]
        self.electronic_descriptors["partial_charge_carbon"] = charges[self.aryl_carbon_index]
        self.electronic_descriptors["partial_charge_carboxylic_oxygen"] = charges[self.carboxylic_oxygen_index]
        self.electronic_descriptors["partial_charge_carbonyl_oxygen"] = charges[self.carbonyl_oxygen_index]
        self.electronic_descriptors["partial_charge_avg_amine_proton"] = round(
            sum([charges[x] for x in self.amine_proton_indices]) / 2, 5
        )
        self.electronic_descriptors["partial_charge_amine_carbon"] = charges[self.amine_carbon_index]
        self.electronic_descriptors["partial_charge_carbonyl_carbon"] = charges[self.carbonyl_carbon_index]

        if self.ipso_carbon:
            self.electronic_descriptors["partial_charge_ipso_carbon"] = charges[self.ipso_carbon_index + 76]
        else:
            self.electronic_descriptors["partial_charge_ipso_carbon"] = np.NaN

    def fukui(self):
        """Get the Fukui function values of the ligands and selected substituents.

        Updates:
            electronic_descriptors (dict): Adds various Fukui function attributes.
        """
        fukui_values = self._get_fukui_values(self.metal_index)
        self.electronic_descriptors["fukui_f_plus_metal"] = fukui_values[0]
        self.electronic_descriptors["fukui_f_minus_metal"] = fukui_values[1]
        self.electronic_descriptors["fukui_f_zero_metal"] = fukui_values[2]

        fukui_values = self._get_fukui_values(self.amine_nitrogen_index)
        self.electronic_descriptors["fukui_f_plus_amine_nitrogen"] = fukui_values[0]
        self.electronic_descriptors["fukui_f_minus_amine_nitrogen"] = fukui_values[1]
        self.electronic_descriptors["fukui_f_zero_amine_nitrogen"] = fukui_values[2]

        fukui_values = self._get_fukui_values(self.aryl_carbon_index)
        self.electronic_descriptors["fukui_f_plus_aryl_carbon"] = fukui_values[0]
        self.electronic_descriptors["fukui_f_minus_aryl_carbon"] = fukui_values[1]
        self.electronic_descriptors["fukui_f_zero_aryl_carbon"] = fukui_values[2]

        fukui_values = self._get_fukui_values(self.carboxylic_oxygen_index)
        self.electronic_descriptors["fukui_f_plus_carb_oh"] = fukui_values[0]
        self.electronic_descriptors["fukui_f_minus_carb_oh"] = fukui_values[1]
        self.electronic_descriptors["fukui_f_zero_carb_oh"] = fukui_values[2]

        fukui_values = self._get_fukui_values(self.carbonyl_oxygen_index)
        self.electronic_descriptors["fukui_f_plus_carb_o"] = fukui_values[0]
        self.electronic_descriptors["fukui_f_minus_carb_o"] = fukui_values[1]
        self.electronic_descriptors["fukui_f_zero_carb_o"] = fukui_values[2]

        fukui_values_0 = self._get_fukui_values(self.amine_proton_indices[0])
        fukui_values_1 = self._get_fukui_values(self.amine_proton_indices[1])
        self.electronic_descriptors["fukui_f_plus_avg_amine_proton"] = round(
            ((fukui_values_0[0] + fukui_values_1[0]) / 2), 5
        )
        self.electronic_descriptors["fukui_f_minus_avg_amine_proton"] = round(
            ((fukui_values_0[1] + fukui_values_1[1]) / 2), 5
        )
        self.electronic_descriptors["fukui_f_zero_avg_amine_proton"] = round(
            ((fukui_values_0[2] + fukui_values_1[2]) / 2), 5
        )

        if self.ipso_carbon:
            fukui_values = self._get_fukui_values(self.ipso_carbon_index + 76)
            self.electronic_descriptors["fukui_f_plus_ipso_carbon"] = fukui_values[0]
            self.electronic_descriptors["fukui_f_minus_ipso_carbon"] = fukui_values[1]
            self.electronic_descriptors["fukui_f_zero_ipso_carbon"] = fukui_values[2]
        else:
            self.electronic_descriptors["fukui_f_plus_ipso_carbon"] = np.NaN
            self.electronic_descriptors["fukui_f_minus_ipso_carbon"] = np.NaN
            self.electronic_descriptors["fukui_f_zero_ipso_carbon"] = np.NaN

    def global_electrophilicity(self):
        """Calculate the global electrophilicity index.

        Updates:
            electronic_descriptors (dict): Adds 'global_electrophilicity'.
        """
        VOMEGA_START = "Calculation of global electrophilicity index (IP+EA)²/(8·(IP-EA))"
        VOMEGA_END = "          |                Property Printout                |"

        lines = self._in_between(self.xyz_file.parent / "vomega.out", VOMEGA_START, VOMEGA_END)

        self.electronic_descriptors["global_electrophilicity"] = float(lines[1].split()[-1])

    def global_nucleophilicity(self):
        """Calculate the global nucleophilicity index.

        Updates:
            electronic_descriptors (dict): Adds 'global_nucleophilicity'.
        """
        VIPEA_START = "empirical IP shift (eV)"
        VIPEA_END = "delta SCC IP (eV):"

        lines = self._in_between(self.xyz_file.parent / "vipea.out", VIPEA_START, VIPEA_END)

        self.electronic_descriptors["global_nucleophilicity"] = -float(lines[1].split()[-1])

    def tolman_electronic_parameter(self):
        """Calculate the Tolman electronic parameter.

        Updates:
            electronic_descriptors (dict): Adds 'tolman_electronic_parameter'.
        """
        import json

        hess_json = self.xyz_file.parents[2] / "tolman" / "hess.xtbout.json"
        assert hess_json.exists(), "Hessian file not found."

        with open(hess_json) as file:
            freqs = json.load(file)["vibrational frequencies / rcm"]

        closest_freq = min(freqs, key=lambda x: abs(x - 2143))

        self.electronic_descriptors["tolman_electronic_parameter"] = closest_freq


In [14]:
import pandas as pd


class CatalystConformers:
    """For a given catalyst, find all conformers and calculate Boltzmann weighted descriptors.

    Attributes:
        name (str): Name of the catalyst.
        split_dir (Path): Path to the directory containing the conformer splits.
        crest_outfile (Path): Path to the CREST output file.
        conformers (list): List of Catalyst objects representing each conformer.
        boltzmann_weights (list): List of Boltzmann weights for each conformer.
        boltzmann_weighted_descriptors (dict): Dictionary of Boltzmann-weighted descriptors.
    """

    def __init__(self, name: str, split_dir: Path, crest_outfile: Path, calculate_descriptors: bool = True):
        """Initialize the CatalystConformers object.

        Args:
            name (str): Name of the catalyst.
            split_dir (Path): Path to the directory containing the conformer splits.
            crest_outfile (Path): Path to the CREST output file.
            calculate_descriptors (bool): Whether to calculate descriptors. Defaults to True.
        """
        self.name = name
        self.split_dir = split_dir
        self.crest_outfile = crest_outfile
        self.calculate_descriptors = calculate_descriptors

        self._find_all_conformers()
        self._parse_crest_outfile()

        assert len(self.conformers) == len(self.boltzmann_weights), "Number of conformers and weights do not match."

        self._calculate_boltz_descriptors()
        self._calculate_co_of_var()
        self._calculate_min_max_descriptors()

    def __str__(self):
        """Returns the catalyst name."""
        return f"CatalystConformers: {self.name}"

    def __repr__(self):
        """Returns the catalyst name."""
        return self.__str__()

    def _find_all_conformers(self):
        """Find all conformers in the split directory."""
        conformers_xyzs = {
            int("".join(filter(str.isdigit, folder.parent.name))): folder
            for folder in self.split_dir.rglob("opt.xtbopt.xyz")
        }
        conformers_xyzs = {k: v for k, v in sorted(conformers_xyzs.items())}
        self.conformers = [
            Catalyst(f"{self.name}_conformer-{i:03}", xyz, calculate_descriptors=self.calculate_descriptors)
            for i, xyz in conformers_xyzs.items()
        ]

    def _parse_crest_outfile(self):
        """Parse the CREST output file to extract the energies and rankings of the conformers."""
        START = "       Erel/kcal        Etot weight/tot  conformer     set   degen     origin"
        END = "T /K"

        if len(self.conformers) == 1:
            self.boltzmann_weights = [1.0]
            return

        lines = self._in_between(self.crest_outfile, START, END)
        cleaned = []
        for line in lines[1:]:
            try:
                cleaned.append(line.strip().split()[4])
            except IndexError:
                pass
        self.boltzmann_weights = [float(x) for x in cleaned]

    @staticmethod
    def _in_between(text_file: str | Path, start_keyword: str, end_keyword: str) -> list:
        """Returns text between two strings in a file.

        Takes a list of text and returns the list containing entries between start_keyword
        and end_keyword (inclusive).

        Args:
            text_file (str | Path): Path to the text file to parse.
            start_keyword (str): String to start text saving.
            end_keyword (str): String to end text saving.

        Returns:
            list: List of each line from the parsed text.
        """
        output = []
        parsing = False
        with open(text_file) as f:
            for line in f:
                if start_keyword in line:
                    parsing = True
                if parsing:
                    output.append(line)
                if end_keyword in line:
                    parsing = False
        return output

    def _calculate_boltz_descriptors(self):
        self.descriptors = pd.DataFrame([x.descriptors for x in self.conformers])

        df = self.descriptors.drop(columns=["ligand_1_name"])
        weighted_df = df.mul(self.boltzmann_weights, axis=0).sum() / sum(self.boltzmann_weights)
        weighted_df = weighted_df.to_frame().T
        weighted_df["num_conformers"] = len(self.conformers)

        self.boltzmann_weighted_descriptors = {"ligand_1_name": self.name} | weighted_df.to_dict(orient="records")[0]

    def _calculate_co_of_var(self):
        from scipy.stats import variation

        self.variations = self.descriptors.drop(columns=["ligand_1_name"]).apply(variation, axis=0)

    def _calculate_min_max_descriptors(self):
        self.min_descriptors = self.descriptors.drop(columns=["ligand_1_name"]).min(axis=0).to_frame().T
        self.max_descriptors = self.descriptors.drop(columns=["ligand_1_name"]).max(axis=0).to_frame().T

        # Add min or max to the title of each descriptor
        self.min_descriptors.columns = [f"min_{x}" for x in self.min_descriptors.columns]
        self.max_descriptors.columns = [f"max_{x}" for x in self.max_descriptors.columns]

        self.min_max_descriptors = pd.concat([self.min_descriptors, self.max_descriptors], axis=1)
        self.full_descriptors = pd.concat(
            [
                pd.DataFrame(self.boltzmann_weighted_descriptors, index=[0]),
                self.min_max_descriptors,
            ],
            axis=1,
        )

In [15]:
class CatalystLibrary:
    """Class to represent a library of catalysts and their conformers."""

    def __init__(
        self,
        root_dir: Path,
        split_dir_name: str = "featurize_ensemble-long",
        crest_dir_name: str = "crest-long",
        n_workers: int = 22,
        calculate_descriptors: bool = True,
    ):
        """Initialize the CatalystLibrary object.

        Args:
            root_dir (Path): Path to the directory containing all catalyst directories.
            split_dir_name (str): Name of the directory containing the conformer splits.
            crest_dir_name (str): Name of the directory containing the CREST output files.
            n_workers (int): Number of workers to use for parallel processing. Defaults to 22.
            calculate_descriptors (bool): Whether to calculate descriptors for each conformer. Defaults to True.
        """
        self.root_dir = root_dir
        self.split_dir_name = split_dir_name
        self.crest_dir_name = crest_dir_name
        self.n_workers = n_workers
        self.calculate_descriptors = calculate_descriptors

        self._find_all_catalysts()

    def _process_catalyst(self, name, split_dir, crest_outfile):
        """Process a single catalyst."""
        try:
            return CatalystConformers(name, split_dir, crest_outfile, calculate_descriptors=self.calculate_descriptors)
        except Exception as e:
            print(f"Error for {name}: {e}")
            return

    def _find_all_catalysts(self):
        """Find all catalysts in the root directory."""
        from tqdm.contrib.concurrent import process_map

        split_dirs = sorted([x for x in self.root_dir.glob(f"*/{self.split_dir_name}")])
        crest_outfiles = [x.parent / f"{self.crest_dir_name}/screen/crest.out" for x in split_dirs]
        names = [x.parent.name for x in split_dirs]

        assert len(split_dirs) == len(crest_outfiles), "Number of splits and CREST output files do not match."
        assert len(split_dirs) == len(names), "Number of split directories and names do not match."

        self.catalysts = process_map(
            self._process_catalyst,
            names,
            split_dirs,
            crest_outfiles,
            max_workers=self.n_workers,
        )

    @property
    def boltzmann_weighted_descriptors(self):
        """Get the Boltzmann-weighted descriptors for all catalysts."""
        return pd.DataFrame([x.boltzmann_weighted_descriptors for x in self.catalysts])

    @property
    def full_descriptors(self):
        """Get the full descriptors for all catalysts."""
        return pd.concat([x.full_descriptors for x in self.catalysts])

    @property
    def descriptors(self):
        """Get the descriptors for all catalysts."""
        return pd.concat([x.descriptors for x in self.catalysts])

    def __len__(self):
        """Return the number of catalysts in the library."""
        return len(self.catalysts)

    def get_catalyst(self, name: str):
        """Get a single catalyst from the library."""
        catalyst = [x for x in self.catalysts if x.name == name][0]
        if catalyst:
            return catalyst
        else:
            raise ValueError(f"Catalyst {name} not found in library.")

In [None]:
import pickle

root_dir = Path("data/dft/qsar/benzylamine/struc_gen")

library = CatalystLibrary(root_dir, n_workers=22, calculate_descriptors=True)

with open("data/dft/qsar/benzylamine/catalyst_library.pkl", "wb") as f:
    pickle.dump(library, f)

  0%|          | 0/125 [00:00<?, ?it/s]

In [None]:
descriptor_df = library.full_descriptors
data = pd.read_csv("data/ligand-qsar/alkylamine-hte-ligand-data.tsv", sep="\t")

merged = pd.merge(descriptor_df, data[["ligand_1_name", "product_1_yield", "buchwald-type"]], on="ligand_1_name")

merged.to_csv("data/ligand-qsar/raw/alkylamine-ligand-modeling-unprocessed.tsv", sep="\t", index=False)