In [None]:
'''
@ author: haijun xiong
@ date  : 2021/10/13
'''
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import torch.autograd as autograd
from tqdm import tqdm
from torchvision.utils import save_image

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=5, stride=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(32, 64, kernel_size=5, stride=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Flatten(start_dim=1),
            nn.Linear(64 * 4 *4, 512),
            nn.ReLU(inplace=True),
            nn.Linear(512, 1),
            nn.Sigmoid()
        )
    def forward(self, x):
        x = self.conv1(x)
        return x

# 定义生成器
class Generator(nn.Module):
    def __init__(self, z_dimension):
        super(Generator, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(in_features=z_dimension, out_features=1024, bias=True),
            nn.BatchNorm1d(1024),
            nn.ReLU(True),
            nn.Linear(in_features=1024, out_features=128 * 7 * 7, bias=True),
            nn.BatchNorm1d(128 * 7 * 7),
            nn.ReLU(True),
        )
        self.conv = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 1 , kernel_size=4, stride=2, padding=1),
            nn.Tanh(),
        )
    def forward(self, inputs):
        return self.conv(self.fc(inputs).view(-1, 128, 7, 7))

In [None]:
def getdataset(batch_size=256):
    transform = transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize(mean=0.5, std=0.5)
                ])
    trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)
    return train_loader

In [None]:
class Model:
    def __init__(self, train_data, lr, epoch, z_dimension):
        self.train_loader = train_data
        self.lr = lr
        self.epoch = epoch
        self.z_dimension = z_dimension
        self.D = Discriminator().float()
        self.D.cuda()
        self.G = Generator(self.z_dimension).float()
        self.G.cuda()
        self.loss = nn.BCELoss()
        self.loss.cuda()
        self.D_optimizer = optim.Adam(self.D.parameters(), lr=self.lr, betas=(0.5, 0.999))
        self.G_optimizer = optim.Adam(self.G.parameters(), lr=self.lr, betas=(0.5, 0.999))

    def Var(self, x):
        x = autograd.Variable(x).cuda()
        return x

    def fit(self):
        for epoch in range(self.epoch):
            self.D.train()
            self.G.train()
            all_D_loss = []
            all_G_loss = []
            for (inputs, targets) in tqdm(self.train_loader):
                
                self.D_optimizer.zero_grad()
                self.G_optimizer.zero_grad()

                inputs, targets = self.Var(inputs), self.Var(targets)
                batch = targets.size(0)

                real_labels = torch.ones((batch, 1)).cuda()
                fake_labels = torch.zeros((batch, 1)).cuda()

                real_outputs = self.D(inputs)
                D_real_loss = self.loss(real_outputs, real_labels)

                z = torch.randn((batch, self.z_dimension)).cuda()
                fake_img = self.G(z)
                fake_outputs = self.D(fake_img)
                D_fake_loss = self.loss(fake_outputs, fake_labels)
                D_loss = (D_real_loss + D_fake_loss) / 2
                D_loss.backward()
                self.D_optimizer.step()

                z = torch.randn((batch, self.z_dimension)).cuda()
                fake_img = self.G(z)
                fake_outputs = self.D(fake_img)

                G_loss = self.loss(fake_outputs, real_labels)
                G_loss.backward()
                self.G_optimizer.step()

                all_D_loss.append(D_loss.item())
                all_G_loss.append(G_loss.item())
            print('Epoch {}, d_loss: {:.6f}, g_loss: {:.6f} '
            'D real: {:.6f}, D fake: {:.6f}'.format
            (epoch, torch.mean(torch.tensor(all_D_loss)), torch.mean(torch.tensor(all_G_loss)),
            torch.mean(real_outputs), torch.mean(fake_outputs)))
            self.generate_synthetic_images('./result/synthetic_images_{}.png'.format(epoch))
        self.save()

    def save(self):
        torch.save(self.G.state_dict(), './generator.pth')
        torch.save(self.D.state_dict(), './discriminator.pth')

    def generate_synthetic_images(self, ro):
        self.G.eval()
        num_img = 100
        z = torch.randn((num_img, self.z_dimension)).cuda()
        fake_img = self.G(z)
        save_image(fake_img, ro, nrow=10)

In [None]:
train_loader = getdataset()
model = Model(train_loader, lr=1e-4, epoch=200, z_dimension=100)

In [None]:
model.fit()

In [None]:
model.generate_synthetic_images('./result/synthetic_images.png')