Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sampling #13

Open
ChristophJud opened this issue Jan 30, 2020 · 10 comments
Open

Sampling #13

ChristophJud opened this issue Jan 30, 2020 · 10 comments

Comments

@ChristophJud
Copy link

ChristophJud commented Jan 30, 2020

Hi, thank you for sharing your code about the LSA model. Having trained the full model, I'm wondering how I can sample from the LSA? Can I just draw a uniform random vector as z and put it into the CPD estimator?

@DavideA
Copy link
Contributor

DavideA commented Jan 30, 2020

Hi Cristoph,

you can sample from the model by sampling a latent vector from the prior (with ancestral sampling) and then decoding it. A snippet should look like this:

def generate_images(model: LSACIFAR10, how_many: int) -> Tuple[torch.Tensor, torch.Tensor]:
    z_samples = torch.zeros(how_many, model.code_length, dtype=torch.float).to('cuda')
    model.eval()
    with torch.no_grad():
        for i in range(0, model.code_length):
            z_dist = model.estimator(z_samples)
            probs = torch.softmax(z_dist[:, :, i]).data
            z_samples[:, i] = torch.multinomial(probs, 1).float().squeeze(1) / model.cpd_channels
               
        x_samples = model.decoder(z_samples)
    return z_samples, x_samples

In principle, you can sample also from video models, such as the one trained on ShanghaiTech, but the sampling on latent vectors will have two nested loops as the code as the temporal axis as well.

Let me know how it goes.

D

@ChristophJud
Copy link
Author

ChristophJud commented Feb 5, 2020

Hi Davide,
thank you for your sampling example. So, I started to train mnist which did work quite well. After 200 epochs the samples (first row) and reconstructions (second row) look like:

sample_00190

Then I ran the sampling example. I had to adjust
probs = torch.softmax(z_dist[:, :, i]).data
to
probs = torch.softmax(z_dist[:, :, i], dim=1).data

Here are the samples:
mnist_samples

The variation of the samples looks a bit underestimated. It seems as the "memorization" works well but the density estimation is too smooth. Have I over-regularized with lambda=1?

Second:
I've tried to train the cifar10 incorporating all the hints I have found in the issues (and the paper/supplement). After 200 epochs, the model did not learn that much:

sample_00199

Interestingly, the reconstructions are almost grayscale.

And the samples:
cifar10_samples

Have you had similar issues during development?

I'm looking forward to you hints and suggestions!

Christoph

@ChristophJud
Copy link
Author

ok, I should have read your supplemental material better. With the right lambda and step size I now get:

sample_00199

and the samples:

cifar10_samples

Really cool though :)
C

@DavideA
Copy link
Contributor

DavideA commented Feb 5, 2020

Great!!!

It looked like an over-regularization issue :)

Best,
D

@ChristophJud
Copy link
Author

Yap :-P

How would you generalize the LSA_cifar10 model to a model with 256x256 input? Just add Down/Up sampling layers or would you also add residual blocks?

Best,
C

@ahugj
Copy link

ahugj commented Aug 17, 2020

@DavideA Hi, thank you for sharing your code about the LSA model. May I get the full training code?

@ahugj
Copy link

ahugj commented Aug 17, 2020

@ChristophJud Can you share your code?

@ChristophJud
Copy link
Author

Hey casgj,
do you mean something like:

for epoch in range(max_epochs):
        for i, (x, y) in enumerate(loader):
            opt.zero_grad()

            x = x.to('cuda')

            # Forward pass
            x_r, z, z_dist = model(x)

            # evaluate loss
            loss = criterion(x, x_r, z, z_dist)

            # calculate gradients
            loss.backward()

            # update step
            opt.step()

@ahugj
Copy link

ahugj commented Aug 18, 2020

@ChristophJud Yes, my training code doesn't work, the accuray and recall is always 0.009%.

@ChristophJud
Copy link
Author

I'd recommend to check all hyper-parameters in the supplemental material

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants