In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import torchvision
import torchvision.transforms as transforms

import torch.optim as optim

import matplotlib.pyplot as plt
import numpy as np

import os.path

In [2]:
batch_size = 100
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])

trainset = torchvision.datasets.MNIST('../data',train=True,download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)

In [3]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator,self).__init__()
        self.features = nn.Sequential(
            nn.Linear(196,256),
            nn.ReLU(),
            nn.Linear(256,512),
            nn.ReLU(),
            nn.Linear(512,1024),
            nn.ReLU(),
            nn.Linear(1024,784),
            nn.Sigmoid()
        )
        
    def forward(self,x):
        x = x.view(batch_size, 196)
        x = self.features(x)
        return x
    
Gnr = Generator()

In [4]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.features = nn.Sequential(
            nn.Linear(784,1024),
            nn.ReLU(),
            nn.Linear(1024,512),
            nn.ReLU(),
            nn.Linear(512,256),
            nn.ReLU(),
            nn.Linear(256,196),
            nn.ReLU(),
            nn.Linear(196,1),
            nn.Sigmoid()
        )
    
    def forward(self,x):
        x = x.view(batch_size, 784)
        x = self.features(x)
        x = x.view(batch_size, -1)
        return x
    
Dsc = Discriminator()

In [5]:
criterion = nn.BCELoss()
Gnr_optimizer = optim.Adam(Gnr.parameters(), lr=0.0001)
Dsc_optimizer = optim.Adam(Dsc.parameters(), lr=0.0001)

In [6]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

Gnr.to(device)
Dsc.to(device)

cuda:0


Discriminator(
  (features): Sequential(
    (0): Linear(in_features=784, out_features=1024, bias=True)
    (1): ReLU()
    (2): Linear(in_features=1024, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=256, bias=True)
    (5): ReLU()
    (6): Linear(in_features=256, out_features=196, bias=True)
    (7): ReLU()
    (8): Linear(in_features=196, out_features=1, bias=True)
    (9): Sigmoid()
  )
)

In [7]:
def noiseInput():
    return torch.randn(batch_size, 196)

test_noise = noiseInput().to(device)
epoch_len = 500

In [None]:
if os.path.isfile('./save/gen.pth'):
    Gnr.load_state_dict(torch.load('./save/gen.pth'))
if os.path.isfile('./save/dsc.pth'):
    Dsc.load_state_dict(torch.load('./save/dsc.pth'))
    
for epoch in range(epoch_len):
    print(epoch)
    for i, data in enumerate(trainloader):
        inputs, _ = data
        inputs = inputs.to(device)
        
        Dsc.zero_grad()
        dsc_real_result = Dsc(inputs)
        dsc_real_error = criterion(dsc_real_result, torch.ones(batch_size).to(device))

        noise = noiseInput().to(device)
        gnr_fake = Gnr(noise)
        dsc_fake_result = Dsc(gnr_fake)
        dsc_fake_error = criterion(dsc_fake_result,torch.zeros(batch_size).to(device))
        
        loss = dsc_real_error + dsc_fake_error
        loss.backward()
        Dsc_optimizer.step()
        
        Gnr.zero_grad()
        noise = noiseInput().to(device)
        gnr_fake = Gnr(noise)
        dsc_gnr_fake_result = Dsc(gnr_fake)
        gnr_error = criterion(dsc_gnr_fake_result, torch.ones(batch_size).to(device))
        gnr_error.backward()
        Gnr_optimizer.step()
        
        if epoch%5 == 0 and i % 1000 == 0:
            plt.figure(epoch, figsize=(10,1))
            gnr_fake = Gnr(test_noise)
            
            fake = gnr_fake.data.cpu()
            fake_img = fake.numpy().reshape(-1,28,28)
            
            for n in range(10):
                plt.subplot(1,10,n+1)
                plt.imshow(fake_img[n], cmap='Greys')
            plt.savefig('./result/result_r'+str(epoch)+'_'+str(int(i/1000))+'.jpg')
            
    torch.save(Gnr.state_dict(),'./save/gen.pth')
    torch.save(Dsc.state_dict(),'./save/dsc.pth')
    

0


  "Please ensure they have the same size.".format(target.size(), input.size()))


1
2
3
4
5
