In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets
from torchvision import transforms
from torchvision.utils import save_image
from torch.autograd import Variable
import os

In [2]:
batch_size=96
epoch_num=1000
z_size=110

In [3]:
transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5],std=[0.5])
])

In [4]:
mnist=datasets.MNIST('./data/',train=True,transform=transform,download=True)
dataloader=torch.utils.data.DataLoader(mnist,batch_size=batch_size,shuffle=True)

In [5]:
def to_img(x):
    out=0.5*(x+1)
    out=out.clamp(0,1)
    out=out.view(-1,1,28,28)
    return out

In [6]:
def save_img(real,fake,n):
    imgr=to_img(real)
    imgr=imgr.numpy()
    imgf=to_img(fake)
    imgf=imgf.numpy()
    empty=[]
    count=0
    for i,j in zip(imgr,imgf):
        empty.append(i)
        empty.append(j)
        empty.append(np.zeros([1,28,28]))
        count+=1
        if count%8==0:
            empty.pop()
    empty=np.array(empty)
    empty=torch.from_numpy(empty)
    save_image(empty,'./CGAN_MNIST-img/R&F_Epoch-{}.png'.format(n),nrow=23)

In [7]:
class discriminator(nn.Module):
    def __init__(self):
        super(discriminator,self).__init__()
        self.dis = nn.Sequential(
            nn.Linear(784,256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 10), 
            nn.Sigmoid()
        )
    
    def forward(self,x):
        x=self.dis(x)
        return x

In [8]:
class generator(nn.Module):
    def __init__(self):
        super(generator,self).__init__()
        self.gen = nn.Sequential(
            nn.Linear(110, 256),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.5),
            nn.Linear(256, 256), 
            nn.LeakyReLU(0.2), 
            nn.Dropout(0.5),
            nn.Linear(256, 784), 
            nn.Tanh()
            )
    
    def forward(self,x):
        x=self.gen(x)
        return x

In [9]:
G=generator()
D=discriminator()
criterion = nn.BCELoss()
g_optimizer=torch.optim.Adam(G.parameters(),lr=0.0003)
d_optimizer=torch.optim.SGD(D.parameters(),lr=0.0003)

In [10]:
loss_G=[]
loss_D=[]
for epoch in range(epoch_num):
    for i,(img,label) in enumerate(dataloader):
        num_img=img.size(0)
        
        img=img.view(num_img,-1)
        #真实数据添加随训练衰弱的噪声
        for simg in img:
            simg+=torch.randn([784])*(1-(epoch+1)/epoch_num)
        real_img=Variable(img)
        #添加了随机浮动的onehot
        label_onehot=np.zeros((num_img,10))
        label_onehot[np.arange(num_img),label.numpy()]=np.random.randint(7,12)/10

        real_label=Variable(torch.from_numpy(label_onehot).float())
        fake_label=Variable(torch.from_numpy(np.random.randint(0,4,(num_img,10))/10).float())

        real_out=D(real_img)
        d_loss_real=criterion(real_out,real_label)
        real_scores=d_loss_real

        z=Variable(torch.randn((num_img,z_size)))
        z=G(z)
        fake_out=D(z)
        d_loss_fake=criterion(fake_out,fake_label)
        fake_scores=d_loss_fake

        d_loss=fake_scores+real_scores
        d_optimizer.zero_grad()
        d_loss.backward()
        d_optimizer.step()

        loss_D.append(d_loss.item())

        z=torch.randn((num_img,100))
        z=np.concatenate((z.numpy(),label_onehot),axis=1)
        z=Variable(torch.from_numpy(z).float())
        fake_img=G(z)
        output=D(fake_img)

        g_loss=criterion(output,real_label)
        g_optimizer.zero_grad()
        g_loss.backward()
        g_optimizer.step()

        loss_G.append(g_loss.item())

    if (epoch+1)%50==0 :
        print('Epoch:{}/{}, Discriminator loss:{:.6f}, Generator loss:{:.6f},'.format(
            epoch+1,epoch_num,d_loss.item(),g_loss.item())
        )
        real_img=real_img.data
        fake_img=fake_img.data
        save_img(real_img,fake_img,epoch+1)
    
    if (epoch+1)%100==0 :
        plt.subplot(1,2,1)
        plt.plot(np.arange(len(loss_G)),loss_G)
        plt.title('G_loss')
        plt.subplot(1,2,2)
        plt.plot(np.arange(len(loss_D)),loss_D)
        plt.title('D_loss')
        plt.savefig('./CGAN_MNIST-img/Loss_epoch{}.png'.format(epoch+1))
        plt.close('all')
torch.save(G.state_dict(),'./generator.pth')
torch.save(D.state_dict(), './discriminator.pth')