In [24]:
"""
MMP-Based Data Augmentation for Drug Discovery
A comprehensive class for expanding molecular datasets using Matched Molecular Pairs
"""

import os
import pandas as pd
import numpy as np
from rdkit import Chem, RDLogger
from rdkit.Chem import AllChem, inchi
import subprocess
import tempfile
import base64
import warnings
from tqdm import tqdm
from typing import List, Dict, Optional, Union
from collections import defaultdict
from itertools import combinations
from scipy.stats import pearsonr
import time


# Suppress RDKit warnings
RDLogger.DisableLog('rdApp.*')
warnings.filterwarnings("ignore", category=UserWarning, module="rdkit")
warnings.filterwarnings("ignore", category=FutureWarning, module="rdkit")
warnings.filterwarnings("ignore", category=DeprecationWarning, module="rdkit")


class MMPDataAugmentor:
    """
    Complete pipeline for MMP-based data augmentation.
    
    This class handles:
    1. Molecule fragmentation
    2. MMP identification
    3. Data augmentation through matched pairs
    4. Chemical structure generation
    5. Quality control and filtering
    
    Parameters
    ----------
    df : pd.DataFrame
        Input dataframe containing molecules and properties
    smiles_col : str
        Name of the column containing SMILES strings
    target_cols : Union[str, List[str]]
        Name(s) of target column(s) to augment
    mmpa_dir : str, default='./mmpa'
        Path to MMPA scripts directory
    symmetric : bool, default=True
        Generate symmetric MMPs (A->B and B->A)
    max_heavy : int, default=15
        Maximum heavy atom change allowed
    max_ratio : float, default=0.35
        Maximum ratio of change relative to molecule size
    min_common : int, default=4
        Minimum common MMPs required between scaffold pairs
    pearson_thresh : float, default=0.3
        Minimum Pearson correlation for scaffold pairing
    crmsd_thresh : float, default=0.8
        Maximum cRMSD for scaffold pairing
    std_threshold : float, default=0.8
        Maximum standard deviation for keeping augmented data
    verbose : bool, default=True
        Print progress messages
    """
    
    def __init__(
        self,
        df: pd.DataFrame,
        smiles_col: str = "SMILES",
        target_cols: Union[str, List[str]] = "Y",
        mmpa_dir: str = '../mmpa',
        symmetric: bool = True,
        max_heavy: int = 15,
        max_ratio: float = 0.35,
        min_common: int = 4,
        pearson_thresh: float = 0.3,
        crmsd_thresh: float = 0.8,
        std_threshold: float = 0.8,
        verbose: bool = True
    ):
        self.df_original = df.copy()
        self.smiles_col = smiles_col
        self.target_cols = [target_cols] if isinstance(target_cols, str) else target_cols
        self.mmpa_dir = mmpa_dir
        self.symmetric = symmetric
        self.max_heavy = max_heavy
        self.max_ratio = max_ratio
        self.min_common = min_common
        self.pearson_thresh = pearson_thresh
        self.crmsd_thresh = crmsd_thresh
        self.std_threshold = std_threshold
        self.verbose = verbose
        
        # Initialize storage for results
        self.augmented_df = None
        self.statistics = {}
        
        # Validate inputs
        self._validate_inputs()
    
    def _validate_inputs(self):
        """Validate input parameters and data"""
        if self.smiles_col not in self.df_original.columns:
            raise ValueError(f"SMILES column '{self.smiles_col}' not found in dataframe")
        
        for col in self.target_cols:
            if col not in self.df_original.columns:
                raise ValueError(f"Target column '{col}' not found in dataframe")
        
        if not os.path.exists(self.mmpa_dir):
            raise ValueError(f"MMPA directory '{self.mmpa_dir}' not found")
    
    def _log(self, message: str):
        """Print message if verbose mode is enabled"""
        if self.verbose:
            print(f"[MMPAugmentor] {message}")
    
    @staticmethod
    def _encode_string(s: str) -> str:
        """Encode string to base64"""
        return base64.urlsafe_b64encode(s.encode()).decode()
    
    @staticmethod
    def _smiles_to_inchikey(smiles: str) -> Optional[str]:
        """Convert SMILES to InChIKey"""
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            return None
        try:
            return inchi.MolToInchiKey(mol)
        except:
            return None
    
    def _fragment_molecules(self, output_csv: str):
        """
        Fragment molecules using MMPA rfrag.py
        
        Parameters
        ----------
        output_csv : str
            Path to output CSV file
        """
        self._log("0) Generating molecular fragments...")
        
        # Prepare data for fragmentation
        df_prep = self.df_original[[self.smiles_col] + self.target_cols].copy()
        df_prep['ID'] = [self._smiles_to_inchikey(smi) for smi in df_prep[self.smiles_col]]
        
        # Create temporary directory
        with tempfile.TemporaryDirectory() as tmp:
            smi_path = os.path.join(tmp, 'input.smi')
            frag_path = os.path.join(tmp, 'fragmented.txt')
            mmps_path = os.path.join(tmp, 'mmps.csv')
            smirks_path = os.path.join(tmp, 'smirks.txt')
            cansmirks_path = os.path.join(tmp, 'cansmirks.txt')
            
            # Write SMILES file
            df_prep[[self.smiles_col, 'ID']].to_csv(
                smi_path, index=False, sep=' ', header=False
            )
            
            # Fragment molecules
            with open(frag_path, 'w') as out:
                subprocess.run(
                    ['python', f'{self.mmpa_dir}/rfrag.py'],
                    stdin=open(smi_path),
                    stdout=out
                )
            
            self._log("1) Indexing fragments and generating MMPs...")
            
            # Build MMPs
            cmd = ['python', f'{self.mmpa_dir}/indexing.py']
            if self.symmetric:
                cmd.append('-s')
            if self.max_heavy:
                cmd.extend(['-m', str(self.max_heavy)])
            if self.max_ratio:
                cmd.extend(['-r', str(self.max_ratio)])
            
            with open(mmps_path, 'w') as out:
                subprocess.run(cmd, stdin=open(frag_path), stdout=out)
            
            # Read and process MMPs
            with open(mmps_path) as f:
                lines = [line.strip() for line in f if line.strip()]
            
            splits = [line.split(',') for line in lines]
            df_mmps = pd.DataFrame(
                splits,
                columns=['L_SMILES', 'R_SMILES', 'L_ID', 'R_ID', 'SMIRKS', 'CORE']
            )
            
            # Map target values
            for target_col in self.target_cols:
                y_map = df_prep.set_index('ID')[target_col].to_dict()
                df_mmps[f'L_{target_col}'] = df_mmps['L_ID'].map(y_map)
                df_mmps[f'R_{target_col}'] = df_mmps['R_ID'].map(y_map)
                df_mmps[f'Delta_{target_col}'] = df_mmps[f'R_{target_col}'] - df_mmps[f'L_{target_col}']
            
            # Filter valid SMIRKS
            df_mmps = df_mmps[df_mmps['SMIRKS'].apply(
                lambda x: isinstance(x, str) and '>>' in x
            )]
            
            self._log("2) Canonicalizing SMIRKS...")
            
            # Canonicalize SMIRKS
            df_mmps['__row'] = range(len(df_mmps))
            df_mmps[['SMIRKS', '__row']].to_csv(
                smirks_path, index=False, sep=' ', header=False
            )
            
            with open(cansmirks_path, 'w') as out:
                subprocess.run(
                    ['python', f'{self.mmpa_dir}/cansmirk.py'],
                    stdin=open(smirks_path),
                    stdout=out
                )
            
            canon_df = pd.read_csv(
                cansmirks_path, sep=' ', names=['Canonical_SMIRKS', 'index']
            )
            
            df_mmps = df_mmps.merge(
                canon_df, left_on='__row', right_on='index'
            ).drop(columns=['__row', 'index'])
            
            # Split canonical SMIRKS
            df_mmps[['L_sub', 'R_sub']] = df_mmps['Canonical_SMIRKS'].str.split(
                '>>', expand=True
            )
            
            # Add encoded IDs
            df_mmps['L_sub_ID'] = [self._encode_string(k) for k in df_mmps['L_sub']]
            df_mmps['R_sub_ID'] = [self._encode_string(k) for k in df_mmps['R_sub']]
            df_mmps['SMIRKS_ID'] = [self._encode_string(k) for k in df_mmps['Canonical_SMIRKS']]
            df_mmps['CORE_ID'] = [self._encode_string(k) for k in df_mmps['CORE']]
            
            df_mmps = df_mmps.drop_duplicates()
            
            # Save to output
            df_mmps.to_csv(output_csv, index=False)
        
        return df_mmps
    
    def _augment_data(self, df_mmps: pd.DataFrame) -> pd.DataFrame:
        """
        Augment data using matched molecular pairs
        
        Parameters
        ----------
        df_mmps : pd.DataFrame
            DataFrame containing MMPs
            
        Returns
        -------
        pd.DataFrame
            Augmented dataset
        """
        self._log("3) Computing pairwise scaffold correlations...")
        
        # Group by CORE
        series = {core: group for core, group in df_mmps.groupby("CORE")}
        
        # Compute pairwise scores
        pair_scores = []
        series_items = list(series.items())
        total_combinations = len(series_items) * (len(series_items) - 1) // 2
        
        for (core1, df1), (core2, df2) in tqdm(
            combinations(series_items, 2),
            total=total_combinations,
            desc="Computing correlations",
            disable=not self.verbose
        ):
            subs1 = set(df1["L_sub"])
            subs2 = set(df2["L_sub"])
            common = subs1 & subs2
            
            if len(common) < self.min_common:
                continue
            
            merged = pd.merge(
                df1, df2,
                left_on=["L_sub", "R_sub"],
                right_on=["L_sub", "R_sub"],
                suffixes=('_1', '_2')
            )
            
            if len(merged) < self.min_common:
                continue
            
            # Use first target column for filtering
            target_col = self.target_cols[0]
            y1 = merged[f'Delta_{target_col}_1'].values
            y2 = merged[f'Delta_{target_col}_2'].values
            crmsd = np.sqrt(np.mean((y1 - y2) ** 2))
            
            try:
                corr = pearsonr(y1, y2)[0]
            except:
                corr = np.nan
            
            pair_scores.append((core1, core2, crmsd, corr, len(merged)))
        
        # Filter pairs
        filtered_pairs = [
            (s1, s2) for s1, s2, rmsd, corr, n in pair_scores
            if rmsd <= self.crmsd_thresh and (not np.isnan(corr) and corr >= self.pearson_thresh)
        ]
        
        self._log(f"4) Found {len(filtered_pairs)} valid scaffold pairs. Generating augmented data...")
        
        # Augment data
        augmented_entries = []
        
        for s1, s2 in tqdm(filtered_pairs, desc="Augmenting", disable=not self.verbose):
            df1 = series[s1]
            df2 = series[s2]
            
            # Create dictionaries for fast lookup
            df1_dict = defaultdict(list)
            for _, row in df1.iterrows():
                df1_dict[row["L_sub"]].append(row.to_dict())
            
            df2_dict = defaultdict(list)
            for _, row in df2.iterrows():
                df2_dict[row["L_sub"]].append(row.to_dict())
            
            # Generate augmented entries
            for target_col in self.target_cols:
                tf1 = df1[["L_sub", "R_sub", f"Delta_{target_col}"]].to_dict("records")
                tf2 = df2[["L_sub", "R_sub", f"Delta_{target_col}"]].to_dict("records")
                
                for entry in tf1:
                    l_sub = entry["L_sub"]
                    for base in df2_dict.get(l_sub, []):
                        r_sub = entry["R_sub"]
                        delta = entry[f"Delta_{target_col}"]
                        new_y = base[f"L_{target_col}"] + delta
                        smirks_new = l_sub + ">>" + r_sub
                        
                        aug_entry = {
                            "CORE": s1,
                            "L_sub": l_sub,
                            "R_sub": r_sub,
                            f"L_{target_col}": base[f"L_{target_col}"],
                            f"R_{target_col}": new_y,
                            f"Delta_{target_col}": delta,
                            "AUG": True,
                            "L_SMILES": base.get("L_SMILES"),
                            "L_ID": base.get("L_ID"),
                            "L_sub_ID": base.get("L_sub_ID"),
                            "R_sub_ID": self._encode_string(r_sub),
                            "SMIRKS": smirks_new,
                            "SMIRKS_ID": self._encode_string(smirks_new),
                            "CORE_ID": self._encode_string(s1)
                        }
                        augmented_entries.append(aug_entry)
                
                for entry in tf2:
                    l_sub = entry["L_sub"]
                    for base in df1_dict.get(l_sub, []):
                        r_sub = entry["R_sub"]
                        delta = entry[f"Delta_{target_col}"]
                        new_y = base[f"L_{target_col}"] + delta
                        smirks_new = l_sub + ">>" + r_sub
                        
                        aug_entry = {
                            "CORE": s2,
                            "L_sub": l_sub,
                            "R_sub": r_sub,
                            f"L_{target_col}": base[f"L_{target_col}"],
                            f"R_{target_col}": new_y,
                            f"Delta_{target_col}": delta,
                            "AUG": True,
                            "L_SMILES": base.get("L_SMILES"),
                            "L_ID": base.get("L_ID"),
                            "L_sub_ID": base.get("L_sub_ID"),
                            "R_sub_ID": self._encode_string(r_sub),
                            "SMIRKS": smirks_new,
                            "SMIRKS_ID": self._encode_string(smirks_new),
                            "CORE_ID": self._encode_string(s2)
                        }
                        augmented_entries.append(aug_entry)
        
        augmented_df = pd.DataFrame(augmented_entries)
        
        # Mark original data
        df_mmps["AUG"] = False
        
        # Combine original and augmented
        combined = pd.concat([df_mmps, augmented_df], ignore_index=True)
        
        return combined
    
    def _apply_transformations(self, df: pd.DataFrame) -> pd.DataFrame:
        """
        Apply chemical transformations to generate R_SMILES
        
        Parameters
        ----------
        df : pd.DataFrame
            DataFrame with augmented data
            
        Returns
        -------
        pd.DataFrame
            DataFrame with applied transformations
        """
        self._log("5) Applying chemical transformations...")
        
        df = df.drop_duplicates()
        mask = df["AUG"] == True
        indices = df[mask].index
        
        rxn_cache = {}
        heavy_cache = {}
        core_cache = {}
        
        new_rows = []
        failure_tracker = {"total": 0, "empty_prodsets": 0}
        
        for idx in tqdm(indices, desc="Transforming", disable=not self.verbose):
            row = df.loc[idx]
            l_smiles = row["L_SMILES"]
            smirks = row["SMIRKS"]
            core = row["CORE"]
            
            r_smiles_list = self._fast_apply_transformation(
                smirks, l_smiles, rxn_cache, core_cache, heavy_cache, core,
                failure_tracker
            )
            
            if r_smiles_list:
                for r_smiles in r_smiles_list:
                    mol = Chem.MolFromSmiles(r_smiles)
                    if mol:
                        r_id = Chem.InchiToInchiKey(Chem.MolToInchi(mol))
                        new_row = row.copy()
                        new_row["R_SMILES"] = r_smiles
                        new_row["R_ID"] = r_id
                        new_rows.append(new_row)
        
        df_non_aug = df[~mask]
        df_aug_expanded = pd.DataFrame(new_rows)
        df_final = pd.concat([df_non_aug, df_aug_expanded], ignore_index=True)
        
        # Report failure rate
        total = failure_tracker["total"]
        failed = failure_tracker["empty_prodsets"]
        if total > 0:
            self._log(
                f"⚠️  Empty product sets in {failed}/{total} "
                f"({100*failed/total:.2f}%) transformations"
            )
        
        return df_final
    
    @staticmethod
    def _fast_apply_transformation(transformation, l_smiles, rxn_cache, 
                                   core_cache, heavy_cache, core_smarts,
                                   failure_tracker=None):
        """Apply SMIRKS transformation with caching"""
        if pd.isna(transformation) or pd.isna(l_smiles):
            return None
        
        # Cache reaction
        if transformation not in rxn_cache:
            try:
                rxn = AllChem.ReactionFromSmarts(transformation)
                left_smi, right_smi = transformation.split(">>")
                left_mol = Chem.MolFromSmarts(left_smi)
                right_mol = Chem.MolFromSmarts(right_smi)
                delta_heavy = right_mol.GetNumHeavyAtoms() - left_mol.GetNumHeavyAtoms()
                rxn_cache[transformation] = (rxn, delta_heavy)
            except:
                return None
        else:
            rxn, delta_heavy = rxn_cache[transformation]
        
        # Cache molecule
        if l_smiles not in heavy_cache:
            mol_l = Chem.MolFromSmiles(l_smiles)
            if mol_l is None:
                return None
            n_heavy_l = mol_l.GetNumHeavyAtoms()
            heavy_cache[l_smiles] = (mol_l, n_heavy_l)
        else:
            mol_l, n_heavy_l = heavy_cache[l_smiles]
        
        # Cache core
        if core_smarts not in core_cache:
            core_mol = Chem.MolFromSmarts(core_smarts)
            if core_mol is None:
                return None
            core_cache[core_smarts] = core_mol
        else:
            core_mol = core_cache[core_smarts]
        
        # Run reaction
        try:
            products = rxn.RunReactants((mol_l,))
        except:
            return None
        
        if failure_tracker is not None:
            failure_tracker["total"] += 1
            if not products:
                failure_tracker["empty_prodsets"] += 1
        
        all_products = []
        for prod_set in products:
            for prod in prod_set:
                if prod is None:
                    continue
                try:
                    if not prod.HasSubstructMatch(core_mol):
                        continue
                except:
                    continue
                n_heavy_r = prod.GetNumHeavyAtoms()
                if n_heavy_r - n_heavy_l != delta_heavy:
                    continue
                all_products.append(Chem.MolToSmiles(prod, isomericSmiles=True))
        
        return all_products if all_products else None
    
    def _prepare_output(self, df: pd.DataFrame) -> pd.DataFrame:
        """
        Prepare final output with proper formatting and statistics
        
        Parameters
        ----------
        df : pd.DataFrame
            Processed dataframe
            
        Returns
        -------
        pd.DataFrame
            Final output dataframe
        """
        self._log("6) Preparing output...")
        
        # Extract L and R entries
        output_rows = []
        
        for target_col in self.target_cols:
            l_df = df[["L_SMILES", f"L_{target_col}", "AUG"]].copy()
            l_df.columns = ["SMILES", target_col, "AUG"]
            
            r_df = df[["R_SMILES", f"R_{target_col}", "AUG"]].copy()
            r_df.columns = ["SMILES", target_col, "AUG"]
            
            combined = pd.concat([l_df, r_df])
            output_rows.append(combined)
        
        # Combine all targets
        clean_df = pd.concat(output_rows, axis=0)
        
        # Remove invalid SMILES
        total_before = len(clean_df)
        clean_df = clean_df.dropna(subset=["SMILES"])
        clean_df = clean_df[clean_df["SMILES"].apply(
            lambda x: Chem.MolFromSmiles(x) is not None
        )]
        total_after = len(clean_df)
        
        fail_pct = 100 * (total_before - total_after) / total_before
        self._log(
            f"Invalid SMILES removed: {total_before - total_after}/{total_before} "
            f"({fail_pct:.2f}%)"
        )
        
        # Standardize SMILES
        self._log("7) Standardizing SMILES...")
        tqdm.pandas(desc="Standardizing", disable=not self.verbose)
        clean_df["SMILES"] = clean_df["SMILES"].progress_apply(
            lambda x: Chem.MolToSmiles(Chem.MolFromSmiles(x), isomericSmiles=True)
        )
        
        # Generate InChIKeys
        self._log("8) Generating InChIKeys...")
        tqdm.pandas(desc="InChIKeys", disable=not self.verbose)
        clean_df['InChIKey'] = clean_df["SMILES"].progress_apply(self._smiles_to_inchikey)
        
        # Separate experimental and predicted
        clean_df_exp = clean_df[clean_df["AUG"] != True].copy()
        clean_df_pred = clean_df[clean_df["AUG"] == True].copy()
        
        # Aggregate predicted values
        output_dfs = []
        for target_col in self.target_cols:
            grouped_pred = clean_df_pred.groupby("InChIKey")[target_col].agg(
                ["median", "std", "mean", "count"]
            ).reset_index()
            grouped_pred.columns = [
                "InChIKey", 
                f"{target_col}_median",
                f"{target_col}_std", 
                f"{target_col}_mean",
                f"{target_col}_count"
            ]
            grouped_pred["AUG"] = True
            
            # Filter by std threshold
            grouped_pred = grouped_pred[
                grouped_pred[f"{target_col}_std"] <= self.std_threshold
            ]
            
            # Add SMILES
            inchikey_to_smiles = clean_df_pred.drop_duplicates("InChIKey").set_index(
                "InChIKey"
            )["SMILES"].to_dict()
            grouped_pred["SMILES"] = grouped_pred["InChIKey"].map(inchikey_to_smiles)
            
            # Experimental data
            exp_out = clean_df_exp[['InChIKey', "SMILES", target_col]].copy()
            exp_out[f"{target_col}_median"] = exp_out[target_col]
            exp_out[f"{target_col}_std"] = None
            exp_out[f"{target_col}_mean"] = exp_out[target_col]
            exp_out[f"{target_col}_count"] = 1
            exp_out["AUG"] = False
            exp_out = exp_out.drop(columns=[target_col])
            
            # Remove predicted entries that exist in experimental
            known_keys = set(exp_out['InChIKey'])
            grouped_pred = grouped_pred[~grouped_pred['InChIKey'].isin(known_keys)]
            
            # Combine
            combined = pd.concat([
                exp_out.drop_duplicates("InChIKey"),
                grouped_pred.drop_duplicates("InChIKey")
            ], ignore_index=True)
            
            output_dfs.append(combined)
        
        # Merge all targets
        final_df = output_dfs[0]
        for df_target in output_dfs[1:]:
            final_df = final_df.merge(df_target, on=["InChIKey", "SMILES", "AUG"], how="outer")
        
        # Calculate statistics
        self.statistics = {
            "n_original": len(clean_df_exp.drop_duplicates("InChIKey")),
            "n_augmented": len(grouped_pred.drop_duplicates("InChIKey")),
            "n_total": len(final_df),
            "augmentation_ratio": len(grouped_pred) / len(clean_df_exp) if len(clean_df_exp) > 0 else 0
        }
        
        self._log(f"✅ Augmentation complete!")
        self._log(f"   Original compounds: {self.statistics['n_original']}")
        self._log(f"   Augmented compounds: {self.statistics['n_augmented']}")
        self._log(f"   Total compounds: {self.statistics['n_total']}")
        self._log(f"   Augmentation ratio: {self.statistics['augmentation_ratio']:.2f}x")
        
        return final_df
    
    def run(self, output_path: Optional[str] = None) -> pd.DataFrame:
        """
        Run the complete augmentation pipeline
        
        Parameters
        ----------
        output_path : str, optional
            Path to save the output CSV/parquet file
            
        Returns
        -------
        pd.DataFrame
            Augmented dataset with statistics
        """
        start_time = time.time()
        
        # Create temporary output file
        with tempfile.NamedTemporaryFile(
            mode='w', suffix='.csv', delete=False
        ) as tmp_file:
            tmp_output = tmp_file.name
        
        try:
            # Step 1: Fragment and generate MMPs
            df_mmps = self._fragment_molecules(tmp_output)
            
            # Step 2: Augment data
            df_augmented = self._augment_data(df_mmps)
            
            # Step 3: Apply transformations
            df_transformed = self._apply_transformations(df_augmented)
            
            # Step 4: Prepare output
            self.augmented_df = self._prepare_output(df_transformed)
            
            # Save if path provided
            if output_path:
                if output_path.endswith('.parquet'):
                    self.augmented_df.to_parquet(output_path, index=False)
                else:
                    self.augmented_df.to_csv(output_path, index=False)
                self._log(f"Output saved to: {output_path}")
            
        finally:
            # Cleanup temporary file
            if os.path.exists(tmp_output):
                os.unlink(tmp_output)
        
        elapsed = time.time() - start_time
        self._log(f"⏱️  Total time: {elapsed:.2f} seconds")
        
        # Get InChIKeys from augmented_df
        augmented_inchikeys = set(self.augmented_df['InChIKey'].dropna())

        # Get InChIKeys from original
        self.df_original['InChIKey'] = self.df_original[self.smiles_col].apply(self._smiles_to_inchikey)
        original_inchikeys = set(self.df_original['InChIKey'].dropna())

        # Find missing compounds
        missing_inchikeys = original_inchikeys - augmented_inchikeys

        if len(missing_inchikeys) > 0:
            self._log(f"Adding {len(missing_inchikeys)} missing original compounds...")

            # Get missing compounds from original
            missing_df = self.df_original[self.df_original['InChIKey'].isin(missing_inchikeys)].copy()

            # Format to match augmented_df structure
            missing_formatted = pd.DataFrame()
            missing_formatted['InChIKey'] = missing_df['InChIKey']
            missing_formatted['SMILES'] = missing_df[self.smiles_col]
            missing_formatted['AUG'] = False

            # Add target columns
            for target_col in self.target_cols:
                missing_formatted[f'{target_col}_median'] = missing_df[target_col]
                missing_formatted[f'{target_col}_mean'] = missing_df[target_col]
                missing_formatted[f'{target_col}_std'] = None
                missing_formatted[f'{target_col}_count'] = 1

            # Concatenate
            self.augmented_df = pd.concat([self.augmented_df, missing_formatted], ignore_index=True)

            # Update statistics
            self.statistics['n_original'] = len(original_inchikeys)        
        
        return self.augmented_df
    
    def get_statistics(self) -> Dict:
        """Return augmentation statistics"""
        return self.statistics
    
    def get_experimental_only(self) -> pd.DataFrame:
        """Return only experimental compounds"""
        if self.augmented_df is None:
            raise ValueError("Run augmentation first using .run()")
        return self.augmented_df[self.augmented_df["AUG"] == False].copy()
    
    def get_augmented_only(self) -> pd.DataFrame:
        """Return only augmented compounds"""
        if self.augmented_df is None:
            raise ValueError("Run augmentation first using .run()")
        return self.augmented_df[self.augmented_df["AUG"] == True].copy()



# Data

In [25]:
df = pd.read_parquet("../data/exp_subset/LogBB.parquet")

In [26]:
print("Size of input dataset: ", len(df))

# Create augmentor with relaxed thresholds for the example
augmentor = MMPDataAugmentor(
    df=df,
    smiles_col="SMILES",
    target_cols=["LogBB"],
    max_heavy=15,          # Allow smaller changes
    max_ratio=0.4,        # More permissive ratio
    min_common=3,         # Lower threshold for common MMPs
    pearson_thresh=0.3,   # Lower correlation threshold
    crmsd_thresh=1.,     # Higher cRMSD tolerance
    std_threshold=1.0,    # Higher std tolerance
    verbose=True
)

# Run augmentation
augmented_df = augmentor.run(output_path="augmented_data.csv")


Size of input dataset:  540
[MMPAugmentor] 0) Generating molecular fragments...
[MMPAugmentor] 1) Indexing fragments and generating MMPs...
[MMPAugmentor] 2) Canonicalizing SMIRKS...




[MMPAugmentor] 3) Computing pairwise scaffold correlations...


  corr = pearsonr(y1, y2)[0]
  corr = pearsonr(y1, y2)[0]
Computing correlations: 100%|██████████| 1734453/1734453 [00:24<00:00, 72048.89it/s]


[MMPAugmentor] 4) Found 43 valid scaffold pairs. Generating augmented data...


Augmenting: 100%|██████████| 43/43 [00:03<00:00, 12.35it/s] 


[MMPAugmentor] 5) Applying chemical transformations...


Transforming: 100%|██████████| 4921/4921 [00:07<00:00, 675.50it/s] 


[MMPAugmentor] ⚠️  Empty product sets in 35/4921 (0.71%) transformations
[MMPAugmentor] 6) Preparing output...
[MMPAugmentor] Invalid SMILES removed: 0/121540 (0.00%)
[MMPAugmentor] 7) Standardizing SMILES...


Standardizing: 100%|██████████| 121540/121540 [00:22<00:00, 5370.18it/s]


[MMPAugmentor] 8) Generating InChIKeys...


InChIKeys: 100%|██████████| 121540/121540 [00:54<00:00, 2250.62it/s]
  combined = pd.concat([


[MMPAugmentor] ✅ Augmentation complete!
[MMPAugmentor]    Original compounds: 459
[MMPAugmentor]    Augmented compounds: 501
[MMPAugmentor]    Total compounds: 960
[MMPAugmentor]    Augmentation ratio: 0.00x
[MMPAugmentor] Output saved to: augmented_data.csv
[MMPAugmentor] ⏱️  Total time: 179.32 seconds
[MMPAugmentor] Adding 81 missing original compounds...


  self.augmented_df = pd.concat([self.augmented_df, missing_formatted], ignore_index=True)


In [33]:
augmented_df

Unnamed: 0,InChIKey,SMILES,LogBB_median,LogBB_std,LogBB_mean,LogBB_count,AUG
0,DOMXUEMWDBAQBQ-UHFFFAOYSA-N,CN(CC=CC#CC(C)(C)C)Cc1cccc2ccccc12,0.090000,,0.090000,1,False
1,AQHHHDLHHXJYJD-UHFFFAOYSA-N,CC(C)NCC(O)COc1cccc2ccccc12,0.640000,,0.640000,1,False
2,NJXRHIKVLFLBME-UHFFFAOYSA-N,O=C(NC1C2CC3CC1CC(O)(C3)C2)c1sc(OCCO)nc1C1CC1,-0.849485,,-0.849485,1,False
3,ACMDXKZEHRFMAQ-UHFFFAOYSA-N,COCCOc1nc(C2CC2)c(C(=O)NC2C3CC4CC2CC(O)(C4)C3)s1,-0.071334,,-0.071334,1,False
4,RPLJDCCDRTXZTG-UHFFFAOYSA-N,O=C(NC1C2CC3CC1CC(O)(C3)C2)c1sc(OC2CCOCC2)nc1C...,-0.071334,,-0.071334,1,False
...,...,...,...,...,...,...,...
1036,VOKSWYLNZZRQPF-UHFFFAOYSA-N,CC(C)=CCN1CCC2(C)C(C)C1Cc1ccc(O)cc21,0.543750,,0.543750,1,False
1037,ZULQKQGLUORZDC-UHFFFAOYSA-N,N#Cc1cc(ccc1)N1CCN(CCN2CCC(CC2)C(F)(F)F)C1=O,0.806180,,0.806180,1,False
1038,LAVIXOKYHRWFDZ-UHFFFAOYSA-N,CC1CCC(C)N1c1ccc([n][n]1)-c1c[n]cc2ccccc21,0.000000,,0.000000,1,False
1039,PZAIVSQWZAHADG-UHFFFAOYSA-N,CC(C)(C(=O)Nc1ccc(c(Cl)c1)N1CCC2(CN(CC2)CC2CC2...,0.361728,,0.361728,1,False
