# 训练 StyleGAN V2 生成草莓图像

在自己的数据集上，训练基于 StyleGAN V2 的 Unconditional GAN（非条件生成对抗网络）。

> 作者：[同济子豪兄](https://space.bilibili.com/1900783)、杨逸飞 2022-4-10

扩展阅读：

古典花瓶图像生成 https://thisvesseldoesnotexist.com/#/

各种甲虫图像：https://www.flickr.com/photos/coleoptera-us/albums/72157607363771409

## 进入MMGeneration主目录

In [1]:
import os
os.chdir('mmgeneration')
os.listdir()

['.git',
 '.dev_scripts',
 '.github',
 '.gitignore',
 '.pre-commit-config.yaml',
 '.pylintrc',
 '.readthedocs.yml',
 'CITATION.cff',
 'LICENSE',
 'LICENSES.md',
 'MANIFEST.in',
 'README.md',
 'README_zh-CN.md',
 'apps',
 'configs',
 'demo',
 'docker',
 'docs',
 'mmgen',
 'model-index.yml',
 'requirements.txt',
 'requirements',
 'setup.cfg',
 'setup.py',
 'tests',
 'tools',
 'mmgen.egg-info',
 'outputs',
 'data',
 'checkpoints',
 'work_dirs',
 '.ipynb_checkpoints']

## 下载草莓图像数据集

In [2]:
# 下载草莓图像数据集压缩包
# !wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220322-mmgeneration/watermelon.zip -O data/watermelon.zip
!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220322-mmgeneration/strawberry.zip -O data/strawberry.zip

--2022-05-05 22:53:48--  https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220322-mmgeneration/strawberry.zip
Connecting to 172.16.0.13:5848... connected.
Proxy request sent, awaiting response... 200 OK
Length: 43600183 (42M) [application/zip]
Saving to: ‘data/strawberry.zip’


2022-05-05 22:53:49 (45.1 MB/s) - ‘data/strawberry.zip’ saved [43600183/43600183]



In [3]:
# 解压至 data/strawberry 目录
!unzip -o data/strawberry.zip -d data/strawberry

Archive:  data/strawberry.zip
  inflating: data/strawberry/img1.webp  
  inflating: data/strawberry/img2.jpg  
  inflating: data/strawberry/img3.jpg  
  inflating: data/strawberry/img4.jpg  
  inflating: data/strawberry/img5.jpg  
  inflating: data/strawberry/img6.jpg  
  inflating: data/strawberry/img7.webp  
  inflating: data/strawberry/img8.jpg  
  inflating: data/strawberry/img9.jpg  
  inflating: data/strawberry/img10.jpg  
  inflating: data/strawberry/img11.jpg  
  inflating: data/strawberry/img12.jpg  
  inflating: data/strawberry/img13.jpg  
  inflating: data/strawberry/img14.jpg  
  inflating: data/strawberry/img15.jpg  
  inflating: data/strawberry/img16.jpg  
  inflating: data/strawberry/img17.jpg  
  inflating: data/strawberry/img18.jpg  
  inflating: data/strawberry/img19.jpg  
  inflating: data/strawberry/img20.jpg  
  inflating: data/strawberry/img21.jpg  
  inflating: data/strawberry/img22.jpg  
  inflating: data/strawberry/img23.jpg  
  inflating: data/strawberry/img24

### 训练用于 FID 评估指标的 Inception V3 模型

生成 `work_dirs/inception_pkl/strawberry.pkl` 文件

In [4]:
!python tools/utils/inception_stat.py \
        --imgsdir data/strawberry \
        --pklname strawberry.pkl \
        --size 256 \
        --flip \
        --num-samples -1

2022-05-05 22:54:32,513 - mmgen - INFO - dataset_name: <class 'mmgen.datasets.unconditional_image_dataset.UnconditionalImageDataset'>, total 892 images in imgs_root: data/strawberry
Downloading: "https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth" to /home/featurize/.cache/torch/hub/checkpoints/pt_inception-2015-12-05-6726825d.pth
100%|███████████████████████████████████████| 91.2M/91.2M [00:00<00:00, 115MB/s]
2022-05-05 22:54:40,039 - mmgen - INFO - Use all samples in subset
[>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>] 36/36, 19.3 task/s, elapsed: 2s, ETA:     0s
2022-05-05 22:54:42,058 - mmgen - INFO - Extract 892 features


### 配置 config 文件

In [2]:
!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220322-mmgeneration/stylegan2_c2_ffhq_256_b4x8_800k_strawberry.py -O configs/styleganv2/stylegan2_c2_ffhq_256_b4x8_800k_strawberry.py

--2022-05-06 09:57:35--  https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220322-mmgeneration/stylegan2_c2_ffhq_256_b4x8_800k_strawberry.py
Connecting to 172.16.0.13:5848... connected.
Proxy request sent, awaiting response... 200 OK
Length: 2220 (2.2K) [binary/octet-stream]
Saving to: ‘configs/styleganv2/stylegan2_c2_ffhq_256_b4x8_800k_strawberry.py’


2022-05-06 09:57:36 (56.1 MB/s) - ‘configs/styleganv2/stylegan2_c2_ffhq_256_b4x8_800k_strawberry.py’ saved [2220/2220]



In [3]:
!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220322-mmgeneration/unconditional_imgs_flip_256x256.py -O configs/_base_/datasets/unconditional_imgs_flip_256x256.py

--2022-05-06 09:57:37--  https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220322-mmgeneration/unconditional_imgs_flip_256x256.py
Connecting to 172.16.0.13:5848... connected.
Proxy request sent, awaiting response... 200 OK
Length: 797 [binary/octet-stream]
Saving to: ‘configs/_base_/datasets/unconditional_imgs_flip_256x256.py’


2022-05-06 09:57:37 (20.6 MB/s) - ‘configs/_base_/datasets/unconditional_imgs_flip_256x256.py’ saved [797/797]



## 训练模型(运行半小时左右)

训练过程中，在 work_dirs/experiments/experiments_name目录下

training_samples保存了训练过程中达到不同迭代次数时的训练效果

ckpt/experiments_name保存了训练过程中达到不同迭代次数时的模型权重文件

In [6]:
!bash tools/dist_train.sh configs/styleganv2/stylegan2_c2_ffhq_256_b4x8_800k_strawberry.py 1 --work-dir work_dirs/experiments/stylegan2_c2_ffhq_256_b4x8_800k

and will be removed in future. Use torchrun.
Note that --use_env is set by default in torchrun.
If your script expects `--local_rank` argument to be set, please
change it to read from `os.environ['LOCAL_RANK']` instead. See 
https://pytorch.org/docs/stable/distributed.html#launch-utility for 
further instructions

  f'Setting OMP_NUM_THREADS environment variable for each process '
  f'Setting MKL_NUM_THREADS environment variable for each process '
2022-05-06 09:59:28,086 - mmgen - INFO - Environment info:
------------------------------------------------------------
sys.platform: linux
Python: 3.7.10 (default, Jun  4 2021, 14:48:32) [GCC 7.5.0]
CUDA available: True
CUDA_HOME: /usr/local/cuda
NVCC: Build cuda_11.2.r11.2/compiler.29618528_0
GPU 0: NVIDIA RTX A4000
GCC: gcc (Ubuntu 9.3.0-17ubuntu1~20.04) 9.3.0
PyTorch: 1.10.1+cu113
PyTorch compiling details: PyTorch built with:
  - GCC 7.3
  - C++ Version: 201402
  - Intel(R) Math Kernel Library Version 2020.0.0 Product Build 20191122 for 

In [7]:
import time
time.localtime()

time.struct_time(tm_year=2022, tm_mon=5, tm_mday=6, tm_hour=10, tm_min=37, tm_sec=46, tm_wday=4, tm_yday=126, tm_isdst=0)