In [None]:
import ml_collections
import torch
from torch import multiprocessing as mp
import accelerate
import utils
from datasets import get_dataset
from dpm_solver_pp import NoiseScheduleVP, DPM_Solver
from absl import logging
import builtins
import einops
import libs.autoencoder
import libs.clip
import numpy as np
from torchvision.utils import save_image
from tqdm import tqdm
import os


def stable_diffusion_beta_schedule(linear_start=0.00085, linear_end=0.0120, n_timestep=1000):
    _betas = (
        torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
    )
    return _betas.numpy()

In [None]:
def evaluate(config):
    if config.get('benchmark', False):
        torch.backends.cudnn.benchmark = True
        torch.backends.cudnn.deterministic = False

    mp.set_start_method('spawn')
    accelerator = accelerate.Accelerator()
    device = accelerator.device
    accelerate.utils.set_seed(config.seed, device_specific=True)
    logging.info(f'Process {accelerator.process_index} using device: {device}')

    config.mixed_precision = accelerator.mixed_precision
    config = ml_collections.FrozenConfigDict(config)
    if accelerator.is_main_process:
        utils.set_logger(log_level='info')
    else:
        utils.set_logger(log_level='error')
        builtins.print = lambda *args: None

    dataset = get_dataset(**config.dataset)

    # 处理单个文本输入
    if config.input_text:
        # 直接使用提供的文本
        prompt = config.input_text
        logging.info(f'使用提供的文本: {prompt}')
    elif config.input_file:
        # 从文件读取单行文本
        with open(config.input_file, 'r', encoding='utf-8') as f:
            prompt = f.read().strip()
        logging.info(f'从文件读取文本: {prompt}')
    else:
        raise ValueError("必须提供 input_text 或 input_file 参数")

    print(f"处理文本: {prompt}")

    # 初始化CLIP编码器
    clip = libs.clip.BertEmbedder(version='michiyasunaga/BioLinkBERT-base', mask=True)
    clip.eval()
    clip.to(device)
            
    # 编码单个文本
    context, attn_mask = clip.encode(prompt)  # 传入列表，返回批次维度为1的张量
    # context = context * attn_mask.unsqueeze(-1).to(context.device) # mask

    # 加载神经网络模型
    nnet = utils.get_nnet(**config.nnet)
    nnet = accelerator.prepare(nnet)
    logging.info(f'从 {config.nnet_path} 加载模型')
    accelerator.unwrap_model(nnet).load_state_dict(torch.load(config.nnet_path, map_location='cpu'))
    nnet.eval()

    def cfg_nnet(x, timesteps, context):
        _cond = nnet(x, timesteps, context=context)
        if config.sample.scale == 0:
            print("config.sample.scale == 0, 不使用CFG")
            return _cond
        _empty_context = torch.tensor(dataset.empty_context, device=device)
        _empty_context = einops.repeat(_empty_context, 'L D -> B L D', B=x.size(0))
        _uncond = nnet(x, timesteps, context=_empty_context)
        return _cond + config.sample.scale * (_cond - _uncond)

    # 加载自动编码器
    autoencoder = libs.autoencoder.get_model(**config.autoencoder)
    autoencoder.to(device)

    @torch.cuda.amp.autocast()
    def encode(_batch):
        return autoencoder.encode(_batch)

    @torch.cuda.amp.autocast()
    def decode(_batch):
        return autoencoder.decode(_batch)

    _betas = stable_diffusion_beta_schedule()
    N = len(_betas)

    logging.info(config.sample)
    logging.info(f'mixed_precision={config.mixed_precision}')
    logging.info(f'N={N}')

    # 确保输出目录存在
    os.makedirs(config.output_path, exist_ok=True)

    # 生成单个图像
    logging.info("开始生成图像...")
    
    # 创建随机噪声（批次大小为1）
    z_init = torch.randn(1, *config.z_shape, device=device)
    noise_schedule = NoiseScheduleVP(schedule='discrete', betas=torch.tensor(_betas, device=device).float())

    def model_fn(x, t_continuous):
        t = t_continuous * N
        return cfg_nnet(x, t, context=context)

    # 使用DPM求解器进行采样
    dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True, thresholding=False)
    z = dpm_solver.sample(z_init, steps=config.sample.sample_steps, eps=1. / N, T=1.)
    
    # 解码生成的潜在表示
    samples = dataset.unpreprocess(decode(z))
    
    # 保存生成的图像
    sample = samples[0] 
    output_filename = config.get('output_filename', 'generated_image.png')
    output_path = os.path.join(config.output_path, output_filename)
    save_image(sample, output_path)
    
    logging.info(f"图像已保存到: {output_path}")
    print(f"生成完成！图像保存在: {output_path}")

In [None]:
from absl import flags
from absl import app
from ml_collections import config_flags

FLAGS = flags.FLAGS
config_flags.DEFINE_config_file(
    "config", None, "训练配置文件", lock_config=False)
flags.mark_flags_as_required(["config"])
flags.DEFINE_string("nnet_path", None, "要评估的神经网络模型路径")
flags.DEFINE_string("output_path", None, "输出图像的路径")
flags.DEFINE_string("input_text", None, "输入的文本提示（直接指定文本）")
flags.DEFINE_string("input_file", None, "输入文本文件的路径（从文件读取文本）")
flags.DEFINE_string("output_filename", "generated_image.png", "输出图像的文件名")


def main(argv):
    config = FLAGS.config
    config.nnet_path = FLAGS.nnet_path
    config.output_path = FLAGS.output_path
    config.input_text = FLAGS.input_text
    config.input_file = FLAGS.input_file
    config.output_filename = FLAGS.output_filename
    
    # 验证输入参数
    if not FLAGS.input_text and not FLAGS.input_file:
        raise ValueError("必须提供 --input_text 或 --input_file 参数之一")
    
    if FLAGS.input_text and FLAGS.input_file:
        logging.warning("同时提供了 input_text 和 input_file，将优先使用 input_text")
    
    evaluate(config)


if __name__ == "__main__":
    app.run(main)