In [33]:
import torch
import os
from networks.networks_lgan import Generator, Discriminator
from networks.networks_seq2seq import Seq2SeqAE
from dataset.data_utils import n_parts_map

In [38]:
gan_checkpoint = torch.load(os.path.join("chkt_dir", "lgan", "ckpt_epoch90000.pth"))
n_dim = 128 # dimension of noise vector'
h_dim = 2048 # dimension of MLP hidden layer
z_dim = 1024 # dimension of shape code, why is it different from chair latent dim, hidden_size?
netG = Generator(n_dim, h_dim, z_dim).cuda()
netG.load_state_dict(gan_checkpoint['netG_state_dict'])
netD = Discriminator(h_dim, z_dim).cuda()
netD.load_state_dict(gan_checkpoint['netD_state_dict'])

max_n_parts = 9
en_z_dim = 128 # part latent dim
hidden_size = 256 # chair latent dim
boxparam_size = 6 # dimension for part box parameters
part_feat_size = en_z_dim + boxparam_size
en_input_size = part_feat_size + n_parts_map(max_n_parts) + 1 # seq2seq input size
de_input_size = part_feat_size
n_layer = 2
max_length = 10 # max seq length
netSeq2Seq = Seq2SeqAE(en_input_size, de_input_size, hidden_size)
seq2seq_checkpoint = torch.load(os.path.join("chkt_dir", "seq2seq", "ckpt_epoch2000.pth"))
netSeq2Seq.load_state_dict(seq2seq_checkpoint["model_state_dict"])
netDecoder = netSeq2Seq.decoder.cuda()
def infer_decoder(decoder, decoder_hidden, length=None): # outputs for a series of parts
    decoder_outputs = []
    decoder_input = decoder.init_input.detach().repeat(1, 1, 1).cuda()
    for di in range(max_length):
        decoder_output, decoder_hidden, output_seq, stop_sign = decoder(decoder_input, decoder_hidden)
        decoder_outputs.append(output_seq)
        if length is not None:
            if di == length - 1:
                break
        elif torch.sigmoid(stop_sign[0, 0]) > 0.5:
            # stop condition
            break
        decoder_input = output_seq.detach().unsqueeze(0)  # using target seq as input
    decoder_outputs = torch.stack(decoder_outputs, dim=0)
    return {"boxparams":decoder_outputs[:, :, -boxparam_size:], "vecs":decoder_outputs[:,:,:-boxparam_size]}

In [93]:
def test_pqnet_generation():
    noise = torch.randn(n_dim).cuda()
    with torch.no_grad():
        fake = netG(noise)
        score = netD(fake)
        seq = infer_decoder(netDecoder, fake.view(2, 1, 512))
        boxs = seq['boxparams'] # # [n_parts, 1, 6]
        vecs = seq['vecs'] # [n_parts, 1, 128]
        print(seq['vecs'].shape)

test_pqnet_generation()

torch.Size([5, 1, 128])


In [None]:
sample_size = 64
learning_rate = 1e-4

adv_noise = torch.randn(sample_size, n_dim, requires_grad=True).cuda()  # N(0,1)
# uses weight decay to enforce gaussian prior for generator input
# https://stats.stackexchange.com/questions/163388/why-is-the-l2-regularization-equivalent-to-gaussian-prior
weight_decay_param = 2e-1
adv_adam = torch.optim.Adam([adv_noise], lr=learning_rate, betas=(0.5, 0.9))

