In [1]:
import torch.nn as nn
from torchvision import transforms
from torchvision import datasets
from torchvision.utils import save_image
import torch.autograd
from torch.autograd import Variable
import os

In [2]:
#创建数据集文件夹
if not os.path.exists('./gan_img'):
    os.mkdir('./gan_img')

In [3]:
## 用于把生成的图片还原
def to_img(x):
    out = 0.5*(x+1)
    out = out.clamp(0,1)         # Clamp函数可以把超过max或者低于min的数切割到max或min
    out = out.view(-1, 1, 28, 28)
    return out

In [18]:
batch_size = 128 #一批128个
num_epoch = 100 #总共100批
z_dimension = 100 #噪音维度

In [5]:
#图形的处理过程
img_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5,),std=(0.5,))
])

In [6]:
# mnist dataset mnist数据集下载
mnist = datasets.MNIST(
    root='./data', train=True, transform = img_transform, download = True
)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data\MNIST\raw\train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:03<00:00, 2716828.71it/s]


Extracting ./data\MNIST\raw\train-images-idx3-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data\MNIST\raw\train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 1495317.79it/s]


Extracting ./data\MNIST\raw\train-labels-idx1-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data\MNIST\raw\t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:02<00:00, 720014.67it/s] 


Extracting ./data\MNIST\raw\t10k-images-idx3-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 91884.17it/s]

Extracting ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw






In [7]:
# data loader 数据载入(批次读取)
dataloader = torch.utils.data.DataLoader(
    dataset = mnist, batch_size = batch_size, shuffle = True
)

#### 查看数据 ####
a = iter(dataloader)
print(next(a)[0].shape)
print(next(a)[1])   # 标签

torch.Size([128, 1, 28, 28])
tensor([0, 9, 2, 5, 2, 8, 3, 7, 4, 5, 7, 4, 1, 2, 7, 1, 7, 9, 2, 2, 2, 4, 9, 5,
        4, 5, 0, 0, 2, 4, 9, 1, 7, 0, 6, 3, 4, 3, 5, 9, 6, 6, 0, 4, 3, 2, 7, 4,
        1, 0, 9, 9, 6, 7, 0, 0, 1, 9, 5, 2, 9, 9, 5, 1, 9, 5, 7, 1, 5, 2, 9, 3,
        9, 7, 7, 3, 8, 0, 1, 6, 4, 7, 2, 7, 3, 7, 1, 4, 9, 7, 0, 0, 3, 8, 0, 7,
        1, 0, 4, 6, 6, 3, 2, 8, 3, 1, 3, 0, 9, 7, 8, 7, 3, 3, 6, 7, 5, 4, 9, 7,
        1, 8, 8, 0, 8, 7, 1, 2])


In [8]:
class discriminator(nn.Module):
    def __init__(self):
        super(discriminator,self).__init__()
        self.dis = nn.Sequential(
            nn.Linear(784,256), #输入特征数为784，输出为256
            nn.LeakyReLU(0.2),  #进行非线性映射
            nn.Linear(256,256), #进行一个线性映射
            nn.LeakyReLU(0.2),
            nn.Linear(256,1),
            nn.Sigmoid()        #也是一个激活函数，二分类问题中，
            # 二分类用：sigmoid函数，可以把实数映射到[0,1]，作为概率值
            # 多分类用：softmax函数
        )
    def forward(self, x):
        x = self.dis(x)
        return x.squeeze(-1)


In [9]:
class generator(nn.Module):
    
    def __init__(self):
        super(generator, self).__init__()
        self.gen = nn.Sequential(
            nn.Linear(100, 256), #用线性变换将输入映射到256维
            nn.ReLU(True),       #relu激活
            nn.Linear(256, 256), #线性变换
            nn.ReLU(True),       #relu激活
            nn.Linear(256, 784), #线性变换
            nn.Tanh()            #Tanh激活使得生成数据分布在[-1,1]之间
        )
 
    def forward(self, x):
        x = self.gen(x)
        return x.squeeze(-1)


In [11]:
#创建对象
D = discriminator()
G = generator()
if torch.cuda.is_available():
    D = D.cuda()
    G = G.cuda()
    print()




In [12]:
criterion = nn.BCELoss() #是单目标二分类交叉熵函数
d_optimizer = torch.optim.Adam(D.parameters(),lr=0.0003)
g_optimizer = torch.optim.Adam(G.parameters(),lr=0.0003)

In [19]:
# 加载预训练模型
G = torch.load('./generator.pth')
D = torch.load('./discriminator.pth')
# 进入训练##判别器的判断过程

for epoch in range(num_epoch): 
    for i,(img, _) in enumerate(dataloader):
        num_img = img.size(0)  # batch_size
        # 训练判别器
        img = img.view(num_img, -1)  # 将图片展开为28*28=784，(batch_size,784)
        real_img = Variable(img).cuda()  # 将tensor变成Variable放入计算图中
        real_label = Variable(torch.ones(num_img)).cuda()  # 定义真实的图片label为1
        fake_label = Variable(torch.zeros(num_img)).cuda()  # 定义假的图片的label为0

        # 计算真实图片的损失
        real_out = D(real_img)  # 将真实图片放入判别器中,返回模型认为图片是1的概率(batch_size,)
        d_loss_real = criterion(real_out, real_label)  # 得到真实图片的loss
        real_scores = real_out  # 得到真实图片的判别值，输出的值越接近1越好(batch_size,)

        # 计算假的图片的损失
        z = Variable(torch.randn(num_img, z_dimension)).cuda()  # 随机生成一些噪声
        fake_img = G(z)  # 随机噪声放入生成网络中，生成一张假的图片
        fake_out = D(fake_img)  # 判别器判断假的图片
        d_loss_fake = criterion(fake_out, fake_label)  # 得到假的图片的loss
        fake_scores = fake_out  # 得到假图片的判别值，对于判别器来说，假图片的损失越接近0越好(batch_size,)

        # 损失函数和优化
        d_loss = d_loss_real + d_loss_fake # 损失包括判真损失和判假损失
        d_optimizer.zero_grad()  # 在反向传播之前，先将梯度归0
        d_loss.backward()  # 将误差反向传播
        d_optimizer.step()  # 只更新了判别器的更新参数

        # 训练生成器
        if i%2 == 1:     # 判别器更新2轮，生成器更新1轮
            z = Variable(torch.randn(num_img, z_dimension)).cuda()  # 得到随机噪声
            fake_img = G(z) # 随机噪声输入到生成器中，得到一副假的图片
            output = D(fake_img)  # 经过判别器得到的结果
            g_loss = criterion(output, real_label)  # 得到的假的图片与真实的图片的label的loss
    
            # bp and optimize
            g_optimizer.zero_grad()  # 梯度归0
            g_loss.backward()  # 进行反向传播
            g_optimizer.step()  # 只更新了生成网络的参数

        #打印中间的损失
        if (i+1)%100 == 0:
            print('Epoch[{}/{}],d_loss:{:.6f},g_loss:{:.6f} ''D real: {:.6f},D fake: {:.6f}'.format(
                epoch,num_epoch,d_loss.item(),g_loss.item(),
                real_scores.data.mean(),fake_scores.data.mean()  # 打印的是真实图片的损失均值
            ))

        if epoch == 0:
            real_images=to_img(real_img.cpu().data)
            save_image(real_images, './gan_img/real_images.png')
        if epoch%10 == 0:
            fake_images = to_img(fake_img.cpu().data)
            save_image(fake_images, './gan_img/fake_images-{}.png'.format(epoch+1))

Epoch[0/100],d_loss:0.650117,g_loss:2.183088 D real: 0.806861,D fake: 0.167287
Epoch[0/100],d_loss:0.504957,g_loss:2.128167 D real: 0.855262,D fake: 0.174525
Epoch[0/100],d_loss:0.536496,g_loss:2.130408 D real: 0.836221,D fake: 0.178357
Epoch[0/100],d_loss:0.753064,g_loss:2.202585 D real: 0.803427,D fake: 0.207039
Epoch[1/100],d_loss:0.535368,g_loss:2.226351 D real: 0.827991,D fake: 0.157820
Epoch[1/100],d_loss:0.657019,g_loss:2.312122 D real: 0.802359,D fake: 0.172030
Epoch[1/100],d_loss:0.478254,g_loss:2.212487 D real: 0.837762,D fake: 0.149263
Epoch[1/100],d_loss:0.569643,g_loss:2.368406 D real: 0.836110,D fake: 0.168625
Epoch[2/100],d_loss:0.524208,g_loss:2.168599 D real: 0.860115,D fake: 0.168562
Epoch[2/100],d_loss:0.627839,g_loss:2.310429 D real: 0.820354,D fake: 0.189048
Epoch[2/100],d_loss:0.651681,g_loss:2.042801 D real: 0.799388,D fake: 0.163930
Epoch[2/100],d_loss:0.710351,g_loss:2.121243 D real: 0.806579,D fake: 0.191685
Epoch[3/100],d_loss:0.529927,g_loss:2.076138 D real:

In [17]:
# torch.save(G.state_dict(),'./generator.pth')
# torch.save(D.state_dict(),'./discriminator.pth')
torch.save(G,'./generator.pth')
torch.save(D,'./discriminator.pth')