In [1]:
import os
import gc
import math
import torch
import pandas as pd
from tqdm.auto import tqdm
from datasets import load_from_disk, Dataset
from transformers import (
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    BitsAndBytesConfig
)
from peft import PeftModel
import numpy as np
import traceback
from collections import Counter
import re

# Attempt to import RDKit
try:
    from rdkit import Chem
    from rdkit import RDLogger
    from rdkit.DataStructs import TanimotoSimilarity
    from rdkit.Chem.rdMolDescriptors import GetMorganFingerprintAsBitVect # ECFP
    RDLogger.DisableLog('rdApp.*')
    RDKIT_AVAILABLE = True
    print("RDKit found. Advanced SMILES processing, similarity, and property calculation will be used.")
except ImportError:
    RDKIT_AVAILABLE = False
    print("RDKit not found. SMILES processing will be string-based, no fingerprint or advanced properties.")
    print("For better accuracy and features, please install RDKit: pip install rdkit-pypi")

# --- Configuration ---
BASE_MODEL_NAME = "google/flan-t5-base"
ADAPTER_MODEL_PATH = "./best_model_multi_eval_v3_correct_loss/"
DATA_PATH = './data/data'
OUTPUT_CSV_PATH_SINGLE_FORMATTED = "./single_reactant_predictions_comprehensive_scores.csv"

INPUT_REACTANT_SMILES = "C=CC(C)C.[OH]" # REPLACE THIS or set to None
LOAD_FROM_TEST_SET_INDEX = 0

NUM_RETURN_SEQUENCES = 5
MAX_LENGTH_GENERATION = 256
NUM_BEAMS = NUM_RETURN_SEQUENCES*2
TEMPERATURE = 1.0
TOP_K = 50
TOP_P = 0.95
DO_SAMPLE = False

# Scoring weights for comprehensive score
MODEL_SCORE_WEIGHT = 0.4    # 模型排名权重
VALIDITY_SCORE_WEIGHT = 0.3 # 化学有效性权重
BALANCE_SCORE_WEIGHT = 0.3  # 原子平衡权重

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

FINGERPRINT_SIMILARITY_THRESHOLD = 1

# --- Atom Balancing Functions ---
def get_atom_counts_from_smiles(smiles: str) -> Counter:
    """
    Extract atom counts from SMILES string using RDKit when available,
    with fallback to regex-based parsing for cases with radicals or abbreviations.
    """
    if not smiles or not isinstance(smiles, str):
        return Counter()

    smiles = smiles.strip()
    if not smiles:
        return Counter()

    # First try RDKit if available
    if RDKIT_AVAILABLE:
        try:
            # Try parsing with sanitization
            mol = Chem.MolFromSmiles(smiles, sanitize=True)
            if mol:
                atom_counts = Counter()
                for atom in mol.GetAtoms():
                    symbol = atom.GetSymbol()
                    atom_counts[symbol] += 1
                return atom_counts
        except Exception:
            pass

        try:
            # Try parsing without sanitization for radicals
            mol = Chem.MolFromSmiles(smiles, sanitize=False)
            if mol:
                atom_counts = Counter()
                for atom in mol.GetAtoms():
                    symbol = atom.GetSymbol()
                    atom_counts[symbol] += 1
                return atom_counts
        except Exception:
            pass

    # Fallback to regex-based parsing
    return _regex_atom_count_fallback(smiles)

def _regex_atom_count_fallback(smiles: str) -> Counter:
    """
    Fallback method to extract atom counts using regex patterns.
    Handles basic SMILES notation including some radical notations.
    """
    atom_counts = Counter()

    # Remove brackets for simplicity in counting, but keep track of charges and radicals
    # Pattern to match atoms (including those in brackets)
    # This pattern matches: [Element], [Element+], [Element-], [Element], Element
    atom_pattern = r'\[([A-Z][a-z]?)(?:[+-]?\d*|\.)*\]|([A-Z][a-z]?)'

    matches = re.findall(atom_pattern, smiles)

    for match in matches:
        # match is a tuple: (bracketed_atom, unbracketted_atom)
        atom = match[0] if match[0] else match[1]
        if atom:
            # Handle common abbreviations
            atom = _expand_atom_abbreviation(atom)
            atom_counts[atom] += 1

    # Handle explicit hydrogen counts in brackets like [CH3], [NH2], etc.
    bracket_pattern = r'\[([A-Z][a-z]?)H?(\d*)[+-]?\d*\.?\]'
    bracket_matches = re.findall(bracket_pattern, smiles)

    for atom, h_count in bracket_matches:
        if h_count:
            try:
                h_num = int(h_count) if h_count else 1
                atom_counts['H'] += h_num
            except ValueError:
                pass

    return atom_counts

def _expand_atom_abbreviation(atom: str) -> str:
    """
    Expand common atom abbreviations to standard element symbols.
    """
    abbreviations = {
        'Me': 'C',  # Methyl - simplified to carbon
        'Et': 'C',  # Ethyl - simplified to carbon
        'Ph': 'C',  # Phenyl - simplified to carbon
        'Ac': 'C',  # Acetyl - simplified to carbon
        'Bn': 'C',  # Benzyl - simplified to carbon
        'Bu': 'C',  # Butyl - simplified to carbon
        'Pr': 'C',  # Propyl - simplified to carbon
        'Tf': 'C',  # Trifluoromethyl - simplified (should include F but simplified here)
        'Ts': 'C',  # Tosyl - simplified to carbon
        'Boc': 'C', # tert-Butoxycarbonyl - simplified to carbon
        'Cbz': 'C', # Carboxybenzyl - simplified to carbon
    }
    return abbreviations.get(atom, atom)

def parse_multi_component_smiles(multi_smiles: str) -> list:
    """
    Parse multi-component SMILES (separated by '.') into individual components.
    """
    if not multi_smiles or not isinstance(multi_smiles, str):
        return []

    # Split by '.' to get individual components
    components = [comp.strip() for comp in multi_smiles.split('.') if comp.strip()]
    return components

def check_atom_balance(reactant_smiles: str, product_smiles: str) -> dict:
    """
    Check if the reaction is atom-balanced between reactants and products.

    Args:
        reactant_smiles: SMILES string of reactants (may be multi-component)
        product_smiles: SMILES string of products (may be multi-component)

    Returns:
        dict: {
            'is_balanced': bool,
            'reactant_atoms': Counter,
            'product_atoms': Counter,
            'missing_in_products': Counter,
            'extra_in_products': Counter,
            'balance_score': float  # 0.0 to 1.0, where 1.0 is perfectly balanced
        }
    """
    # Parse multi-component SMILES
    reactant_components = parse_multi_component_smiles(reactant_smiles)
    product_components = parse_multi_component_smiles(product_smiles)

    # Count atoms in reactants
    reactant_atoms = Counter()
    for component in reactant_components:
        component_atoms = get_atom_counts_from_smiles(component)
        reactant_atoms.update(component_atoms)

    # Count atoms in products
    product_atoms = Counter()
    for component in product_components:
        component_atoms = get_atom_counts_from_smiles(component)
        product_atoms.update(component_atoms)

    # Check balance
    missing_in_products = reactant_atoms - product_atoms
    extra_in_products = product_atoms - reactant_atoms

    # Remove zero counts
    missing_in_products = +missing_in_products  # This removes zero and negative counts
    extra_in_products = +extra_in_products

    is_balanced = len(missing_in_products) == 0 and len(extra_in_products) == 0

    # Calculate balance score
    total_reactant_atoms = sum(reactant_atoms.values())
    total_imbalance = sum(missing_in_products.values()) + sum(extra_in_products.values())

    if total_reactant_atoms == 0:
        balance_score = 0.0
    else:
        balance_score = max(0.0, 1.0 - (total_imbalance / total_reactant_atoms))

    return {
        'is_balanced': is_balanced,
        'reactant_atoms': reactant_atoms,
        'product_atoms': product_atoms,
        'missing_in_products': missing_in_products,
        'extra_in_products': extra_in_products,
        'balance_score': balance_score
    }

# --- Helper Functions ---
def canonicalize_smiles_rdkit(smiles: str, sanitize=True) -> str:
    if not RDKIT_AVAILABLE or not smiles or not isinstance(smiles, str):
        return str(smiles).strip() if isinstance(smiles, str) else ""
    try:
        mol = Chem.MolFromSmiles(smiles, sanitize=sanitize)
        if mol: return Chem.MolToSmiles(mol, canonical=True)
        return smiles.strip()
    except Exception:
        if sanitize:
            try:
                mol_no_sanitize = Chem.MolFromSmiles(smiles, sanitize=False)
                if mol_no_sanitize:
                    Chem.SanitizeMol(mol_no_sanitize, Chem.SanitizeFlags.SANITIZE_FINDRADICALS | Chem.SanitizeFlags.SANITIZE_SETAROMATICITY | Chem.SanitizeFlags.SANITIZE_SETHYBRIDIZATION | Chem.SanitizeFlags.SANITIZE_SYMMRINGS, catchErrors=True)
                    return Chem.MolToSmiles(mol_no_sanitize, canonical=True)
            except Exception: pass
        return smiles.strip()

def standardize_generated_smiles(smiles_str: str) -> str:
    if not RDKIT_AVAILABLE: return str(smiles_str).strip() if isinstance(smiles_str, str) else ""
    if not smiles_str or not isinstance(smiles_str, str): return ""
    if '.' not in smiles_str: return canonicalize_smiles_rdkit(smiles_str)
    components = smiles_str.split('.')
    canonical_components = sorted([canonicalize_smiles_rdkit(c) for c in components if c.strip()])
    return '.'.join(filter(None, canonical_components))

def score_chemical_validity(smiles_to_score: str) -> float:
    """Scores the chemical validity of a SMILES string (preferably raw model output)."""
    if not RDKIT_AVAILABLE or not smiles_to_score or not isinstance(smiles_to_score, str):
        return 0.0
    try:
        # Try to parse with sanitization first for a stricter validity check
        mol = Chem.MolFromSmiles(smiles_to_score, sanitize=True)
        if mol: return 1.0
        # If strict sanitization fails, try less strict parsing
        mol_no_sanitize = Chem.MolFromSmiles(smiles_to_score, sanitize=False)
        if mol_no_sanitize: return 0.5 # Parsable but maybe not "fully" valid by strict rules
        return 0.1 # Not even parsable without sanitization
    except Exception:
        return 0.0

def robust_smiles_match(raw_predicted_smiles: str, standardized_true_products_set: set) -> bool:
    """
    Checks if the standardized form of raw_predicted_smiles matches any SMILES
    in standardized_true_products_set.
    Uses multi-component standardization for prediction and optional fingerprint fallback.
    """
    if not raw_predicted_smiles or not isinstance(raw_predicted_smiles, str):
        return False

    standardized_prediction = canonicalize_smiles_rdkit(raw_predicted_smiles)
    if not standardized_prediction: # If standardization results in empty string
        return False

    # 1. Direct match of standardized forms
    if standardized_prediction in standardized_true_products_set:
        return True

    # 2. Fingerprint similarity (if RDKit available and direct match failed)
    if RDKIT_AVAILABLE:
        try:
            pred_mol = Chem.MolFromSmiles(standardized_prediction)
            if not pred_mol: # If standardized isn't parsable, try raw
                pred_mol = Chem.MolFromSmiles(raw_predicted_smiles)

            if not pred_mol: # If neither form is parsable
                return False

            pred_fp = GetMorganFingerprintAsBitVect(pred_mol, 2, nBits=2048)

            for true_s_standardized in standardized_true_products_set:
                true_mol = Chem.MolFromSmiles(true_s_standardized)
                if not true_mol: continue

                true_fp = GetMorganFingerprintAsBitVect(true_mol, 2, nBits=2048)
                similarity = TanimotoSimilarity(pred_fp, true_fp)
                if similarity >= FINGERPRINT_SIMILARITY_THRESHOLD:
                    return True
        except Exception:
            pass
    return False

def group_true_products_for_eval_assuming_standardized_input(dataset: Dataset) -> dict:
    reactant_to_products_map = {}
    for example in tqdm(dataset, desc="Grouping pre-standardized true products", disable=True):
        reactant_std = str(example.get('reactant', '')).strip()
        product_std = str(example.get('product', '')).strip()
        if not reactant_std or not product_std: continue
        if reactant_std not in reactant_to_products_map:
            reactant_to_products_map[reactant_std] = set()
        reactant_to_products_map[reactant_std].add(product_std)
    return reactant_to_products_map

# --- Main Prediction Script for Single Reactant ---
def predict_for_single_reactant_formatted():
    global INPUT_REACTANT_SMILES

    # 1. Load Tokenizer and Model
    print(f"Loading tokenizer from {ADAPTER_MODEL_PATH}...")
    try:
        tokenizer = AutoTokenizer.from_pretrained(ADAPTER_MODEL_PATH, use_fast=True)
        if tokenizer.pad_token is None: tokenizer.add_special_tokens({'pad_token': '[PAD]'})
        if tokenizer.eos_token is None: tokenizer.add_special_tokens({'eos_token': '</s>'})
        if tokenizer.bos_token is None: tokenizer.add_special_tokens({'bos_token': '<s>'})

    except Exception as e: print(f"Error loading tokenizer: {e}. Exiting."); return

    print(f"Loading base model '{BASE_MODEL_NAME}' and adapter from '{ADAPTER_MODEL_PATH}'...")
    try:
        quantization_config_bnb = BitsAndBytesConfig(
            load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16,
        )
        base_model = AutoModelForSeq2SeqLM.from_pretrained(
            BASE_MODEL_NAME, quantization_config=quantization_config_bnb, device_map={"":DEVICE}
        )
        base_model.resize_token_embeddings(len(tokenizer))
        model = PeftModel.from_pretrained(base_model, ADAPTER_MODEL_PATH)
        model.eval(); model.to(DEVICE)
        print("Model loaded successfully.")
    except Exception as e: print(f"Error loading model: {e}"); traceback.print_exc(); return

    model.config.pad_token_id = tokenizer.pad_token_id
    model.config.eos_token_id = tokenizer.eos_token_id
    model.config.decoder_start_token_id = tokenizer.pad_token_id

    # 2. Prepare Input Reactant and True Products (if available)
    true_products_lookup_map = {}
    if not INPUT_REACTANT_SMILES:
        print(f"INPUT_REACTANT_SMILES not set. Loading from test_data_raw_pairs[{LOAD_FROM_TEST_SET_INDEX}]...")
        try:
            raw_dataset = load_from_disk(DATA_PATH)
            test_data_raw_pairs = raw_dataset['test']
            if LOAD_FROM_TEST_SET_INDEX < len(test_data_raw_pairs):
                INPUT_REACTANT_SMILES = str(test_data_raw_pairs[LOAD_FROM_TEST_SET_INDEX].get('reactant', '')).strip()
                print(f"Using reactant from test set: {INPUT_REACTANT_SMILES}")
                true_products_lookup_map = group_true_products_for_eval_assuming_standardized_input(test_data_raw_pairs)
            else:
                print(f"Error: Index {LOAD_FROM_TEST_SET_INDEX} out of bounds."); return
        except Exception as e: print(f"Error loading test data: {e}"); traceback.print_exc(); return
    elif os.path.exists(DATA_PATH):
         try:
            raw_dataset = load_from_disk(DATA_PATH)
            test_data_raw_pairs = raw_dataset['test']
            true_products_lookup_map = group_true_products_for_eval_assuming_standardized_input(test_data_raw_pairs)
            print("Loaded true products map for potential matching.")
         except Exception as e:
            print(f"Could not load true products map from {DATA_PATH}: {e}.")

    if not INPUT_REACTANT_SMILES: print("No input reactant SMILES. Exiting."); return

    print(f"\n--- Predicting for Reactant: {INPUT_REACTANT_SMILES} ---")
    reactant_for_lookup = INPUT_REACTANT_SMILES
    true_standardized_products_set = true_products_lookup_map.get(reactant_for_lookup, set())
    all_true_products_str = '; '.join(sorted(list(true_standardized_products_set))) if true_standardized_products_set else "N/A (not in test set or no products defined)"

    # 3. Generate Predictions
    inputs = tokenizer(INPUT_REACTANT_SMILES, return_tensors="pt", max_length=MAX_LENGTH_GENERATION, truncation=True)
    input_ids = inputs.input_ids.to(DEVICE)
    attention_mask = inputs.attention_mask.to(DEVICE)

    predictions_data = []
    original_ranking_matches = {"top1": False, "top3": False, "top5": False, "top10": False}
    reranked_matches = {"top1": False, "top3": False, "top5": False, "top10": False}

    try:
        with torch.no_grad():
            outputs = model.generate(
                input_ids=input_ids, attention_mask=attention_mask,
                max_length=MAX_LENGTH_GENERATION, num_return_sequences=NUM_RETURN_SEQUENCES,
                num_beams=NUM_BEAMS, do_sample=DO_SAMPLE,
                temperature=TEMPERATURE if DO_SAMPLE else 1.0,
                top_k=TOP_K if DO_SAMPLE else None, top_p=TOP_P if DO_SAMPLE else None,
                early_stopping=True, eos_token_id=model.config.eos_token_id,
                pad_token_id=model.config.pad_token_id,
                decoder_start_token_id=model.config.decoder_start_token_id,
                output_scores=True, return_dict_in_generate=True
            )

        raw_predicted_smiles_batch = tokenizer.batch_decode(outputs.sequences, skip_special_tokens=True)

        print("\n--- Original Model Ranking ---")

        # First evaluate with original model ranking
        original_order_predictions = []
        seen_original = set()
        for i, pred_smiles_raw in enumerate(raw_predicted_smiles_batch):
            standardized_pred = standardize_generated_smiles(pred_smiles_raw)
            if standardized_pred and standardized_pred not in seen_original:
                original_order_predictions.append(standardized_pred)
                seen_original.add(standardized_pred)
                print(f"  Rank {i+1}: {standardized_pred}")

        # Check original ranking matches
        if true_standardized_products_set:
            if original_order_predictions and robust_smiles_match(original_order_predictions[0], true_standardized_products_set):
                original_ranking_matches["top1"] = True
            if any(robust_smiles_match(p, true_standardized_products_set) for p in original_order_predictions[:3]):
                original_ranking_matches["top3"] = True
            if any(robust_smiles_match(p, true_standardized_products_set) for p in original_order_predictions[:5]):
                original_ranking_matches["top5"] = True
            if any(robust_smiles_match(p, true_standardized_products_set) for p in original_order_predictions[:10]):
                original_ranking_matches["top10"] = True

        # Now calculate comprehensive scores
        for i, pred_smiles_raw in enumerate(raw_predicted_smiles_batch):
            pred_smiles_standardized = standardize_generated_smiles(pred_smiles_raw)

            # Model score: higher rank (lower index) gets higher score
            model_score = 1.0 - (i / NUM_RETURN_SEQUENCES)

            # Validity score
            validity_score = score_chemical_validity(pred_smiles_raw)

            # Balance score
            balance_score = 0.0
            if pred_smiles_standardized:
                balance_result = check_atom_balance(INPUT_REACTANT_SMILES, pred_smiles_standardized)
                balance_score = balance_result['balance_score']

            # Comprehensive score
            comprehensive_score = (model_score * MODEL_SCORE_WEIGHT +
                                 validity_score * VALIDITY_SCORE_WEIGHT +
                                 balance_score * BALANCE_SCORE_WEIGHT)

            matches_ground_truth = False
            matching_true_product_for_row = "N/A"
            max_tanimoto_to_true = 0.0

            if true_standardized_products_set:
                if pred_smiles_standardized in true_standardized_products_set:
                    matches_ground_truth = True
                    matching_true_product_for_row = pred_smiles_standardized

                if RDKIT_AVAILABLE and pred_smiles_standardized:
                    try:
                        pred_mol = Chem.MolFromSmiles(pred_smiles_standardized)
                        if pred_mol:
                            pred_fp = GetMorganFingerprintAsBitVect(pred_mol, 2, nBits=2048)
                            for true_std_smi in true_standardized_products_set:
                                true_mol = Chem.MolFromSmiles(true_std_smi)
                                if true_mol:
                                    true_fp = GetMorganFingerprintAsBitVect(true_mol, 2, nBits=2048)
                                    similarity = TanimotoSimilarity(pred_fp, true_fp)
                                    max_tanimoto_to_true = max(max_tanimoto_to_true, similarity)
                    except Exception:
                        max_tanimoto_to_true = -1.0

            predictions_data.append({
                "reactant": INPUT_REACTANT_SMILES,
                "predicted_product_raw": pred_smiles_raw,
                "predicted_product": pred_smiles_standardized,
                "model_score": model_score,
                "validity_score": validity_score,
                "balance_score": balance_score,
                "comprehensive_score": comprehensive_score,
                "max_tanimoto_to_true": max_tanimoto_to_true if max_tanimoto_to_true >= 0 else "N/A",
                "original_rank": i + 1,
                "matches_ground_truth": matches_ground_truth,
                "matching_true_product": matching_true_product_for_row,
                "all_true_products_for_reactant": all_true_products_str
            })

        # Sort by comprehensive score
        predictions_data.sort(key=lambda x: x["comprehensive_score"], reverse=True)

        print("\n--- Comprehensive Score Reranking ---")

        # Evaluate reranked results
        reranked_predictions = []
        seen_reranked = set()
        for rerank_idx, item in enumerate(predictions_data):
            item["reranked_position"] = rerank_idx + 1

            if item['predicted_product'] and item['predicted_product'] not in seen_reranked:
                reranked_predictions.append(item['predicted_product'])
                seen_reranked.add(item['predicted_product'])

            print(f"  Reranked {item['reranked_position']} (Orig Rank {item['original_rank']}):")
            print(f"    Prediction: {item['predicted_product']}")
            print(f"    Model Score: {item['model_score']:.3f}, Validity: {item['validity_score']:.3f}, Balance: {item['balance_score']:.3f}")
            print(f"    Comprehensive Score: {item['comprehensive_score']:.4f}")
            print(f"    Matches GT: {item['matches_ground_truth']}")
            if item['max_tanimoto_to_true'] != "N/A":
                print(f"    Max Tanimoto: {item['max_tanimoto_to_true']:.3f}")

        # Check reranked matches
        if true_standardized_products_set:
            if reranked_predictions and robust_smiles_match(reranked_predictions[0], true_standardized_products_set):
                reranked_matches["top1"] = True
            if any(robust_smiles_match(p, true_standardized_products_set) for p in reranked_predictions[:3]):
                reranked_matches["top3"] = True
            if any(robust_smiles_match(p, true_standardized_products_set) for p in reranked_predictions[:5]):
                reranked_matches["top5"] = True
            if any(robust_smiles_match(p, true_standardized_products_set) for p in reranked_predictions[:10]):
                reranked_matches["top10"] = True

    except Exception as e_gen:
        print(f"Error during generation: {e_gen}"); traceback.print_exc()
        # Fill with error rows
        for rank_idx in range(NUM_RETURN_SEQUENCES):
             predictions_data.append({
                "reactant": INPUT_REACTANT_SMILES,
                "predicted_product_raw": "ERROR",
                "predicted_product": "ERROR",
                "model_score": 0.0,
                "validity_score": 0.0,
                "balance_score": 0.0,
                "comprehensive_score": 0.0,
                "max_tanimoto_to_true": "N/A",
                "original_rank": rank_idx + 1,
                "reranked_position": rank_idx + 1,
                "matches_ground_truth": False,
                "matching_true_product": "N/A",
                "all_true_products_for_reactant": all_true_products_str
            })

    # Print comparison results
    if true_standardized_products_set:
        print(f"\n--- Accuracy Comparison ---")
        print(f"True Products: {all_true_products_str}")
        print(f"\nOriginal Model Ranking:")
        print(f"  Top-1: {'✓' if original_ranking_matches['top1'] else '✗'}")
        print(f"  Top-3: {'✓' if original_ranking_matches['top3'] else '✗'}")
        print(f"  Top-5: {'✓' if original_ranking_matches['top5'] else '✗'}")
        print(f"  Top-10: {'✓' if original_ranking_matches['top10'] else '✗'}")

        print(f"\nComprehensive Score Reranking:")
        print(f"  Top-1: {'✓' if reranked_matches['top1'] else '✗'}")
        print(f"  Top-3: {'✓' if reranked_matches['top3'] else '✗'}")
        print(f"  Top-5: {'✓' if reranked_matches['top5'] else '✗'}")
        print(f"  Top-10: {'✓' if reranked_matches['top10'] else '✗'}")

        print(f"\nImprovement:")
        for k in ["top1", "top3", "top5", "top10"]:
            improvement = "+" if reranked_matches[k] and not original_ranking_matches[k] else "-" if original_ranking_matches[k] and not reranked_matches[k] else "="
            print(f"  {k.upper()}: {improvement}")
    else:
        print(f"\n--- No Ground Truth Available for Comparison ---")

    # 5. Save to CSV
    if predictions_data:
        column_order = [
            "reactant", "predicted_product",
            "model_score", "validity_score", "balance_score", "comprehensive_score",
            "max_tanimoto_to_true", "original_rank", "reranked_position",
            "matches_ground_truth", "predicted_product_raw", "matching_true_product",
            "all_true_products_for_reactant"
        ]
        df_results = pd.DataFrame(predictions_data)

        # Ensure all desired columns are present
        for col in column_order:
            if col not in df_results.columns:
                df_results[col] = "N/A"

        df_results = df_results[column_order]

        try:
            df_results.to_csv(OUTPUT_CSV_PATH_SINGLE_FORMATTED, index=False, float_format='%.4f')
            print(f"\nDetailed predictions saved to {OUTPUT_CSV_PATH_SINGLE_FORMATTED}")
        except Exception as e_csv:
            print(f"Error saving to CSV: {e_csv}")
    else:
        print("No results generated.")

    del model, base_model, tokenizer, inputs, input_ids, attention_mask
    if 'outputs' in locals(): del outputs
    gc.collect()
    if DEVICE.type == 'cuda': torch.cuda.empty_cache()

if __name__ == "__main__":
    predict_for_single_reactant_formatted()

RDKit found. Advanced SMILES processing, similarity, and property calculation will be used.
Using device: cuda
Loading tokenizer from ./best_model_multi_eval_v3_correct_loss/...
Loading base model 'google/flan-t5-base' and adapter from './best_model_multi_eval_v3_correct_loss/'...
Model loaded successfully.
Loaded true products map for potential matching.

--- Predicting for Reactant: C=CC(C)C.[OH] ---

--- Original Model Ranking ---
  Rank 1: [CH2]C(O)C(C)C
  Rank 2: CC(C)[CH]CO
  Rank 3: C=CC(C)C.O
  Rank 4: )(C)C.O
  Rank 5: CC[CH]C(C)C.O

--- Comprehensive Score Reranking ---
  Reranked 1 (Orig Rank 1):
    Prediction: [CH2]C(O)C(C)C
    Model Score: 1.000, Validity: 1.000, Balance: 1.000
    Comprehensive Score: 1.0000
    Matches GT: True
    Max Tanimoto: 1.000
  Reranked 2 (Orig Rank 2):
    Prediction: CC(C)[CH]CO
    Model Score: 0.800, Validity: 1.000, Balance: 1.000
    Comprehensive Score: 0.9200
    Matches GT: True
    Max Tanimoto: 1.000
  Reranked 3 (Orig Rank 3):
    