The first line of the following picture is the input frame, the second line is the label, the third line is the final prediction, the fourth line is backbone output, and now the diff is added, but the image quality is significantly decreased (see line 3), I do not know whether there is a problem with the function of the denoising part, please help to see.
if compute_loss:
B, T_in, c, h, w = frames_in.shape
device = self.device
backbone_output, backbone_loss = self.backbone_net.predict(frames_in, frames_gt=frames_gt,
compute_loss=compute_loss, **kwargs)
residual = frames_gt - backbone_output
global_ctx, local_ctx = self.ctx_net.scan_ctx(torch.cat((frames_in, backbone_output), dim=1))
pre_frag = frames_in
pre_mu = None
pred_ress = []
diff_loss = 0.
t = torch.randint(0, self.num_timesteps, (B,), device=self.device).long()
for frag_idx in range(T_out // T_in):
mu = backbone_output[:, frag_idx * T_in : (frag_idx + 1) * T_in]
res = residual[:, frag_idx * T_in : (frag_idx + 1) * T_in]
cond = pre_frag - pre_mu if pre_mu is not None else torch.zeros_like(pre_frag)
res_pred, noise_loss = self.p_losses(res, t, cond=cond, ctx=global_ctx if frag_idx > 0 else local_ctx,
idx=torch.full((B,), frag_idx, device=device, dtype=torch.long))
diff_loss += noise_loss
frag_pred = res_pred + mu
pre_frag = frag_pred
pre_mu = mu
alpha = torch.tensor(0.5)
loss = (1 - alpha) * backbone_loss + alpha * diff_loss / 3.
#backbone_output = self.unnormalize(backbone_output)
return backbone_output, loss
else:
pred, mu, y = self.sample(frames_in=frames_in, T_out=T_out)
loss = None
backbone_loss = None
diff_loss = None
# return pred, mu, y, loss
return pred, mu
def p_losses(self, x_start, t, noise=None, offset_noise_strength=None, ctx=None, idx=None, cond=None):
b, _, c, h, w = x_start.shape
noise = default(noise, lambda: torch.randn_like(x_start))
# offset noise - https://www.crosslabs.org/blog/diffusion-with-offset-noise
offset_noise_strength = default(offset_noise_strength, self.offset_noise_strength)
if offset_noise_strength > 0.:
offset_noise = torch.randn(x_start.shape[:2], device = self.device)
noise += offset_noise_strength * rearrange(offset_noise, 'b c -> b c 1 1')
# noise sample
x = self.predict_v(x_start=x_start, t=t, noise=noise)
model_out = self.model(x, t, cond=cond, ctx=ctx, idx=idx)
if self.objective == 'pred_noise':
target = noise
elif self.objective == 'pred_x0':
target = x_start
elif self.objective == 'pred_v':
v = self.predict_v(x_start, t, noise)
target = v
else:
raise ValueError(f'unknown objective {self.objective}')
loss = F.mse_loss(model_out, target, reduction = 'none')
loss = reduce(loss, 'b ... -> b', 'mean')
loss = loss * extract(self.loss_weight, t, loss.shape)
return model_out, loss.mean()
Hi, Dear author
Thank you so much for your open source work. I have the following questions when running the code, I hope you can take some time to answer them. I'm 6 frames predict 6 frames.
The first line of the following picture is the input frame, the second line is the label, the third line is the final prediction, the fourth line is backbone output, and now the diff is added, but the image quality is significantly decreased (see line 3), I do not know whether there is a problem with the function of the denoising part, please help to see.
Only add the following functions to diffcast.py
` def predict(self, frames_in, compute_loss=False, frames_gt=None, **kwargs):
T_out = default(kwargs.get('T_out'), 6)
`