New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Minimal working example for sampling from pre-trained BigGAN? #8
Comments
Hi Milan, Please see the included sample script provided in the scripts folder. As to your snippet, the line |
Thanks for replying so fast. I tried
|
You need to create a Generator which has the same parameters as the model in the pretrained checkpoint. Again, please look at the bash script I linked. |
Thanks for pointing us to that shell script @ajbrock. In case it helps someone else, here is a minimal working example: import torch
import torchvision
import utils
parser = utils.prepare_parser()
parser = utils.add_sample_parser(parser)
config = vars(parser.parse_args())
# See: https://github.com/ajbrock/BigGAN-PyTorch/blob/master/scripts/sample_BigGAN_bs256x8.sh.
config["resolution"] = utils.imsize_dict["I128_hdf5"]
config["n_classes"] = utils.nclass_dict["I128_hdf5"]
config["G_activation"] = utils.activation_dict["inplace_relu"]
config["D_activation"] = utils.activation_dict["inplace_relu"]
config["G_attn"] = "64"
config["D_attn"] = "64"
config["G_ch"] = 96
config["D_ch"] = 96
config["hier"] = True
config["dim_z"] = 120
config["shared_dim"] = 128
config["G_shared"] = True
config = utils.update_config_roots(config)
config["skip_init"] = True
config["no_optim"] = True
config["device"] = "cuda"
# Seed RNG.
utils.seed_rng(config["seed"])
# Set up cudnn.benchmark for free speed.
torch.backends.cudnn.benchmark = True
# Import the model.
model = __import__(config["model"])
G = model.Generator(**config).to(config["device"])
utils.count_parameters(G)
# Load weights.
weights_path = "/mnt/raid/qi/biggan_weights/G_ema.pth" # Change this.
G.load_state_dict(torch.load(weights_path))
# Update batch size setting used for G.
G_batch_size = max(config["G_batch_size"], config["batch_size"])
(z_, y_) = utils.prepare_z_y(
G_batch_size,
G.dim_z,
config["n_classes"],
device=config["device"],
fp16=config["G_fp16"],
z_var=config["z_var"],
)
G.eval()
out_path = "/home/michael/test_imgs/random_samples.jpg" # Change this.
with torch.no_grad():
z_.sample_()
y_.sample_()
image_tensors = G(z_, G.shared(y_))
torchvision.utils.save_image(
image_tensors,
out_path,
nrow=int(G_batch_size ** 0.5),
normalize=True,
) that produces the following output: |
Hi ajbrock,
I am so excited that you released the Pytorch version of BigGAN. I am trying to sample some results. Could you provide a minimal working example for sampling from pre-trained BigGAN? @airalcorn2 and I wrote a piece of code for sampling, but the results look bad.
Here is our sample code.
Here is one of our results.
Thanks a lot.
The text was updated successfully, but these errors were encountered: