In [1]:
import torch

import wandb

from trainer import Trainer
from models import Generator, Discriminator
from utils import get_loader


DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

args = {
    "Z_DIM": 512,
    "W_DIM": 512,
    "LAMBDA_GP": 10,
    "EPOCHS": [40] * 8,
    "BATCH_SIZES": [256, 256, 128, 64, 32, 16, 8, 4],
    "IMAGE_SIZES": [4, 8, 16, 32, 64, 128, 256, 512],
}

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
wandb.init(project="StyleGAN1", entity="donghwankim")

wandb.run.name = f'ffhq/lambda_gp:{args["LAMBDA_GP"]}/z_dim:{args["Z_DIM"]}/w_dim:{args["W_DIM"]}'
wandb.save()

wandb.config.update(args)

[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin




In [3]:
gen = Generator(args["Z_DIM"], args["W_DIM"], const_channels=512)
disc = Discriminator()

print(sum([p.numel() for p in gen.parameters()]))
print(sum([p.numel() for p in disc.parameters()]))

26179256
25430785


In [4]:
trainer = Trainer(gen, disc, lr=4e-4)

for step in range(len(args["BATCH_SIZES"])):
    loader, _ = get_loader(
        args["IMAGE_SIZES"][step],
        dataset_root="/home/kdhsimplepro/kdhsimplepro/AI/ffhq/",
        batch_size=args["BATCH_SIZES"][step]
    )
    trainer.run(step=step, epochs=args["EPOCHS"][step], loader=loader)



Image Size: 4x4

EPOCH: 1/40


100%|██████████| 204/204 [05:01<00:00,  1.48s/it, disc_loss=-13.4, gen_loss=20.1, gp=0.41] 


EPOCH: 2/40


100%|██████████| 204/204 [05:02<00:00,  1.48s/it, disc_loss=-12.4, gen_loss=16.9, gp=0.299]


EPOCH: 3/40


100%|██████████| 204/204 [05:02<00:00,  1.48s/it, disc_loss=-11.1, gen_loss=16.9, gp=0.291]


EPOCH: 4/40


100%|██████████| 204/204 [05:01<00:00,  1.48s/it, disc_loss=-9.85, gen_loss=11.3, gp=0.211]


EPOCH: 5/40


100%|██████████| 204/204 [05:01<00:00,  1.48s/it, disc_loss=-7.84, gen_loss=9.39, gp=0.165]


EPOCH: 6/40


100%|██████████| 204/204 [05:02<00:00,  1.48s/it, disc_loss=-6.42, gen_loss=7.08, gp=0.218]


EPOCH: 7/40


100%|██████████| 204/204 [05:03<00:00,  1.49s/it, disc_loss=-7.67, gen_loss=8.76, gp=0.15] 


EPOCH: 8/40


100%|██████████| 204/204 [05:01<00:00,  1.48s/it, disc_loss=-7.72, gen_loss=8.83, gp=0.109]


EPOCH: 9/40


100%|██████████| 204/204 [05:01<00:00,  1.48s/it, disc_loss=-7.1, gen_loss=6.95, gp=0.133]  


EPOCH: 10/40


100%|██████████| 204/204 [05:01<00:00,  1.48s/it, disc_loss=-6.73, gen_loss=6.58, gp=0.152] 


EPOCH: 11/40


100%|██████████| 204/204 [05:00<00:00,  1.47s/it, disc_loss=-5.79, gen_loss=5.87, gp=0.123] 


EPOCH: 12/40


100%|██████████| 204/204 [04:59<00:00,  1.47s/it, disc_loss=-4.98, gen_loss=5.5, gp=0.105]  


EPOCH: 13/40


100%|██████████| 204/204 [04:59<00:00,  1.47s/it, disc_loss=-5.66, gen_loss=7.01, gp=0.0777]


EPOCH: 14/40


100%|██████████| 204/204 [05:01<00:00,  1.48s/it, disc_loss=-3.64, gen_loss=3.94, gp=0.0741]


EPOCH: 15/40


100%|██████████| 204/204 [05:01<00:00,  1.48s/it, disc_loss=-5.35, gen_loss=8, gp=0.0879]   


EPOCH: 16/40


100%|██████████| 204/204 [05:01<00:00,  1.48s/it, disc_loss=-4.68, gen_loss=5.73, gp=0.0773]


EPOCH: 17/40


100%|██████████| 204/204 [05:00<00:00,  1.47s/it, disc_loss=-4.34, gen_loss=5.99, gp=0.0603]


EPOCH: 18/40


100%|██████████| 204/204 [05:00<00:00,  1.48s/it, disc_loss=-4.18, gen_loss=7.03, gp=0.0357]


EPOCH: 19/40


100%|██████████| 204/204 [05:00<00:00,  1.47s/it, disc_loss=-2.64, gen_loss=2.82, gp=0.0628]


EPOCH: 20/40


100%|██████████| 204/204 [05:00<00:00,  1.47s/it, disc_loss=-2.63, gen_loss=3.66, gp=0.0344]


EPOCH: 21/40


100%|██████████| 204/204 [04:59<00:00,  1.47s/it, disc_loss=-2.25, gen_loss=3.7, gp=0.0251] 


EPOCH: 22/40


 46%|████▌     | 94/204 [02:18<02:42,  1.48s/it, disc_loss=-2.71, gen_loss=4.56, gp=0.0334]