# 直播主题:Classifier-based Guided Diffusion论文及核心采样代码讲解
参考文献:《Diffusion Models Beat GANs on Image Synthesis》论文关联代码:github搜索guided-difusion

来自b站up主deep_thoughts 合集【PyTorch源码教程与前沿人工智能算法复现讲解】

P_66_Classifier_Guided_Diffusion条件扩散模型论文与PyTorch代码详细解读：

https://www.bilibili.com/video/BV1m84y1e7hP/?spm_id_from=333.788&vd_source=18e91d849da09d846f771c89a366ed40

## UNet、Diffusion代码之前几期直播多次讲过，本期直播结尾只讲解classifier-based sampling的PyTorch代码

# Part1预备知识
## 多元高斯分布的似然函数

![](./img/P66_1.png)

![](./img/P66_2.png)

## 多元高斯分布协方差矩阵:对称矩阵，也是半正定矩阵

![](./img/P66_3.png)

## 如何评价生成式模型的效果?
### 两大目标

Quality: 真实性

Diversity: 多样性

### 客观评测指标1:Inception Score(简称IS)

![](./img/P66_4.png)

In [None]:
#用代码实现IS

def calculate_inception_score(p_yx, eps=1E-16):
    #calculate p(y)
    P_y = expand_dims(p_yx.mean(axis=0), 0)
    # kl divergence for each image
    kl_d = p_yx * (log(p_yx + eps) - log(p_y + eps))
    # sum over classes
    sum_kl_d= kl_d.sum(axis=1)
    # average over images
    avg_kl_d = mean(sum_kl_d)
    # undo the logs
    is_score = exp(avg_kl_d)
    return is_score

### 客观评测指标2:Frechlet Inception Distance(简称FID)

![](./img/P66_5.png)

In [None]:
#用代码突现FID

def calculate_fid(act1, act2):
    #calculate mean and covariance statistics
    mu1, sigma1 = act1.mean(axis=0), cov(act1, rowvar=False)
    mu2, sigma2 = act2.mean(axis=0), cov(act2, rowvar=False)
    #calculate sum squared difference between means
    ssdiff = numpy.sum((mul-mu2)**2.0)
    #calculate sgrt of product between cov
    covmean = sqrtm(sigma1.dot(sigma2))
    #check and correct imaginary numbers from sqrt
    if iscomplexobj(covmean):
        covmean = covmean.real
    #calculate score
    fid = ssdiff + trace(sigma1 + sigma2 - 2.0 * covmean)
    return fid

### 其他客观评测指标:Precision与Recall

![](./img/P66_6.png)

# Part2 论文摘要与基础模型改进讲解

![](./img/P66_7.png)

![](./img/P66_8.png)

![](./img/P66_9.png)

![](./img/P66_9.png)

![](./img/P66_11.png)

![](./img/P66_12.png)

![](./img/P66_13.png)

![](./img/P66_14.png)

![](./img/P66_15.png)

![](./img/P66_16.png)

![](./img/P66_17.png)

![](./img/P66_18.png)

# Part 3 基于分类器的条件采样算法的原理与效果
## 先看效果

![](./img/P66_19.png)

![](./img/P66_20.png)

### 接下来，分别证明q^的加噪条件分布、联合分布和边缘分布，在不加y条件的情况下，q^与q的表现相同;并且进一步表明逆扩散条件分布也相同。

![](./img/P66_21.png)

![](./img/P66_22.png)

![](./img/P66_23.png)

### 已知了q^(y|x0)，那么q^(y|xt)又具有什么样的性质呢?同时为了推导q^(xt|xt+1,y)做铺垫

![](./img/P66_24.png)

### q(xt|xt+1)已经训练好了，只剩下q^(y|xt)这个分类器的训练。接下来来看，如何从q^(xt| xt+1,y)中逐步采样

![](./img/P66_25.png)

![](./img/P66_26.png)

![](./img/P66_27.png)

![](./img/P66_28.png)

![](./img/P66_29.png)

![](./img/P66_30.png)

![](./img/P66_31.png)

![](./img/P66_32.png)

# Part 4 Guided Diffusion核心采样部分代码讲解

In [None]:
def cond_fn(x, t, y=None):
    assert y is not None
    with th.enable_grad():
        x_in = x.detach().requires_grad_(True)
        logits = classifier(xin, t)
        log_probs = F.log_softmax(logits, dim=-1)
        selected = log_probs[range(len(logits)), y.view(-1)]
        return th.autograd.grad(selected.sum(), x_in)[0] * args.classifier_scale
    
def model_fn(x, t, y=None):
    assert y is not None
    return model(x, t, y if args.class_cond else None)

logger.log("sampling...")
all_images = []
all_labels = []

desp = f"scale={args.classifier_scale}"
if args.use_ddim:
    desp += f"_ddim"
    
while len(all_images) * args.batch_size < args.num_samples:
    model_kwargs = {}
    classes = th.randint(
        low=0, high=NUM_CLASSES, size=(args.batch_size,), device=dist_util.dev()
    )
    model_kwargs["y"] = classes
    sample_fn = (
        diffusion.p_sample_loop if not args.use_ddim else diffusion.ddim_sample_loop
    )
    sample = sample_fn(
        model_fn,
        (args.batch_size, 3, args.image_size, args.image_size),
        clip_denoised=args.clip_denoised,
        model_kwargs=model_kwargs,
        cond_fn=cond_fn,
        device=dist_util.dev(),
    )
    sample = ((sample + 1) * 127.5).clamp(0, 255).to(th.uint8)
    sample = sample.permute(0, 2, 3, 1)
    sample = sample.contiguous()

In [None]:
def p_sample(
    self,
    model,
    x,
    t,
    clip_denoised=True,
    denoised_fn=None,
    cond_fn=None,
    model_kwargs=None,
):
    out = self.p_mean_variance(
        model,
        x,
        t,
        clip_denoised=clip_denoised,
        denoised_fn=denoised_fn,
        model_kwargs=model_kwargs,
    )
    noise = th.randn_like(x)
    nonzero_mask = (
        (t != 0).float().view(-1, *([1] * (len(x.shape)- 1)))
    )  # no noise when t == 0
    if cond_fn is not None:
        out["mean"] = self.condition_mean(
            cond_fn, out, x ,t, model_kwargs=model_kwargs
        )
    sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log variance"]) * noise
    return {"sample": sample, "pred_xstart": out["pred_xstart"]}

In [None]:
def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
    gradient = cond_fn(x, self._scale_timesteps(t), **model_kwargs)
    new mean = (
        p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
    )
    return new_mean