Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

question about training input #52

Open
D222097 opened this issue Jan 8, 2024 · 0 comments
Open

question about training input #52

D222097 opened this issue Jan 8, 2024 · 0 comments

Comments

@D222097
Copy link

D222097 commented Jan 8, 2024

Nice work!
I'm wondering that why the input can be set this way during training?

image_GT????? inpaint_image inpaint_mask ref_imgs
img masked_img msk ref_img

In this work, I found that the inputs are gt(add noise), masked_img, mask and ref_img. As follows, the input x_start to unet is concatenated by z(encode on gt), z_inpaint(encode on masked_img) and mask_resize(downsampling mask):

z_new = torch.cat((z,z_inpaint,mask_resize),dim=1)  # x_start
def p_losses(self, x_start, cond, t, noise=None, ):
    if self.first_stage_key == 'inpaint':
        # x_start=x_start[:,:4,:,:]
        noise = default(noise, lambda: torch.randn_like(x_start[:,:4,:,:]))
        x_noisy = self.q_sample(x_start=x_start[:,:4,:,:], t=t, noise=noise)
        x_noisy = torch.cat((x_noisy,x_start[:,4:,:,:]),dim=1)
    ...
    model_output = self.apply_model(x_noisy, t, cond)
    ...

    if self.parameterization == "x0":
        target = x_start
    elif self.parameterization == "eps":
        target = noise
    else:
        raise NotImplementedError()

    loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3])
    ...

I am curious about why GT image can be input into the unet directly. Even though it has been added with noise, it is still visible to the unet.

Use the images above as an example: the input is car image, and the expected output is car image during training. And when comes for infering, the input is image unrelated to car(arbitrary object or just background), and the expected output is car image.

This is a little weird. On the one hand, model needs GT to be optimized, and it is often used as a target in other generative model, rather than as a direct input to the model. On the other hand, diffusion model usually do not predict pixels but Gaussian noise, there seems to be no other way for diffusion model to be constrained from gt. I don't know how to understand how model learns, I'd be grateful if anyone could give me advice

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant