In [1]:
# --- Setup base dir & env, then import deps (condensed) ---

# Stdlib
import os, sys
from pathlib import Path

def get_basedir(up: int = 2) -> Path:
    """Return dir `up` levels above running .py/.ipynb."""
    try:
        p = Path(__file__).resolve()        # .py
    except NameError:
        try:
            import ipynbname                # notebook
            p = Path(ipynbname.path()).resolve()
        except Exception:
            p = (Path.cwd() / "_dummy").resolve()    # fallback
    for _ in range(up): p = p.parent
    return p

# only initialize BASE_DIR once
BASE_DIR = globals().get("BASE_DIR")
if not isinstance(BASE_DIR, Path) or not BASE_DIR.exists():
    BASE_DIR = get_basedir()

sys.path[:0] = [
    str(BASE_DIR),
    str(BASE_DIR / "packages"),
    str(BASE_DIR / "packages" / "DiffBIR"),
    str(BASE_DIR / "packages" / "MST" / "simulation" / "train_code"),
]
os.chdir(BASE_DIR)
print(f"Set BASE = {BASE_DIR}")
os.environ.update({"CUDA_DEVICE_ORDER": "PCI_BUS_ID", "CUDA_VISIBLE_DEVICES": "1"})

# Third-party
import numpy as np, torch, scipy.io as scio, pytorch_lightning as pl
from argparse import ArgumentParser, Namespace
from typing import Optional, Tuple, Set, List, Dict
from torch import nn
from torch.nn import functional as F
from PIL import Image
from omegaconf import OmegaConf
from skimage.metrics import structural_similarity as ssim
from accelerate.utils import set_seed
from hyperopt import hp, fmin, tpe, Trials
import matplotlib.pyplot as plt

# Project
from packages.DiffBIR.utils.common import (
    instantiate_from_config, load_file_from_url, count_vram_usage,
    wavelet_decomposition, wavelet_reconstruction, wavelet_decomposition_msi,
)
from packages.DiffBIR.utils.inference import InferenceLoop
from packages.DiffBIR.utils.helpers import MSI_Pipeline
from packages.DiffBIR.utils.cond_fn import MeasMSEGuidance, Guidance
from packages.DiffBIR.model.gaussian_diffusion import Diffusion
from packages.DiffBIR.model.cldm import ControlLDM
from packages.MST.simulation.train_code.utils import *
from packages.MST.simulation.train_code.architecture import *


Set BASE = /home/newdisk/btsun/project/PSR-SCI
use sdp attention as default
keep default attention mode
use sdp attention as default
keep default attention mode


In [2]:
class ChannelAttention(nn.Module):
    def __init__(self, in_channels, reduction_ratio=2):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d((2, 2))
        self.max_pool = nn.AdaptiveMaxPool2d((2, 2))

        self.fc1 = nn.Conv2d(in_channels, in_channels // reduction_ratio, 1, bias=False)
        self.fc2 = nn.Conv2d(in_channels // reduction_ratio, in_channels, 2, bias=False)

        self.SiLU = nn.SiLU(inplace=True)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc2(self.SiLU(self.fc1(self.avg_pool(x))))
        max_out = self.fc2(self.SiLU(self.fc1(self.max_pool(x))))
        out = avg_out + max_out
        return self.sigmoid(out)


class DoubleConvWoBN(nn.Module):
    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
            nn.SiLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
            nn.SiLU(inplace=True),
        )
        self.res_conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)

    def forward(self, x):
        return self.double_conv(x) + self.res_conv(x)


class ChannelEncoder(nn.Module):
    def __init__(self):
        super(ChannelEncoder, self).__init__()
        self.conv1 = DoubleConvWoBN(in_channels=28, out_channels=21)
        self.conv2 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode="bilinear"),
            DoubleConvWoBN(in_channels=21, out_channels=9),
        )
        self.conv3 = DoubleConvWoBN(in_channels=9, out_channels=3)
        self.conv_out = DoubleConvWoBN(in_channels=3, out_channels=3)
        self.conv_res = nn.Sequential(
            DoubleConvWoBN(in_channels=28, out_channels=3),
            nn.Upsample(scale_factor=2, mode="bilinear"),
        )

        self.ca1 = ChannelAttention(28, 2)
        self.ca2 = ChannelAttention(21, 2)
        self.ca3 = ChannelAttention(9, 2)
        self.ca_res = ChannelAttention(28, 2)

    def forward(self, x):

        res = self.conv_res(x * self.ca_res(x))

        x = x * self.ca1(x)
        x = self.conv1(x)

        x = x * self.ca2(x)
        x = self.conv2(x)

        x = x * self.ca3(x)
        x = self.conv3(x)

        x = self.conv_out(x + res)
        return x


class ChannelDecoder(nn.Module):
    def __init__(self):
        super(ChannelDecoder, self).__init__()
        self.conv1 = DoubleConvWoBN(in_channels=3, out_channels=9)
        self.conv2 = nn.Sequential(
            nn.Conv2d(in_channels=9, out_channels=21, kernel_size=2, stride=2),
            nn.SiLU(inplace=True),
            DoubleConvWoBN(in_channels=21, out_channels=21),

        )
        self.conv3 = DoubleConvWoBN(in_channels=21, out_channels=28)
        self.conv_out = DoubleConvWoBN(in_channels=28, out_channels=28)
        self.conv_res = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=3, kernel_size=2, stride=2),
            DoubleConvWoBN(in_channels=3, out_channels=28),
        )

        self.ca3 = ChannelAttention(28, 2)
        self.ca2 = ChannelAttention(21, 2)
        self.ca1 = ChannelAttention(9, 2)

        self.ca_res = ChannelAttention(28, 2)

    def forward(self, x):

        res = self.conv_res(x)
        res = res * self.ca_res(res)

        x = self.conv1(x)
        x = x * self.ca1(x)

        x = self.conv2(x)
        x = x * self.ca2(x)

        x = self.conv3(x)
        x = x * self.ca3(x)

        x = self.conv_out(x + res)

        return x


class ChannelVAE(nn.Module):
    def __init__(self):
        super(ChannelVAE, self).__init__()
        self.encoder = ChannelEncoder()
        self.decoder = ChannelDecoder()

    def forward(self, x):
        en = self.encoder(x)
        return self.decoder(en)

In [3]:
import matplotlib.pyplot as plt
import numpy as np
import os
import torch
from typing import Tuple

# 删除已经存在的类定义以避免冲突
if "MeasMSEGuidance" in globals():
    del MeasMSEGuidance


class MeasMSEGuidance(Guidance):
    def load_guidance(self, target: torch.Tensor, masks: torch.Tensor, max_val_channel, min_val_channel, inputs_msi_lf, decoder: torch.nn.Module) -> None:
        self.target = target
        self.mask3d_batch = masks
        self.max_val_channel, self.min_val_channel, self.inputs_msi_lf = max_val_channel, min_val_channel, inputs_msi_lf
        self.decoder = decoder
        self.rgb_target = None
        self.bias = 0
        self.lambda_reg = 0.005

    def load_bias(self, bias: torch.Tensor):
        self.bias = bias

    def load_rgb_target(self, rgb_target: torch.Tensor):
        self.rgb_target = rgb_target

    def _forward(self, target: torch.Tensor, pred_x0: torch.Tensor, t: int, visual:bool=False) -> Tuple[torch.Tensor, float]:
        # Ensure the directory exists
        with torch.enable_grad():
            pred_x0.requires_grad_(True)
            pred_x0 = (pred_x0) / 2 + self.bias
            # Clamp pred_x0 to [0, 1] range
            pred_x0_clamped = torch.clamp(pred_x0, 0, 1)

            def shift(inputs, step=2):
                [bs, nC, row, col] = inputs.shape
                output = torch.zeros(bs, nC, row, col + (nC - 1) * step).cuda().float()
                for i in range(nC):
                    output[:, i, :, step * i : step * i + col] = inputs[:, i, :, :]
                return output

            def gen_meas_torch(data_batch, mask3d_batch):
                temp = shift(mask3d_batch * data_batch, 2)
                meas = torch.sum(temp, 1)
                return meas

            # Calculate meas and loss
            loss = 0
            msi = self.decoder((pred_x0_clamped) * (self.max_val_channel - self.min_val_channel) + self.min_val_channel) + self.inputs_msi_lf

            penalty_msi_low = torch.relu(-msi)  # penalize values below 0
            msi_clamped = torch.clamp(msi, 0, 10)

            meas = gen_meas_torch(msi_clamped, self.mask3d_batch)

            meas = meas[:, :, 64:-64]
            target = target[:, :, 64:-64]

            target_mean = target.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True)  # Calculate mean across height and width
            meas_mean = meas.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True)  # Calculate mean of meas

            meas = meas * (target_mean / meas_mean * 0.5 + 1) /1.5

            # Calculate the regularization term for out-of-bound values
            penalty_low = torch.relu(-pred_x0)  # penalize values below 0
            penalty_high = torch.relu(pred_x0 - 1)  # penalize values above 1
            regularization = penalty_low.mean((1, 2, 3)).sum() + penalty_high.mean((1, 2, 3)).sum() + penalty_msi_low.mean((1, 2)).sum()

            # Add regularization to the loss
            loss += (meas - target).abs().mean((1, 2)).sum()
            loss += self.lambda_reg * regularization  # lambda_reg is a weighting factor for the regularization term

            if self.rgb_target is not None and self.rgb_subscale > 0:
                loss += (pred_x0_clamped[:, :, :, 128:-128] - self.rgb_target[:, :, :, 128:-128]).pow(2).mean((1, 2, 3)).sum() * self.rgb_subscale

        scale = self.scale
        g = -torch.autograd.grad(loss, pred_x0)[0] * scale

        if visual and t % 4 == 1:
            visual_dir = "visual/"
            os.makedirs(visual_dir, exist_ok=True)
            with torch.no_grad():
                # Prepare numpy arrays for pred_x0, meas, target, and difference (meas - target)
                pred_x0_np = np.transpose(pred_x0_clamped.detach().cpu().numpy()[0], (1, 2, 0))  # HWC format
                meas_np = meas.detach().cpu().numpy()[0]
                target_np = target.detach().cpu().numpy()[0]
                diff_np = (meas - target).detach().cpu().numpy()[0]

                # Set up subplots
                fig, axs = plt.subplots(1, 4, figsize=(16, 4))
                images = [(pred_x0_np, "pred_x0", None), (meas_np, "meas", "gray"), (target_np, "target", "gray"), (diff_np, "meas - target", "coolwarm")]  # pred_x0 image (no colorbar)  # meas image (grayscale)  # target image (grayscale)  # Difference image with colormap

                # Display images
                for ax, (img, title, cmap) in zip(axs, images):
                    im = ax.imshow(img, cmap=cmap, vmin=0 if title in ["meas", "target"] else -0.4, vmax=10 if title in ["meas", "target"] else 0.4)
                    ax.set_title(title)
                    ax.axis("off")

                # Share a single colorbar between meas and target
                cax = fig.add_axes([0.35, 0.1, 0.3, 0.03])  # Position for shared colorbar
                fig.colorbar(axs[1].get_images()[0], cax=cax, orientation="horizontal", label="Intensity (0-10)")

                # Save the combined figure
                plt.tight_layout()
                plt.savefig(os.path.join(visual_dir, f"{t}_combined.png"), bbox_inches="tight", pad_inches=0)
                plt.close()

        return g, loss.item()

In [4]:
MODELS = {
    # stage_1 model weights
    "bsrnet": "https://github.com/cszn/KAIR/releases/download/v1.0/BSRNet.pth",
    # the following checkpoint is up-to-date, but we use the old version in our paper
    # "swinir_face": "https://github.com/zsyOAOA/DifFace/releases/download/V1.0/General_Face_ffhq512.pth",
    "swinir_face": "https://huggingface.co/lxq007/DiffBIR/resolve/main/face_swinir_v1.ckpt",
    "scunet_psnr": "https://github.com/cszn/KAIR/releases/download/v1.0/scunet_color_real_psnr.pth",
    "swinir_general": "https://huggingface.co/lxq007/DiffBIR/resolve/main/general_swinir_v1.ckpt",
    # stage_2 model weights
    "sd_v21": "https://huggingface.co/stabilityai/stable-diffusion-2-1-base/resolve/main/v2-1_512-ema-pruned.ckpt",
    "v1_face": "https://huggingface.co/lxq007/DiffBIR-v2/resolve/main/v1_face.pth",
    "v1_general": "https://huggingface.co/lxq007/DiffBIR-v2/resolve/main/v1_general.pth",
    "v2": "https://huggingface.co/lxq007/DiffBIR-v2/resolve/main/v2.pth"
}


def load_model_from_url(url: str) -> Dict[str, torch.Tensor]:
    sd_path = load_file_from_url(url, model_dir=str(BASE_DIR)+"/packages/DiffBIR/weights")
    sd = torch.load(sd_path, map_location="cpu", weights_only=False)
    if "state_dict" in sd:
        sd = sd["state_dict"]
    if list(sd.keys())[0].startswith("module"):
        sd = {k[len("module."):]: v for k, v in sd.items()}
    return sd


class InferenceLoop_NoPre:

    def __init__(self, args: Namespace) -> "InferenceLoop":
        self.args = args
        self.loop_ctx = {}
        self.pipeline: MSI_Pipeline = None
        self.init_stage2_model()

    @count_vram_usage
    def init_stage2_model(self) -> None:
        # load uent, vae, clip
        self.cldm: ControlLDM = instantiate_from_config(OmegaConf.load(str(BASE_DIR)+"/packages/DiffBIR/configs/inference/cldm.yaml"))
        sd = load_model_from_url(MODELS["sd_v21"])
        unused = self.cldm.load_pretrained_sd(sd)
        print(f"strictly load pretrained sd_v2.1, unused weights: {unused}")
        # load controlnet
        self.cldm.load_controlnet_from_ckpt(torch.load(self.args.ckpt, map_location="cpu"))
        print(f"strictly load controlnet weight {self.args.ckpt}")
        if self.args.vae != None:
            self.cldm.load_vae_from_ckpt(torch.load(self.args.vae, map_location="cpu"))
            print(f"strictly load vae weight {self.args.vae}")
        self.cldm.eval().cuda()
        # load diffusion
        self.diffusion: Diffusion = instantiate_from_config(OmegaConf.load(str(BASE_DIR)+"/packages/DiffBIR/configs/inference/diffusion.yaml"))
        self.diffusion.cuda()

    @torch.no_grad()
    def run(self, images: torch.tensor) -> torch.tensor:
        # We don't support batch processing since input images may have different size

        return self.pipeline.run_stage2(
            images, self.args.steps, 1.0, self.args.tiled,
            self.args.tile_size, self.args.tile_stride,
            self.args.pos_prompt, self.args.neg_prompt, self.args.cfg_scale,
            self.args.better_start
        )


@torch.no_grad()
def preprocess_data(input_meas: torch.tensor, input_mask: torch.tensor, DiffSCI_Pipeline: InferenceLoop_NoPre) -> Tuple[torch.tensor, torch.tensor]:
    """
    Preprocess input data by applying necessary transformations and normalization.

    Args:
        input_meas (torch.tensor): Input measurement data.
        input_mask (torch.tensor): Input mask data.
        model (ControlLDM): Model instance.

    Returns:
        Tuple[torch.tensor, torch.tensor]: Tuple containing preprocessed RGB images and normalization coefficients.
    """

    n_samples = input_meas.shape[0]

    with torch.no_grad():
        MSI_IMAGE = DiffSCI_Pipeline.MSI_model(input_meas, input_mask)
        inputs_msi_hf, inputs_msi_lf = wavelet_decomposition_msi(MSI_IMAGE, 3)

        MSI_images_encoded = DiffSCI_Pipeline.encoder(inputs_msi_hf)

        RANGE_MAX = 0.85
        RANGE_MIN = 0.15

        range_channel = torch.tensor([MSI_images_encoded[i].max() - MSI_images_encoded[i].min() for i in range(n_samples)]).cuda()
        max_val_channel = torch.tensor([MSI_images_encoded[i].max() + range_channel[i] / (RANGE_MAX - RANGE_MIN)*(1-RANGE_MAX) for i in range(n_samples)]).cuda().view(n_samples, 1, 1, 1)
        min_val_channel = torch.tensor([MSI_images_encoded[i].min() - range_channel[i] / (RANGE_MAX - RANGE_MIN)*(RANGE_MIN) for i in range(n_samples)]).cuda().view(n_samples, 1, 1, 1)

        normalized_images = (MSI_images_encoded - min_val_channel) / (max_val_channel - min_val_channel)

        return normalized_images, max_val_channel, min_val_channel, inputs_msi_lf

@torch.no_grad()
def process_diffusion(
    DiffSCI_Pipeline: InferenceLoop_NoPre,
    normalized_images: torch.tensor,
    max_val_channel: torch.tensor,
    min_val_channel: torch.tensor,
    inputs_msi_lf: torch.tensor,
    steps: int,
    upscale: int,
    cfg_scale: float,
    cond_fn: Optional[MeasMSEGuidance],
    tiled: bool,
    tile_size: int,
    tile_stride: int,
    better_start: bool = False,
    pos_prompt: str = "",
    neg_prompt: str = "low quality, blurry, low-resolution, noisy, unsharp, weird textures",
) -> Tuple[torch.tensor, torch.tensor]:
    """
    Apply Diffusion model on preprocessed data to generate restoration results.

    Args:
        model (ControlLDM): Model.
        normalized_images (torch.tensor): Preprocessed normalized images.
        max_val_channel (torch.tensor): Maximum values for each channel.
        min_val_channel (torch.tensor): Minimum values for each channel.
        steps (int): Sampling steps.
        strength (float): Control strength.
        color_fix_type (str): Type of color correction for samples.
        cond_fn (Guidance | None): Guidance function.
        tiled (bool): If True, a patch-based sampling strategy will be used.
        tile_size (int): Size of patch.
        tile_stride (int): Stride of sliding patch.

    Returns:
        Tuple[torch.tensor, torch.tensor]: Tuple containing restored images and diffusion outputs.
    """

    if upscale > 1.0:
        normalized_images_up = F.interpolate(normalized_images, size=(normalized_images.shape[-2] * upscale, normalized_images.shape[-1] * upscale), mode="bicubic", antialias=True)
    else:
        normalized_images_up = normalized_images

    DiffSCI_Pipeline.pipeline = MSI_Pipeline(DiffSCI_Pipeline.cldm, DiffSCI_Pipeline.diffusion, cond_fn, DiffSCI_Pipeline.args.device)

    diffusion_output = DiffSCI_Pipeline.pipeline.run_stage2(
        clean=normalized_images_up, steps=steps, strength=1.0, upscale=upscale, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
        pos_prompt=pos_prompt, neg_prompt=neg_prompt, cfg_scale=cfg_scale, better_start=better_start
    )

    if upscale > 1.0:
        diffusion_output = F.interpolate(
            diffusion_output,
            size=(normalized_images.shape[-2], normalized_images.shape[-1]),
            mode="bicubic", antialias=True
        )

    diffusion_outputs = diffusion_output.contiguous().clamp(0, 1)

    restored_images = DiffSCI_Pipeline.decoder(diffusion_outputs * (max_val_channel - min_val_channel) + min_val_channel) + inputs_msi_lf

    return restored_images, diffusion_outputs

In [5]:
def parse_args() -> Namespace:
    parser = ArgumentParser()
    # model parameters
    parser.add_argument("--ckpt", type=str, default=str(BASE_DIR)+"/weights/controlnet_sample0160000.pt")
    parser.add_argument("--vae", type=str, default=str(BASE_DIR)+"/weights/vae_sample0012000.pt")
    parser.add_argument("--channel_vae", type=str, default=str(BASE_DIR)+"/weights/model_SeVAE_hf3_endecoder_c21_bu2_c9_DConvWoBN_resca_silu_2024-09-05_psnr49.5199.pt")
    # sampling parameters
    parser.add_argument("--steps", type=int, default=100)
    parser.add_argument("--better_start", type=bool, default=True)
    parser.add_argument("--upscale", type=int, default=1.0)
    parser.add_argument("--tiled", type=bool, default=False)
    parser.add_argument("--tile_size", type=int, default=512)
    parser.add_argument("--tile_stride", type=int, default=256)
    parser.add_argument("--pos_prompt", type=str, default="")
    parser.add_argument("--neg_prompt", type=str, default="")
    parser.add_argument("--cfg_scale", type=float, default=1.0)
    # input parameters
    parser.add_argument("--n_samples", type=int, default=1)
    # guidance parameters
    parser.add_argument("--guidance", type=bool, default=True)
    parser.add_argument("--g_scale", type=float, default=1)
    parser.add_argument("--g_t_start", type=int, default=400)
    parser.add_argument("--g_t_stop", type=int, default=-1)
    parser.add_argument("--g_space", type=str, default="rgb")
    parser.add_argument("--g_repeat", type=int, default=1)
    # output parameters
    # common parameters
    parser.add_argument("--seed", type=int, default=231)
    parser.add_argument("--output", type=str, default="./results/")
    parser.add_argument("--num_evals", type=int, default=300)
    parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")

    return parser.parse_known_args()[0]

args = parse_args()
args.device = torch.device(args.device)
set_seed(args.seed)

PSRSCI_Pipeline = InferenceLoop_NoPre(args=args)

# 加载模型
SeVAE_model = ChannelVAE()
SeVAE_model = torch.load(args.channel_vae, map_location="cpu" ,weights_only=False)
SeVAE_model.eval().cuda()
PSRSCI_Pipeline.encoder = SeVAE_model.encoder
PSRSCI_Pipeline.decoder = SeVAE_model.decoder

Setting up SDPCrossAttention (sdp). Query dim is 320, context_dim is None and using 5 heads.
Setting up SDPCrossAttention (sdp). Query dim is 320, context_dim is 1024 and using 5 heads.
Setting up SDPCrossAttention (sdp). Query dim is 320, context_dim is None and using 5 heads.
Setting up SDPCrossAttention (sdp). Query dim is 320, context_dim is 1024 and using 5 heads.
Setting up SDPCrossAttention (sdp). Query dim is 640, context_dim is None and using 10 heads.
Setting up SDPCrossAttention (sdp). Query dim is 640, context_dim is 1024 and using 10 heads.
Setting up SDPCrossAttention (sdp). Query dim is 640, context_dim is None and using 10 heads.
Setting up SDPCrossAttention (sdp). Query dim is 640, context_dim is 1024 and using 10 heads.
Setting up SDPCrossAttention (sdp). Query dim is 1280, context_dim is None and using 20 heads.
Setting up SDPCrossAttention (sdp). Query dim is 1280, context_dim is 1024 and using 20 heads.
Setting up SDPCrossAttention (sdp). Query dim is 1280, context

In [10]:
from thop import profile, clever_format
import torch
import torch.nn.functional as F

# 创建测试输入
input_image = torch.zeros((1, 3, 1024, 1024), dtype=torch.float32, device=args.device)
pos_prompt = [""]

print("="*50)
print("开始计算FLOPs...")
print("="*50)

# ==================== 1. VAE编码 FLOPs ====================
print("\n[1/3] 计算VAE编码的FLOPs...")
with torch.no_grad():
    # 准备VAE编码输入
    vae_input = input_image * 2 - 1  # 归一化到[-1, 1]

    # 计算VAE编码的FLOPs
    flops_vae_encode, params_vae_encode = profile(
        PSRSCI_Pipeline.cldm.vae.encoder,
        inputs=(vae_input,),
        verbose=False
    )

    # 实际执行一次获取输出shape
    z_encoded = PSRSCI_Pipeline.cldm.vae_encode(input_image, sample=False)

flops_vae_encode_format, params_vae_encode_format = clever_format([flops_vae_encode, params_vae_encode], "%.3f")
print(f"VAE Encode - FLOPs: {flops_vae_encode_format}, Params: {params_vae_encode_format}")
print(f"VAE Encode - Latent shape: {z_encoded.shape}")

# ==================== 2. Diffusion FLOPs (10步然后乘以10) ====================
print("\n[2/3] 计算Diffusion的FLOPs (10步*10)...")

# 准备condition
with torch.no_grad():
    c_txt = PSRSCI_Pipeline.cldm.clip.encode(pos_prompt)
    c_img = z_encoded
    cond = {"c_txt": c_txt, "c_img": c_img}

    # 准备diffusion输入
    bs, _, h, w = z_encoded.shape
    x_noisy = torch.randn((bs, 4, h, w), dtype=torch.float32, device=args.device)
    t = torch.tensor([500], dtype=torch.long, device=args.device)

    # 设置control scales
    PSRSCI_Pipeline.cldm.control_scales = [1.0] * 13

    # 计算单步diffusion的FLOPs
    # 先计算controlnet
    flops_controlnet, params_controlnet = profile(
        PSRSCI_Pipeline.cldm.controlnet,
        inputs=(x_noisy, c_img, t, c_txt),
        verbose=False
    )

    # 执行controlnet获取control输出
    control = PSRSCI_Pipeline.cldm.controlnet(x=x_noisy, hint=c_img, timesteps=t, context=c_txt)
    control = [c * scale for c, scale in zip(control, PSRSCI_Pipeline.cldm.control_scales)]

    # 计算unet
    flops_unet, params_unet = profile(
        PSRSCI_Pipeline.cldm.unet,
        inputs=(x_noisy, t, c_txt, control, False),
        verbose=False
    )

    # 单步diffusion总FLOPs
    flops_diffusion_single = flops_controlnet + flops_unet

    # 100步的总FLOPs (通过10步乘以10估算)
    flops_diffusion_total = flops_diffusion_single * 10 * 10
    params_diffusion = params_controlnet + params_unet

flops_diffusion_format, params_diffusion_format = clever_format([flops_diffusion_total, params_diffusion], "%.3f")
flops_single_format = clever_format([flops_diffusion_single], "%.3f")[0]

print(f"Diffusion (单步) - FLOPs: {flops_single_format}")
print(f"Diffusion (100步) - FLOPs: {flops_diffusion_format}, Params: {params_diffusion_format}")
print(f"  - ControlNet FLOPs: {clever_format([flops_controlnet * 100], '%.3f')[0]}")
print(f"  - UNet FLOPs: {clever_format([flops_unet * 100], '%.3f')[0]}")

# ==================== 3. VAE解码 FLOPs ====================
print("\n[3/3] 计算VAE解码的FLOPs...")
with torch.no_grad():
    # 准备解码输入
    z_to_decode = z_encoded / PSRSCI_Pipeline.cldm.scale_factor

    # 计算VAE解码的FLOPs
    flops_vae_decode, params_vae_decode = profile(
        PSRSCI_Pipeline.cldm.vae.decoder,
        inputs=(z_to_decode,),
        verbose=False
    )

flops_vae_decode_format, params_vae_decode_format = clever_format([flops_vae_decode, params_vae_decode], "%.3f")
print(f"VAE Decode - FLOPs: {flops_vae_decode_format}, Params: {params_vae_decode_format}")

# ==================== 总结 ====================
print("\n" + "="*50)
print("FLOPs 统计总结")
print("="*50)

total_flops = flops_vae_encode + flops_diffusion_total + flops_vae_decode
total_params = params_vae_encode + params_diffusion + params_vae_decode

total_flops_format, total_params_format = clever_format([total_flops, total_params], "%.3f")

print(f"\n输入图像尺寸: 1x3x1024x1024")
print(f"Diffusion步数: 100步")
print(f"\n各模块FLOPs:")
print(f"  1. VAE编码:       {flops_vae_encode_format:>15} ({flops_vae_encode/total_flops*100:.2f}%)")
print(f"  2. Diffusion:     {flops_diffusion_format:>15} ({flops_diffusion_total/total_flops*100:.2f}%)")
print(f"  3. VAE解码:       {flops_vae_decode_format:>15} ({flops_vae_decode/total_flops*100:.2f}%)")
print(f"\n总计:")
print(f"  总FLOPs:         {total_flops_format}")
print(f"  总参数量:        {total_params_format}")
print(f"\n每步Diffusion的FLOPs: {flops_single_format}")

print("\n" + "="*50)

开始计算FLOPs...

[1/3] 计算VAE编码的FLOPs...
VAE Encode - FLOPs: 2.165T, Params: 34.147M
VAE Encode - Latent shape: torch.Size([1, 4, 128, 128])

[2/3] 计算Diffusion的FLOPs (10步*10)...
Diffusion (单步) - FLOPs: 1
Diffusion (100步) - FLOPs: 178.577T, Params: 1.229G
  - ControlNet FLOPs: 4
  - UNet FLOPs: 1

[3/3] 计算VAE解码的FLOPs...
VAE Decode - FLOPs: 4.960T, Params: 49.467M

FLOPs 统计总结

输入图像尺寸: 1x3x1024x1024
Diffusion步数: 100步

各模块FLOPs:
  1. VAE编码:                2.165T (1.17%)
  2. Diffusion:            178.577T (96.16%)
  3. VAE解码:                4.960T (2.67%)

总计:
  总FLOPs:         185.702T
  总参数量:        1.313G

每步Diffusion的FLOPs: 1



In [11]:
from torch.utils.flop_counter import FlopCounterMode
import torch
import torch.nn.functional as F

# 创建测试输入
input_image = torch.zeros((1, 3, 1024, 1024), dtype=torch.float32, device=args.device)
pos_prompt = [""]

print("="*50)
print("开始计算FLOPs...")
print("="*50)

# ==================== 1. VAE编码 FLOPs ====================
print("\n[1/3] 计算VAE编码的FLOPs...")
with torch.no_grad():
    # 准备VAE编码输入
    vae_input = input_image * 2 - 1  # 归一化到[-1, 1]

    # 使用FlopCounterMode计算VAE编码的FLOPs
    flop_counter_vae_encode = FlopCounterMode(PSRSCI_Pipeline.cldm.vae.encoder, display=False)
    with flop_counter_vae_encode:
        z_encoded = PSRSCI_Pipeline.cldm.vae_encode(input_image, sample=False)

    flops_vae_encode = flop_counter_vae_encode.get_total_flops()
    flops_vae_encode_giga = flops_vae_encode / 1e9

print(f"VAE Encode - FLOPs: {flops_vae_encode_giga:.3f} GFLOPs")
print(f"VAE Encode - Latent shape: {z_encoded.shape}")

# ==================== 2. Diffusion FLOPs (10步然后乘以10) ====================
print("\n[2/3] 计算Diffusion的FLOPs (10步*10)...")

# 准备condition
with torch.no_grad():
    c_txt = PSRSCI_Pipeline.cldm.clip.encode(pos_prompt)
    c_img = z_encoded
    cond = {"c_txt": c_txt, "c_img": c_img}

    # 准备diffusion输入
    bs, _, h, w = z_encoded.shape

    # 设置control scales
    PSRSCI_Pipeline.cldm.control_scales = [1.0] * 13

    # 计算10步diffusion的FLOPs
    flop_counter_controlnet = FlopCounterMode(PSRSCI_Pipeline.cldm.controlnet, display=False)
    flop_counter_unet = FlopCounterMode(PSRSCI_Pipeline.cldm.unet, display=False)

    flops_controlnet_total = 0
    flops_unet_total = 0

    # 运行10步
    for step in range(10):
        x_noisy = torch.randn((bs, 4, h, w), dtype=torch.float32, device=args.device)
        t = torch.tensor([step * 100], dtype=torch.long, device=args.device)

        # 计算controlnet
        with flop_counter_controlnet:
            control = PSRSCI_Pipeline.cldm.controlnet(x=x_noisy, hint=c_img, timesteps=t, context=c_txt)

        control = [c * scale for c, scale in zip(control, PSRSCI_Pipeline.cldm.control_scales)]

        # 计算unet
        with flop_counter_unet:
            eps = PSRSCI_Pipeline.cldm.unet(x=x_noisy, timesteps=t, context=c_txt, control=control, only_mid_control=False)

    flops_controlnet_10steps = flop_counter_controlnet.get_total_flops()
    flops_unet_10steps = flop_counter_unet.get_total_flops()

    # 估算100步的FLOPs
    flops_controlnet_100steps = flops_controlnet_10steps * 10
    flops_unet_100steps = flops_unet_10steps * 10
    flops_diffusion_total = flops_controlnet_100steps + flops_unet_100steps

    flops_controlnet_giga = flops_controlnet_100steps / 1e9
    flops_unet_giga = flops_unet_100steps / 1e9
    flops_diffusion_giga = flops_diffusion_total / 1e9
    flops_single_step_giga = (flops_controlnet_10steps + flops_unet_10steps) / 10 / 1e9

print(f"Diffusion (单步) - FLOPs: {flops_single_step_giga:.3f} GFLOPs")
print(f"Diffusion (100步) - FLOPs: {flops_diffusion_giga:.3f} GFLOPs")
print(f"  - ControlNet FLOPs: {flops_controlnet_giga:.3f} GFLOPs")
print(f"  - UNet FLOPs: {flops_unet_giga:.3f} GFLOPs")

# ==================== 3. VAE解码 FLOPs ====================
print("\n[3/3] 计算VAE解码的FLOPs...")
with torch.no_grad():
    # 准备解码输入 - 使用随机latent而不是之前的结果
    z_to_decode = torch.randn_like(z_encoded) / PSRSCI_Pipeline.cldm.scale_factor

    # 使用FlopCounterMode计算VAE解码的FLOPs
    flop_counter_vae_decode = FlopCounterMode(PSRSCI_Pipeline.cldm.vae.decoder, display=False)
    with flop_counter_vae_decode:
        decoded_image = PSRSCI_Pipeline.cldm.vae_decode(z_to_decode)

    flops_vae_decode = flop_counter_vae_decode.get_total_flops()
    flops_vae_decode_giga = flops_vae_decode / 1e9

print(f"VAE Decode - FLOPs: {flops_vae_decode_giga:.3f} GFLOPs")

# ==================== 总结 ====================
print("\n" + "="*50)
print("FLOPs 统计总结")
print("="*50)

total_flops = flops_vae_encode + flops_diffusion_total + flops_vae_decode
total_flops_giga = total_flops / 1e9

print(f"\n输入图像尺寸: 1x3x1024x1024")
print(f"Diffusion步数: 100步")
print(f"\n各模块FLOPs:")
print(f"  1. VAE编码:       {flops_vae_encode_giga:>10.3f} GFLOPs ({flops_vae_encode/total_flops*100:>6.2f}%)")
print(f"  2. Diffusion:     {flops_diffusion_giga:>10.3f} GFLOPs ({flops_diffusion_total/total_flops*100:>6.2f}%)")
print(f"     - ControlNet:  {flops_controlnet_giga:>10.3f} GFLOPs ({flops_controlnet_100steps/total_flops*100:>6.2f}%)")
print(f"     - UNet:        {flops_unet_giga:>10.3f} GFLOPs ({flops_unet_100steps/total_flops*100:>6.2f}%)")
print(f"  3. VAE解码:       {flops_vae_decode_giga:>10.3f} GFLOPs ({flops_vae_decode/total_flops*100:>6.2f}%)")
print(f"\n总计:")
print(f"  总FLOPs:         {total_flops_giga:>10.3f} GFLOPs")
print(f"\n每步Diffusion的FLOPs: {flops_single_step_giga:.3f} GFLOPs")

print("\n" + "="*50)

# 可选：保存结果到字典
flops_results = {
    "vae_encode_gflops": flops_vae_encode_giga,
    "diffusion_total_gflops": flops_diffusion_giga,
    "controlnet_gflops": flops_controlnet_giga,
    "unet_gflops": flops_unet_giga,
    "vae_decode_gflops": flops_vae_decode_giga,
    "total_gflops": total_flops_giga,
    "single_step_gflops": flops_single_step_giga,
}

print("\n结果字典:")
for key, value in flops_results.items():
    print(f"  {key}: {value:.3f}")

开始计算FLOPs...

[1/3] 计算VAE编码的FLOPs...
VAE Encode - FLOPs: 4878.951 GFLOPs
VAE Encode - Latent shape: torch.Size([1, 4, 128, 128])

[2/3] 计算Diffusion的FLOPs (10步*10)...


  flop_counter_vae_encode = FlopCounterMode(PSRSCI_Pipeline.cldm.vae.encoder, display=False)
  flop_counter_controlnet = FlopCounterMode(PSRSCI_Pipeline.cldm.controlnet, display=False)
  flop_counter_unet = FlopCounterMode(PSRSCI_Pipeline.cldm.unet, display=False)


Diffusion (单步) - FLOPs: 633.405 GFLOPs
Diffusion (100步) - FLOPs: 63340.536 GFLOPs
  - ControlNet FLOPs: 16590.633 GFLOPs
  - UNet FLOPs: 46749.902 GFLOPs

[3/3] 计算VAE解码的FLOPs...
VAE Decode - FLOPs: 10470.393 GFLOPs

FLOPs 统计总结

输入图像尺寸: 1x3x1024x1024
Diffusion步数: 100步

各模块FLOPs:
  1. VAE编码:         4878.951 GFLOPs (  6.20%)
  2. Diffusion:      63340.536 GFLOPs ( 80.49%)
     - ControlNet:   16590.633 GFLOPs ( 21.08%)
     - UNet:         46749.902 GFLOPs ( 59.41%)
  3. VAE解码:        10470.393 GFLOPs ( 13.31%)

总计:
  总FLOPs:          78689.879 GFLOPs

每步Diffusion的FLOPs: 633.405 GFLOPs


结果字典:
  vae_encode_gflops: 4878.951
  diffusion_total_gflops: 63340.536
  controlnet_gflops: 16590.633
  unet_gflops: 46749.902
  vae_decode_gflops: 10470.393
  total_gflops: 78689.879
  single_step_gflops: 633.405


  flop_counter_vae_decode = FlopCounterMode(PSRSCI_Pipeline.cldm.vae.decoder, display=False)


各模块FLOPs:
  1. VAE编码:                2.165T (1.17%)
  2. Diffusion:            178.577T (96.16%)
  3. VAE解码:                4.960T (2.67%)

总计:
  总FLOPs:         185.702T
  总参数量:        1.313G