Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
L1aoXingyu committed Sep 5, 2017
1 parent 8f5b779 commit 765f849
Showing 1 changed file with 14 additions and 24 deletions.
38 changes: 14 additions & 24 deletions 09-Generative Adversarial network/simple_Gan.py
Expand Up @@ -25,18 +25,15 @@ def to_img(x):

# Image processing
img_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5),
std=(0.5, 0.5, 0.5))])
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])
# MNIST dataset
mnist = datasets.MNIST(root='./data/',
train=True,
transform=img_transform,
download=True)
mnist = datasets.MNIST(
root='./data/', train=True, transform=img_transform, download=True)
# Data loader
dataloader = torch.utils.data.DataLoader(dataset=mnist,
batch_size=batch_size,
shuffle=True)
dataloader = torch.utils.data.DataLoader(
dataset=mnist, batch_size=batch_size, shuffle=True)


# Discriminator
Expand All @@ -47,10 +44,7 @@ def __init__(self):
nn.Linear(784, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 1),
nn.Sigmoid()
)
nn.LeakyReLU(0.2), nn.Linear(256, 1), nn.Sigmoid())

def forward(self, x):
x = self.dis(x)
Expand All @@ -64,11 +58,7 @@ def __init__(self):
self.gen = nn.Sequential(
nn.Linear(100, 256),
nn.ReLU(True),
nn.Linear(256, 256),
nn.ReLU(True),
nn.Linear(256, 784),
nn.Tanh()
)
nn.Linear(256, 256), nn.ReLU(True), nn.Linear(256, 784), nn.Tanh())

def forward(self, x):
x = self.gen(x)
Expand Down Expand Up @@ -125,17 +115,17 @@ def forward(self, x):
g_loss.backward()
g_optimizer.step()

if (i+1) % 100 == 0:
if (i + 1) % 100 == 0:
print('Epoch [{}/{}], d_loss: {:.6f}, g_loss: {:.6f} '
'D real: {:.6f}, D fake: {:.6f}'
.format(epoch, num_epoch, d_loss.data[0], g_loss.data[0],
real_scores.data.mean(), fake_scores.data.mean()))
'D real: {:.6f}, D fake: {:.6f}'.format(
epoch, num_epoch, d_loss.data[0], g_loss.data[0],
real_scores.data.mean(), fake_scores.data.mean()))
if epoch == 0:
real_images = to_img(real_img.cpu().data)
save_image(real_images, './img/real_images.png')

fake_images = to_img(fake_img.cpu().data)
save_image(fake_images, './img/fake_images-{}.png'.format(epoch+1))
save_image(fake_images, './img/fake_images-{}.png'.format(epoch + 1))

torch.save(G.state_dict(), './generator.pth')
torch.save(D.state_dict(), './discriminator.pth')

0 comments on commit 765f849

Please sign in to comment.