In [1]:
cd ../../apps

e:\kaggle\pytorch-book\apps


In [2]:
from torchvision.utils import save_image
import torch
import time
from matplotlib import pyplot as plt 

from tools.file import mkdir
from utils.torch_loader_all import Loader
from tools.toml import load_option
from app import  init, mask_op, array2image

In [3]:
opt = load_option('../origin/train-512.toml')
print(opt)
loader = Loader(**opt)

{'root': 'E:/kaggle/datasets/tune', 'mask': 'D:/kaggle/dataset/mask/testing_mask_dataset', 'fine_size': 512, 'batch_size': 1}


In [4]:
# 超参数设定
## 固定参数
epochs = 500
display_freq = 49
save_epoch_freq = 1

## 模型参数
alpha = 1
beta = 0.9

model_name = f'CSA-512-{beta}-{alpha}'
base_opt = load_option('../options/base512.toml')
model_opt = load_option('../options/train-new.toml')
model = init(model_name, beta, model_opt, base_opt)
image_save_dir = model.save_dir / 'images'
mkdir(image_save_dir)

initialize network with normal
initialize network with normal
initialize network with normal
initialize network with normal
---------- Networks initialized -------------
UnetGeneratorCSA(
  (model): UnetSkipConnectionBlock_3(
    (model): Sequential(
      (0): Conv2d(6, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): UnetSkipConnectionBlock_3(
        (model): Sequential(
          (0): LeakyReLU(negative_slope=0.2, inplace=True)
          (1): Conv2d(64, 64, kernel_size=(4, 4), stride=(2, 2), padding=(3, 3), dilation=(2, 2))
          (2): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
          (3): LeakyReLU(negative_slope=0.2, inplace=True)
          (4): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (5): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
          (6): UnetSkipConnectionBlock_3(
            (model): Sequential(
              (0): LeakyReLU(negat

In [5]:
# 训练阶段
start_epoch = 0
total_steps = 0
iter_start_time = time.time()
for epoch in range(start_epoch, epochs):
    epoch_start_time = time.time()
    epoch_iter = 0
    epoch_iter = 0
    # 初始化数据集
    trainset = loader.trainset() # 训练集
    maskset = loader.maskset() # mask 数据集
    for (image, _), mask in zip(trainset, maskset):
        mask = mask_op(mask)
        total_steps += model.batch_size
        epoch_iter += model.batch_size
        # it not only sets the input data with mask,
        #  but also sets the latent mask.
        model.set_input(image, mask)
        model.set_gt_latent()
        model.optimize_parameters()
        if total_steps % display_freq == 0:
            real_A, real_B, fake_B = model.get_current_visuals()
            # real_A=input, real_B=ground truth fake_b=output
            pic = (torch.cat([real_A, real_B, fake_B], dim=0) + 1) / 2.0
            image_name = f"epoch{epoch}-{total_steps}-{alpha}.png"
            save_image(pic, image_save_dir/image_name, ncol=1)
        if total_steps % 100 == 0:
            errors = model.get_current_errors()
            t = (time.time() - iter_start_time) / model.batch_size
            print(
                f"Epoch/total_steps/alpha-beta: {epoch}/{total_steps}/{alpha}-{beta}", dict(errors))
    if epoch % save_epoch_freq == 0:
        print(f'保存模型 Epoch {epoch}, iters {total_steps} 在 {model.save_dir}')
        model.save(epoch)
    print(
        f'Epoch/Epochs {epoch}/{epochs-1} 花费时间：{time.time() - epoch_start_time}s')
    model.update_learning_rate()

Epoch/total_steps/alpha-beta: 0/100/1-0.9 {'G_GAN': 5.229508399963379, 'G_L1': 56.7972297668457, 'D': 1.00749671459198, 'F': 0.10400338470935822}
Epoch/total_steps/alpha-beta: 0/200/1-0.9 {'G_GAN': 5.556301593780518, 'G_L1': 99.63152313232422, 'D': 0.6910452842712402, 'F': 0.1091005727648735}
Epoch/total_steps/alpha-beta: 0/300/1-0.9 {'G_GAN': 5.832674026489258, 'G_L1': 98.16615295410156, 'D': 0.770594596862793, 'F': 0.07101070880889893}
Epoch/total_steps/alpha-beta: 0/400/1-0.9 {'G_GAN': 5.475033760070801, 'G_L1': 58.53555679321289, 'D': 0.7877010107040405, 'F': 0.14161866903305054}
Epoch/total_steps/alpha-beta: 0/500/1-0.9 {'G_GAN': 5.6975555419921875, 'G_L1': 54.07685089111328, 'D': 0.9263743758201599, 'F': 0.024598095566034317}
Epoch/total_steps/alpha-beta: 0/600/1-0.9 {'G_GAN': 6.807770729064941, 'G_L1': 41.897850036621094, 'D': 0.4004296660423279, 'F': 0.03762740641832352}
Epoch/total_steps/alpha-beta: 0/700/1-0.9 {'G_GAN': 6.675029754638672, 'G_L1': 28.500131607055664, 'D': 0.61

KeyboardInterrupt: 