Skip to content

About the loss of image quality after adding diffusion #4

@Spring-lovely

Description

@Spring-lovely

Hi, Dear author
Thank you so much for your open source work. I have the following questions when running the code, I hope you can take some time to answer them. I'm 6 frames predict 6 frames.

1715245073046

The first line of the following picture is the input frame, the second line is the label, the third line is the final prediction, the fourth line is backbone output, and now the diff is added, but the image quality is significantly decreased (see line 3), I do not know whether there is a problem with the function of the denoising part, please help to see.

Only add the following functions to diffcast.py
` def predict(self, frames_in, compute_loss=False, frames_gt=None, **kwargs):
T_out = default(kwargs.get('T_out'), 6)

    if compute_loss:
        B, T_in, c, h, w = frames_in.shape
        device = self.device

        backbone_output, backbone_loss = self.backbone_net.predict(frames_in, frames_gt=frames_gt,
                                                                   compute_loss=compute_loss, **kwargs)

        residual = frames_gt - backbone_output
        global_ctx, local_ctx = self.ctx_net.scan_ctx(torch.cat((frames_in, backbone_output), dim=1))

        pre_frag = frames_in
        pre_mu = None
        pred_ress = []
        diff_loss = 0.
        t = torch.randint(0, self.num_timesteps, (B,), device=self.device).long()
        for frag_idx in range(T_out // T_in):
            mu = backbone_output[:, frag_idx * T_in : (frag_idx + 1) * T_in]
            res = residual[:, frag_idx * T_in : (frag_idx + 1) * T_in]

            cond = pre_frag - pre_mu if pre_mu is not None else torch.zeros_like(pre_frag)
            res_pred, noise_loss = self.p_losses(res, t, cond=cond, ctx=global_ctx if frag_idx > 0 else local_ctx,
                                                 idx=torch.full((B,), frag_idx, device=device, dtype=torch.long))
            diff_loss += noise_loss

            frag_pred = res_pred + mu
            pre_frag = frag_pred
            pre_mu = mu

        alpha = torch.tensor(0.5)
        loss = (1 - alpha) * backbone_loss + alpha * diff_loss / 3.

        #backbone_output = self.unnormalize(backbone_output)
        return backbone_output, loss
    else:
        pred, mu, y = self.sample(frames_in=frames_in, T_out=T_out)
        loss = None
        backbone_loss = None
        diff_loss = None

        # return pred, mu, y, loss
        return pred, mu

def p_losses(self, x_start, t, noise=None, offset_noise_strength=None, ctx=None, idx=None, cond=None):
    b, _, c, h, w = x_start.shape

    noise = default(noise, lambda: torch.randn_like(x_start))

    # offset noise - https://www.crosslabs.org/blog/diffusion-with-offset-noise
    offset_noise_strength = default(offset_noise_strength, self.offset_noise_strength)

    if offset_noise_strength > 0.:
        offset_noise = torch.randn(x_start.shape[:2], device = self.device)
        noise += offset_noise_strength * rearrange(offset_noise, 'b c -> b c 1 1')

    # noise sample
    x = self.predict_v(x_start=x_start, t=t, noise=noise)
    model_out = self.model(x, t, cond=cond, ctx=ctx, idx=idx)

    if self.objective == 'pred_noise':
        target = noise
    elif self.objective == 'pred_x0':
        target = x_start
    elif self.objective == 'pred_v':
        v = self.predict_v(x_start, t, noise)
        target = v
    else:
        raise ValueError(f'unknown objective {self.objective}')

    loss = F.mse_loss(model_out, target, reduction = 'none')
    loss = reduce(loss, 'b ... -> b', 'mean')

    loss = loss * extract(self.loss_weight, t, loss.shape)
    return model_out, loss.mean()

`

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions