Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minimal working example for sampling from pre-trained BigGAN? #8

Closed
qilimk opened this issue Mar 27, 2019 · 4 comments
Closed

Minimal working example for sampling from pre-trained BigGAN? #8

qilimk opened this issue Mar 27, 2019 · 4 comments

Comments

@qilimk
Copy link

qilimk commented Mar 27, 2019

Hi ajbrock,
I am so excited that you released the Pytorch version of BigGAN. I am trying to sample some results. Could you provide a minimal working example for sampling from pre-trained BigGAN? @airalcorn2 and I wrote a piece of code for sampling, but the results look bad.
Here is our sample code.

import functools
import numpy as np
import torch
import utils

from PIL import Image

parser = utils.prepare_parser()
parser = utils.add_sample_parser(parser)
config = vars(parser.parse_args())

# update config (see train.py for explanation)
config["resolution"] = utils.imsize_dict[config["dataset"]]
config["n_classes"] = utils.nclass_dict[config["dataset"]]
config["G_activation"] = utils.activation_dict[config["G_nl"]]
config["D_activation"] = utils.activation_dict[config["D_nl"]]
config = utils.update_config_roots(config)
config["skip_init"] = True
config["no_optim"] = True
device = "cuda:7"

# Seed RNG
utils.seed_rng(config["seed"])

# Setup cudnn.benchmark for free speed
torch.backends.cudnn.benchmark = True

# Import the model--this line allows us to dynamically select different files.
model = __import__(config["model"])
experiment_name = utils.name_from_config(config)
G = model.Generator(**config).to(device)
utils.count_parameters(G)

# Load weights
G.load_state_dict(torch.load("/mnt/raid/qi/biggan_weighs/G_optim.pth"), strict=False)

# Update batch size setting used for G
G_batch_size = max(config["G_batch_size"], config["batch_size"])
(z_, y_) = utils.prepare_z_y(
    G_batch_size,
    G.dim_z,
    config["n_classes"],
    device=device,
    fp16=config["G_fp16"],
    z_var=config["z_var"],
)

G.eval()

# Sample function
sample = functools.partial(utils.sample, G=G, z_=z_, y_=y_, config=config)

with torch.no_grad():
    z_.sample_()
    y_.sample_()
    image_tensors = G(z_, G.shared(y_))


for i in range(len(image_tensors)):
    image_array = image_tensors[i].permute(1, 2, 0).detach().cpu().numpy()
    image_array = np.uint8(255 * (1 + image_array) / 2)
    image = Image.fromarray(image_array).save("./test_images/{i}.png")

Here is one of our results.
59

Thanks a lot.

@ajbrock
Copy link
Owner

ajbrock commented Mar 27, 2019

Hi Milan,

Please see the included sample script provided in the scripts folder.

As to your snippet, the line G.load_state_dict(torch.load("/mnt/raid/qi/biggan_weighs/G_optim.pth"), strict=False) is incorrect, you appear to be attempting to load G with the state dict for G's optimizer. Use G_ema.pth or G.pth, and consider turning strict=True on to help catch mistakes like this--I have it as false in there for a variety of reasons but it it's generally better for strict to be True.

@qilimk
Copy link
Author

qilimk commented Mar 27, 2019

Thanks for replying so fast. I tried
G.load_state_dict(torch.load("/mnt/raid/qi/biggan_weighs/G.pth"), strict=True)
but I got the error like this

RuntimeError: Error(s) in loading state_dict for Generator:
	Unexpected key(s) in state_dict: "shared.weight", "blocks.0.0.bn1.gain.u0", "blocks.0.0.bn1.gain.sv0", "blocks.0.0.bn1.bias.u0", "blocks.0.0.bn1.bias.sv0", "blocks.0.0.bn2.gain.u0", "blocks.0.0.bn2.gain.sv0", "blocks.0.0.bn2.bias.u0", "blocks.0.0.bn2.bias.sv0", "blocks.1.0.bn1.gain.u0", "blocks.1.0.bn1.gain.sv0", "blocks.1.0.bn1.bias.u0", "blocks.1.0.bn1.bias.sv0", "blocks.1.0.bn2.gain.u0", "blocks.1.0.bn2.gain.sv0", "blocks.1.0.bn2.bias.u0", "blocks.1.0.bn2.bias.sv0", "blocks.2.0.bn1.gain.u0", "blocks.2.0.bn1.gain.sv0", "blocks.2.0.bn1.bias.u0", "blocks.2.0.bn1.bias.sv0", "blocks.2.0.bn2.gain.u0", "blocks.2.0.bn2.gain.sv0", "blocks.2.0.bn2.bias.u0", "blocks.2.0.bn2.bias.sv0", "blocks.3.0.bn1.gain.u0", "blocks.3.0.bn1.gain.sv0", "blocks.3.0.bn1.bias.u0", "blocks.3.0.bn1.bias.sv0", "blocks.3.0.bn2.gain.u0", "blocks.3.0.bn2.gain.sv0", "blocks.3.0.bn2.bias.u0", "blocks.3.0.bn2.bias.sv0", "blocks.4.0.bn1.gain.u0", "blocks.4.0.bn1.gain.sv0", "blocks.4.0.bn1.bias.u0", "blocks.4.0.bn1.bias.sv0", "blocks.4.0.bn2.gain.u0", "blocks.4.0.bn2.gain.sv0", "blocks.4.0.bn2.bias.u0", "blocks.4.0.bn2.bias.sv0".
	size mismatch for linear.weight: copying a param with shape torch.Size([24576, 20]) from checkpoint, the shape in current model is torch.Size([16384, 128]).
	size mismatch for linear.bias: copying a param with shape torch.Size([24576]) from checkpoint, the shape in current model is torch.Size([16384]).
	size mismatch for linear.u0: copying a param with shape torch.Size([1, 24576]) from checkpoint, the shape in current model is torch.Size([1, 16384]).
	size mismatch for blocks.0.0.conv1.weight: copying a param with shape torch.Size([1536, 1536, 3, 3]) from checkpoint, the shape in current model is torch.Size([1024, 1024, 3, 3]).
	size mismatch for blocks.0.0.conv1.bias: copying a param with shape torch.Size([1536]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for blocks.0.0.conv1.u0: copying a param with shape torch.Size([1, 1536]) from checkpoint, the shape in current model is torch.Size([1, 1024]).
	size mismatch for blocks.0.0.conv2.weight: copying a param with shape torch.Size([1536, 1536, 3, 3]) from checkpoint, the shape in current model is torch.Size([1024, 1024, 3, 3]).
	size mismatch for blocks.0.0.conv2.bias: copying a param with shape torch.Size([1536]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for blocks.0.0.conv2.u0: copying a param with shape torch.Size([1, 1536]) from checkpoint, the shape in current model is torch.Size([1, 1024]).
	size mismatch for blocks.0.0.conv_sc.weight: copying a param with shape torch.Size([1536, 1536, 1, 1]) from checkpoint, the shape in current model is torch.Size([1024, 1024, 1, 1]).
	size mismatch for blocks.0.0.conv_sc.bias: copying a param with shape torch.Size([1536]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for blocks.0.0.conv_sc.u0: copying a param with shape torch.Size([1, 1536]) from checkpoint, the shape in current model is torch.Size([1, 1024]).
	size mismatch for blocks.0.0.bn1.stored_mean: copying a param with shape torch.Size([1536]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for blocks.0.0.bn1.stored_var: copying a param with shape torch.Size([1536]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for blocks.0.0.bn1.gain.weight: copying a param with shape torch.Size([1536, 148]) from checkpoint, the shape in current model is torch.Size([1000, 1024]).
	size mismatch for blocks.0.0.bn1.bias.weight: copying a param with shape torch.Size([1536, 148]) from checkpoint, the shape in current model is torch.Size([1000, 1024]).
	size mismatch for blocks.0.0.bn2.stored_mean: copying a param with shape torch.Size([1536]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for blocks.0.0.bn2.stored_var: copying a param with shape torch.Size([1536]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for blocks.0.0.bn2.gain.weight: copying a param with shape torch.Size([1536, 148]) from checkpoint, the shape in current model is torch.Size([1000, 1024]).
	size mismatch for blocks.0.0.bn2.bias.weight: copying a param with shape torch.Size([1536, 148]) from checkpoint, the shape in current model is torch.Size([1000, 1024]).
	size mismatch for blocks.1.0.conv1.weight: copying a param with shape torch.Size([768, 1536, 3, 3]) from checkpoint, the shape in current model is torch.Size([512, 1024, 3, 3]).
	size mismatch for blocks.1.0.conv1.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for blocks.1.0.conv1.u0: copying a param with shape torch.Size([1, 768]) from checkpoint, the shape in current model is torch.Size([1, 512]).
	size mismatch for blocks.1.0.conv2.weight: copying a param with shape torch.Size([768, 768, 3, 3]) from checkpoint, the shape in current model is torch.Size([512, 512, 3, 3]).
	size mismatch for blocks.1.0.conv2.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for blocks.1.0.conv2.u0: copying a param with shape torch.Size([1, 768]) from checkpoint, the shape in current model is torch.Size([1, 512]).
	size mismatch for blocks.1.0.conv_sc.weight: copying a param with shape torch.Size([768, 1536, 1, 1]) from checkpoint, the shape in current model is torch.Size([512, 1024, 1, 1]).
	size mismatch for blocks.1.0.conv_sc.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for blocks.1.0.conv_sc.u0: copying a param with shape torch.Size([1, 768]) from checkpoint, the shape in current model is torch.Size([1, 512]).
	size mismatch for blocks.1.0.bn1.stored_mean: copying a param with shape torch.Size([1536]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for blocks.1.0.bn1.stored_var: copying a param with shape torch.Size([1536]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for blocks.1.0.bn1.gain.weight: copying a param with shape torch.Size([1536, 148]) from checkpoint, the shape in current model is torch.Size([1000, 1024]).
	size mismatch for blocks.1.0.bn1.bias.weight: copying a param with shape torch.Size([1536, 148]) from checkpoint, the shape in current model is torch.Size([1000, 1024]).
	size mismatch for blocks.1.0.bn2.stored_mean: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for blocks.1.0.bn2.stored_var: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for blocks.1.0.bn2.gain.weight: copying a param with shape torch.Size([768, 148]) from checkpoint, the shape in current model is torch.Size([1000, 512]).
	size mismatch for blocks.1.0.bn2.bias.weight: copying a param with shape torch.Size([768, 148]) from checkpoint, the shape in current model is torch.Size([1000, 512]).
	size mismatch for blocks.2.0.conv1.weight: copying a param with shape torch.Size([384, 768, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 512, 3, 3]).
	size mismatch for blocks.2.0.conv1.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for blocks.2.0.conv1.u0: copying a param with shape torch.Size([1, 384]) from checkpoint, the shape in current model is torch.Size([1, 256]).
	size mismatch for blocks.2.0.conv2.weight: copying a param with shape torch.Size([384, 384, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 256, 3, 3]).
	size mismatch for blocks.2.0.conv2.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for blocks.2.0.conv2.u0: copying a param with shape torch.Size([1, 384]) from checkpoint, the shape in current model is torch.Size([1, 256]).
	size mismatch for blocks.2.0.conv_sc.weight: copying a param with shape torch.Size([384, 768, 1, 1]) from checkpoint, the shape in current model is torch.Size([256, 512, 1, 1]).
	size mismatch for blocks.2.0.conv_sc.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for blocks.2.0.conv_sc.u0: copying a param with shape torch.Size([1, 384]) from checkpoint, the shape in current model is torch.Size([1, 256]).
	size mismatch for blocks.2.0.bn1.stored_mean: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for blocks.2.0.bn1.stored_var: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for blocks.2.0.bn1.gain.weight: copying a param with shape torch.Size([768, 148]) from checkpoint, the shape in current model is torch.Size([1000, 512]).
	size mismatch for blocks.2.0.bn1.bias.weight: copying a param with shape torch.Size([768, 148]) from checkpoint, the shape in current model is torch.Size([1000, 512]).
	size mismatch for blocks.2.0.bn2.stored_mean: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for blocks.2.0.bn2.stored_var: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for blocks.2.0.bn2.gain.weight: copying a param with shape torch.Size([384, 148]) from checkpoint, the shape in current model is torch.Size([1000, 256]).
	size mismatch for blocks.2.0.bn2.bias.weight: copying a param with shape torch.Size([384, 148]) from checkpoint, the shape in current model is torch.Size([1000, 256]).
	size mismatch for blocks.3.0.conv1.weight: copying a param with shape torch.Size([192, 384, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 256, 3, 3]).
	size mismatch for blocks.3.0.conv1.bias: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for blocks.3.0.conv1.u0: copying a param with shape torch.Size([1, 192]) from checkpoint, the shape in current model is torch.Size([1, 128]).
	size mismatch for blocks.3.0.conv2.weight: copying a param with shape torch.Size([192, 192, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 128, 3, 3]).
	size mismatch for blocks.3.0.conv2.bias: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for blocks.3.0.conv2.u0: copying a param with shape torch.Size([1, 192]) from checkpoint, the shape in current model is torch.Size([1, 128]).
	size mismatch for blocks.3.0.conv_sc.weight: copying a param with shape torch.Size([192, 384, 1, 1]) from checkpoint, the shape in current model is torch.Size([128, 256, 1, 1]).
	size mismatch for blocks.3.0.conv_sc.bias: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for blocks.3.0.conv_sc.u0: copying a param with shape torch.Size([1, 192]) from checkpoint, the shape in current model is torch.Size([1, 128]).
	size mismatch for blocks.3.0.bn1.stored_mean: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for blocks.3.0.bn1.stored_var: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for blocks.3.0.bn1.gain.weight: copying a param with shape torch.Size([384, 148]) from checkpoint, the shape in current model is torch.Size([1000, 256]).
	size mismatch for blocks.3.0.bn1.bias.weight: copying a param with shape torch.Size([384, 148]) from checkpoint, the shape in current model is torch.Size([1000, 256]).
	size mismatch for blocks.3.0.bn2.stored_mean: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for blocks.3.0.bn2.stored_var: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for blocks.3.0.bn2.gain.weight: copying a param with shape torch.Size([192, 148]) from checkpoint, the shape in current model is torch.Size([1000, 128]).
	size mismatch for blocks.3.0.bn2.bias.weight: copying a param with shape torch.Size([192, 148]) from checkpoint, the shape in current model is torch.Size([1000, 128]).
	size mismatch for blocks.3.1.theta.weight: copying a param with shape torch.Size([24, 192, 1, 1]) from checkpoint, the shape in current model is torch.Size([16, 128, 1, 1]).
	size mismatch for blocks.3.1.theta.u0: copying a param with shape torch.Size([1, 24]) from checkpoint, the shape in current model is torch.Size([1, 16]).
	size mismatch for blocks.3.1.phi.weight: copying a param with shape torch.Size([24, 192, 1, 1]) from checkpoint, the shape in current model is torch.Size([16, 128, 1, 1]).
	size mismatch for blocks.3.1.phi.u0: copying a param with shape torch.Size([1, 24]) from checkpoint, the shape in current model is torch.Size([1, 16]).
	size mismatch for blocks.3.1.g.weight: copying a param with shape torch.Size([96, 192, 1, 1]) from checkpoint, the shape in current model is torch.Size([64, 128, 1, 1]).
	size mismatch for blocks.3.1.g.u0: copying a param with shape torch.Size([1, 96]) from checkpoint, the shape in current model is torch.Size([1, 64]).
	size mismatch for blocks.3.1.o.weight: copying a param with shape torch.Size([192, 96, 1, 1]) from checkpoint, the shape in current model is torch.Size([128, 64, 1, 1]).
	size mismatch for blocks.3.1.o.u0: copying a param with shape torch.Size([1, 192]) from checkpoint, the shape in current model is torch.Size([1, 128]).
	size mismatch for blocks.4.0.conv1.weight: copying a param with shape torch.Size([96, 192, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 128, 3, 3]).
	size mismatch for blocks.4.0.conv1.bias: copying a param with shape torch.Size([96]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for blocks.4.0.conv1.u0: copying a param with shape torch.Size([1, 96]) from checkpoint, the shape in current model is torch.Size([1, 64]).
	size mismatch for blocks.4.0.conv2.weight: copying a param with shape torch.Size([96, 96, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 64, 3, 3]).
	size mismatch for blocks.4.0.conv2.bias: copying a param with shape torch.Size([96]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for blocks.4.0.conv2.u0: copying a param with shape torch.Size([1, 96]) from checkpoint, the shape in current model is torch.Size([1, 64]).
	size mismatch for blocks.4.0.conv_sc.weight: copying a param with shape torch.Size([96, 192, 1, 1]) from checkpoint, the shape in current model is torch.Size([64, 128, 1, 1]).
	size mismatch for blocks.4.0.conv_sc.bias: copying a param with shape torch.Size([96]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for blocks.4.0.conv_sc.u0: copying a param with shape torch.Size([1, 96]) from checkpoint, the shape in current model is torch.Size([1, 64]).
	size mismatch for blocks.4.0.bn1.stored_mean: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for blocks.4.0.bn1.stored_var: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for blocks.4.0.bn1.gain.weight: copying a param with shape torch.Size([192, 148]) from checkpoint, the shape in current model is torch.Size([1000, 128]).
	size mismatch for blocks.4.0.bn1.bias.weight: copying a param with shape torch.Size([192, 148]) from checkpoint, the shape in current model is torch.Size([1000, 128]).
	size mismatch for blocks.4.0.bn2.stored_mean: copying a param with shape torch.Size([96]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for blocks.4.0.bn2.stored_var: copying a param with shape torch.Size([96]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for blocks.4.0.bn2.gain.weight: copying a param with shape torch.Size([96, 148]) from checkpoint, the shape in current model is torch.Size([1000, 64]).
	size mismatch for blocks.4.0.bn2.bias.weight: copying a param with shape torch.Size([96, 148]) from checkpoint, the shape in current model is torch.Size([1000, 64]).
	size mismatch for output_layer.0.gain: copying a param with shape torch.Size([96]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for output_layer.0.bias: copying a param with shape torch.Size([96]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for output_layer.0.stored_mean: copying a param with shape torch.Size([96]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for output_layer.0.stored_var: copying a param with shape torch.Size([96]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for output_layer.2.weight: copying a param with shape torch.Size([3, 96, 3, 3]) from checkpoint, the shape in current model is torch.Size([3, 64, 3, 3]).

@ajbrock
Copy link
Owner

ajbrock commented Mar 27, 2019

You need to create a Generator which has the same parameters as the model in the pretrained checkpoint. Again, please look at the bash script I linked.

@ajbrock ajbrock closed this as completed Mar 27, 2019
@airalcorn2
Copy link

airalcorn2 commented Mar 27, 2019

Thanks for pointing us to that shell script @ajbrock. In case it helps someone else, here is a minimal working example:

import torch
import torchvision
import utils

parser = utils.prepare_parser()
parser = utils.add_sample_parser(parser)
config = vars(parser.parse_args())

# See: https://github.com/ajbrock/BigGAN-PyTorch/blob/master/scripts/sample_BigGAN_bs256x8.sh.
config["resolution"] = utils.imsize_dict["I128_hdf5"]
config["n_classes"] = utils.nclass_dict["I128_hdf5"]
config["G_activation"] = utils.activation_dict["inplace_relu"]
config["D_activation"] = utils.activation_dict["inplace_relu"]
config["G_attn"] = "64"
config["D_attn"] = "64"
config["G_ch"] = 96
config["D_ch"] = 96
config["hier"] = True
config["dim_z"] = 120
config["shared_dim"] = 128
config["G_shared"] = True
config = utils.update_config_roots(config)
config["skip_init"] = True
config["no_optim"] = True
config["device"] = "cuda"

# Seed RNG.
utils.seed_rng(config["seed"])

# Set up cudnn.benchmark for free speed.
torch.backends.cudnn.benchmark = True

# Import the model.
model = __import__(config["model"])
G = model.Generator(**config).to(config["device"])
utils.count_parameters(G)

# Load weights.
weights_path = "/mnt/raid/qi/biggan_weights/G_ema.pth"  # Change this.
G.load_state_dict(torch.load(weights_path))

# Update batch size setting used for G.
G_batch_size = max(config["G_batch_size"], config["batch_size"])
(z_, y_) = utils.prepare_z_y(
    G_batch_size,
    G.dim_z,
    config["n_classes"],
    device=config["device"],
    fp16=config["G_fp16"],
    z_var=config["z_var"],
)

G.eval()

out_path = "/home/michael/test_imgs/random_samples.jpg"  # Change this.
with torch.no_grad():
    z_.sample_()
    y_.sample_()
    image_tensors = G(z_, G.shared(y_))
    torchvision.utils.save_image(
        image_tensors,
        out_path,
        nrow=int(G_batch_size ** 0.5),
        normalize=True,
    )

that produces the following output:

random_samples

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants