In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

In [2]:
#device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")
image_channels = 100
batch_size = 8
epochs = 50
learning_rate = 0.0002

In [3]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

In [4]:
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

In [5]:
# Modelo Gerador
class Generator(nn.Module):
    def __init__(self, image_channels):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(image_channels, 256),
            nn.ReLU(0.2),
            nn.Linear(256, 512),
            nn.ReLU(0.2),
            nn.Linear(512, 1024),
            nn.ReLU(0.2),
            nn.Linear(1024, 784),
            nn.Tanh()
        )
    
    def forward(self, z):
        img = self.model(z)
        return img

# Modelo
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(784, 512),
            nn.ReLU(0.2),
            nn.Linear(512, 256),
            nn.ReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
    
    def forward(self, img):
        img = img.view(img.size(0), -1)
        x = self.model(img)
        return x

In [6]:
print(device)

cpu


In [7]:
generator = Generator(image_channels).to(device)
discriminator = Discriminator().to(device)
print(generator)
print(discriminator)

Generator(
  (model): Sequential(
    (0): Linear(in_features=100, out_features=256, bias=True)
    (1): ReLU(inplace=True)
    (2): Linear(in_features=256, out_features=512, bias=True)
    (3): ReLU(inplace=True)
    (4): Linear(in_features=512, out_features=1024, bias=True)
    (5): ReLU(inplace=True)
    (6): Linear(in_features=1024, out_features=784, bias=True)
    (7): Tanh()
  )
)
Discriminator(
  (model): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): ReLU(inplace=True)
    (2): Linear(in_features=512, out_features=256, bias=True)
    (3): ReLU(inplace=True)
    (4): Linear(in_features=256, out_features=1, bias=True)
    (5): Sigmoid()
  )
)


In [8]:
adversarial_loss = nn.BCELoss()
optimizer_g = optim.Adam(generator.parameters(), lr=learning_rate)
optimizer_d = optim.Adam(discriminator.parameters(), lr=learning_rate)

In [None]:
history_loss_g = []
history_loss_d = []
for epoch in range(epochs):
    loss_g = 0
    loss_d = 0
    for i, (real_imgs, _) in enumerate(train_loader):
        real_imgs = real_imgs.to(device)
        optimizer_d.zero_grad()

        ### treinando o discrimiative (achar nome em português)

        # treinando em imagens verdadeiras
        real_labels = torch.ones(real_imgs.size(0), 1).to(device)
        real_output = discriminator(real_imgs)
        real_loss = adversarial_loss(real_output, real_labels)
        real_loss.backward()
        # optimizer_d.step()
        
        # treinando em imagens falsas
        noise = torch.randn(real_imgs.size(0), image_channels).to(device)
        fake_imgs = generator(noise) # gerando imagens falsas
        fake_labels = torch.zeros(real_imgs.size(0), 1).to(device)
        fake_output = discriminator(fake_imgs.detach())
        fake_loss = adversarial_loss(fake_output, fake_labels)
        fake_loss.backward()
        optimizer_d.step()

        loss_d += real_loss + fake_loss
        
        ### treinando o gerador
        optimizer_g.zero_grad()
        fake_imgs = generator(noise)
        output_g = discriminator(fake_imgs)
        loss = adversarial_loss(output_g, real_labels)
        loss.backward()
        optimizer_g.step()
        loss_g += loss
                
    # printando metricas
    loss_g /= i
    loss_d /= i
    history_loss_d.append(loss_d)
    history_loss_g.append(loss_g)
    print(f"[Epoch {epoch}/{epochs}]")
    print(f"[D loss: {loss_d.item():.4f}] [G loss: {loss_g.item():.4f}]")
    
    # Save some generated images at the end of each epoch
    if (epoch+1) % 10 == 0:
        with torch.no_grad():
            sample_noise = torch.randn(16, image_channels).to(device)
            gen_img = generator(sample_noise).cpu().numpy()
            gen_img = gen_img.reshape(-1, 28, 28)
            
            fig, axs = plt.subplots(4, 4, figsize=(4, 4))
            cnt = 0
            for i in range(4):
                for j in range(4):
                    axs[i,j].imshow(gen_img[cnt], cmap='gray')
                    axs[i,j].axis('off')
                    cnt += 1
            plt.savefig(f"gan_images/epoch_{epoch}.png")
            plt.close()

# Save models
torch.save(generator.state_dict(), './models/generator.pth')
torch.save(discriminator.state_dict(), './models/discriminator.pth')

[Epoch 0/50] [Batch 7500/7500]
[D loss: 0.6627] [G loss: 2.2826]
(16, 784)
[Epoch 1/50] [Batch 7500/7500]
[D loss: 0.7798] [G loss: 2.0541]
[Epoch 2/50] [Batch 7500/7500]
[D loss: 0.8373] [G loss: 1.7418]
