In [3]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from rdkit import Chem, RDLogger
from rdkit.Chem.Draw import rdMolDraw2D
from captum.attr import IntegratedGradients
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from datasets import load_from_disk
from tqdm import tqdm
import traceback
import os
from typing import List, Dict, Tuple, Set, Optional
import warnings
from PIL import Image, ImageDraw, ImageFont
import io
import gc
from collections import defaultdict
import re
warnings.filterwarnings('ignore')
RDLogger.DisableLog('rdApp.*')

# 检查版本兼容性
try:
    from transformers.cache_utils import EncoderDecoderCache
    TRANSFORMERS_NEW_CACHE = True
except ImportError:
    TRANSFORMERS_NEW_CACHE = False

try:
    from rdkit import Chem
    RDKIT_AVAILABLE = True
except ImportError:
    RDKIT_AVAILABLE = False
    print("Warning: RDKit not available.")

# 统一配色方案
COLORS = {
    'primary': '#2E86AB', 'light_blue': '#A8DADC', 'dark_blue': '#1D3557',
    'background': '#F1FAEE', 'success': '#00B04F', 'error': '#EB5757',
    'correct_highlight': '#FF6B35', 'neutral': '#828282'
}

# 设置绘图风格
plt.rcParams.update({
    'font.family': 'Times New Roman', 'font.size': 16, 'axes.linewidth': 1.5,
    'figure.dpi': 600, 'savefig.dpi': 600, 'axes.grid': False,
    'axes.titlesize': 18, 'axes.labelsize': 16
})

def canonicalize_smiles_rdkit(smiles: str, sanitize=True) -> str:
    if not RDKIT_AVAILABLE or not smiles:
        return str(smiles).strip() if smiles else ""
    try:
        mol = Chem.MolFromSmiles(smiles, sanitize=sanitize)
        if mol:
            result = Chem.MolToSmiles(mol, canonical=True)
            del mol
            return result
        return smiles.strip()
    except Exception:
        return smiles.strip()

def group_data_mappings(dataset) -> Tuple[dict, dict]:
    """同时创建产品和反应类型映射"""
    products_map = {}
    types_map = {}
    for example in tqdm(dataset, desc="Grouping data"):
        reactant = str(example.get('reactant', '')).strip()
        if reactant:
            product = str(example.get('product', '')).strip()
            if product:
                if reactant not in products_map:
                    products_map[reactant] = set()
                products_map[reactant].add(product)
            types_map[reactant] = str(example.get('type', 'Unknown')).strip()
    return products_map, types_map

def clear_memory():
    """统一的内存清理函数"""
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()
    gc.collect()

def get_font(size: int):
    """简化的字体获取"""
    paths = ["/System/Library/Fonts/Times.ttc", "C:/Windows/Fonts/times.ttf",
             "/usr/share/fonts/truetype/liberation/LiberationSerif-Regular.ttf"]
    for path in paths:
        try:
            return ImageFont.truetype(path, size)
        except:
            continue
    return ImageFont.load_default()

def format_smiles_multiline(smiles: str, max_chars: int = 70) -> str:
    """格式化SMILES为多行"""
    if len(smiles) <= max_chars:
        return smiles
    first = smiles[:max_chars]
    second = smiles[max_chars:]
    if len(second) > max_chars:
        second = second[:max_chars-3] + "..."
    return f"{first}\n{second}"

class ChemicalReactionExplainer:
    def __init__(self, model, tokenizer, device):
        self.model = model
        self.tokenizer = tokenizer
        self.device = device
        self.model.eval()

        # 设置模型配置
        for attr in ['bos_token_id', 'eos_token_id', 'pad_token_id']:
            setattr(self.model.config, attr, getattr(tokenizer, attr))
        self.model.config.decoder_start_token_id = tokenizer.pad_token_id

        self.output_dir = "./explainability_results"
        os.makedirs(self.output_dir, exist_ok=True)

        # 预计算特殊token（移除"."，因为它需要在global attribution中显示）
        self.special_tokens = {
            tokenizer.pad_token, tokenizer.eos_token, tokenizer.bos_token,
            tokenizer.unk_token, '<pad>', '</s>', '<s>', '<unk>'
        }

    def sanitize_smiles(self, smiles: str) -> Tuple[Optional[str], bool]:
        if not smiles or smiles.strip() == "":
            return None, False
        try:
            canonical = canonicalize_smiles_rdkit(smiles.strip())
            if canonical:
                mol = Chem.MolFromSmiles(canonical)
                if mol is not None:
                    del mol
                    return canonical, True
            return None, False
        except Exception:
            return None, False

    def beam_search_predict(self, reactant_smiles: str, num_beams: int = 3, max_length: int = 256) -> List[Dict]:
        inputs = self.tokenizer(reactant_smiles, return_tensors="pt", padding=True, truncation=True, max_length=max_length)
        inputs = {k: v.to(self.device) for k, v in inputs.items()}

        generate_kwargs = {
            'num_beams': 30, 'num_return_sequences': num_beams, 'max_length': max_length,
            'do_sample': False, 'return_dict_in_generate': True, 'output_scores': True,
            'early_stopping': True, 'use_cache': False,
            'eos_token_id': self.model.config.eos_token_id,
            'pad_token_id': self.model.config.pad_token_id,
            'decoder_start_token_id': self.model.config.decoder_start_token_id
        }

        if not TRANSFORMERS_NEW_CACHE:
            generate_kwargs['past_key_values'] = None

        with torch.no_grad():
            outputs = self.model.generate(**inputs, **generate_kwargs)

        predictions = []
        for i in range(min(num_beams, len(outputs.sequences))):
            pred_text = self.tokenizer.decode(outputs.sequences[i], skip_special_tokens=True)
            sanitized, is_valid = self.sanitize_smiles(pred_text)
            predictions.append({
                'raw_smiles': pred_text, 'sanitized_smiles': sanitized,
                'is_valid': is_valid, 'rank': i + 1
            })

        del outputs, inputs
        clear_memory()
        return predictions

    def filter_tokens(self, tokens: List[str]) -> List[str]:
        """过滤特殊token，保留"."用于global attribution显示"""
        return [token for token in tokens if token not in self.special_tokens and token is not None]

    def parse_token_to_mol(self, token: str):
        """解析token为分子对象"""
        if not token or token in self.special_tokens or token == '.':
            return None

        # 尝试多种解析方法
        parse_attempts = [
            token,  # 直接解析
            re.sub(r'[0-9()=#+\-@/\\]', '', token),  # 移除特殊字符
            f'[{token}]' if len(token) <= 3 and token.isalpha() else None  # 单原子
        ]

        for attempt in parse_attempts:
            if attempt:
                try:
                    mol = Chem.MolFromSmiles(attempt)
                    if mol and mol.GetNumAtoms() > 0:
                        return mol
                except:
                    continue
        return None

    def map_tokens_to_atoms(self, smiles: str, tokens: List[str], attributions: np.ndarray) -> Dict[int, float]:
        """映射token到原子"""
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            return {}

        atom_contributions = {}
        assigned_atoms = set()

        # 过滤并配对token和attribution，但排除"."token
        filtered_items = [(token, attr) for token, attr in zip(tokens, attributions)
                         if token not in self.special_tokens and token != '.']

        print(f"\n{'='*50}")
        print(f"TOKEN TO ATOM MAPPING")
        print(f"{'='*50}")
        print(f"SMILES: {smiles}")
        print(f"Number of atoms in molecule: {mol.GetNumAtoms()}")
        print(f"Atom symbols: {[mol.GetAtoms()[i].GetSymbol() for i in range(mol.GetNumAtoms())]}")
        print(f"\nFiltered tokens and their attribution scores:")
        for i, (token, attr) in enumerate(filtered_items):
            print(f"  Token[{i}]: '{token}' -> Attribution: {attr:.6f}")

        # 第一轮：子结构匹配
        print(f"\n--- ROUND 1: Substructure Matching ---")
        successful_matches = set()
        for i, (token, attribution) in enumerate(filtered_items):
            token_mol = self.parse_token_to_mol(token)
            if token_mol is not None:
                try:
                    matches = mol.GetSubstructMatches(token_mol)
                    print(f"Token '{token}' (attr: {attribution:.6f}) -> Found {len(matches)} potential matches: {matches}")

                    for match in matches:
                        if all(atom_idx not in assigned_atoms for atom_idx in match):
                            for atom_idx in match:
                                atom_contributions[atom_idx] = attribution
                                assigned_atoms.add(atom_idx)
                            successful_matches.add(i)
                            atom_symbols = [mol.GetAtoms()[idx].GetSymbol() for idx in match]
                            print(f"  ✓ MATCHED: Token '{token}' -> Atoms {match} (symbols: {atom_symbols}) with attribution {attribution:.6f}")
                            break
                    else:
                        print(f"  ✗ No unassigned atoms available for token '{token}'")
                    del token_mol
                except Exception as e:
                    print(f"  ✗ Error matching '{token}': {e}")
                    if token_mol:
                        del token_mol
            else:
                print(f"Token '{token}' could not be parsed to molecule")

        # 第二轮：顺序分配未匹配的token
        unmatched_items = [filtered_items[i] for i in range(len(filtered_items)) if i not in successful_matches]

        if unmatched_items:
            print(f"\n--- ROUND 2: Sequential Assignment ---")
            print(f"Unmatched tokens: {[item[0] for item in unmatched_items]}")

            next_atom = 0
            for token, attribution in unmatched_items:
                # 跳到下一个未分配的原子
                while next_atom in assigned_atoms and next_atom < mol.GetNumAtoms():
                    next_atom += 1

                token_mol = self.parse_token_to_mol(token)
                atom_count = token_mol.GetNumAtoms() if token_mol else 1

                assigned_for_token = []
                for _ in range(atom_count):
                    if next_atom < mol.GetNumAtoms():
                        atom_contributions[next_atom] = attribution
                        assigned_atoms.add(next_atom)
                        assigned_for_token.append(next_atom)
                        next_atom += 1
                        while next_atom in assigned_atoms and next_atom < mol.GetNumAtoms():
                            next_atom += 1

                if assigned_for_token:
                    atom_symbols = [mol.GetAtoms()[idx].GetSymbol() for idx in assigned_for_token]
                    print(f"  ✓ SEQUENTIAL: Token '{token}' -> Atoms {assigned_for_token} (symbols: {atom_symbols}) with attribution {attribution:.6f}")
                else:
                    print(f"  ✗ No atoms available for token '{token}'")

                if token_mol:
                    del token_mol

        # 打印最终的原子贡献
        print(f"\n--- FINAL ATOM CONTRIBUTIONS ---")
        print(f"Total assigned atoms: {len(assigned_atoms)} / {mol.GetNumAtoms()}")

        if atom_contributions:
            print("Atom contributions:")
            for atom_idx in range(mol.GetNumAtoms()):
                contribution = atom_contributions.get(atom_idx, 0.0)
                atom_symbol = mol.GetAtoms()[atom_idx].GetSymbol()
                status = "✓" if atom_idx in assigned_atoms else "✗"
                print(f"  Atom[{atom_idx}] ({atom_symbol}): {contribution:.6f} {status}")
        else:
            print("No atom contributions found!")

        unassigned = set(range(mol.GetNumAtoms())) - assigned_atoms
        if unassigned:
            unassigned_symbols = [mol.GetAtoms()[idx].GetSymbol() for idx in unassigned]
            print(f"\n⚠️  UNASSIGNED ATOMS: {sorted(unassigned)} (symbols: {unassigned_symbols})")

        print(f"{'='*50}\n")

        del mol
        return atom_contributions

    def compute_attribution_scores(self, input_text: str, target_text: str) -> Tuple[np.ndarray, List[str], List[str]]:
        """计算归因分数"""
        print("Computing attributions...")

        # 编码输入
        input_encoding = self.tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=256)
        target_encoding = self.tokenizer(target_text, return_tensors="pt", padding=True, truncation=True, max_length=256)

        input_ids = input_encoding['input_ids'].to(self.device)
        target_ids = target_encoding['input_ids'].to(self.device)

        input_tokens = self.tokenizer.convert_ids_to_tokens(input_ids[0])
        target_tokens = self.tokenizer.convert_ids_to_tokens(target_ids[0])

        # 准备decoder输入
        decoder_input_ids = target_ids.clone()
        if decoder_input_ids[0, 0] != self.tokenizer.bos_token_id:
            decoder_input_ids = torch.cat([
                torch.tensor([[self.tokenizer.bos_token_id]], device=self.device), decoder_input_ids
            ], dim=1)

        # 获取embeddings
        input_embeds = self.model.get_encoder().embed_tokens(input_ids)
        baseline_embeds = self.model.get_encoder().embed_tokens(
            torch.full_like(input_ids, self.tokenizer.pad_token_id)
        )

        max_target_len = min(len(target_tokens), decoder_input_ids.shape[1] - 1)
        max_input_len = min(len(input_tokens), input_embeds.shape[1])

        # 计算归因矩阵
        attribution_matrix = []
        for target_pos in range(max_target_len):
            def forward_func(input_embeds_var):
                decoder_embeds = self.model.get_decoder().embed_tokens(decoder_input_ids)
                outputs = self.model(
                    inputs_embeds=input_embeds_var,
                    decoder_inputs_embeds=decoder_embeds,
                    return_dict=True, use_cache=False
                )
                if target_pos < outputs.logits.shape[1]:
                    return outputs.logits[:, target_pos, :].max(dim=-1)[0]
                return torch.zeros(outputs.logits.shape[0], device=outputs.logits.device)

            try:
                ig = IntegratedGradients(forward_func)
                attributions = ig.attribute(
                    inputs=input_embeds, baselines=baseline_embeds,
                    n_steps=50, internal_batch_size=1
                )
                token_attributions = attributions.sum(dim=-1).squeeze(0)[:max_input_len]
                attribution_matrix.append(token_attributions.detach().cpu().numpy())
                del attributions, token_attributions
            except Exception as e:
                print(f"Warning: Failed at position {target_pos}: {e}")
                attribution_matrix.append(np.zeros(max_input_len))
            clear_memory()

        attribution_matrix = np.array(attribution_matrix)
        attribution_matrix = np.abs(attribution_matrix)

        # 零化'.'token的归因
        dot_positions = [i for i, token in enumerate(input_tokens[:max_input_len]) if token == '.']
        if dot_positions:
            attribution_matrix[:, dot_positions] = 0.0

        # 归一化
        if attribution_matrix.max() > 0:
            attribution_matrix = attribution_matrix / attribution_matrix.max()

        # 清理
        del input_encoding, target_encoding, input_ids, target_ids, input_embeds, baseline_embeds, decoder_input_ids
        clear_memory()

        return attribution_matrix, input_tokens, target_tokens

    def create_sample_directory(self, sample_name: str) -> str:
        """创建样本目录，使用reactant名称而非序号"""
        # 将SMILES中的特殊字符替换为下划线
        safe_name = re.sub(r'[<>:"/\\|?*]', '_', sample_name)[:50]  # 限制长度
        sample_dir = os.path.join(self.output_dir, f"reactant_{safe_name}")
        os.makedirs(sample_dir, exist_ok=True)
        return sample_dir

    def create_visualizations_from_attributions(self, attributions: np.ndarray, input_tokens: List[str],
                                              target_tokens: List[str], reactant_smiles: str) -> np.ndarray:
        """创建可视化图表"""
        # 过滤token
        input_display = self.filter_tokens(input_tokens)
        target_display = self.filter_tokens(target_tokens)

        input_mask = [i for i, token in enumerate(input_tokens) if token in input_display]
        target_mask = [i for i, token in enumerate(target_tokens) if token in target_display]

        if input_mask and target_mask:
            attr_display = attributions[np.ix_(target_mask, input_mask)]
        else:
            attr_display = attributions

        sample_dir = self.create_sample_directory(reactant_smiles)

        # 1. Token Attribution Heatmap
        self._create_heatmap(attr_display, input_display, target_display, sample_dir)

        # 2. Global Attribution Plot
        global_attr = self._create_global_plot(attributions, input_tokens, input_display, input_mask, sample_dir)

        return global_attr

    def _create_heatmap(self, attributions: np.ndarray, input_tokens: List[str],
                       target_tokens: List[str], sample_dir: str):
        """创建热力图"""
        fig, ax = plt.subplots(figsize=(max(16, len(input_tokens) * 0.8),
                                       max(12, len(target_tokens) * 0.8)))

        # 归一化
        attr_norm = attributions
        if attributions.max() > attributions.min():
            attr_norm = (attributions - attributions.min()) / (attributions.max() - attributions.min())

        # 创建colormap
        from matplotlib.colors import LinearSegmentedColormap
        colors = ['#F1FAEE', COLORS['light_blue'], COLORS['primary'], COLORS['dark_blue']]
        cmap = LinearSegmentedColormap.from_list('custom_blues', colors, N=100)

        im = ax.imshow(attr_norm, cmap=cmap, aspect='auto', interpolation='nearest')

        # 设置标签
        ax.set_xticks(range(len(input_tokens)))
        ax.set_xticklabels(input_tokens, rotation=45, ha='right', fontsize=35, color='black')
        ax.set_yticks(range(len(target_tokens)))
        ax.set_yticklabels(target_tokens, fontsize=35, color='black')
        ax.tick_params(axis='both', which='major', labelsize=35, length=8, width=1.5)

        # 添加数值
        for i in range(len(target_tokens)):
            for j in range(len(input_tokens)):
                if j < attributions.shape[1] and i < attributions.shape[0]:
                    text_color = 'white' if attr_norm[i, j] > 0.6 else COLORS['dark_blue']
                    ax.text(j, i, f'{attributions[i, j]:.2f}',
                           ha="center", va="center", color=text_color, fontsize=45, fontweight='bold')

        ax.set_xlabel('Input Tokens (Reactant)', fontweight='bold', fontsize=36, color=COLORS['dark_blue'])
        ax.set_ylabel('Output Tokens (Product)', fontweight='bold', fontsize=36, color=COLORS['dark_blue'])
        ax.set_title('Token-to-Token Attribution Heatmap',
                    fontweight='bold', pad=24, fontsize=36, color=COLORS['dark_blue'])

        cbar = plt.colorbar(im, ax=ax, shrink=0.8)
        cbar.set_label('Attribution Score', fontweight='bold', fontsize=36, color=COLORS['dark_blue'])
        cbar.ax.tick_params(labelsize=26, width=1.5, length=6, colors=COLORS['dark_blue'])

        plt.tight_layout()
        plt.savefig(os.path.join(sample_dir, "token_attribution_heatmap.png"), dpi=600, bbox_inches='tight', facecolor='white')
        plt.close(fig)
        del fig, ax, attr_norm, im
        clear_memory()

    def _create_global_plot(self, attributions: np.ndarray, input_tokens: List[str],
                           input_display: List[str], input_mask: List[int], sample_dir: str) -> np.ndarray:
        """创建全局归因图，包含"."token并设置其score为0"""
        global_attr_full = np.mean(np.abs(attributions), axis=0)

        # 零化'.'token的归因
        dot_positions = [i for i, token in enumerate(input_tokens) if token == '.']
        if dot_positions:
            global_attr_full[dot_positions] = 0.0

        # 创建包含"."token的显示列表
        input_display_with_dots = []
        global_attr_with_dots = []

        for i, token in enumerate(input_tokens):
            if token == '.' or token in input_display:
                input_display_with_dots.append(token)
                if i < len(global_attr_full):
                    if token == '.':
                        global_attr_with_dots.append(0.0)  # 确保"."的score为0
                    else:
                        global_attr_with_dots.append(global_attr_full[i])
                else:
                    global_attr_with_dots.append(0.0)

        global_attr_with_dots = np.array(global_attr_with_dots)

        print(f"\n--- GLOBAL ATTRIBUTION SCORES (INCLUDING DOTS) ---")
        print("Input tokens and their global attribution scores:")
        for i, (token, score) in enumerate(zip(input_display_with_dots, global_attr_with_dots)):
            print(f"  Token[{i}]: '{token}' -> Global Attribution: {score:.6f}")
        print(f"{'='*40}\n")

        fig, ax = plt.subplots(figsize=(max(14, len(input_display_with_dots) * 0.9), 7))

        if len(global_attr_with_dots) > 0 and global_attr_with_dots.max() > 0:
            normalized_values = global_attr_with_dots / global_attr_with_dots.max()
            bar_colors = []
            for i, (token, val) in enumerate(zip(input_display_with_dots, normalized_values)):
                if token == '.':
                    bar_colors.append(COLORS['neutral'])  # "."用灰色
                else:
                    bar_colors.append(plt.cm.Blues(0.3 + 0.7 * val))
        else:
            bar_colors = []
            for token in input_display_with_dots:
                if token == '.':
                    bar_colors.append(COLORS['neutral'])
                else:
                    bar_colors.append(COLORS['light_blue'])

        bars = ax.bar(range(len(input_display_with_dots)), global_attr_with_dots,
                     color=bar_colors, alpha=0.9, edgecolor=COLORS['dark_blue'], linewidth=1.5)

        # 添加数值标签
        for bar, attr, token in zip(bars, global_attr_with_dots, input_display_with_dots):
            height = bar.get_height()
            if token == '.':
                # 为"."添加特殊标记
                ax.text(bar.get_x() + bar.get_width()/2., height + max(height * 0.03, 0.01),
                       '0.000\n(separator)', ha='center', va='bottom', fontweight='bold',
                       fontsize=16, color=COLORS['neutral'])
            else:
                ax.text(bar.get_x() + bar.get_width()/2., height + max(height * 0.03, 0.01),
                       f'{attr:.3f}', ha='center', va='bottom', fontweight='bold',
                       fontsize=18, color=COLORS['dark_blue'])

        ax.set_xlabel('Input Tokens (Reactant)', fontweight='bold', fontsize=28, color=COLORS['dark_blue'])
        ax.set_ylabel('Global Attribution Score', fontweight='bold', fontsize=28, color=COLORS['dark_blue'])
        ax.set_title('Global Token Attribution (Including Separators)', fontweight='bold', pad=24, fontsize=30, color=COLORS['dark_blue'])
        ax.set_xticks(range(len(input_display_with_dots)))
        ax.set_xticklabels(input_display_with_dots, rotation=45, ha='right', fontsize=26, color='black')
        ax.tick_params(axis='y', labelsize=20, width=1.5, length=8, colors=COLORS['dark_blue'])

        for spine in ['bottom', 'left']:
            ax.spines[spine].set_color(COLORS['dark_blue'])
        for spine in ['top', 'right']:
            ax.spines[spine].set_visible(False)

        plt.tight_layout()
        plt.savefig(os.path.join(sample_dir, "global_attribution.png"), dpi=600, bbox_inches='tight', facecolor='white')
        plt.close(fig)
        del fig, ax, bars
        clear_memory()

        return global_attr_full

    def create_molecule_image(self, smiles: str, size=(700, 500)) -> Image.Image:
        """创建分子图像"""
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            return self._create_error_image(smiles, size)

        try:
            drawer = rdMolDraw2D.MolDraw2DCairo(size[0], size[1])
            drawer.SetFontSize(12)
            drawer.drawOptions().addAtomIndices = False
            drawer.DrawMolecule(mol)
            drawer.FinishDrawing()
            png_data = drawer.GetDrawingText()
            img = Image.open(io.BytesIO(png_data))
            del mol, drawer, png_data
            return img
        except Exception as e:
            print(f"Warning: Could not create molecule image: {e}")
            if mol:
                del mol
            return self._create_error_image(smiles, size)

    def _create_error_image(self, smiles: str, size=(600, 400)) -> Image.Image:
        """创建错误SMILES图像"""
        img = Image.new('RGB', size, 'white')
        draw = ImageDraw.Draw(img)

        # 绘制边框
        draw.rectangle([1, 1, size[0]-2, size[1]-2], outline=COLORS['error'], width=3)

        # 添加文字
        center_x, center_y = size[0] // 2, size[1] // 2

        font_large = get_font(28)
        title = "INVALID SMILES"
        title_bbox = draw.textbbox((0, 0), title, font=font_large)
        title_width = title_bbox[2] - title_bbox[0]
        draw.text((center_x - title_width//2, center_y - 80), title, fill=COLORS['error'], font=font_large)

        font_small = get_font(24)
        formatted_smiles = format_smiles_multiline(smiles, 30)
        lines = formatted_smiles.split('\n')

        start_y = center_y - len(lines) * 15 + 30
        for i, line in enumerate(lines):
            line_bbox = draw.textbbox((0, 0), line, font=font_small)
            line_width = line_bbox[2] - line_bbox[0]
            draw.text((center_x - line_width//2, start_y + i * 30), line, fill=COLORS['dark_blue'], font=font_small)

        return img

    def create_highlighted_molecule_image(self, smiles: str, global_attributions: np.ndarray,
                                        input_tokens: List[str], size=(800, 600)) -> Image.Image:
        """创建高亮分子图像"""
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            return self._create_error_image(smiles, size)

        try:
            # 过滤掉"."token，因为它不参与分子结构映射
            input_display = self.filter_tokens(input_tokens)

            # 创建原始token到过滤后token的映射（排除"."）
            token_to_filtered_map = {}
            filtered_index = 0
            valid_attributions = []

            for i, token in enumerate(input_tokens):
                if token in input_display:  # 已经排除了"."
                    token_to_filtered_map[i] = filtered_index
                    if i < len(global_attributions):
                        valid_attributions.append(global_attributions[i])
                    else:
                        valid_attributions.append(0.0)
                    filtered_index += 1

            valid_attributions = np.array(valid_attributions)

            # 映射token到原子
            atom_contributions = self.map_tokens_to_atoms(smiles, input_display, valid_attributions)

            # 创建原子高亮和化学键高亮
            atom_highlights = {}
            bond_highlights = {}
            num_atoms = mol.GetNumAtoms()

            print(f"\n--- ATOM AND BOND HIGHLIGHTING DETAILS ---")
            if atom_contributions:
                max_contribution = max(atom_contributions.values()) or 1.0
                print(f"Maximum contribution value: {max_contribution:.6f}")
                print(f"Atom highlighting details:")

                # 计算原子颜色
                for atom_idx in range(num_atoms):
                    contribution = atom_contributions.get(atom_idx, 0.0)
                    normalized_contribution = contribution / max_contribution
                    intensity = 0.2 + 0.8 * normalized_contribution

                    from matplotlib.colors import LinearSegmentedColormap
                    cmap = LinearSegmentedColormap.from_list('custom_blues',
                                                           [COLORS['background'], COLORS['dark_blue']])
                    rgba = cmap(intensity)
                    rgb = tuple(rgba[:3])
                    atom_highlights[atom_idx] = rgb

                    atom_symbol = mol.GetAtoms()[atom_idx].GetSymbol()
                    print(f"  Atom[{atom_idx}] ({atom_symbol}): contribution={contribution:.6f}, "
                          f"normalized={normalized_contribution:.3f}, intensity={intensity:.3f}, "
                          f"RGB=({rgb[0]:.3f}, {rgb[1]:.3f}, {rgb[2]:.3f})")

                # 计算化学键颜色（取连接原子的平均归因值）
                print(f"\nBond highlighting details:")
                bond_list = []
                for bond in mol.GetBonds():
                    bond_idx = bond.GetIdx()
                    atom1_idx = bond.GetBeginAtomIdx()
                    atom2_idx = bond.GetEndAtomIdx()

                    # 获取两个原子的归因值
                    contrib1 = atom_contributions.get(atom1_idx, 0.0)
                    contrib2 = atom_contributions.get(atom2_idx, 0.0)

                    # 计算平均值作为化学键的归因值
                    bond_contribution = (contrib1 + contrib2) / 2.0
                    normalized_bond_contribution = bond_contribution / max_contribution
                    bond_intensity = 0.2 + 0.8 * normalized_bond_contribution

                    rgba = cmap(bond_intensity)
                    rgb = tuple(rgba[:3])
                    bond_highlights[bond_idx] = rgb
                    bond_list.append(bond_idx)

                    atom1_symbol = mol.GetAtoms()[atom1_idx].GetSymbol()
                    atom2_symbol = mol.GetAtoms()[atom2_idx].GetSymbol()
                    print(f"  Bond[{bond_idx}] ({atom1_symbol}{atom1_idx}-{atom2_symbol}{atom2_idx}): "
                          f"atom1_contrib={contrib1:.6f}, atom2_contrib={contrib2:.6f}, "
                          f"avg_contrib={bond_contribution:.6f}, intensity={bond_intensity:.3f}, "
                          f"RGB=({rgb[0]:.3f}, {rgb[1]:.3f}, {rgb[2]:.3f})")

            else:
                print("No atom contributions found - using default highlighting")
                for atom_idx in range(num_atoms):
                    atom_highlights[atom_idx] = (0.8, 0.9, 1.0)

                bond_list = []
                for bond in mol.GetBonds():
                    bond_idx = bond.GetIdx()
                    bond_highlights[bond_idx] = (0.8, 0.9, 1.0)
                    bond_list.append(bond_idx)

            print(f"{'='*40}\n")

            # 绘制分子
            drawer = rdMolDraw2D.MolDraw2DCairo(size[0], size[1])
            drawer.SetFontSize(10)
            drawer.drawOptions().addAtomIndices = False

            # 同时高亮原子和化学键
            if atom_highlights and bond_highlights:
                drawer.DrawMolecule(mol,
                                  highlightAtoms=list(atom_highlights.keys()),
                                  highlightAtomColors=atom_highlights,
                                  highlightBonds=bond_list,
                                  highlightBondColors=bond_highlights)
            elif atom_highlights:
                drawer.DrawMolecule(mol, highlightAtoms=list(atom_highlights.keys()),
                                  highlightAtomColors=atom_highlights)
            else:
                drawer.DrawMolecule(mol)

            drawer.FinishDrawing()
            png_data = drawer.GetDrawingText()
            img = Image.open(io.BytesIO(png_data))

            # 添加统计信息
            new_img = Image.new('RGB', (size[0], size[1] + 60), 'white')
            new_img.paste(img, (0, 0))

            del mol, drawer, png_data, img
            return new_img

        except Exception as e:
            print(f"Warning: Could not create highlighted molecule: {e}")
            if mol:
                del mol
            return self.create_molecule_image(smiles, size)

    def _create_prediction_image(self, pred: Dict, true_products: Set[str], size=(550, 320)) -> Image.Image:
        """创建单个预测图像"""
        if pred['is_valid'] and pred['sanitized_smiles']:
            mol_img = self.create_molecule_image(pred['sanitized_smiles'], (550, 180))
            img = Image.new('RGB', size, 'white')
            img.paste(mol_img, (0, 0))

            draw = ImageDraw.Draw(img)
            font = get_font(26)
            font_small = get_font(22)

            is_correct = pred['sanitized_smiles'] in true_products
            status = f"Rank {pred['rank']}: {'CORRECT' if is_correct else 'Valid'}"
            color = COLORS['correct_highlight'] if is_correct else COLORS['primary']

            draw.text((8, 190), status, fill=color, font=font)

            formatted_smiles = format_smiles_multiline(pred['sanitized_smiles'], 35)
            for j, line in enumerate(formatted_smiles.split('\n')):
                text = f"{'SMILES: ' if j == 0 else '        '}{line}"
                draw.text((8, 220 + j * 25), text, fill=COLORS['dark_blue'], font=font_small)

            del mol_img
        else:
            error_img = self._create_error_image(pred['raw_smiles'], (550, 180))
            img = Image.new('RGB', size, 'white')
            img.paste(error_img, (0, 0))

            draw = ImageDraw.Draw(img)
            font = get_font(26)
            font_small = get_font(22)

            draw.text((8, 190), f"Rank {pred['rank']}: Invalid", fill=COLORS['error'], font=font)

            formatted_smiles = format_smiles_multiline(pred['raw_smiles'], 35)
            for j, line in enumerate(formatted_smiles.split('\n')):
                text = f"{'SMILES: ' if j == 0 else '        '}{line}"
                draw.text((8, 220 + j * 25), text, fill=COLORS['error'], font=font_small)

            del error_img

        return img

    def create_reaction_comparison_plot(self, reactant_smiles: str, predictions: List[Dict],
                                      true_products: Set[str], global_attributions: np.ndarray,
                                      input_tokens: List[str], reaction_type: str = "Unknown"):
        """创建反应对比图"""
        print(f"Creating comparison plot for reactant: {reactant_smiles[:50]}...")

        sample_dir = self.create_sample_directory(reactant_smiles)

        fig = plt.figure(figsize=(28, 16))
        gs = fig.add_gridspec(3, 3, height_ratios=[0.5, 4, 0.3], width_ratios=[1, 1, 1],
                             hspace=0.12, wspace=0.08, left=0.03, right=0.97, top=0.95, bottom=0.05)

        fig.suptitle('Chemical Reaction Prediction Analysis',
                    fontweight='bold', fontsize=32, y=0.98, color=COLORS['dark_blue'])

        # 第一列：Reactant
        ax_reactant_title = fig.add_subplot(gs[0, 0])
        formatted_reactant = format_smiles_multiline(reactant_smiles, 40)
        ax_reactant_title.text(0.5, 0.5, f'Input Reactant\n{formatted_reactant}',
                              ha='center', va='center', transform=ax_reactant_title.transAxes,
                              fontsize=22, fontweight='bold',
                              bbox=dict(boxstyle="round,pad=0.3", facecolor=COLORS['light_blue'], alpha=0.8))
        ax_reactant_title.axis('off')

        ax_reactant_mol = fig.add_subplot(gs[1, 0])
        highlighted_img = self.create_highlighted_molecule_image(
            reactant_smiles, global_attributions, input_tokens, size=(700, 560)
        )
        ax_reactant_mol.imshow(highlighted_img)
        ax_reactant_mol.set_title('Token-to-Token Highlighted Reactant\n(Darker Blue = Higher Attribution)',
                                 fontweight='bold', fontsize=26, pad=12, color=COLORS['dark_blue'])
        ax_reactant_mol.axis('off')

        # 第二列：Predictions
        ax_pred_title = fig.add_subplot(gs[0, 1])
        ax_pred_title.text(0.5, 0.5, 'Model Predictions (Top 3)', ha='center', va='center',
                          transform=ax_pred_title.transAxes, fontsize=22, fontweight='bold',
                          bbox=dict(boxstyle="round,pad=0.3", facecolor=COLORS['background'], alpha=0.8))
        ax_pred_title.axis('off')

        ax_pred_mol = fig.add_subplot(gs[1, 1])

        if predictions:
            pred_images = [self._create_prediction_image(pred, true_products) for pred in predictions[:3]]
            if pred_images:
                combined_img = self._combine_images_vertically(pred_images)
                ax_pred_mol.imshow(combined_img)
                del pred_images, combined_img
        else:
            ax_pred_mol.text(0.5, 0.5, 'No predictions available', ha='center', va='center',
                            transform=ax_pred_mol.transAxes, fontsize=20, fontweight='bold',
                            color=COLORS['error'])

        ax_pred_mol.set_title('Predicted Products', fontweight='bold', fontsize=26, pad=12, color=COLORS['dark_blue'])
        ax_pred_mol.axis('off')

        # 第三列：Ground Truth
        ax_true_title = fig.add_subplot(gs[0, 2])
        ax_true_title.text(0.5, 0.5, f'Ground Truth Products\n({len(true_products)} total)',
                          ha='center', va='center', transform=ax_true_title.transAxes,
                          fontsize=22, fontweight='bold',
                          bbox=dict(boxstyle="round,pad=0.3", facecolor=COLORS['success'], alpha=0.3))
        ax_true_title.axis('off')

        ax_true_mol = fig.add_subplot(gs[1, 2])

        if true_products:
            true_images = []
            for i, product in enumerate(sorted(list(true_products))[:3]):
                mol_img = self.create_molecule_image(product, (550, 180))
                img = Image.new('RGB', (550, 320), 'white')
                img.paste(mol_img, (0, 0))

                draw = ImageDraw.Draw(img)
                font = get_font(26)
                font_small = get_font(22)

                draw.text((8, 190), f"True Product {i+1}", fill=COLORS['success'], font=font)

                formatted_smiles = format_smiles_multiline(product, 35)
                for j, line in enumerate(formatted_smiles.split('\n')):
                    text = f"{'SMILES: ' if j == 0 else '        '}{line}"
                    draw.text((8, 220 + j * 25), text, fill=COLORS['dark_blue'], font=font_small)

                true_images.append(img)
                del mol_img

            if true_images:
                combined_img = self._combine_images_vertically(true_images)
                ax_true_mol.imshow(combined_img)
                del true_images, combined_img
        else:
            ax_true_mol.text(0.5, 0.5, 'No ground truth\nproducts available',
                            ha='center', va='center', transform=ax_true_mol.transAxes,
                            fontsize=20, fontweight='bold', color=COLORS['neutral'])

        ax_true_mol.set_title('Ground Truth Products', fontweight='bold', fontsize=26, pad=12, color=COLORS['dark_blue'])
        ax_true_mol.axis('off')

        # 底部统计
        ax_stats = fig.add_subplot(gs[2, :])
        valid_count = sum(1 for p in predictions if p['is_valid'])
        correct_count = sum(1 for p in predictions if p['is_valid'] and p['sanitized_smiles'] in true_products)

        stats_text = (f"Performance: Valid ({valid_count}/{len(predictions)}) | "
                     f"Correct ({correct_count}/{len(predictions)}) | "
                     f"Reaction Type: {reaction_type}")

        ax_stats.text(0.5, 0.5, stats_text, ha='center', va='center',
                     transform=ax_stats.transAxes, fontsize=24, fontweight='bold',
                     bbox=dict(boxstyle="round,pad=0.4", facecolor=COLORS['background'], alpha=0.8),
                     color=COLORS['dark_blue'])
        ax_stats.axis('off')

        plt.tight_layout()
        save_path = os.path.join(sample_dir, "reaction_comparison.png")
        plt.savefig(save_path, dpi=600, bbox_inches='tight', facecolor='white', pad_inches=0.03)
        plt.close(fig)

        del fig, highlighted_img
        clear_memory()
        print(f"Saved: {save_path}")

    def _combine_images_vertically(self, images: List[Image.Image]) -> Image.Image:
        """垂直组合图像"""
        total_height = sum(img.height for img in images)
        max_width = max(img.width for img in images)
        combined = Image.new('RGB', (max_width, total_height), 'white')

        y_offset = 0
        for img in images:
            combined.paste(img, (0, y_offset))
            y_offset += img.height
            del img

        return combined

    def analyze_reaction(self, reactant_smiles: str, true_products: Set[str],
                        reaction_type: str = "Unknown") -> Dict:
        """分析反应"""
        print(f"\n{'='*60}")
        print(f"Analyzing reactant: {reactant_smiles}")
        print(f"Reaction type: {reaction_type}")
        print(f"{'='*60}")

        # 预测
        predictions = self.beam_search_predict(reactant_smiles)
        best_pred = next((p for p in predictions if p['is_valid']), None)
        target_smiles = best_pred['sanitized_smiles'] if best_pred else (predictions[0]['raw_smiles'] if predictions else reactant_smiles)

        print(f"Using target: {target_smiles}")

        # 计算归因
        attributions, input_tokens, target_tokens = self.compute_attribution_scores(reactant_smiles, target_smiles)

        # 创建可视化
        global_attr = self.create_visualizations_from_attributions(
            attributions, input_tokens, target_tokens, reactant_smiles
        )

        # 创建对比图
        self.create_reaction_comparison_plot(
            reactant_smiles, predictions, true_products, global_attr, input_tokens, reaction_type
        )

        clear_memory()

        return {
            'reactant': reactant_smiles, 'predictions': predictions, 'true_products': true_products,
            'best_prediction': best_pred, 'attributions': attributions,
            'global_attributions': global_attr, 'input_tokens': input_tokens,
            'target_tokens': target_tokens, 'reaction_type': reaction_type, 'analysis_complete': True
        }

def main_analysis(model, tokenizer, device, data_path, dataset_split, reactant_smiles):
    """主分析函数 - 手动输入reactant进行分析"""
    print("Initializing Chemical Reaction Explainer...")
    explainer = ChemicalReactionExplainer(model, tokenizer, device)

    print(f"Loading data from '{data_path}' using {dataset_split} split...")
    try:
        dataset = load_from_disk(data_path)[dataset_split]
        products_map, types_map = group_data_mappings(dataset)
        print(f"Loaded {len(dataset)} samples from {dataset_split} split, {len(products_map)} unique reactants")
    except Exception as e:
        print(f"Error loading data: {e}")
        return

    # 检查输入的reactant是否存在于数据集中
    if reactant_smiles not in products_map:
        print(f"Warning: Reactant '{reactant_smiles}' not found in the {dataset_split} dataset.")
        print("Available reactants (first 10):")
        for i, reactant in enumerate(list(products_map.keys())[:10]):
            print(f"  {i+1}: {reactant}")
        print(f"Total available reactants: {len(products_map)}")

        # 选择是否继续分析
        user_choice = input("Continue analysis anyway? (y/n): ").strip().lower()
        if user_choice != 'y':
            return

        # 使用空的产品集合
        true_products = set()
        reaction_type = "Unknown"
    else:
        true_products = products_map[reactant_smiles]
        reaction_type = types_map.get(reactant_smiles, "Unknown")
        print(f"Found reactant in dataset with {len(true_products)} product(s) and reaction type: {reaction_type}")

    try:
        print(f"\nStarting analysis for reactant: {reactant_smiles}")
        result = explainer.analyze_reaction(reactant_smiles, true_products, reaction_type)
        print(f"✓ Analysis completed successfully")

        print(f"\n{'='*60}")
        print(f"Analysis Summary:")
        print(f"Reactant: {reactant_smiles}")
        print(f"Reaction Type: {reaction_type}")
        print(f"True Products: {len(true_products)}")
        print(f"Valid Predictions: {sum(1 for p in result['predictions'] if p['is_valid'])}")
        print(f"Results saved in: {explainer.output_dir}")

        return result

    except Exception as e:
        print(f"✗ Error during analysis: {e}")
        traceback.print_exc()
        return None
    finally:
        clear_memory()

# 使用示例
if __name__ == "__main__":
    BASE_MODEL_NAME = "google/flan-t5-base"
    ADAPTER_MODEL_PATH = "./best_model_multi_eval_v3_correct_loss/"
    DATA_PATH = './data/data'
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

    # 加载模型
    from transformers import BitsAndBytesConfig
    from peft import PeftModel

    quantization_config = BitsAndBytesConfig(
        load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
    )

    base_model = AutoModelForSeq2SeqLM.from_pretrained(
        BASE_MODEL_NAME, quantization_config=quantization_config, device_map={"": DEVICE}
    )

    tokenizer = AutoTokenizer.from_pretrained(ADAPTER_MODEL_PATH)
    base_model.resize_token_embeddings(len(tokenizer))
    model = PeftModel.from_pretrained(base_model, ADAPTER_MODEL_PATH)

    # 配置模型
    for attr in ['bos_token_id', 'eos_token_id', 'pad_token_id', 'decoder_start_token_id']:
        setattr(model.config, attr, getattr(tokenizer, attr.replace('decoder_start_', 'pad_')))

    model.eval()

    # 手动配置参数
    DATASET_SPLIT = "test"  # 可选: "train", "test", "validation"
    REACTANT_SMILES = "C=CC(C)C.[OH]"  # 手动输入要分析的反应物SMILES

    print(f"Configuration:")
    print(f"  Dataset Split: {DATASET_SPLIT}")
    print(f"  Reactant SMILES: {REACTANT_SMILES}")
    print(f"  Device: {DEVICE}")

    # 运行分析
    result = main_analysis(model, tokenizer, DEVICE, DATA_PATH, DATASET_SPLIT, REACTANT_SMILES)

    if result:
        print("Analysis completed successfully!")
    else:
        print("Analysis failed!")

Configuration:
  Dataset Split: test
  Reactant SMILES: C=CC(C)C.[OH]
  Device: cuda
Initializing Chemical Reaction Explainer...
Loading data from './data/data' using test split...


Grouping data: 100%|██████████| 508/508 [00:00<00:00, 28178.36it/s]

Loaded 508 samples from test split, 429 unique reactants
Found reactant in dataset with 2 product(s) and reaction type: addition

Starting analysis for reactant: C=CC(C)C.[OH]

Analyzing reactant: C=CC(C)C.[OH]
Reaction type: addition





Using target: [CH2]C(O)C(C)C
Computing attributions...

--- GLOBAL ATTRIBUTION SCORES (INCLUDING DOTS) ---
Input tokens and their global attribution scores:
  Token[0]: 'C=CC' -> Global Attribution: 0.627311
  Token[1]: '(C)C' -> Global Attribution: 0.297719
  Token[2]: '.' -> Global Attribution: 0.000000
  Token[3]: '[OH]' -> Global Attribution: 0.122509


--- GLOBAL ATTRIBUTION SCORES (INCLUDING DOTS) ---
Input tokens and their global attribution scores:
  Token[0]: 'C=CC' -> Global Attribution: 0.627311
  Token[1]: '(C)C' -> Global Attribution: 0.297719
  Token[2]: '.' -> Global Attribution: 0.000000
  Token[3]: '[OH]' -> Global Attribution: 0.122509

Creating comparison plot for reactant: C=CC(C)C.[OH]...

TOKEN TO ATOM MAPPING
SMILES: C=CC(C)C.[OH]
Number of atoms in molecule: 6
Atom symbols: ['C', 'C', 'C', 'C', 'C', 'O']

Filtered tokens and their attribution scores:
  Token[0]: 'C=CC' -> Attribution: 0.627311
  Token[1]: '(C)C' -> Attribution: 0.297719
  Token[2]: '[OH]' -> Att