In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils as utils
import torchvision

from helpers import Logger

torch.manual_seed(0)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Available_device: {device}')

TensorBoard URL: http://localhost:6006/


Available_device: cuda


In [2]:
# Load and prepare the MNIST dataset
data = utils.data.DataLoader(
    torchvision.datasets.MNIST('./data', 
                               transform=torchvision.transforms.ToTensor(), 
                               download=True),
    batch_size=128,
    shuffle=True)

In [3]:
class Discriminator(nn.Module):
    """
    A three hidden-layer discriminative neural network
    """
    def __init__(self):
        super(Discriminator, self).__init__()
        self.linear1 = nn.Linear(784, 1024)
        self.linear2 = nn.Linear(1024, 512)
        self.linear3 = nn.Linear(512, 256)
        self.out = nn.Linear(256, 1)

    def forward(self, x):
        x = torch.flatten(x, start_dim=1)
        
        x = F.dropout(F.leaky_relu(self.linear1(x), negative_slope=0.2), p=0.3)
        x = F.dropout(F.leaky_relu(self.linear2(x), negative_slope=0.2), p=0.3)
        x = F.dropout(F.leaky_relu(self.linear3(x), negative_slope=0.2), p=0.3)
        
        return torch.sigmoid(self.out(x))

In [4]:
class Generator(nn.Module):
    """
    A three hidden-layer generative neural network
    """
    def __init__(self):
        super(Generator, self).__init__()
        self.linear1 = nn.Linear(100, 256)
        self.linear2 = nn.Linear(256, 512)
        self.linear3 = nn.Linear(512, 1024)
        self.out = nn.Linear(1024, 784)

    def forward(self, x):
        x = F.leaky_relu(self.linear1(x), negative_slope=0.2)
        x = F.leaky_relu(self.linear2(x), negative_slope=0.2)
        x = F.leaky_relu(self.linear3(x), negative_slope=0.2)
        x = torch.tanh(self.out(x))
        return x

In [5]:
# Initialize the models
generator = Generator()
discriminator = Discriminator()

# Loss function and optimizers
criterion = nn.BCELoss()
optimizer_g = torch.optim.Adam(generator.parameters(), lr=0.0002)
optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=0.0002)

# Logger setup
logger = Logger(model_name='GAN', data_name='MNIST')

# Training loop
num_epochs = 50
for epoch in range(num_epochs):
    for n, (real_samples, _) in enumerate(data):
        # Data for training the discriminator
        real_samples = real_samples.view(-1, 784)
        real_samples_labels = torch.ones((real_samples.size(0), 1))
        latent_space_samples = torch.randn((real_samples.size(0), 100))
        generated_samples = generator(latent_space_samples)
        generated_samples_labels = torch.zeros((real_samples.size(0), 1))

        # Concatenate real and fake data
        all_samples = torch.cat((real_samples, generated_samples))
        all_samples_labels = torch.cat((real_samples_labels, generated_samples_labels))

        # Training the discriminator
        discriminator.zero_grad()
        output_d = discriminator(all_samples)
        loss_d = criterion(output_d, all_samples_labels)
        loss_d.backward()
        optimizer_d.step()

        # Data for training the generator
        latent_space_samples = torch.randn((real_samples.size(0), 100))
        generator.zero_grad()
        generated_samples = generator(latent_space_samples)
        output_d_generated = discriminator(generated_samples)

        # Reverse the labels for the generator
        generated_samples_labels = torch.ones((real_samples.size(0), 1))

        # Training the generator
        loss_g = criterion(output_d_generated, generated_samples_labels)
        loss_g.backward()
        optimizer_g.step()

        # Logging
        logger.log(d_error=loss_d, g_error=loss_g, epoch=epoch, n_batch=n, num_batches=len(data))
        if n % 100 == 0:
            logger.display_status(epoch, num_epochs, n, len(data), loss_d, loss_g, discriminator(real_samples), output_d_generated)

        if n % 100 == 0:
            logger.log_images(generated_samples.view(real_samples.size(0), 1, 28, 28), num_images=real_samples.size(0), epoch=epoch, n_batch=n, num_batches=len(data))

    # Save the model parameters
    logger.save_models(generator, discriminator, epoch)

# Close the logger
logger.close()

print("Training Finished.")

AttributeError: 'Logger' object has no attribute 'set_tensorboard_url'