In [1]:
import torch

import wandb

from trainer import Trainer
from models import Generator, Discriminator
from utils import get_loader

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def train_step(
    gen_state_dict=None,
    disc_state_dict=None,
    z_dim=512,
    w_dim=512,
    step=0,
    epochs=50,
    lr=0.0002,
    betas=(0.5, 0.99),
    alpha=1e-7,
    could_gp=True,
    image_size=4,
    dataset_root="/home/kdhsimplepro/kdhsimplepro/AI/ffhq/",
    batch_size=4,
):
    trainer = Trainer(
        gen_state_dict,
        disc_state_dict,
        z_dim,
        w_dim,
        lr=lr,
        betas=betas,
        alpha=alpha,
        step=step,
        could_gp=could_gp
    )

    loader, _ = get_loader(
        image_size=image_size,
        dataset_root=dataset_root,
        batch_size=batch_size
    )

    trainer.run(epochs=epochs, loader=loader)

    torch.save(trainer.gen.state_dict(), f"./gen_state_dict_step{step}.pt")
    torch.save(trainer.disc.state_dict(), f"./disc_state_dict_step{step}.pt")

### Step 0

In [3]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

args = {
    "Z_DIM": 512,
    "W_DIM": 512,
    "LAMBDA_GP": 10,
    "EPOCHS": [20, 20, 20, 50, 50, 75, 75, 100],
    "BATCH_SIZES": [256, 256, 128, 64, 32, 16, 8, 4],
    "IMAGE_SIZES": [4, 8, 16, 32, 64, 128, 256, 512],
    "STEP": 0,
}

wandb.init(project="StyleGAN1", entity="donghwankim")

wandb.run.name = f'ffhq/LAMBDA_GP:{args["LAMBDA_GP"]}/z_dim:{args["Z_DIM"]}/w_dim:{args["W_DIM"]}/step:{args["STEP"]}'
wandb.save()

wandb.config.update(args)

[34m[1mwandb[0m: Currently logged in as: [33mdonghwankim[0m. Use [1m`wandb login --relogin`[0m to force relogin




In [4]:
train_step(
    gen_state_dict=None,
    disc_state_dict=None,
    z_dim=args["Z_DIM"],
    w_dim=args["W_DIM"],
    step=args["STEP"],
    epochs=args["EPOCHS"][args["STEP"]],
    lr=0.0008,
    betas=(0.5, 0.99),
    alpha=1e-7,
    could_gp=True,
    image_size=args["IMAGE_SIZES"][args["STEP"]],
    dataset_root="/home/kdhsimplepro/kdhsimplepro/AI/ffhq/",
    batch_size=args["BATCH_SIZES"][args["STEP"]]
)



Image Size: 4x4

EPOCH: 1/20


100%|██████████| 204/204 [05:02<00:00,  1.48s/it, disc_loss=-8.01, gen_loss=10.6, gp=0.201]


EPOCH: 2/20


100%|██████████| 204/204 [05:01<00:00,  1.48s/it, disc_loss=-8.15, gen_loss=9.06, gp=0.174] 


EPOCH: 3/20


100%|██████████| 204/204 [05:01<00:00,  1.48s/it, disc_loss=-4.46, gen_loss=3.92, gp=0.153] 


EPOCH: 4/20


100%|██████████| 204/204 [05:01<00:00,  1.48s/it, disc_loss=-4.14, gen_loss=4.25, gp=0.0896]


EPOCH: 5/20


100%|██████████| 204/204 [05:01<00:00,  1.48s/it, disc_loss=-4.63, gen_loss=2.98, gp=0.0922] 


EPOCH: 6/20


100%|██████████| 204/204 [05:01<00:00,  1.48s/it, disc_loss=-3.93, gen_loss=3.61, gp=0.056] 


EPOCH: 7/20


100%|██████████| 204/204 [05:01<00:00,  1.48s/it, disc_loss=-3.79, gen_loss=2.77, gp=0.0523]


EPOCH: 8/20


100%|██████████| 204/204 [05:01<00:00,  1.48s/it, disc_loss=-2.76, gen_loss=3.31, gp=0.0418]


EPOCH: 9/20


100%|██████████| 204/204 [05:01<00:00,  1.48s/it, disc_loss=-2.86, gen_loss=3.01, gp=0.0272]


EPOCH: 10/20


100%|██████████| 204/204 [05:01<00:00,  1.48s/it, disc_loss=-2.53, gen_loss=3.99, gp=0.0282]


EPOCH: 11/20


100%|██████████| 204/204 [05:01<00:00,  1.48s/it, disc_loss=-1.31, gen_loss=1.67, gp=0.0263]


EPOCH: 12/20


100%|██████████| 204/204 [05:01<00:00,  1.48s/it, disc_loss=-1.42, gen_loss=1.98, gp=0.0186]


EPOCH: 13/20


100%|██████████| 204/204 [05:01<00:00,  1.48s/it, disc_loss=-1.21, gen_loss=2.14, gp=0.0186]


EPOCH: 14/20


100%|██████████| 204/204 [05:01<00:00,  1.48s/it, disc_loss=-.908, gen_loss=1.84, gp=0.0103]


EPOCH: 15/20


100%|██████████| 204/204 [05:02<00:00,  1.48s/it, disc_loss=-1.02, gen_loss=1.4, gp=0.0205]  


EPOCH: 16/20


100%|██████████| 204/204 [05:04<00:00,  1.49s/it, disc_loss=-.716, gen_loss=0.678, gp=0.0129]


EPOCH: 17/20


100%|██████████| 204/204 [05:02<00:00,  1.48s/it, disc_loss=-.596, gen_loss=0.726, gp=0.0294]


EPOCH: 18/20


100%|██████████| 204/204 [05:02<00:00,  1.48s/it, disc_loss=-.352, gen_loss=0.566, gp=0.00843]


EPOCH: 19/20


100%|██████████| 204/204 [05:02<00:00,  1.48s/it, disc_loss=-.586, gen_loss=0.521, gp=0.0125] 


EPOCH: 20/20


100%|██████████| 204/204 [05:03<00:00,  1.49s/it, disc_loss=-.449, gen_loss=0.492, gp=0.00684]


### Step 1

In [3]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

args = {
    "Z_DIM": 512,
    "W_DIM": 512,
    "LAMBDA_GP": 10,
    "EPOCHS": [20, 20, 20, 50, 50, 75, 75, 100],
    "BATCH_SIZES": [256, 256, 128, 64, 32, 16, 8, 4],
    "IMAGE_SIZES": [4, 8, 16, 32, 64, 128, 256, 512],
    "STEP": 1,
}

wandb.init(project="StyleGAN1", entity="donghwankim")

wandb.run.name = f'ffhq/LAMBDA_GP:{args["LAMBDA_GP"]}/z_dim:{args["Z_DIM"]}/w_dim:{args["W_DIM"]}/step:{args["STEP"]}'
wandb.save()

wandb.config.update(args)

train_step(
    gen_state_dict=torch.load(f"./gen_state_dict_step{args['STEP']-1}.pt"),
    disc_state_dict=torch.load(f"./disc_state_dict_step{args['STEP']-1}.pt"),
    z_dim=args["Z_DIM"],
    w_dim=args["W_DIM"],
    step=args["STEP"],
    epochs=args["EPOCHS"][args["STEP"]],
    lr=0.0008,
    betas=(0.5, 0.99),
    alpha=1e-7,
    could_gp=True,
    image_size=args["IMAGE_SIZES"][args["STEP"]],
    dataset_root="/home/kdhsimplepro/kdhsimplepro/AI/ffhq/",
    batch_size=args["BATCH_SIZES"][args["STEP"]]
)

[34m[1mwandb[0m: Currently logged in as: [33mdonghwankim[0m. Use [1m`wandb login --relogin`[0m to force relogin






Image Size: 8x8

EPOCH: 1/20


100%|██████████| 204/204 [05:43<00:00,  1.68s/it, disc_loss=-.502, gen_loss=1.04, gp=0.0492] 


EPOCH: 2/20


100%|██████████| 204/204 [05:43<00:00,  1.68s/it, disc_loss=-2.13, gen_loss=2.92, gp=0.0234] 


EPOCH: 3/20


100%|██████████| 204/204 [05:46<00:00,  1.70s/it, disc_loss=-.988, gen_loss=1.38, gp=0.0305] 


EPOCH: 4/20


100%|██████████| 204/204 [05:44<00:00,  1.69s/it, disc_loss=-.876, gen_loss=1.5, gp=0.02]    


EPOCH: 5/20


100%|██████████| 204/204 [05:41<00:00,  1.67s/it, disc_loss=-.634, gen_loss=1.28, gp=0.03]     


EPOCH: 6/20


100%|██████████| 204/204 [05:43<00:00,  1.68s/it, disc_loss=-.599, gen_loss=1.78, gp=0.00928] 


EPOCH: 7/20


100%|██████████| 204/204 [05:45<00:00,  1.69s/it, disc_loss=0.007, gen_loss=0.348, gp=0.0409] 


EPOCH: 8/20


100%|██████████| 204/204 [05:43<00:00,  1.68s/it, disc_loss=-.224, gen_loss=0.978, gp=0.0148]  


EPOCH: 9/20


100%|██████████| 204/204 [05:42<00:00,  1.68s/it, disc_loss=0.706, gen_loss=0.107, gp=0.0245] 


EPOCH: 10/20


100%|██████████| 204/204 [05:41<00:00,  1.67s/it, disc_loss=-1.57, gen_loss=1.92, gp=0.023]    


EPOCH: 11/20


100%|██████████| 204/204 [05:41<00:00,  1.68s/it, disc_loss=-.0284, gen_loss=-.0159, gp=0.0223] 


EPOCH: 12/20


100%|██████████| 204/204 [05:41<00:00,  1.67s/it, disc_loss=-.402, gen_loss=0.961, gp=0.0252]  


EPOCH: 13/20


100%|██████████| 204/204 [05:40<00:00,  1.67s/it, disc_loss=-.0461, gen_loss=0.0518, gp=0.0143]  


EPOCH: 14/20


100%|██████████| 204/204 [05:40<00:00,  1.67s/it, disc_loss=-.517, gen_loss=0.937, gp=0.00958]  


EPOCH: 15/20


100%|██████████| 204/204 [05:40<00:00,  1.67s/it, disc_loss=0.197, gen_loss=-.0115, gp=0.0137]  


EPOCH: 16/20


100%|██████████| 204/204 [05:40<00:00,  1.67s/it, disc_loss=-.151, gen_loss=0.507, gp=0.0134]   


EPOCH: 17/20


100%|██████████| 204/204 [05:40<00:00,  1.67s/it, disc_loss=0.0243, gen_loss=0.188, gp=0.00806] 


EPOCH: 18/20


100%|██████████| 204/204 [05:40<00:00,  1.67s/it, disc_loss=0.0647, gen_loss=1.09, gp=0.00829] 


EPOCH: 19/20


100%|██████████| 204/204 [05:40<00:00,  1.67s/it, disc_loss=-.277, gen_loss=0.335, gp=0.00926]    


EPOCH: 20/20


100%|██████████| 204/204 [05:40<00:00,  1.67s/it, disc_loss=0.095, gen_loss=-.00478, gp=0.00639] 


### Step 2

In [3]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

args = {
    "Z_DIM": 512,
    "W_DIM": 512,
    "LAMBDA_GP": 10,
    "EPOCHS": [20, 20, 20, 50, 50, 75, 75, 100],
    "BATCH_SIZES": [256, 256, 128, 64, 32, 16, 8, 4],
    "IMAGE_SIZES": [4, 8, 16, 32, 64, 128, 256, 512],
    "STEP": 2,
}

wandb.init(project="StyleGAN1", entity="donghwankim")

wandb.run.name = f'ffhq/LAMBDA_GP:{args["LAMBDA_GP"]}/z_dim:{args["Z_DIM"]}/w_dim:{args["W_DIM"]}/step:{args["STEP"]}'
wandb.save()

wandb.config.update(args)

train_step(
    gen_state_dict=torch.load(f"./gen_state_dict_step{args['STEP']-1}.pt"),
    disc_state_dict=torch.load(f"./disc_state_dict_step{args['STEP']-1}.pt"),
    z_dim=args["Z_DIM"],
    w_dim=args["W_DIM"],
    step=args["STEP"],
    epochs=args["EPOCHS"][args["STEP"]],
    lr=0.0008,
    betas=(0.5, 0.99),
    alpha=1e-7,
    could_gp=True,
    image_size=args["IMAGE_SIZES"][args["STEP"]],
    dataset_root="/home/kdhsimplepro/kdhsimplepro/AI/ffhq/",
    batch_size=args["BATCH_SIZES"][args["STEP"]]
)

[34m[1mwandb[0m: Currently logged in as: [33mdonghwankim[0m. Use [1m`wandb login --relogin`[0m to force relogin






Image Size: 16x16

EPOCH: 1/20


100%|██████████| 407/407 [07:33<00:00,  1.11s/it, disc_loss=-.575, gen_loss=0.162, gp=0.0177]  


EPOCH: 2/20


 35%|███▌      | 143/407 [02:38<04:52,  1.11s/it, disc_loss=-.437, gen_loss=2.51, gp=0.0109]  

### Step 3

In [3]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

args = {
    "Z_DIM": 512,
    "W_DIM": 512,
    "LAMBDA_GP": 10,
    "EPOCHS": [20, 20, 20, 50, 50, 75, 75, 100],
    "BATCH_SIZES": [256, 256, 128, 64, 32, 16, 8, 4],
    "IMAGE_SIZES": [4, 8, 16, 32, 64, 128, 256, 512],
    "STEP": 3,
}

wandb.init(project="StyleGAN1", entity="donghwankim")

wandb.run.name = f'ffhq/LAMBDA_GP:{args["LAMBDA_GP"]}/z_dim:{args["Z_DIM"]}/w_dim:{args["W_DIM"]}/step:{args["STEP"]}'
wandb.save()

wandb.config.update(args)

train_step(
    gen_state_dict=torch.load(f"./gen_state_dict_step{args['STEP']-1}.pt"),
    disc_state_dict=torch.load(f"./disc_state_dict_step{args['STEP']-1}.pt"),
    z_dim=args["Z_DIM"],
    w_dim=args["W_DIM"],
    step=args["STEP"],
    epochs=args["EPOCHS"][args["STEP"]],
    lr=0.0008,
    betas=(0.5, 0.99),
    alpha=1e-7,
    could_gp=True,
    image_size=args["IMAGE_SIZES"][args["STEP"]],
    dataset_root="/home/kdhsimplepro/kdhsimplepro/AI/ffhq/",
    batch_size=args["BATCH_SIZES"][args["STEP"]]
)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mdonghwankim[0m. Use [1m`wandb login --relogin`[0m to force relogin






Image Size: 32x32

EPOCH: 1/50


100%|██████████| 813/813 [12:50<00:00,  1.06it/s, disc_loss=-2.01, gen_loss=2, gp=0.00949]     


EPOCH: 2/50


100%|██████████| 813/813 [12:49<00:00,  1.06it/s, disc_loss=-.805, gen_loss=2.64, gp=0.0115]    


EPOCH: 3/50


100%|██████████| 813/813 [12:49<00:00,  1.06it/s, disc_loss=0.0182, gen_loss=-1.99, gp=0.00459] 


EPOCH: 4/50


100%|██████████| 813/813 [12:49<00:00,  1.06it/s, disc_loss=-.844, gen_loss=-2.5, gp=0.0168]     


EPOCH: 5/50


100%|██████████| 813/813 [12:49<00:00,  1.06it/s, disc_loss=-1.51, gen_loss=1.5, gp=0.0158]    


EPOCH: 6/50


100%|██████████| 813/813 [12:49<00:00,  1.06it/s, disc_loss=-1.88, gen_loss=9.54, gp=0.00532]   


EPOCH: 7/50


100%|██████████| 813/813 [12:49<00:00,  1.06it/s, disc_loss=-.548, gen_loss=-.0274, gp=0.018]   


EPOCH: 8/50


100%|██████████| 813/813 [12:49<00:00,  1.06it/s, disc_loss=-1.07, gen_loss=2.42, gp=0.00387]   


EPOCH: 9/50


100%|██████████| 813/813 [12:49<00:00,  1.06it/s, disc_loss=-.694, gen_loss=8.96, gp=0.0214]    


EPOCH: 10/50


100%|██████████| 813/813 [12:49<00:00,  1.06it/s, disc_loss=-.0484, gen_loss=-.889, gp=0.00362] 


EPOCH: 11/50


100%|██████████| 813/813 [12:49<00:00,  1.06it/s, disc_loss=-.339, gen_loss=-1.21, gp=0.00382]  


EPOCH: 12/50


100%|██████████| 813/813 [12:49<00:00,  1.06it/s, disc_loss=-.777, gen_loss=-2.17, gp=0.0281]   


EPOCH: 13/50


100%|██████████| 813/813 [12:49<00:00,  1.06it/s, disc_loss=-.693, gen_loss=6.47, gp=0.00717]    


EPOCH: 14/50


100%|██████████| 813/813 [12:49<00:00,  1.06it/s, disc_loss=-.644, gen_loss=-1.89, gp=0.00215]  


EPOCH: 15/50


100%|██████████| 813/813 [12:49<00:00,  1.06it/s, disc_loss=0.0744, gen_loss=8.08, gp=0.0197]   


EPOCH: 16/50


100%|██████████| 813/813 [12:49<00:00,  1.06it/s, disc_loss=0.528, gen_loss=3.44, gp=0.004]     


EPOCH: 17/50


100%|██████████| 813/813 [12:49<00:00,  1.06it/s, disc_loss=-.744, gen_loss=-2.93, gp=0.0123]   


EPOCH: 18/50


100%|██████████| 813/813 [12:49<00:00,  1.06it/s, disc_loss=0.258, gen_loss=17.1, gp=0.0106]     


EPOCH: 19/50


100%|██████████| 813/813 [12:49<00:00,  1.06it/s, disc_loss=0.0194, gen_loss=1.09, gp=0.00796]  


EPOCH: 20/50


100%|██████████| 813/813 [12:49<00:00,  1.06it/s, disc_loss=-.518, gen_loss=-5.65, gp=0.0218]    


EPOCH: 21/50


100%|██████████| 813/813 [12:49<00:00,  1.06it/s, disc_loss=-.985, gen_loss=-2.41, gp=0.0111]   


EPOCH: 22/50


100%|██████████| 813/813 [12:49<00:00,  1.06it/s, disc_loss=-1.09, gen_loss=5.97, gp=0.0168]    


EPOCH: 23/50


100%|██████████| 813/813 [12:49<00:00,  1.06it/s, disc_loss=-.651, gen_loss=-1.48, gp=0.0111]    


EPOCH: 24/50


100%|██████████| 813/813 [12:49<00:00,  1.06it/s, disc_loss=-.494, gen_loss=-2.73, gp=0.00767]   


EPOCH: 25/50


100%|██████████| 813/813 [12:49<00:00,  1.06it/s, disc_loss=-.398, gen_loss=-3.18, gp=0.00475]  


EPOCH: 26/50


100%|██████████| 813/813 [12:49<00:00,  1.06it/s, disc_loss=-.83, gen_loss=1.01, gp=0.0138]     


EPOCH: 27/50


100%|██████████| 813/813 [12:49<00:00,  1.06it/s, disc_loss=-1.06, gen_loss=2.11, gp=0.00832]   


EPOCH: 28/50


100%|██████████| 813/813 [12:49<00:00,  1.06it/s, disc_loss=-.463, gen_loss=-4.74, gp=0.0106]   


EPOCH: 29/50


100%|██████████| 813/813 [12:49<00:00,  1.06it/s, disc_loss=-1.03, gen_loss=-6.35, gp=0.00321]  


EPOCH: 30/50


100%|██████████| 813/813 [12:49<00:00,  1.06it/s, disc_loss=-1.08, gen_loss=-2.14, gp=0.0307]    


EPOCH: 31/50


100%|██████████| 813/813 [12:49<00:00,  1.06it/s, disc_loss=-.679, gen_loss=-1.27, gp=0.0039]   


EPOCH: 32/50


100%|██████████| 813/813 [12:49<00:00,  1.06it/s, disc_loss=-.0826, gen_loss=5.52, gp=0.0033]   


EPOCH: 33/50


100%|██████████| 813/813 [12:49<00:00,  1.06it/s, disc_loss=-.917, gen_loss=3.33, gp=0.00743]   


EPOCH: 34/50


100%|██████████| 813/813 [12:49<00:00,  1.06it/s, disc_loss=0.342, gen_loss=0.279, gp=0.0146]    


EPOCH: 35/50


100%|██████████| 813/813 [12:49<00:00,  1.06it/s, disc_loss=-.241, gen_loss=-.256, gp=0.00505]   


EPOCH: 36/50


100%|██████████| 813/813 [12:49<00:00,  1.06it/s, disc_loss=-.201, gen_loss=-.00537, gp=0.00177]


EPOCH: 37/50


100%|██████████| 813/813 [12:49<00:00,  1.06it/s, disc_loss=0.36, gen_loss=1.91, gp=0.0171]     


EPOCH: 38/50


100%|██████████| 813/813 [12:49<00:00,  1.06it/s, disc_loss=-.908, gen_loss=-5.09, gp=0.00369]   


EPOCH: 39/50


100%|██████████| 813/813 [12:49<00:00,  1.06it/s, disc_loss=-.37, gen_loss=-.541, gp=0.00386]    


EPOCH: 40/50


100%|██████████| 813/813 [12:49<00:00,  1.06it/s, disc_loss=-.761, gen_loss=-.513, gp=0.00555]  


EPOCH: 41/50


100%|██████████| 813/813 [12:49<00:00,  1.06it/s, disc_loss=-.187, gen_loss=-2.25, gp=0.00313]   


EPOCH: 42/50


100%|██████████| 813/813 [12:49<00:00,  1.06it/s, disc_loss=0.612, gen_loss=7.61, gp=0.00447]   


EPOCH: 43/50


100%|██████████| 813/813 [12:49<00:00,  1.06it/s, disc_loss=-.679, gen_loss=0.696, gp=0.00326]  


EPOCH: 44/50


100%|██████████| 813/813 [12:49<00:00,  1.06it/s, disc_loss=0.0455, gen_loss=-4.57, gp=0.00467]  


EPOCH: 45/50


100%|██████████| 813/813 [12:49<00:00,  1.06it/s, disc_loss=-.114, gen_loss=-.383, gp=0.00405]  


EPOCH: 46/50


100%|██████████| 813/813 [12:49<00:00,  1.06it/s, disc_loss=0.0651, gen_loss=5.21, gp=0.0109]   


EPOCH: 47/50


100%|██████████| 813/813 [12:49<00:00,  1.06it/s, disc_loss=-.332, gen_loss=3.65, gp=0.0025]     


EPOCH: 48/50


100%|██████████| 813/813 [12:49<00:00,  1.06it/s, disc_loss=-.0739, gen_loss=-.588, gp=0.00218]  


EPOCH: 49/50


100%|██████████| 813/813 [12:49<00:00,  1.06it/s, disc_loss=0.252, gen_loss=0.49, gp=0.00247]   


EPOCH: 50/50


100%|██████████| 813/813 [12:49<00:00,  1.06it/s, disc_loss=-.13, gen_loss=7.34, gp=0.00626]    
