https://blog.csdn.net/weixin_41278720/article/details/80861284
简单网络（非卷积），训练快

In [1]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets
from torchvision import transforms
from torchvision.utils import save_image
from torch.autograd import Variable
import os
 
if not os.path.exists('./img'):
    os.mkdir('./img')
 
 
def to_img(x):
    out = 0.5 * (x + 1)
    out = out.clamp(0, 1)
    out = out.view(-1, 1, 28, 28)
    return out
 
 
batch_size = 128
num_epoch = 5
z_dimension = 100
 
# Image processing
img_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])
# MNIST dataset
mnist = datasets.MNIST(
    root='./data/', train=True, transform=img_transform, download=True)
# Data loader
dataloader = torch.utils.data.DataLoader(
    dataset=mnist, batch_size=batch_size, shuffle=True)
 
 
# Discriminator
class discriminator(nn.Module):
    def __init__(self):
        super(discriminator, self).__init__()
        self.dis = nn.Sequential(
            nn.Linear(784, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid())
 
    def forward(self, x):
        x = self.dis(x)
        return x
 
 
# Generator
class generator(nn.Module):
    def __init__(self):
        super(generator, self).__init__()
        self.gen = nn.Sequential(
            nn.Linear(100, 256),
            nn.ReLU(True),
            nn.Linear(256, 256), 
            nn.ReLU(True), 
            nn.Linear(256, 784), 
            nn.Tanh())
 
    def forward(self, x):
        x = self.gen(x)
        return x
 
 
D = discriminator()
G = generator()
if torch.cuda.is_available():
    D = D.cuda()
    G = G.cuda()
# Binary cross entropy loss and optimizer
criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0003)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0003)
 
# Start training
for epoch in range(num_epoch):
    for i, (img, _) in enumerate(dataloader):
        num_img = img.size(0)
        # =================train discriminator
        img = img.view(num_img, -1)
        if torch.cuda.is_available():
            real_img = Variable(img).cuda()
            real_label = Variable(torch.ones(num_img)).cuda()
            fake_label = Variable(torch.zeros(num_img)).cuda()
            
        real_img = Variable(img)
        real_label = Variable(torch.ones(num_img))
        fake_label = Variable(torch.zeros(num_img))
 
        # compute loss of real_img
        real_out = D(real_img)
        d_loss_real = criterion(real_out, real_label)
        real_scores = real_out  # closer to 1 means better
 
        # compute loss of fake_img
        if torch.cuda.is_available():
            z = Variable(torch.randn(num_img, z_dimension)).cuda()
            
        z = Variable(torch.randn(num_img, z_dimension))
        fake_img = G(z)
        fake_out = D(fake_img)
        d_loss_fake = criterion(fake_out, fake_label)
        fake_scores = fake_out  # closer to 0 means better
 
        # bp and optimize
        d_loss = d_loss_real + d_loss_fake
        d_optimizer.zero_grad()
        d_loss.backward()
        d_optimizer.step()
 
        # ===============train generator
        # compute loss of fake_img
        if torch.cuda.is_available():
            z = Variable(torch.randn(num_img, z_dimension)).cuda()
            
        z = Variable(torch.randn(num_img, z_dimension))
        fake_img = G(z)
        output = D(fake_img)
        g_loss = criterion(output, real_label)
 
        # bp and optimize
        g_optimizer.zero_grad()
        g_loss.backward()
        g_optimizer.step()
 
        if (i + 1) % 100 == 0:
            print('Epoch [{}/{}], d_loss: {:.6f}, g_loss: {:.6f} '
                  'D real: {:.6f}, D fake: {:.6f}'.format(
                      epoch, num_epoch, d_loss.data[0], g_loss.data[0],
                      real_scores.data.mean(), fake_scores.data.mean()))
    if epoch == 0:
        real_images = to_img(real_img.cpu().data)
        save_image(real_images, './img/real_images.png')
 
    fake_images = to_img(fake_img.cpu().data)
    save_image(fake_images, './img/fake_images-{}.png'.format(epoch + 1))

torch.save(G.state_dict(), './generator.pth')
torch.save(D.state_dict(), './discriminator.pth')


  "Please ensure they have the same size.".format(target.size(), input.size()))


Epoch [0/5], d_loss: 0.221809, g_loss: 3.154774 D real: 0.946541, D fake: 0.148741
Epoch [0/5], d_loss: 0.043806, g_loss: 4.450689 D real: 0.993143, D fake: 0.036101
Epoch [0/5], d_loss: 0.347796, g_loss: 5.525812 D real: 0.940620, D fake: 0.204275
Epoch [0/5], d_loss: 0.059543, g_loss: 5.930358 D real: 0.976590, D fake: 0.031998


  "Please ensure they have the same size.".format(target.size(), input.size()))


Epoch [1/5], d_loss: 0.352019, g_loss: 3.756926 D real: 0.936991, D fake: 0.213836
Epoch [1/5], d_loss: 0.267467, g_loss: 6.957250 D real: 0.931404, D fake: 0.067651
Epoch [1/5], d_loss: 0.110057, g_loss: 6.036000 D real: 0.970200, D fake: 0.058013
Epoch [1/5], d_loss: 0.065901, g_loss: 6.341205 D real: 0.964275, D fake: 0.020537
Epoch [2/5], d_loss: 0.261443, g_loss: 6.301679 D real: 0.942835, D fake: 0.110214
Epoch [2/5], d_loss: 0.098673, g_loss: 6.742813 D real: 0.963223, D fake: 0.039628
Epoch [2/5], d_loss: 0.157586, g_loss: 7.525417 D real: 0.923961, D fake: 0.020506
Epoch [2/5], d_loss: 0.149618, g_loss: 4.467122 D real: 0.945189, D fake: 0.026485
Epoch [3/5], d_loss: 0.373843, g_loss: 5.177253 D real: 0.850760, D fake: 0.060399
Epoch [3/5], d_loss: 1.296942, g_loss: 2.754641 D real: 0.692055, D fake: 0.324221
Epoch [3/5], d_loss: 0.792131, g_loss: 3.213443 D real: 0.740964, D fake: 0.187717
Epoch [3/5], d_loss: 1.836372, g_loss: 2.281726 D real: 0.594382, D fake: 0.438778
Epoc

卷积网络版

In [None]:
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision import datasets
from torchvision.utils import save_image
import os
 
if not os.path.exists('./dc_img'):
    os.mkdir('./dc_img')
 
 
def to_img(x):
    out = 0.5 * (x + 1)
    out = out.clamp(0, 1)
    out = out.view(-1, 1, 28, 28)
    return out
 
 
batch_size = 128
num_epoch = 5
z_dimension = 100  # noise dimension
 
img_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
 
mnist = datasets.MNIST('./data', transform=img_transform)
dataloader = DataLoader(mnist, batch_size=batch_size, shuffle=True,
                        num_workers=4)
 
 
class discriminator(nn.Module):
    def __init__(self):
        super(discriminator, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 32, 5, padding=2),  # batch, 32, 28, 28
            nn.LeakyReLU(0.2, True),
            nn.AvgPool2d(2, stride=2),  # batch, 32, 14, 14
            )
        self.conv2 = nn.Sequential(
            nn.Conv2d(32, 64, 5, padding=2),  # batch, 64, 14, 14
            nn.LeakyReLU(0.2, True),
            nn.AvgPool2d(2, stride=2)  # batch, 64, 7, 7
        )
        self.fc = nn.Sequential(
            nn.Linear(64*7*7, 1024),
            nn.LeakyReLU(0.2, True),
            nn.Linear(1024, 1),
            nn.Sigmoid()
        )
 
    def forward(self, x):
        '''
        x: batch, width, height, channel=1
        '''
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x
 
 
class generator(nn.Module):
    def __init__(self, input_size, num_feature):
        super(generator, self).__init__()
        self.fc = nn.Linear(input_size, num_feature)  # batch, 3136=1x56x56
        self.br = nn.Sequential(
            nn.BatchNorm2d(1),
            nn.ReLU(True)
        )
        self.downsample1 = nn.Sequential(
            nn.Conv2d(1, 50, 3, stride=1, padding=1),  # batch, 50, 56, 56
            nn.BatchNorm2d(50),
            nn.ReLU(True)
        )
        self.downsample2 = nn.Sequential(
            nn.Conv2d(50, 25, 3, stride=1, padding=1),  # batch, 25, 56, 56
            nn.BatchNorm2d(25),
            nn.ReLU(True)
        )
        self.downsample3 = nn.Sequential(
            nn.Conv2d(25, 1, 2, stride=2),  # batch, 1, 28, 28
            nn.Tanh()
        )
 
    def forward(self, x):
        x = self.fc(x)
        x = x.view(x.size(0), 1, 56, 56)
        x = self.br(x)
        x = self.downsample1(x)
        x = self.downsample2(x)
        x = self.downsample3(x)
        return x
 
 
D = discriminator()  # discriminator model
G = generator(z_dimension, 3136)  # generator model
 
criterion = nn.BCELoss()  # binary cross entropy
 
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0003)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0003)
 
# train
for epoch in range(num_epoch):
    for i, (img, _) in enumerate(dataloader):
        num_img = img.size(0)
        # =================train discriminator
        real_img = Variable(img)
        real_label = Variable(torch.ones(num_img))
        fake_label = Variable(torch.zeros(num_img))
 
        # compute loss of real_img
        real_out = D(real_img)
        d_loss_real = criterion(real_out, real_label)
        real_scores = real_out  # closer to 1 means better
 
        # compute loss of fake_img
        z = Variable(torch.randn(num_img, z_dimension))
        fake_img = G(z)
        fake_out = D(fake_img)
        d_loss_fake = criterion(fake_out, fake_label)
        fake_scores = fake_out  # closer to 0 means better
 
        # bp and optimize
        d_loss = d_loss_real + d_loss_fake
        d_optimizer.zero_grad()
        d_loss.backward()
        d_optimizer.step()
 
        # ===============train generator
        # compute loss of fake_img
        z = Variable(torch.randn(num_img, z_dimension))
        fake_img = G(z)
        output = D(fake_img)
        g_loss = criterion(output, real_label)
 
        # bp and optimize
        g_optimizer.zero_grad()
        g_loss.backward()
        g_optimizer.step()
 
        if (i+1) % 100 == 0:
            print('Epoch [{}/{}], d_loss: {:.6f}, g_loss: {:.6f} '
                  'D real: {:.6f}, D fake: {:.6f}'
                  .format(epoch, num_epoch, d_loss.data[0], g_loss.data[0],
                          real_scores.data.mean(), fake_scores.data.mean()))
    if epoch == 0:
        real_images = to_img(real_img.cpu().data)
        save_image(real_images, './dc_img/real_images.png')
 
    fake_images = to_img(fake_img.cpu().data)
    save_image(fake_images, './dc_img/fake_images-{}.png'.format(epoch+1))

torch.save(G.state_dict(), './generator.pth')
torch.save(D.state_dict(), './discriminator.pth')


  "Please ensure they have the same size.".format(target.size(), input.size()))
