In [1]:
import os
import gc
import math
import torch
import pandas as pd
from tqdm.auto import tqdm
from datasets import load_from_disk, Dataset
from transformers import (
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    BitsAndBytesConfig
)
from peft import PeftModel
import numpy as np
import traceback
from collections import Counter
import re

# Attempt to import RDKit
try:
    from rdkit import Chem
    from rdkit import RDLogger
    from rdkit.DataStructs import TanimotoSimilarity
    from rdkit.Chem.rdMolDescriptors import GetMorganFingerprintAsBitVect # ECFP
    RDLogger.DisableLog('rdApp.*')
    RDKIT_AVAILABLE = True
    print("RDKit found. Advanced SMILES processing, similarity, and property calculation will be used.")
except ImportError:
    RDKIT_AVAILABLE = False
    print("RDKit not found. SMILES processing will be string-based, no fingerprint or advanced properties.")
    print("For better accuracy and features, please install RDKit: pip install rdkit-pypi")

# --- Configuration ---
BASE_MODEL_NAME = "google/flan-t5-base"
ADAPTER_MODEL_PATH = "./best_model_multi_eval_v3_correct_loss/"
DATA_PATH = './data/data'
OUTPUT_CSV_PATH = "./comprehensive_reaction_prediction_results.csv"

NUM_RETURN_SEQUENCES = 20  # Generate 15 sequences for comprehensive ranking
MAX_LENGTH_GENERATION = 256
NUM_BEAMS = NUM_RETURN_SEQUENCES + 5  # Adjusted for 15 sequences
TEMPERATURE = 1.0
TOP_K = 50
TOP_P = 0.95
DO_SAMPLE = False

# Scoring weights for comprehensive score
MODEL_SCORE_WEIGHT = 0.3    # 模型排名权重
VALIDITY_SCORE_WEIGHT = 0.1 # 化学有效性权重
BALANCE_SCORE_WEIGHT = 0.6  # 原子平衡权重

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

FINGERPRINT_SIMILARITY_THRESHOLD = 1

# --- Atom Balancing Functions ---
def get_atom_counts_from_smiles(smiles: str) -> Counter:
    """
    Extract atom counts from SMILES string using RDKit when available,
    with fallback to regex-based parsing for cases with radicals or abbreviations.
    """
    if not smiles or not isinstance(smiles, str):
        return Counter()

    smiles = smiles.strip()
    if not smiles:
        return Counter()

    # First try RDKit if available
    if RDKIT_AVAILABLE:
        try:
            # Try parsing with sanitization
            mol = Chem.MolFromSmiles(smiles, sanitize=True)
            if mol:
                atom_counts = Counter()
                for atom in mol.GetAtoms():
                    symbol = atom.GetSymbol()
                    atom_counts[symbol] += 1
                return atom_counts
        except Exception:
            pass

        try:
            # Try parsing without sanitization for radicals
            mol = Chem.MolFromSmiles(smiles, sanitize=False)
            if mol:
                atom_counts = Counter()
                for atom in mol.GetAtoms():
                    symbol = atom.GetSymbol()
                    atom_counts[symbol] += 1
                return atom_counts
        except Exception:
            pass

    # Fallback to regex-based parsing
    return _regex_atom_count_fallback(smiles)

def _regex_atom_count_fallback(smiles: str) -> Counter:
    """
    Fallback method to extract atom counts using regex patterns.
    Handles basic SMILES notation including some radical notations.
    """
    atom_counts = Counter()

    # Remove brackets for simplicity in counting, but keep track of charges and radicals
    # Pattern to match atoms (including those in brackets)
    # This pattern matches: [Element], [Element+], [Element-], [Element], Element
    atom_pattern = r'\[([A-Z][a-z]?)(?:[+-]?\d*|\.)*\]|([A-Z][a-z]?)'

    matches = re.findall(atom_pattern, smiles)

    for match in matches:
        # match is a tuple: (bracketed_atom, unbracketted_atom)
        atom = match[0] if match[0] else match[1]
        if atom:
            # Handle common abbreviations
            atom = _expand_atom_abbreviation(atom)
            atom_counts[atom] += 1

    # Handle explicit hydrogen counts in brackets like [CH3], [NH2], etc.
    bracket_pattern = r'\[([A-Z][a-z]?)H?(\d*)[+-]?\d*\.?\]'
    bracket_matches = re.findall(bracket_pattern, smiles)

    for atom, h_count in bracket_matches:
        if h_count:
            try:
                h_num = int(h_count) if h_count else 1
                atom_counts['H'] += h_num
            except ValueError:
                pass

    return atom_counts

def _expand_atom_abbreviation(atom: str) -> str:
    """
    Expand common atom abbreviations to standard element symbols.
    """
    abbreviations = {
        'Me': 'C',  # Methyl - simplified to carbon
        'Et': 'C',  # Ethyl - simplified to carbon
        'Ph': 'C',  # Phenyl - simplified to carbon
        'Ac': 'C',  # Acetyl - simplified to carbon
        'Bn': 'C',  # Benzyl - simplified to carbon
        'Bu': 'C',  # Butyl - simplified to carbon
        'Pr': 'C',  # Propyl - simplified to carbon
        'Tf': 'C',  # Trifluoromethyl - simplified (should include F but simplified here)
        'Ts': 'C',  # Tosyl - simplified to carbon
        'Boc': 'C', # tert-Butoxycarbonyl - simplified to carbon
        'Cbz': 'C', # Carboxybenzyl - simplified to carbon
    }
    return abbreviations.get(atom, atom)

def parse_multi_component_smiles(multi_smiles: str) -> list:
    """
    Parse multi-component SMILES (separated by '.') into individual components.
    """
    if not multi_smiles or not isinstance(multi_smiles, str):
        return []

    # Split by '.' to get individual components
    components = [comp.strip() for comp in multi_smiles.split('.') if comp.strip()]
    return components

def check_atom_balance(reactant_smiles: str, product_smiles: str) -> dict:
    """
    Check if the reaction is atom-balanced between reactants and products.

    Args:
        reactant_smiles: SMILES string of reactants (may be multi-component)
        product_smiles: SMILES string of products (may be multi-component)

    Returns:
        dict: {
            'is_balanced': bool,
            'reactant_atoms': Counter,
            'product_atoms': Counter,
            'missing_in_products': Counter,
            'extra_in_products': Counter,
            'balance_score': float  # 0.0 to 1.0, where 1.0 is perfectly balanced
        }
    """
    # Parse multi-component SMILES
    reactant_components = parse_multi_component_smiles(reactant_smiles)
    product_components = parse_multi_component_smiles(product_smiles)

    # Count atoms in reactants
    reactant_atoms = Counter()
    for component in reactant_components:
        component_atoms = get_atom_counts_from_smiles(component)
        reactant_atoms.update(component_atoms)

    # Count atoms in products
    product_atoms = Counter()
    for component in product_components:
        component_atoms = get_atom_counts_from_smiles(component)
        product_atoms.update(component_atoms)

    # Check balance
    missing_in_products = reactant_atoms - product_atoms
    extra_in_products = product_atoms - reactant_atoms

    # Remove zero counts
    missing_in_products = +missing_in_products  # This removes zero and negative counts
    extra_in_products = +extra_in_products

    is_balanced = len(missing_in_products) == 0 and len(extra_in_products) == 0

    # Calculate balance score
    total_reactant_atoms = sum(reactant_atoms.values())
    total_imbalance = sum(missing_in_products.values()) + sum(extra_in_products.values())

    if total_reactant_atoms == 0:
        balance_score = 0.0
    else:
        balance_score = max(0.0, 1.0 - (total_imbalance / total_reactant_atoms))

    return {
        'is_balanced': is_balanced,
        'reactant_atoms': reactant_atoms,
        'product_atoms': product_atoms,
        'missing_in_products': missing_in_products,
        'extra_in_products': extra_in_products,
        'balance_score': balance_score
    }

# --- Helper Functions ---
def canonicalize_smiles_rdkit(smiles: str, sanitize=True) -> str:
    if not RDKIT_AVAILABLE or not smiles or not isinstance(smiles, str):
        return str(smiles).strip() if isinstance(smiles, str) else ""
    try:
        mol = Chem.MolFromSmiles(smiles, sanitize=sanitize)
        if mol: return Chem.MolToSmiles(mol, canonical=True)
        return smiles.strip()
    except Exception:
        if sanitize:
            try:
                mol_no_sanitize = Chem.MolFromSmiles(smiles, sanitize=False)
                if mol_no_sanitize:
                    Chem.SanitizeMol(mol_no_sanitize, Chem.SanitizeFlags.SANITIZE_FINDRADICALS | Chem.SanitizeFlags.SANITIZE_SETAROMATICITY | Chem.SanitizeFlags.SANITIZE_SETHYBRIDIZATION | Chem.SanitizeFlags.SANITIZE_SYMMRINGS, catchErrors=True)
                    return Chem.MolToSmiles(mol_no_sanitize, canonical=True)
            except Exception: pass
        return smiles.strip()

def score_chemical_validity(smiles_to_score: str) -> float:
    """Scores the chemical validity of a SMILES string (preferably raw model output)."""
    if not RDKIT_AVAILABLE or not smiles_to_score or not isinstance(smiles_to_score, str):
        return 0.0
    try:
        # Try to parse with sanitization first for a stricter validity check
        mol = Chem.MolFromSmiles(smiles_to_score, sanitize=True)
        if mol: return 1.0
        # If strict sanitization fails, try less strict parsing
        mol_no_sanitize = Chem.MolFromSmiles(smiles_to_score, sanitize=False)
        if mol_no_sanitize: return 0.5 # Parsable but maybe not "fully" valid by strict rules
        return 0.1 # Not even parsable without sanitization
    except Exception:
        return 0.0

def robust_smiles_match(raw_predicted_smiles: str, standardized_true_products_set: set) -> bool:
    """
    Checks if the standardized form of raw_predicted_smiles matches any SMILES
    in standardized_true_products_set.
    Uses multi-component standardization for prediction and optional fingerprint fallback.
    """
    if not raw_predicted_smiles or not isinstance(raw_predicted_smiles, str):
        return False

    standardized_prediction = canonicalize_smiles_rdkit(raw_predicted_smiles)
    if not standardized_prediction: # If standardization results in empty string
        return False

    # 1. Direct match of standardized forms
    if standardized_prediction in standardized_true_products_set:
        return True

    # 2. Fingerprint similarity (if RDKit available and direct match failed)
    if RDKIT_AVAILABLE:
        try:
            pred_mol = Chem.MolFromSmiles(standardized_prediction)
            if not pred_mol: # If standardized isn't parsable, try raw
                pred_mol = Chem.MolFromSmiles(raw_predicted_smiles)

            if not pred_mol: # If neither form is parsable
                return False

            pred_fp = GetMorganFingerprintAsBitVect(pred_mol, 2, nBits=2048)

            for true_s_standardized in standardized_true_products_set:
                true_mol = Chem.MolFromSmiles(true_s_standardized)
                if not true_mol: continue

                true_fp = GetMorganFingerprintAsBitVect(true_mol, 2, nBits=2048)
                similarity = TanimotoSimilarity(pred_fp, true_fp)
                if similarity >= FINGERPRINT_SIMILARITY_THRESHOLD:
                    return True
        except Exception:
            pass
    return False

def group_true_products_for_eval_assuming_standardized_input(dataset: Dataset) -> dict:
    reactant_to_products_map = {}
    for example in tqdm(dataset, desc="Grouping pre-standardized true products", disable=True):
        reactant_std = str(example.get('reactant', '')).strip()
        product_std = str(example.get('product', '')).strip()
        if not reactant_std or not product_std: continue
        if reactant_std not in reactant_to_products_map:
            reactant_to_products_map[reactant_std] = set()
        reactant_to_products_map[reactant_std].add(product_std)
    return reactant_to_products_map

# --- Main Prediction Script ---
def predict_and_evaluate():
    print(f"Loading tokenizer from {ADAPTER_MODEL_PATH}...")
    tokenizer = AutoTokenizer.from_pretrained(ADAPTER_MODEL_PATH, use_fast=True)
    print(f"Loading base model '{BASE_MODEL_NAME}' and adapter from '{ADAPTER_MODEL_PATH}'...")
    try:
        quantization_config_bnb = BitsAndBytesConfig(
            load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16,
        )
        base_model = AutoModelForSeq2SeqLM.from_pretrained(
            BASE_MODEL_NAME, quantization_config=quantization_config_bnb, device_map={"":DEVICE}
        )
        base_model.resize_token_embeddings(len(tokenizer))
        model = PeftModel.from_pretrained(base_model, ADAPTER_MODEL_PATH)
        model.eval(); model.to(DEVICE)
        print("Model loaded successfully.")
    except Exception as e: print(f"Error loading model: {e}"); traceback.print_exc(); return

    model.config.bos_token_id = tokenizer.bos_token_id
    model.config.eos_token_id = tokenizer.eos_token_id
    model.config.pad_token_id = tokenizer.pad_token_id
    model.config.decoder_start_token_id = tokenizer.pad_token_id

    print(f"Loading and preparing test data from '{DATA_PATH}'...")
    try:
        raw_dataset = load_from_disk(DATA_PATH)
        test_data_raw_pairs = raw_dataset['test']
        true_products_lookup_map = group_true_products_for_eval_assuming_standardized_input(test_data_raw_pairs)
        print(f"Source test pairs: {len(test_data_raw_pairs)}. Unique (pre-standardized) reactants with true products: {len(true_products_lookup_map)}")
    except Exception as e: print(f"Error loading or processing data: {e}"); traceback.print_exc(); return

    results_list = []
    predicted_reactants_from_input = set()

    # For model-only accuracy tracking (only top1, top3, top5)
    model_only_results = {"top1": 0, "top3": 0, "top5": 0}

    for raw_example in tqdm(test_data_raw_pairs, desc="Predicting Reactions"):
        original_reactant_smiles = str(raw_example.get('reactant', '')).strip()
        if not original_reactant_smiles: continue

        if original_reactant_smiles in predicted_reactants_from_input:
            continue
        predicted_reactants_from_input.add(original_reactant_smiles)

        inputs = tokenizer(original_reactant_smiles, return_tensors="pt", max_length=MAX_LENGTH_GENERATION, truncation=True)
        input_ids = inputs.input_ids.to(DEVICE)
        attention_mask = inputs.attention_mask.to(DEVICE)

        true_standardized_products_set = true_products_lookup_map.get(original_reactant_smiles, set())

        try:
            with torch.no_grad():
                outputs = model.generate(
                    input_ids=input_ids, attention_mask=attention_mask,
                    max_length=MAX_LENGTH_GENERATION, num_return_sequences=NUM_RETURN_SEQUENCES,
                    num_beams=NUM_BEAMS, do_sample=DO_SAMPLE,
                    temperature=TEMPERATURE if DO_SAMPLE else 1.0,
                    top_k=TOP_K if DO_SAMPLE else None, top_p=TOP_P if DO_SAMPLE else None,
                    early_stopping=True, eos_token_id=model.config.eos_token_id,
                    pad_token_id=model.config.pad_token_id,
                    decoder_start_token_id=model.config.decoder_start_token_id,
                    return_dict_in_generate=True,
                    output_scores=True
                )

            generated_ids = outputs.sequences
            raw_predicted_smiles_batch = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

            # First, evaluate with original model ranking (for comparison)
            original_order_std_predictions = []
            seen_orig = set()
            for pred_smiles_raw in raw_predicted_smiles_batch:
                standardized_pred = canonicalize_smiles_rdkit(pred_smiles_raw)
                if standardized_pred and standardized_pred not in seen_orig:
                    original_order_std_predictions.append(standardized_pred)
                    seen_orig.add(standardized_pred)

            # Evaluate model-only accuracy (only top1, top3, top5)
            if true_standardized_products_set:
                if original_order_std_predictions and robust_smiles_match(original_order_std_predictions[0], true_standardized_products_set):
                    model_only_results["top1"] += 1
                if any(robust_smiles_match(p, true_standardized_products_set) for p in original_order_std_predictions[:3]):
                    model_only_results["top3"] += 1
                if any(robust_smiles_match(p, true_standardized_products_set) for p in original_order_std_predictions[:5]):
                    model_only_results["top5"] += 1

            # Now proceed with comprehensive scoring
            predictions_data = []

            for i, pred_smiles_raw in enumerate(raw_predicted_smiles_batch):
                # Model score: higher rank (lower index) gets higher score
                model_score = 1.0 - (i / NUM_RETURN_SEQUENCES)

                # Standardize prediction
                standardized_pred = canonicalize_smiles_rdkit(pred_smiles_raw)

                # Calculate validity score
                validity_score = score_chemical_validity(pred_smiles_raw)

                # Calculate atom balance score
                balance_score = 0.0
                if standardized_pred:
                    balance_result = check_atom_balance(original_reactant_smiles, standardized_pred)
                    balance_score = balance_result['balance_score']

                # Calculate comprehensive score
                comprehensive_score = (model_score * MODEL_SCORE_WEIGHT +
                                     validity_score * VALIDITY_SCORE_WEIGHT +
                                     balance_score * BALANCE_SCORE_WEIGHT)

                predictions_data.append({
                    'raw_smiles': pred_smiles_raw,
                    'standardized_smiles': standardized_pred,
                    'model_score': model_score,
                    'validity_score': validity_score,
                    'balance_score': balance_score,
                    'comprehensive_score': comprehensive_score
                })

            # Sort by comprehensive score (descending)
            predictions_data.sort(key=lambda x: x['comprehensive_score'], reverse=True)

            # Only keep top 5 results for CSV output
            top5_predictions = predictions_data[:5]

            # Prepare result structure (only top5 for CSV)
            current_result = {
                "input_reactant_smiles": original_reactant_smiles,
                "true_products_standardized_list_str": '; '.join(sorted(list(true_standardized_products_set))),
                "generated_smiles_raw_list": [pred['raw_smiles'] for pred in top5_predictions],
                "generated_smiles_standardized_list": [pred['standardized_smiles'] for pred in top5_predictions],
                "model_scores_list": [pred['model_score'] for pred in top5_predictions],
                "validity_scores_list": [pred['validity_score'] for pred in top5_predictions],
                "balance_scores_list": [pred['balance_score'] for pred in top5_predictions],
                "comprehensive_scores_list": [pred['comprehensive_score'] for pred in top5_predictions],
                "top1_match": False, "top3_match": False, "top5_match": False
            }

            # Evaluate top-N matches based on re-ranked results (only top1, top3, top5)
            unique_std_predictions = []
            seen_preds_for_topn = set()
            for pred in predictions_data:  # Use all predictions for evaluation, not just top5
                if pred['standardized_smiles'] and pred['standardized_smiles'] not in seen_preds_for_topn:
                    unique_std_predictions.append(pred['standardized_smiles'])
                    seen_preds_for_topn.add(pred['standardized_smiles'])

            if true_standardized_products_set:
                if unique_std_predictions and robust_smiles_match(unique_std_predictions[0], true_standardized_products_set):
                    current_result["top1_match"] = True
                if any(robust_smiles_match(p, true_standardized_products_set) for p in unique_std_predictions[:3]):
                    current_result["top3_match"] = True
                if any(robust_smiles_match(p, true_standardized_products_set) for p in unique_std_predictions[:5]):
                    current_result["top5_match"] = True

        except Exception as e_gen:
            print(f"Error during generation for reactant '{original_reactant_smiles}': {e_gen}")
            current_result = {
                "input_reactant_smiles": original_reactant_smiles,
                "true_products_standardized_list_str": '; '.join(sorted(list(true_standardized_products_set))),
                "generated_smiles_raw_list": ["ERROR_GENERATING"] * 5,  # Only 5 for CSV
                "generated_smiles_standardized_list": ["ERROR_GENERATING"] * 5,  # Only 5 for CSV
                "model_scores_list": [0.0] * 5,  # Only 5 for CSV
                "validity_scores_list": [0.0] * 5,  # Only 5 for CSV
                "balance_scores_list": [0.0] * 5,  # Only 5 for CSV
                "comprehensive_scores_list": [0.0] * 5,  # Only 5 for CSV
                "top1_match": False, "top3_match": False, "top5_match": False
            }

        results_list.append(current_result)
        del input_ids, attention_mask; gc.collect()
        if 'outputs' in locals(): del outputs
        if DEVICE.type == 'cuda': torch.cuda.empty_cache()

    num_total_unique_reactants_predicted = len(predicted_reactants_from_input)
    if num_total_unique_reactants_predicted == 0:
        print("No reactants were processed. Check data or logic."); return

    # Print model-only accuracy (original ranking) - only top1, top3, top5
    model_only_top1 = model_only_results["top1"] / num_total_unique_reactants_predicted * 100
    model_only_top3 = model_only_results["top3"] / num_total_unique_reactants_predicted * 100
    model_only_top5 = model_only_results["top5"] / num_total_unique_reactants_predicted * 100

    print("\n--- Model Score Only (Original Ranking) Metrics ---")
    print(f"Total Unique Input Reactants Processed: {num_total_unique_reactants_predicted}")
    print(f"Top-1 Accuracy: {model_only_top1:.2f}%")
    print(f"Top-3 Accuracy: {model_only_top3:.2f}%")
    print(f"Top-5 Accuracy: {model_only_top5:.2f}%")

    # Print comprehensive score accuracy (re-ranked) - only top1, top3, top5
    overall_top1 = sum(r["top1_match"] for r in results_list) / num_total_unique_reactants_predicted * 100
    overall_top3 = sum(r["top3_match"] for r in results_list) / num_total_unique_reactants_predicted * 100
    overall_top5 = sum(r["top5_match"] for r in results_list) / num_total_unique_reactants_predicted * 100

    print("\n--- Comprehensive Score (Re-ranked from Top-15) Metrics ---")
    print(f"Total Unique Input Reactants Processed: {num_total_unique_reactants_predicted}")
    print(f"Top-1 Accuracy: {overall_top1:.2f}%")
    print(f"Top-3 Accuracy: {overall_top3:.2f}%")
    print(f"Top-5 Accuracy: {overall_top5:.2f}%")

    print("\n--- Improvement from Re-ranking ---")
    print(f"Top-1 Improvement: {overall_top1 - model_only_top1:+.2f}%")
    print(f"Top-3 Improvement: {overall_top3 - model_only_top3:+.2f}%")
    print(f"Top-5 Improvement: {overall_top5 - model_only_top5:+.2f}%")

    df_results = pd.DataFrame(results_list)
    for col in ["generated_smiles_raw_list", "generated_smiles_standardized_list",
                "model_scores_list", "validity_scores_list", "balance_scores_list", "comprehensive_scores_list"]:
        df_results[col] = df_results[col].apply(lambda x: '; '.join(map(str, x)) if isinstance(x, list) else x)

    try:
        df_results.to_csv(OUTPUT_CSV_PATH, index=False)
        print(f"\nResults saved to {OUTPUT_CSV_PATH}")
    except Exception as e_csv: print(f"Error saving results to CSV: {e_csv}")

    # Cleanup
    del model, base_model, tokenizer
    gc.collect()
    if DEVICE.type == 'cuda': torch.cuda.empty_cache()

if __name__ == "__main__":
    predict_and_evaluate()

RDKit found. Advanced SMILES processing, similarity, and property calculation will be used.
Using device: cuda
Loading tokenizer from ./best_model_multi_eval_v3_correct_loss/...
Loading base model 'google/flan-t5-base' and adapter from './best_model_multi_eval_v3_correct_loss/'...
Model loaded successfully.
Loading and preparing test data from './data/data'...
Source test pairs: 508. Unique (pre-standardized) reactants with true products: 429


Predicting Reactions:   0%|          | 0/508 [00:00<?, ?it/s]


--- Model Score Only (Original Ranking) Metrics ---
Total Unique Input Reactants Processed: 429
Top-1 Accuracy: 72.03%
Top-3 Accuracy: 78.79%
Top-5 Accuracy: 79.95%

--- Comprehensive Score (Re-ranked from Top-15) Metrics ---
Total Unique Input Reactants Processed: 429
Top-1 Accuracy: 74.13%
Top-3 Accuracy: 80.89%
Top-5 Accuracy: 82.05%

--- Improvement from Re-ranking ---
Top-1 Improvement: +2.10%
Top-3 Improvement: +2.10%
Top-5 Improvement: +2.10%

Results saved to ./comprehensive_reaction_prediction_results.csv
