In [8]:
import torch
import torch.nn as nn
from torch.nn.functional import relu
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
from torchsummary import summary
import warnings
warnings.filterwarnings('ignore')

In [9]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('Using gpu: %s ' % torch.cuda.is_available())

Using gpu: True 


In [38]:
transform = transforms.Compose([
    transforms.Resize((28, 28)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

trainset = datasets.MNIST(root='./data/', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2)

In [39]:
class Generator(nn.Module):
    def __init__(self, input_dim):
        super(Generator, self).__init__()
        self.fc1 = nn.Linear(input_dim, 256 * 7 * 7)
        self.convT1 = nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1)
        self.convT2 = nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1)
        self.convT3 = nn.ConvTranspose2d(64, 1, kernel_size=3, stride=1, padding=1)
        self.tanh = nn.Tanh()

    def forward(self, x):
        x = relu(self.fc1(x))
        x = x.view(-1, 256, 7, 7)
        x = relu(self.convT1(x))
        x = relu(self.convT2(x))
        x = self.convT3(x)
        x = self.tanh(x)
        return x

generator = Generator(input_dim=100).to(device)
summary(generator, (100,))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1                [-1, 12544]       1,266,944
   ConvTranspose2d-2          [-1, 128, 14, 14]         524,416
   ConvTranspose2d-3           [-1, 64, 28, 28]         131,136
   ConvTranspose2d-4            [-1, 1, 28, 28]             577
              Tanh-5            [-1, 1, 28, 28]               0
Total params: 1,923,073
Trainable params: 1,923,073
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.68
Params size (MB): 7.34
Estimated Total Size (MB): 8.02
----------------------------------------------------------------


In [51]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=4, stride=2, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1)
        self.fc1 = nn.Linear(128 * 3 * 3, 1)
        self.lrelu = nn.LeakyReLU(0.2)
        self.dropout = nn.Dropout(0.3)

    def forward(self, x):
        x = self.lrelu(self.conv1(x))
        x = self.lrelu(self.conv2(x))
        x = self.lrelu(self.conv3(x))
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        return x

discriminator = Discriminator().to(device)
summary(discriminator, (1, 28, 28))


----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 32, 14, 14]             544
         LeakyReLU-2           [-1, 32, 14, 14]               0
            Conv2d-3             [-1, 64, 7, 7]          32,832
         LeakyReLU-4             [-1, 64, 7, 7]               0
            Conv2d-5            [-1, 128, 3, 3]         131,200
         LeakyReLU-6            [-1, 128, 3, 3]               0
            Linear-7                    [-1, 1]           1,153
Total params: 165,729
Trainable params: 165,729
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.16
Params size (MB): 0.63
Estimated Total Size (MB): 0.80
----------------------------------------------------------------


In [52]:
g_opt = torch.optim.Adam(lr=1e-4, betas=(0.5, 0.999), params=generator.parameters())
d_opt = torch.optim.Adam(lr=1e-4, betas=(0.5, 0.999), params=discriminator.parameters())

In [53]:
nb_epochs = 50
batch_size = 64

In [54]:
criterion = nn.BCEWithLogitsLoss()

In [55]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1 or classname.find('ConvTranspose') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)


In [56]:
generator.apply(weights_init)
discriminator.apply(weights_init)

Discriminator(
  (conv1): Conv2d(1, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (conv2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (conv3): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (fc1): Linear(in_features=1152, out_features=1, bias=True)
  (lrelu): LeakyReLU(negative_slope=0.2)
  (dropout): Dropout(p=0.3, inplace=False)
)

In [57]:
def train(train_loader, nb_epochs, generator, discriminator, input_dim, device):
    g_loss_h= []
    d_loss_h = []
    for epoch in (range(nb_epochs)):
        total_d_loss = 0.0
        total_g_loss = 0.0

        for batch_idx, (real, _) in enumerate(train_loader):
            real = real.to(device)
            m = real.size(0)

            d_opt.zero_grad()
            real_labels = torch.ones(m).to(device)
            fake_labels = torch.zeros(m).to(device)
            
            outputs_real = discriminator(real).view(-1)
            d_loss_real = criterion(outputs_real, real_labels)

            noise = torch.randn(m, input_dim, device=device)
            fake_images = generator(noise)
            outputs_fake = discriminator(fake_images.detach()).view(-1)
            d_loss_fake = criterion(outputs_fake, fake_labels)

            d_loss = d_loss_real + d_loss_fake
            d_loss.backward()
            d_opt.step()
            total_d_loss += d_loss.item()

            g_opt.zero_grad()
            g_labels = torch.ones(m, device=device)
            outputs = discriminator(fake_images).view(-1)
            g_loss = criterion(outputs, g_labels)
            g_loss.backward()
            g_opt.step()
            total_g_loss += g_loss.item()

        avg_d_loss = total_d_loss / len(train_loader)
        avg_g_loss = total_g_loss / len(train_loader)

        g_loss_h.append(avg_g_loss)
        d_loss_h.append(avg_d_loss)
        print(f"Epoch [{epoch+1}/{nb_epochs}], Discriminator Loss: {avg_d_loss:.2f}, Generator Loss: {avg_g_loss:.2f}")

In [58]:
train(trainloader, nb_epochs, generator, discriminator, 100, device)

Epoch [1/50], Discriminator Loss: 1.34, Generator Loss: 0.76
Epoch [2/50], Discriminator Loss: 1.32, Generator Loss: 0.79
Epoch [3/50], Discriminator Loss: 1.29, Generator Loss: 0.81
Epoch [4/50], Discriminator Loss: 1.30, Generator Loss: 0.81
Epoch [5/50], Discriminator Loss: 1.30, Generator Loss: 0.82
Epoch [6/50], Discriminator Loss: 1.32, Generator Loss: 0.79
Epoch [7/50], Discriminator Loss: 1.30, Generator Loss: 0.79
Epoch [8/50], Discriminator Loss: 1.30, Generator Loss: 0.81
Epoch [9/50], Discriminator Loss: 1.29, Generator Loss: 0.82
Epoch [10/50], Discriminator Loss: 1.28, Generator Loss: 0.83
Epoch [11/50], Discriminator Loss: 1.27, Generator Loss: 0.84
Epoch [12/50], Discriminator Loss: 1.27, Generator Loss: 0.85
Epoch [13/50], Discriminator Loss: 1.27, Generator Loss: 0.85
Epoch [14/50], Discriminator Loss: 1.27, Generator Loss: 0.84
Epoch [15/50], Discriminator Loss: 1.28, Generator Loss: 0.85
Epoch [16/50], Discriminator Loss: 1.28, Generator Loss: 0.85
Epoch [17/50], Di

KeyboardInterrupt: 