In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets
import torchvision.transforms as transforms
torch.__version__

'1.13.1+cpu'

In [59]:
transform = transforms.Compose([
    transforms.Resize(32),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
train = datasets.CIFAR10(root = '.', train = True, download = True, 
                       transform = transform)
dataset = torch.utils.data.DataLoader(train, batch_size = 256, shuffle=True)

Files already downloaded and verified


In [60]:
train

Dataset CIFAR10
    Number of datapoints: 50000
    Root location: .
    Split: Train
    StandardTransform
Transform: Compose(
               Resize(size=32, interpolation=bilinear, max_size=None, antialias=None)
               ToTensor()
               Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
           )

In [124]:
# Criar o Gerador

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.fc1 = nn.Linear(256, 4 * 4 * 256)
        self.conv1 = nn.Conv2d(256, 128, kernel_size=3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(128, 64, kernel_size=3, stride=2, padding=1)
        self.conv3 = nn.Conv2d(64, 32, kernel_size=3, stride=2, padding=1)
        self.conv4 = nn.Conv2d(32, 1, kernel_size=3, stride=1, padding=1)
        
    def forward(self, x):
        x = self.fc1(x)
        x = x.view(-1, 256, 4, 4)
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = torch.tanh(self.conv4(x))
        return x

In [130]:
# Criar o Discriminador
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)
        self.conv4 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1)
        self.fc = nn.Linear(256 * 4 * 4, 1)
        
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = x.view(-1, 256 * 4 * 4)
        x = self.fc(x)
        return x

In [131]:
G = Generator()
D = Discriminator()

In [132]:
loss = nn.BCEWithLogitsLoss()

# Define the optimizers
D_optimizer = optim.Adam(D.parameters(), lr=0.0002)
G_optimizer = optim.Adam(G.parameters(), lr=0.0002)

In [135]:
for epoch in range(200):
    for i, (images, _) in enumerate(dataset):
        # Update the discriminator
        D.zero_grad()

        real_images = images.reshape(-1, 1, 32, 32)
        real_labels = torch.ones(768, 1)
        real_outputs = D(real_images)
        real_loss = loss(real_outputs, real_labels)
        real_loss.backward()
        
        # Update the generator
        G.zero_grad()
        noise = torch.randn(256, 256)
        fake_images = G(noise)
        outputs = D(fake_images)
        print(f'A: {fake_images.shape}')
        print(f'B: {outputs.shape}')
        G_loss = loss(outputs, real_labels)
        G_loss.backward()
        G_optimizer.step()
        

        if i == 0:
            fig, ax = plt.subplots(1, 5, figsize=(10,5))
            for i in range(5):
                ax[i].imshow(fake_images.cpu().detach().numpy()[i].reshape(28, 28), cmap='gray')
                ax[i].xaxis.set_visible(False)
                ax[i].yaxis.set_visible(False)
            plt.show()
            
        if (i + 1) % 100 == 0:
            fig, ax = plt.subplots(1, 5, figsize=(10,5))
            for i in range(5):
                ax[i].imshow(fake_images.cpu().detach().numpy()[i].reshape(28, 28), cmap='gray')
                ax[i].xaxis.set_visible(False)
                ax[i].yaxis.set_visible(False)
            plt.show()
            print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}' 
                  .format(epoch, 200, i + 1, len(dataset), real_loss , G_loss))
    print(epoch)

A: torch.Size([256, 1, 1, 1])
B: torch.Size([16, 1])


ValueError: Target size (torch.Size([768, 1])) must be the same as input size (torch.Size([16, 1]))