# 晶体结构生成推理

基于Flow Matching模型的晶体结构生成和优化流程：
1. 读取A_sample.csv
2. 多进程并行生成和评估
3. 迭代优化直到满足Rwp要求
4. 实时更新submission.csv

In [None]:
import os
import sys
import time
import pandas as pd
import numpy as np
import torch
from pathlib import Path
from multiprocessing import Process, Queue, Manager, Value
from threading import Lock
from tqdm.auto import tqdm
from rich.console import Console
from rich.table import Table
from rich.live import Live
from rich.layout import Layout
from rich.panel import Panel
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TaskProgressColumn
import warnings
warnings.filterwarnings('ignore')

# 导入自定义模块
from inference_utils import (
    generation_worker, simulation_worker, postprocess_worker,
    submission_manager, load_sample_data, check_termination
)

console = Console()

## 配置参数

In [None]:
# 基础配置
CONFIG = {
    # 文件路径
    'sample_path': 'data/A_sample/A_sample.csv',  # 输入样本文件
    'submission_path': 'submission.csv',  # 输出提交文件
    'checkpoint_path': 'checkpoints/best_model.ckpt',  # 模型检查点
    
    # 进程配置
    'num_cpu_workers': max(1, os.cpu_count() - 2),  # CPU仿真进程数
    'batch_size': 32,  # GPU生成批次大小
    
    # 终止条件
    'max_hours': 5.0,  # 最大运行时间（小时）
    'max_attempts_per_sample': 10,  # 每个样本最大生成次数
    'target_rwp': 0.5,  # 目标Rwp阈值
    'rwp_threshold': 5.0,  # 需要重新生成的Rwp阈值
    
    # 超时设置
    'simulation_timeout': 120,  # PXRD仿真超时（秒）
    
    # GPU设置
    'generation_gpu': 0,  # 生成使用的GPU
    'postprocess_gpu': 0,  # 后处理使用的GPU（可以相同）
}

console.print(Panel.fit(f"配置参数:\n{CONFIG}", title="推理配置"))

## 初始化数据

In [None]:
# 加载样本数据
console.print("[bold blue]加载样本数据...[/bold blue]")

try:
    sample_ids, pxrd_data, comp_data = load_sample_data(CONFIG['sample_path'])
    num_samples = len(sample_ids)
    console.print(f"[green]✓ 成功加载 {num_samples} 个样本[/green]")
except FileNotFoundError:
    # 如果文件不存在，创建测试数据
    console.print("[yellow]⚠ 样本文件不存在，创建测试数据[/yellow]")
    
    # 创建目录
    os.makedirs('data/A_sample', exist_ok=True)
    
    # 创建测试样本
    num_samples = 10
    sample_ids = [f"sample_{i:04d}" for i in range(num_samples)]
    
    # 创建测试PXRD数据和成分数据
    pxrd_data = {}
    comp_data = {}
    for sid in sample_ids:
        pxrd_data[sid] = np.random.randn(11501)
        # 随机生成成分数据
        num_atoms = np.random.randint(5, 20)
        comp = np.random.randint(1, 90, size=num_atoms)
        comp = np.pad(comp, (0, 60 - len(comp)), constant_values=0)
        comp_data[sid] = comp
    
    # 保存测试数据
    test_df = pd.DataFrame({
        'id': sample_ids,
        'pxrd': [str(list(pxrd_data[sid])) for sid in sample_ids],
        'comp': [str(list(comp_data[sid])) for sid in sample_ids]
    })
    test_df.to_csv(CONFIG['sample_path'], index=False)
    console.print(f"[green]✓ 创建 {num_samples} 个测试样本[/green]")

# 初始化submission.csv
initial_submission = pd.DataFrame({
    'id': sample_ids,
    'cif': [''] * num_samples,
    'rwp': [9999] * num_samples
})
initial_submission.to_csv(CONFIG['submission_path'], index=False)
console.print(f"[green]✓ 初始化 submission.csv[/green]")

## 主控制流程

In [None]:
def run_inference():
    """
    主推理流程
    协调多个进程完成生成、仿真、后处理和提交
    """
    # 创建共享资源
    manager = Manager()
    task_queue = Queue(maxsize=100)  # 生成任务队列
    result_queue = Queue(maxsize=100)  # 生成结果队列
    rwp_queue = Queue(maxsize=100)  # Rwp评估队列
    postprocess_queue = Queue(maxsize=100)  # 后处理队列
    final_queue = Queue(maxsize=100)  # 最终结果队列
    
    # 共享数据结构
    best_structures = manager.dict()  # 最佳结构字典
    lock = manager.Lock()  # 文件写入锁
    terminate_flag = Value('i', 0)  # 终止标志
    
    # 记录开始时间
    start_time = time.time()
    
    # 初始化任务队列（所有样本初始任务）
    for sid in sample_ids:
        task = {
            'sample_id': sid,
            'pxrd': pxrd_data[sid],
            'comp': comp_data[sid],
            'attempt': 1
        }
        task_queue.put(task)
    
    # 启动进程
    processes = []
    
    # 1. 生成进程（1个GPU进程）
    p = Process(target=generation_worker, args=(
        task_queue, result_queue, 
        CONFIG['checkpoint_path'], CONFIG['batch_size'],
        CONFIG['generation_gpu']
    ))
    p.start()
    processes.append(p)
    console.print("[green]✓ 启动生成进程[/green]")
    
    # 2. 仿真进程（多个CPU进程）
    for i in range(CONFIG['num_cpu_workers']):
        p = Process(target=simulation_worker, args=(
            result_queue, rwp_queue, i, CONFIG['simulation_timeout']
        ))
        p.start()
        processes.append(p)
    console.print(f"[green]✓ 启动 {CONFIG['num_cpu_workers']} 个仿真进程[/green]")
    
    # 3. 后处理进程（1个GPU进程）
    p = Process(target=postprocess_worker, args=(
        postprocess_queue, final_queue, CONFIG['postprocess_gpu']
    ))
    p.start()
    processes.append(p)
    console.print("[green]✓ 启动后处理进程[/green]")
    
    # 4. 提交管理进程
    p = Process(target=submission_manager, args=(
        final_queue, CONFIG['submission_path'],
        best_structures, lock, terminate_flag
    ))
    p.start()
    processes.append(p)
    console.print("[green]✓ 启动提交管理进程[/green]")
    
    # 主循环：监控进度和管理任务
    console.print("\n[bold cyan]开始推理流程...[/bold cyan]\n")
    
    # 初始化统计
    sample_attempts = {sid: 0 for sid in sample_ids}
    sample_best_rwp = {sid: 9999 for sid in sample_ids}
    
    # 创建进度条
    progress = Progress(
        SpinnerColumn(),
        TextColumn("[bold blue]{task.description}"),
        BarColumn(),
        TaskProgressColumn(),
        console=console
    )
    
    with progress:
        task_progress = progress.add_task(
            "[cyan]处理样本...", total=num_samples
        )
        
        completed_samples = set()
        
        while not terminate_flag.value:
            # 检查时间终止条件
            if check_termination(start_time, CONFIG['max_hours']):
                console.print("[yellow]⚠ 达到最大运行时间限制[/yellow]")
                terminate_flag.value = 1
                break
            
            # 处理Rwp队列结果
            try:
                rwp_result = rwp_queue.get(timeout=1)
                
                sid = rwp_result['sample_id']
                rwp = rwp_result.get('rwp', 9999)
                attempt = rwp_result.get('generation_attempt', 1)
                
                # 更新统计
                sample_attempts[sid] = attempt
                if rwp < sample_best_rwp[sid]:
                    sample_best_rwp[sid] = rwp
                
                # 决定是否需要重新生成
                if rwp > CONFIG['rwp_threshold'] and attempt < CONFIG['max_attempts_per_sample']:
                    # 重新生成
                    task = {
                        'sample_id': sid,
                        'pxrd': pxrd_data[sid],
                        'comp': comp_data[sid],
                        'attempt': attempt + 1
                    }
                    task_queue.put(task)
                elif rwp <= CONFIG['target_rwp']:
                    # 满足要求，进行后处理
                    postprocess_queue.put(rwp_result)
                    if sid not in completed_samples:
                        completed_samples.add(sid)
                        progress.update(task_progress, advance=1)
                else:
                    # 达到最大尝试次数或Rwp在阈值范围内
                    postprocess_queue.put(rwp_result)
                    if sid not in completed_samples:
                        completed_samples.add(sid)
                        progress.update(task_progress, advance=1)
                
                # 打印当前状态
                if attempt % 5 == 0 or rwp < 1.0:
                    console.print(
                        f"[dim]样本 {sid}: 尝试 {attempt}, "
                        f"Rwp = {rwp:.3f} (最佳: {sample_best_rwp[sid]:.3f})[/dim]"
                    )
                
            except:
                pass
            
            # 检查是否所有样本都完成
            if len(completed_samples) == num_samples:
                console.print("[green]✓ 所有样本处理完成[/green]")
                terminate_flag.value = 1
                break
            
            # 检查是否所有样本都满足Rwp要求
            if all(rwp < CONFIG['target_rwp'] for rwp in sample_best_rwp.values()):
                console.print(f"[green]✓ 所有样本Rwp < {CONFIG['target_rwp']}[/green]")
                terminate_flag.value = 1
                break
            
            time.sleep(0.1)  # 避免CPU占用过高
    
    # 清理：发送终止信号
    console.print("\n[yellow]正在关闭进程...[/yellow]")
    
    # 发送终止信号到所有队列
    for _ in range(5):
        task_queue.put(None)
        result_queue.put(None)
        rwp_queue.put(None)
        postprocess_queue.put(None)
        final_queue.put(None)
    
    # 等待所有进程结束
    for p in processes:
        p.join(timeout=5)
        if p.is_alive():
            p.terminate()
    
    # 统计结果
    elapsed_time = (time.time() - start_time) / 60  # 分钟
    
    # 创建结果表格
    table = Table(title="推理结果统计")
    table.add_column("指标", style="cyan")
    table.add_column("值", style="green")
    
    table.add_row("总运行时间", f"{elapsed_time:.1f} 分钟")
    table.add_row("处理样本数", str(len(completed_samples)))
    table.add_row("平均尝试次数", f"{np.mean(list(sample_attempts.values())):.1f}")
    table.add_row("平均最佳Rwp", f"{np.mean(list(sample_best_rwp.values())):.3f}")
    table.add_row("Rwp < 0.5 样本数", str(sum(1 for rwp in sample_best_rwp.values() if rwp < 0.5)))
    
    console.print(table)
    console.print(f"\n[bold green]✓ 推理完成！结果已保存到 {CONFIG['submission_path']}[/bold green]")
    
    return best_structures

## 执行推理

In [None]:
# 检查GPU可用性
if torch.cuda.is_available():
    console.print(f"[green]✓ GPU可用: {torch.cuda.get_device_name(0)}[/green]")
    console.print(f"[dim]  显存: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB[/dim]")
else:
    console.print("[red]✗ GPU不可用，将使用CPU（速度较慢）[/red]")
    CONFIG['generation_gpu'] = 'cpu'
    CONFIG['postprocess_gpu'] = 'cpu'

# 检查模型文件
if not os.path.exists(CONFIG['checkpoint_path']):
    console.print(f"[yellow]⚠ 模型文件不存在: {CONFIG['checkpoint_path']}[/yellow]")
    console.print("[yellow]  请先训练模型或指定正确的模型路径[/yellow]")
    # 创建占位模型文件用于测试
    os.makedirs(os.path.dirname(CONFIG['checkpoint_path']), exist_ok=True)
    torch.save({'state_dict': {}, 'config': {'network': 'transformer', 'flow': 'cfm'}}, 
               CONFIG['checkpoint_path'])
    console.print("[dim]  已创建测试用占位模型[/dim]")

In [None]:
# 运行推理
if __name__ == "__main__":
    try:
        best_structures = run_inference()
        
        # 显示最终结果
        final_df = pd.read_csv(CONFIG['submission_path'])
        console.print(f"\n[bold]最终提交文件前5行:[/bold]")
        console.print(final_df.head())
        
        # 统计Rwp分布
        rwp_values = final_df['rwp'].values
        valid_rwp = rwp_values[rwp_values < 9999]
        
        if len(valid_rwp) > 0:
            console.print(f"\n[bold]Rwp统计:[/bold]")
            console.print(f"  最小值: {valid_rwp.min():.3f}")
            console.print(f"  最大值: {valid_rwp.max():.3f}")
            console.print(f"  平均值: {valid_rwp.mean():.3f}")
            console.print(f"  中位数: {np.median(valid_rwp):.3f}")
            console.print(f"  < 0.5: {sum(valid_rwp < 0.5)} / {len(valid_rwp)}")
            console.print(f"  < 1.0: {sum(valid_rwp < 1.0)} / {len(valid_rwp)}")
            console.print(f"  < 5.0: {sum(valid_rwp < 5.0)} / {len(valid_rwp)}")
            
    except KeyboardInterrupt:
        console.print("\n[yellow]⚠ 用户中断[/yellow]")
    except Exception as e:
        console.print(f"\n[red]✗ 错误: {e}[/red]")
        import traceback
        traceback.print_exc()

## 结果验证

In [None]:
# 验证submission.csv格式
def validate_submission(submission_path: str):
    """
    验证提交文件格式是否正确
    """
    df = pd.read_csv(submission_path)
    
    # 检查必需列
    required_columns = ['id', 'cif', 'rwp']
    for col in required_columns:
        if col not in df.columns:
            console.print(f"[red]✗ 缺少列: {col}[/red]")
            return False
    
    # 检查ID完整性
    missing_ids = set(sample_ids) - set(df['id'])
    if missing_ids:
        console.print(f"[red]✗ 缺少样本: {missing_ids}[/red]")
        return False
    
    # 检查CIF格式（简单验证）
    valid_cifs = 0
    for cif in df['cif']:
        if cif and len(cif) > 0:
            valid_cifs += 1
    
    console.print(f"[green]✓ 格式验证通过[/green]")
    console.print(f"  - 样本数: {len(df)}")
    console.print(f"  - 有效CIF: {valid_cifs}/{len(df)}")
    console.print(f"  - 平均Rwp: {df['rwp'][df['rwp'] < 9999].mean():.3f}")
    
    return True

# 验证
if os.path.exists(CONFIG['submission_path']):
    validate_submission(CONFIG['submission_path'])
else:
    console.print("[yellow]提交文件尚未生成[/yellow]")