In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets
from torchvision import transforms
from torchvision.utils import save_image
from torch.autograd import Variable
import os
import time

In [2]:
batch_size=128
epoch_num=10
z_size=100
input_size=3136

In [3]:
transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5],std=[0.5])
])

In [4]:
mnist=datasets.MNIST('./data/',train=True,transform=transform,download=True)
dataloader=torch.utils.data.DataLoader(mnist,batch_size=batch_size,shuffle=True)

In [5]:
def to_img(x):
    out=0.5*(x+1)
    out=out.clamp(0,1)
    out=out.view(-1,1,28,28)
    return out

In [6]:
def save_img(real,fake,n):
    imgr=to_img(real)
    imgr=imgr.numpy()
    imgf=to_img(fake)
    imgf=imgf.numpy()
    empty=[]
    count=0
    for i,j in zip(imgr,imgf):
        empty.append(i)
        empty.append(j)
        empty.append(np.zeros([1,28,28]))
        count+=1
        if count%8==0:
            empty.pop()
        if count==96:
            break
    empty=np.array(empty)
    empty=torch.from_numpy(empty)
    save_image(empty,'./DCGAN-MNIST-img/R&F_Epoch-{}.png'.format(n),nrow=23)

In [7]:
class discriminator(nn.Module):
    def __init__(self):
        super(discriminator,self).__init__()
        self.dis=nn.Sequential(
            nn.Conv2d(1,32,5,stride=1,padding=2),
            nn.LeakyReLU(0.2,True),
            nn.MaxPool2d((2,2)),
 
            nn.Conv2d(32,64,5,stride=1,padding=2),
            nn.LeakyReLU(0.2,True),
            nn.MaxPool2d((2,2))
        )
        self.fc=nn.Sequential(
            nn.Linear(7 * 7 * 64, 1024),
            nn.LeakyReLU(0.2, True),
            nn.Linear(1024, 1),
            nn.Sigmoid()
        )
    def forward(self, x):
        x=self.dis(x)
        x=x.view(x.size(0),-1)
        x=self.fc(x)
        return x


In [8]:
class generator(nn.Module):
    def __init__(self,input_size,num_feature):
        super(generator,self).__init__()
        self.fc=nn.Linear(input_size,num_feature) #1*56*56
        self.br=nn.Sequential(
            nn.BatchNorm2d(1),
            nn.LeakyReLU(0.2,True)
        )
        self.gen=nn.Sequential(
            nn.Conv2d(1,50,3,stride=1,padding=1),
            nn.BatchNorm2d(50),
            nn.LeakyReLU(0.2,True),
 
            nn.Conv2d(50,25,3,stride=1,padding=1),
            nn.BatchNorm2d(25),
            nn.LeakyReLU(0.2,True),
 
            nn.Conv2d(25,1,2,stride=2),
            nn.Tanh()
        )
    def forward(self, x):
        x=self.fc(x)
        x=x.view(x.size(0),1,56,56)
        x=self.br(x)
        x=self.gen(x)
        return x


In [9]:
G=generator(z_size,input_size)
D=discriminator()
criterion = nn.BCELoss()
g_optimizer=torch.optim.Adam(G.parameters(),lr=0.0002)
d_optimizer=torch.optim.SGD(D.parameters(),lr=0.0002)

In [10]:
loss_G=[]
loss_D=[]

for epoch in range(epoch_num):
    start2 = time.time()
    start1 = time.time()
    for i,(img,label) in enumerate(dataloader):
        
        num_img=img.size(0)
        
        #img=img.view(num_img,-1)
        '''
        #噪声
        for simg in img:
            simg+=torch.randn([784])*(1-(epoch+1)/epoch_num)
        '''
        for s in img:
            s+=torch.randn([28,28])*(1-(epoch+1)/epoch_num)
        real_img=Variable(img)

        real_label=Variable(torch.ones(num_img))
        fake_label=Variable(torch.zeros(num_img))
        
        real_out=D(real_img)
        d_loss_real=criterion(real_out,real_label)
        real_scores=d_loss_real

        z=Variable(torch.randn((num_img,z_size)))
        z=G(z)
        fake_out=D(z)
        d_loss_fake=criterion(fake_out,fake_label)
        fake_scores=d_loss_fake

        d_loss=fake_scores+real_scores
        d_optimizer.zero_grad()
        d_loss.backward()
        d_optimizer.step()

        loss_D.append(d_loss.item())

        z=Variable(torch.randn((num_img,z_size)))
        fake_img=G(z)
        output=D(fake_img)

        g_loss=criterion(output,real_label)
        g_optimizer.zero_grad()
        g_loss.backward()
        g_optimizer.step()

        loss_G.append(g_loss.item())

        tm = time.time() - start1
        print("\rEpoch: {:d} batch: {:d}/{:d} | {:.2%}, Time:{:.2f}"\
            .format(epoch+1, i+1, len(dataloader), (i+1)*1.0/len(dataloader),tm), end='')

    
    if (epoch+1)%1==0:
        tm = time.time() - start2
        start2 = time.time()
        print('\rEpoch:{}/{}, Discriminator loss:{:.6f}, Generator loss:{:.6f}, Time:{:.2f}\n'.format(
            epoch+1,epoch_num,d_loss.item(),g_loss.item(),tm)
        )
        real_img=real_img.data
        fake_img=fake_img.data
        save_img(real_img,fake_img,epoch+1)
    
    if (epoch+1)%1==0 :
        plt.subplot(2,1,1)
        plt.plot(np.arange(len(loss_G)),loss_G)
        plt.title('G_loss')
        plt.subplot(2,1,2)
        plt.plot(np.arange(len(loss_D)),loss_D)
        plt.title('D_loss')
        plt.savefig('./DCGAN-MNIST-img/Loss_epoch{}.png'.format(epoch+1))
        plt.close('all')

torch.save(G.state_dict(),'./generator.pth')
torch.save(D.state_dict(), './discriminator.pth')

  return F.binary_cross_entropy(input, target, weight=self.weight, reduction=self.reduction)
  return F.binary_cross_entropy(input, target, weight=self.weight, reduction=self.reduction)


ValueError: Target and input must have the same number of elements. target nelement (128) != input nelement (96)