In [1]:
import torch

import wandb

from trainer import Trainer
from models import Generator, Discriminator
from utils import get_loader

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def train_step(
    gen_state_dict=None,
    disc_state_dict=None,
    z_dim=512,
    w_dim=512,
    step=0,
    epochs=50,
    lr=0.0002,
    betas=(0.5, 0.99),
    alpha=1e-7,
    could_gp=True,
    image_size=4,
    dataset_root="/home/kdhsimplepro/kdhsimplepro/AI/ffhq/",
    batch_size=4,
):
    trainer = Trainer(
        gen_state_dict,
        disc_state_dict,
        z_dim,
        w_dim,
        lr=lr,
        betas=betas,
        alpha=alpha,
        step=step,
        could_gp=could_gp
    )

    loader, _ = get_loader(
        image_size=image_size,
        dataset_root=dataset_root,
        batch_size=batch_size
    )

    trainer.run(epochs=epochs, loader=loader)

    torch.save(trainer.gen.state_dict(), f"./gen_state_dict_step{step}.pt")
    torch.save(trainer.disc.state_dict(), f"./disc_state_dict_step{step}.pt")

### Step 0

In [3]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

args = {
    "Z_DIM": 512,
    "W_DIM": 512,
    "LAMBDA_GP": 10,
    "EPOCHS": [20, 20, 20, 50, 50, 75, 75, 100],
    "BATCH_SIZES": [256, 256, 128, 64, 32, 16, 8, 4],
    "IMAGE_SIZES": [4, 8, 16, 32, 64, 128, 256, 512],
    "STEP": 0,
}

wandb.init(project="StyleGAN1", entity="donghwankim")

wandb.run.name = f'ffhq/LAMBDA_GP:{args["LAMBDA_GP"]}/z_dim:{args["Z_DIM"]}/w_dim:{args["W_DIM"]}/step:{args["STEP"]}'
wandb.save()

wandb.config.update(args)

[34m[1mwandb[0m: Currently logged in as: [33mdonghwankim[0m. Use [1m`wandb login --relogin`[0m to force relogin




In [4]:
train_step(
    gen_state_dict=None,
    disc_state_dict=None,
    z_dim=args["Z_DIM"],
    w_dim=args["W_DIM"],
    step=args["STEP"],
    epochs=args["EPOCHS"][args["STEP"]],
    lr=0.0008,
    betas=(0.5, 0.99),
    alpha=1e-7,
    could_gp=True,
    image_size=args["IMAGE_SIZES"][args["STEP"]],
    dataset_root="/home/kdhsimplepro/kdhsimplepro/AI/ffhq/",
    batch_size=args["BATCH_SIZES"][args["STEP"]]
)



Image Size: 4x4

EPOCH: 1/20


100%|██████████| 204/204 [05:02<00:00,  1.48s/it, disc_loss=-8.01, gen_loss=10.6, gp=0.201]


EPOCH: 2/20


100%|██████████| 204/204 [05:01<00:00,  1.48s/it, disc_loss=-8.15, gen_loss=9.06, gp=0.174] 


EPOCH: 3/20


100%|██████████| 204/204 [05:01<00:00,  1.48s/it, disc_loss=-4.46, gen_loss=3.92, gp=0.153] 


EPOCH: 4/20


100%|██████████| 204/204 [05:01<00:00,  1.48s/it, disc_loss=-4.14, gen_loss=4.25, gp=0.0896]


EPOCH: 5/20


100%|██████████| 204/204 [05:01<00:00,  1.48s/it, disc_loss=-4.63, gen_loss=2.98, gp=0.0922] 


EPOCH: 6/20


100%|██████████| 204/204 [05:01<00:00,  1.48s/it, disc_loss=-3.93, gen_loss=3.61, gp=0.056] 


EPOCH: 7/20


100%|██████████| 204/204 [05:01<00:00,  1.48s/it, disc_loss=-3.79, gen_loss=2.77, gp=0.0523]


EPOCH: 8/20


100%|██████████| 204/204 [05:01<00:00,  1.48s/it, disc_loss=-2.76, gen_loss=3.31, gp=0.0418]


EPOCH: 9/20


100%|██████████| 204/204 [05:01<00:00,  1.48s/it, disc_loss=-2.86, gen_loss=3.01, gp=0.0272]


EPOCH: 10/20


100%|██████████| 204/204 [05:01<00:00,  1.48s/it, disc_loss=-2.53, gen_loss=3.99, gp=0.0282]


EPOCH: 11/20


100%|██████████| 204/204 [05:01<00:00,  1.48s/it, disc_loss=-1.31, gen_loss=1.67, gp=0.0263]


EPOCH: 12/20


100%|██████████| 204/204 [05:01<00:00,  1.48s/it, disc_loss=-1.42, gen_loss=1.98, gp=0.0186]


EPOCH: 13/20


100%|██████████| 204/204 [05:01<00:00,  1.48s/it, disc_loss=-1.21, gen_loss=2.14, gp=0.0186]


EPOCH: 14/20


100%|██████████| 204/204 [05:01<00:00,  1.48s/it, disc_loss=-.908, gen_loss=1.84, gp=0.0103]


EPOCH: 15/20


100%|██████████| 204/204 [05:02<00:00,  1.48s/it, disc_loss=-1.02, gen_loss=1.4, gp=0.0205]  


EPOCH: 16/20


100%|██████████| 204/204 [05:04<00:00,  1.49s/it, disc_loss=-.716, gen_loss=0.678, gp=0.0129]


EPOCH: 17/20


100%|██████████| 204/204 [05:02<00:00,  1.48s/it, disc_loss=-.596, gen_loss=0.726, gp=0.0294]


EPOCH: 18/20


100%|██████████| 204/204 [05:02<00:00,  1.48s/it, disc_loss=-.352, gen_loss=0.566, gp=0.00843]


EPOCH: 19/20


100%|██████████| 204/204 [05:02<00:00,  1.48s/it, disc_loss=-.586, gen_loss=0.521, gp=0.0125] 


EPOCH: 20/20


100%|██████████| 204/204 [05:03<00:00,  1.49s/it, disc_loss=-.449, gen_loss=0.492, gp=0.00684]
