In [1]:
cd ../../apps/

e:\kaggle\pytorch-book\apps


In [2]:
import time
from pathlib import Path

from random import randint
from matplotlib import pyplot as plt

import torch as np
from torchvision.utils import save_image

from models.CSA import CSA
from tools.toml import load_option
from opt.dataset import init_dataset

from tools.file import mkdir
from utils.torch_loader import Loader


def array2image(x):
    x *= 255
    x = x.detach().cpu().numpy()
    return x.astype('uint8').transpose((1, 2, 0))

def mask_op(mask):
    mask = mask.cuda()
    mask = mask[0][0]
    mask = np.unsqueeze(mask, 0)
    mask = np.unsqueeze(mask, 1)
    mask = mask.byte()
    return mask

## 模型定义

In [3]:
# 超参数设定
## 固定参数
epochs = 1000
display_freq = 200
save_epoch_freq = 1

## 模型参数
alpha = 1
beta = 0.2


model_name = f'CSA-crop-{alpha}-{beta}'

In [4]:
base_opt = load_option('../options/base.toml')
opt = load_option('../options/train-new.toml')
opt.update(base_opt)
opt.update({'name': model_name}) # 设定模型名称
model = CSA(beta, **opt)

image_save_dir = model.save_dir / 'images'
mkdir(image_save_dir)

initialize network with normal
initialize network with normal
initialize network with normal
initialize network with normal
---------- Networks initialized -------------
UnetGeneratorCSA(
  (model): UnetSkipConnectionBlock_3(
    (model): Sequential(
      (0): Conv2d(6, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): UnetSkipConnectionBlock_3(
        (model): Sequential(
          (0): LeakyReLU(negative_slope=0.2, inplace=True)
          (1): Conv2d(64, 64, kernel_size=(4, 4), stride=(2, 2), padding=(3, 3), dilation=(2, 2))
          (2): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
          (3): LeakyReLU(negative_slope=0.2, inplace=True)
          (4): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (5): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
          (6): UnetSkipConnectionBlock_3(
            (model): Sequential(
              (0): LeakyReLU(negat

In [5]:
opt = init_dataset(200)
loader = Loader(**opt)
trainset = loader.trainset # 训练集
maskset = loader.maskset # mask 数据集

{'E:/kaggle/datasets/building/╓╨╛░┤σ┬Σ╖τ├▓': 0, 'E:/kaggle/datasets/building/中景村落风貌': 809, 'E:/kaggle/datasets/building/航拍总图': 281, 'E:/kaggle/datasets/building/近景建筑风貌': 583, 'E:/kaggle/datasets/building/远景村落风貌': 1349}


In [6]:
# 训练阶段
start_epoch = 0
total_steps = 0
iter_start_time = time.time()
for epoch in range(start_epoch, epochs):
    epoch_start_time = time.time()
    epoch_iter = 0
    for batch, mask in zip(trainset, maskset):
        image = batch[0]
        mask = mask_op(mask)
        total_steps += model.batch_size
        epoch_iter += model.batch_size
        # it not only sets the input data with mask, but also sets the latent mask.
        model.set_input(image, mask)
        model.set_gt_latent()
        model.optimize_parameters()
        if total_steps % display_freq == 0:
            real_A, real_B, fake_B = model.get_current_visuals()
            # real_A=input, real_B=ground truth fake_b=output
            pic = (np.cat([real_A, real_B, fake_B], dim=0) + 1) / 2.0
            image_name = f"epoch{epoch}-{total_steps}-{alpha}.png"
            save_image(pic, image_save_dir/image_name, ncol=1)
        if total_steps % 100 == 0:
            errors = model.get_current_errors()
            t = (time.time() - iter_start_time) / model.batch_size
            print(
                f"Epoch/total_steps/alpha-beta: {epoch}/{total_steps}/{alpha}-{beta}", dict(errors))
    if epoch % save_epoch_freq == 0:
        print(f'保存模型 Epoch {epoch}, iters {total_steps} 在 {model.save_dir}')
        model.save(epoch)
    print(
        f'Epoch/Epochs {epoch}/{epochs-1} 花费时间：{time.time() - epoch_start_time}s')
    model.update_learning_rate()

Epoch/total_steps/alpha-beta: 0/100/1-0.2 {'G_GAN': 5.518022537231445, 'G_L1': 55.588680267333984, 'D': 1.1141009330749512, 'F': 0.07483334094285965}
Epoch/total_steps/alpha-beta: 0/200/1-0.2 {'G_GAN': 5.759289741516113, 'G_L1': 55.08604049682617, 'D': 0.6345841288566589, 'F': 0.04530204087495804}
Epoch/total_steps/alpha-beta: 0/300/1-0.2 {'G_GAN': 6.937740325927734, 'G_L1': 59.467960357666016, 'D': 0.5065065026283264, 'F': 0.03301296383142471}
Epoch/total_steps/alpha-beta: 0/400/1-0.2 {'G_GAN': 6.193876266479492, 'G_L1': 38.149436950683594, 'D': 1.1900465488433838, 'F': 0.05299185961484909}
Epoch/total_steps/alpha-beta: 0/500/1-0.2 {'G_GAN': 7.556830406188965, 'G_L1': 39.45917892456055, 'D': 0.2415945678949356, 'F': 0.022815663367509842}
Epoch/total_steps/alpha-beta: 0/600/1-0.2 {'G_GAN': 6.890033721923828, 'G_L1': 65.72925567626953, 'D': 0.5194847583770752, 'F': 0.026784880086779594}
Epoch/total_steps/alpha-beta: 0/700/1-0.2 {'G_GAN': 7.203421592712402, 'G_L1': 55.98662567138672, 'D'

KeyboardInterrupt: 