In [1]:
cd ../../apps

e:\kaggle\pytorch-book\apps


In [2]:
from torchvision.utils import save_image
import torch
import time
from matplotlib import pyplot as plt 

from tools.file import mkdir
from utils.torch_loader_all import Loader
from tools.toml import load_option
from app import  init, mask_op, array2image

In [3]:
# 超参数设定
## 固定参数
epochs = 120
display_freq = 100
save_epoch_freq = 1

## 模型参数
alpha = 0.9
beta = 0.9

fine_size = 256
opt = load_option(f'../origin/train-{fine_size}-new.toml')
print(opt)
loader = Loader(**opt)

model_name = f'CSAx-{fine_size}-{beta}-{alpha}'
base_opt = load_option(f'../options/base{fine_size}.toml')
model_opt = load_option('../options/train-new.toml')
model = init(model_name, beta, model_opt, base_opt)
image_save_dir = model.save_dir / 'images'
mkdir(image_save_dir)

{'root': 'E:/kaggle/datasets/tune', 'mask': 'D:/kaggle/dataset/mask/testing_mask_dataset', 'fine_size': 256, 'batch_size': 1}
initialize network with normal
initialize network with normal
initialize network with normal
initialize network with normal
---------- Networks initialized -------------
UnetGeneratorCSA(
  (model): UnetSkipConnectionBlock_3(
    (model): Sequential(
      (0): Conv2d(6, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): UnetSkipConnectionBlock_3(
        (model): Sequential(
          (0): LeakyReLU(negative_slope=0.2, inplace=True)
          (1): Conv2d(64, 64, kernel_size=(4, 4), stride=(2, 2), padding=(3, 3), dilation=(2, 2))
          (2): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
          (3): LeakyReLU(negative_slope=0.2, inplace=True)
          (4): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (5): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_runn

In [4]:
# 训练阶段
start_epoch = 0
total_steps = 0
iter_start_time = time.time()
for epoch in range(start_epoch, epochs):
    epoch_start_time = time.time()
    epoch_iter = 0
    epoch_iter = 0
    # 初始化数据集
    trainset = loader.trainset() # 训练集
    maskset = loader.maskset() # mask 数据集
    for (image, _), mask in zip(trainset, maskset):
        mask = mask_op(mask)
        total_steps += model.batch_size
        epoch_iter += model.batch_size
        # it not only sets the input data with mask,
        #  but also sets the latent mask.
        model.set_input(image, mask)
        model.set_gt_latent()
        model.optimize_parameters()
        if total_steps % display_freq == 0:
            real_A, real_B, fake_B = model.get_current_visuals()
            # real_A=input, real_B=ground truth fake_b=output
            pic = (torch.cat([real_A, real_B, fake_B], dim=0) + 1) / 2.0
            image_name = f"epoch{epoch}-{total_steps}-{alpha}.png"
            save_image(pic, image_save_dir/image_name, ncol=1)
        if total_steps % 100 == 0:
            errors = model.get_current_errors()
            t = (time.time() - iter_start_time) / model.batch_size
            print(
                f"Epoch/total_steps/alpha-beta: {epoch}/{total_steps}/{alpha}-{beta}", dict(errors))
    if epoch % save_epoch_freq == 0:
        print(f'保存模型 Epoch {epoch}, iters {total_steps} 在 {model.save_dir}')
        model.save(epoch)
    print(
        f'Epoch/Epochs {epoch}/{epochs-1} 花费时间：{time.time() - epoch_start_time}s')
    model.update_learning_rate()

Epoch/total_steps/alpha-beta: 0/100/0.9-0.9 {'G_GAN': 5.583193778991699, 'G_L1': 102.3512954711914, 'D': 1.6398849487304688, 'F': 0.14054235816001892}
Epoch/total_steps/alpha-beta: 0/200/0.9-0.9 {'G_GAN': 5.735714435577393, 'G_L1': 43.281005859375, 'D': 1.3913264274597168, 'F': 0.07673273980617523}
Epoch/total_steps/alpha-beta: 0/300/0.9-0.9 {'G_GAN': 6.328305244445801, 'G_L1': 71.03681182861328, 'D': 0.7898565530776978, 'F': 0.06986531615257263}
Epoch/total_steps/alpha-beta: 0/400/0.9-0.9 {'G_GAN': 6.22367000579834, 'G_L1': 60.55272674560547, 'D': 0.8411492109298706, 'F': 0.054662372916936874}
Epoch/total_steps/alpha-beta: 0/500/0.9-0.9 {'G_GAN': 5.689874649047852, 'G_L1': 79.49015045166016, 'D': 0.7096141576766968, 'F': 0.042588841170072556}
Epoch/total_steps/alpha-beta: 0/600/0.9-0.9 {'G_GAN': 6.076451778411865, 'G_L1': 51.49847412109375, 'D': 0.6659576296806335, 'F': 0.02226349711418152}
Epoch/total_steps/alpha-beta: 0/700/0.9-0.9 {'G_GAN': 7.23018217086792, 'G_L1': 76.882865905761