##### 构建数据集

In [9]:
import numpy as np
from mmcv.transforms import to_tensor
from torch.utils.data import random_split 
from torchvision.datasets import MNIST
from mmengine.dataset import BaseDataset

class MnistDataset(BaseDataset):
    def __init__(self,data_root,pipeline,test_mode=False):
        if test_mode:#如果是训练模式
            mnist_full=MNIST(data_root,train=True,download=True)
            self.mnist_dataset,_=random_split(mnist_full,[55000,5000])
        else:
            self.mnist_dataset=MNIST(data_root,train=False,download=True)
        super().__init__(data_root=data_root,pipeline=pipeline,test_mode=test_mode)

    @staticmethod
    def totensor(img):
        if len(img.shape)<3:#单通道图像
            img=np.expand_dims(img,axis=-1)
        img=np.ascontiguousarray(img.transpose(2,0,1))#H,W,C->C,H,W;并使用内存连续的张量
        return to_tensor(img)
    
    def load_data_list(self):
        return [dict(inputs=self.totensor(np.array(x[0]))) for x in self.mnist_dataset]

In [10]:
dataset=MnistDataset('./data/',[])

In [11]:
import os
import torch
from mmengine.runner import Runner

NUM_WORKERS=int(os.cpu_count()/2)
BATCH_SIZE=256 if torch.cuda.is_available() else 64

train_dataloader=dict(
    num_workers=NUM_WORKERS,
    batch_size=BATCH_SIZE,
    persistent_workers=True,
    sampler=dict(type='DefaultSampler', shuffle=True),
    dataset=dataset
)

train_dataloader=Runner.build_dataloader(dataloader=train_dataloader)

In [12]:
import torch.nn as nn

#生成式网络
class Generator(nn.Module):
    def __init__(self, noise_size, img_shape):
        super().__init__()
        self.img_shape = img_shape
        self.noise_size = noise_size

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(noise_size, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh(),
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), *self.img_shape)
        return img

#判别式网络
class Discriminator(nn.Module):
    def __init__(self, img_shape):
        super().__init__()

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)

        return validity

In [13]:
generator = Generator(100, (1, 28, 28))
discriminator = Discriminator((1, 28, 28))

In [14]:
from mmengine.model import ImgDataPreprocessor
data_preprocessor=ImgDataPreprocessor(mean=([127.5]),std=([127.5]))

In [22]:
#GAN模型
import torch.nn.functional as F
from mmengine.model import BaseModel

def set_requires_grad(net,requires_grad=False):
    """锁定网络权重梯度"""
    if not isinstance(net,list):
        net=[net]
    for i in net:
        if i is not None:
            for param in i.parameters():
                param.requires_grad=requires_grad

class GAN(BaseModel):
    def __init__(self,generator,discriminator,noise_size,data_preprocessor):
        super().__init__(data_preprocessor=data_preprocessor)

        assert generator.noise_size==noise_size
        self.generator=generator
        self.discriminator=discriminator
        self.noise_size=noise_size

    def train_step(self,data,optim_wrapper):
        inputs_dict=data_preprocessor(data,True)
        #训练判别器
        disc_optimizer_wrapper=optim_wrapper['discriminator']
        with disc_optimizer_wrapper.optim_context(self.discriminator):
            log_vars=self.train_discriminator(inputs_dict,disc_optimizer_wrapper)

        #训练生成器
        set_requires_grad(self.discriminator,False)
        gen_optimzer_wrapper=optim_wrapper['generator']
        with gen_optimzer_wrapper.optim_context(self.generator):
            log_vars_gen=self.train_generator(inputs_dict,gen_optimzer_wrapper)

        set_requires_grad(self.discriminator,True)
        log_vars.update(log_vars_gen)

        return log_vars

    def forward(self,batch_inputs,data_samples=None,mode=None):
        return self.generator(batch_inputs)

    def train_discriminator(self,inputs,optimizer_wrapper):
        real_imgs=inputs['inputs']
        z=torch.randn((real_imgs.shape[0],self.noise_size)).type_as(real_imgs)
        with torch.no_grad():
            fake_imgs=self.generator(z)
        disc_pred_fake=self.discriminator(fake_imgs)
        disc_pred_real=self.discriminator(real_imgs)

        parsed_losses,log_vars=self.disc_loss(disc_pred_fake,disc_pred_real)
        optimizer_wrapper.update_params(parsed_losses)

        return log_vars

    def train_generator(self,inputs,optimizer_wrapper):
        real_imgs=inputs['inputs']
        z=torch.randn((real_imgs.shape[0],self.noise_size)).type_as(real_imgs)
        fake_imgs=self.generator(z)
        disc_pred_fake=self.discriminator(fake_imgs)
        parsed_losses,log_vars=self.gen_loss(disc_pred_fake)
        optimizer_wrapper.update_params(parsed_losses)

        return log_vars
    
    def disc_loss(self,disc_pred_fake,disc_pred_real):
        losses_dict=dict()
        losses_dict['loss_disc_fake']=F.binary_cross_entropy(disc_pred_fake,
                                    0.*torch.ones_like(disc_pred_fake))
        losses_dict['loss_disc_real']=F.binary_cross_entropy(disc_pred_real,
                                    1.*torch.ones_like(disc_pred_real))
        loss,log_vars=self.parse_losses(losses_dict)
        return loss,log_vars
    
    def gen_loss(self,disc_pred_fake):
        losses_dict=dict()
        losses_dict['loss_gen']=F.binary_cross_entropy(disc_pred_fake,
                                    1.*torch.ones_like(disc_pred_fake))
        
        loss,log_vars=self.parse_losses(losses_dict)

        return loss,log_vars


In [23]:

model = GAN(generator, discriminator, 100, data_preprocessor)


##### 构建优化器

In [24]:
from mmengine.optim import OptimWrapper,OptimWrapperDict

g_optim=torch.optim.Adam(generator.parameters(),lr=1e-4,betas=(0.5,0.999))
g_optimizer_wrapper=OptimWrapper(g_optim)

d_optim=torch.optim.Adam(discriminator.parameters(),lr=1e-4,betas=(0.5,0.999))
d_optimizer_wrapper=OptimWrapper(d_optim)

optimizer_wrapper=OptimWrapperDict(generator=g_optimizer_wrapper,discriminator=d_optimizer_wrapper)

In [25]:
#训练
from mmengine.runner import Runner
train_cfg=dict(by_epoch=True,max_epochs=220)
runner=Runner(
    model=model,
    work_dir='./work_dir/gan',
    train_cfg=train_cfg,
    train_dataloader=train_dataloader,
    optim_wrapper=optimizer_wrapper
)


runner.train()

08/10 06:48:27 - mmengine - [4m[97mINFO[0m - 
------------------------------------------------------------
System environment:
    sys.platform: linux
    Python: 3.10.4 | packaged by conda-forge | (main, Mar 24 2022, 17:39:04) [GCC 10.3.0]
    CUDA available: True
    numpy_random_seed: 802111463
    GPU 0,1: GeForce GTX 1080
    CUDA_HOME: /usr/local/cuda
    NVCC: Cuda compilation tools, release 10.2, V10.2.8
    GCC: gcc (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
    PyTorch: 1.12.0+cu102
    PyTorch compiling details: PyTorch built with:
  - GCC 7.3
  - C++ Version: 201402
  - Intel(R) Math Kernel Library Version 2020.0.0 Product Build 20191122 for Intel(R) 64 architecture applications
  - Intel(R) MKL-DNN v2.6.0 (Git Hash 52b5f107dd9cf10910aaa19cb47f3abf9b349815)
  - OpenMP 201511 (a.k.a. OpenMP 4.5)
  - LAPACK is enabled (usually provided by MKL)
  - NNPACK is enabled
  - CPU capability usage: AVX2
  - CUDA Runtime 10.2
  - NVCC architecture flags: -gencode;arch=compute_37,code=sm

GAN(
  (data_preprocessor): ImgDataPreprocessor()
  (generator): Generator(
    (model): Sequential(
      (0): Linear(in_features=100, out_features=128, bias=True)
      (1): LeakyReLU(negative_slope=0.2, inplace=True)
      (2): Linear(in_features=128, out_features=256, bias=True)
      (3): BatchNorm1d(256, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
      (4): LeakyReLU(negative_slope=0.2, inplace=True)
      (5): Linear(in_features=256, out_features=512, bias=True)
      (6): BatchNorm1d(512, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
      (7): LeakyReLU(negative_slope=0.2, inplace=True)
      (8): Linear(in_features=512, out_features=1024, bias=True)
      (9): BatchNorm1d(1024, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
      (10): LeakyReLU(negative_slope=0.2, inplace=True)
      (11): Linear(in_features=1024, out_features=784, bias=True)
      (12): Tanh()
    )
  )
  (discriminator): Discriminator(
    (model): Sequenti

In [27]:
#验证结果
z=torch.randn(64,100).cuda()
img=model(z)

from torchvision.utils import save_image
save_image(img,'work_dir/gan_result/result.jpg',normalize=True)