In [None]:
import matplotlib.pyplot as plt
import numpy as np
import os
import time

import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import DataLoader, Subset

from torchvision.transforms import v2
import torchvision.utils as vutils
from torchvision import datasets

from pytorchsummary import summary

The github repository og the GAIN author is found here: https://github.com/jsyoon0823/GAIN/blob/master/gain.py

The networks (both discriminator and generator) have a very simple structure

In [None]:
batch_size = 128

transform = v2.Compose([
    v2.ToImage(), 
    v2.RandomAffine(degrees=10, translate=(0.1, 0.1)),
    v2.ToDtype(torch.float32, scale=True),
    v2.Lambda(lambda x: 2*x - 1)] 
    )

train_data = datasets.MNIST('~/.pytorch/MNIST_data/', download=True, train=True, transform=transform)
test_data = datasets.MNIST('~/.pytorch/MNIST_data/', download=True, train=False, transform=transform)

# Getting a subset of the data for speed purposes
num_samples = 60000
indices = np.random.choice(len(train_data), num_samples, replace=False)
train_subset = Subset(train_data, indices)

train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)

In [None]:
from torchvision.transforms import RandomAffine
RandomAffine(degrees=(-10, 10), translate=(0, 0.1))

dummydata = train_data

In [None]:
# Discriminator class
class DiscriminatorConv(nn.Module):
    def __init__(self):
        super(DiscriminatorConv, self).__init__()
        self.network = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=2, padding=1), 
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=2, padding=1), 
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=2, padding=1), 
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),

            # nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=2, padding=1), 
            # nn.BatchNorm2d(64),
            # nn.LeakyReLU(0.2, inplace=True),

            nn.Flatten(), 
            nn.Linear(1024, 1), 
            
            nn.Sigmoid()
        )
    
    def forward(self, x):
        return self.network(x)
    

# Generator class
class GeneratorConv(nn.Module):
    def __init__(self, z_dimension):
        super(GeneratorConv, self).__init__()

        self.network = nn.Sequential(
            nn.Linear(z_dimension, 128 * 7 * 7),
            nn.BatchNorm1d(128 * 7 * 7),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Unflatten(1, (128, 7, 7)),  # 7 because it divides 28 (mnist resolution)

            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),

            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(32, 1, kernel_size=3, stride=1, padding=1),

            nn.Tanh()  
        )

    def forward(self, z):
        return self.network(z)

class GAN(nn.Module):
    pass
    def __init__(self, generator, discriminator, latent_dimension, lossfunction):
        super().__init__()
        self.generator = generator
        self.discriminator = discriminator
        self.lossfunction = lossfunction
        self.latent_dimension = latent_dimension #?


    def forward_discriminator(self):
        pass

    def forward_generator(self):
        pass

    def forward(self):
        self.forward_discriminator()
        self.forward_generator()

    def save_parameters(self):
        torch.save(self.generator.state_dict(), "generate_1.pth")
        torch.save(self.discriminator.state_dict(), "discriminate_1.pth")

    def print_layer_summary(self):
        summary(model=self.generator, input_size=(1, self.latent_dimension))
        summary(model=self.discriminator, input_size=(1, 28, 28))

    def plot_training_loss(self):
        plt.plot(discriminate_loss[0:])
        plt.plot(generate_loss[0:])
        plt.xlabel("Epoch")
        plt.ylabel("Loss")
        plt.legend(["Discriminator", "Generator"])
        plt.show()

In [None]:
lr_discriminator = 2e-4
lr_generator = 4e-4
noise_dimension = 64

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

lossfunction = nn.BCELoss()

generate = GeneratorConv(noise_dimension).to(device)
discriminate = DiscriminatorConv().to(device)

optimizer_gen = optim.Adam(generate.parameters(), lr=lr_generator, betas=(0.5, 0.999))
optimizer_disc = optim.Adam(discriminate.parameters(), lr=lr_discriminator, betas=(0.5, 0.999))

noise_general = torch.randn(noise_dimension, noise_dimension, device=device)

discriminate_loss, generate_loss = [], []
epoch_current = 0

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

count_parameters(discriminate), count_parameters(generate)


In [None]:
K = 1
epochs = 30

total_batches = len(train_loader)
real_label_noise = 0.95 + 0.05 * torch.rand(batch_size, 1, device=device) # I think I switched the labels.. and it works and does not work the other way around.. ???
fake_label_noise = 0.05 * torch.rand(batch_size, 1, device=device)

for epoch in range(epoch_current, epochs):
    start_time = time.time()

    # Initialize cumulative loss values for the epoch
    running_loss_disc = 0.0
    running_loss_gen = 0.0

    # Updating the Generator 
    for i, batch in enumerate(train_loader):

        #############################################################################################################################
        ### Update the discriminator distribution. This van also be done (K=1) times. K is a hyperparameter in (Goodfellow, 2014) ###
        optimizer_disc.zero_grad()
        noise = torch.randn(batch_size, noise_dimension, device=device)

        real_images = batch[0].to(device)
        fake_images = generate(noise).detach() # detach() so that gradients are not passed back to the generator when updating the discriminator.
        real_images_disc = discriminate(real_images)
        fake_images_disc = discriminate(fake_images)

        # Computing the loss function (As in Algorithm 1 in Goodfellow)
        loss_data = lossfunction(real_images_disc + 1e-8, fake_label_noise)        # E log   D(  x)
        loss_generated = lossfunction(fake_images_disc + 1e-8, real_label_noise)   # E log(1-D(G(z))
        loss_discriminator = loss_data + loss_generated

        # Keeping track of the loss and performing backward propagation
        running_loss_disc += loss_discriminator.item()
        loss_discriminator.backward()
        optimizer_disc.step()

        ##############################
        ### Updating the generator ###
        optimizer_gen.zero_grad()

        noise = torch.randn(batch_size, noise_dimension, device=device) # Does this one need to be different than the noise of beginning of loop?
        fake_images = generate(noise)
        fake_images_disc = discriminate(fake_images)
        loss_generator = lossfunction(fake_images_disc, torch.zeros(batch_size, 1, device=device))

        loss_generator.backward()
        optimizer_gen.step()

        running_loss_gen += loss_generator.item()

    avg_loss_disc = running_loss_disc / (K * total_batches)
    avg_loss_gen = running_loss_gen / total_batches

    discriminate_loss.append(avg_loss_disc)
    generate_loss.append(avg_loss_gen)
    
    # Display the epoch progress
    progress = (epoch + 1) / epochs * 100
    print(f"Epoch [{epoch + 1}/{epochs}] - {progress:.2f}% complete | "
          f"Discriminator Loss: {avg_loss_disc:.4f} | Generator Loss: {avg_loss_gen:.4f} | "
          f"Time: {time.time() - start_time:.2f}s")
    
    # Show every few epochs how the fixed noise is generates into new images and save it
    if (epoch+1) % 2 == 0:
        fake_images = generate(noise_general)
        plt.figure(figsize=(8, 8))
        plt.axis("off")
        plt.title(f"Fake Images at Epoch {epoch + 1}")
        plt.imshow(vutils.make_grid(fake_images[:64], padding=2, normalize=True).cpu().numpy().transpose((1, 2, 0)))
        plt.savefig(f"{os.getcwd()}/images/mnist_gam_Fake Images at Epoch {epoch + 1}.png")

    epoch_current += 1

In [None]:

images = batch[0][:64] 
fig, axes = plt.subplots(8, 8, figsize=(8, 8))
for i, ax in enumerate(axes.flat):
    img = images[i].squeeze().numpy()  # Remove extra dimension and convert to numpy array
    ax.imshow(img, cmap='gray')
    ax.axis('off')  # Hide axes

plt.tight_layout()
plt.show()

In [None]:

        # all_images = torch.cat([real_images, fake_images], axis=0)
        # all_images_disc = discriminate(all_images)
        # labels = torch.cat([real_label_noise, fake_label_noise])
        # loss_discriminator = lossfunction(all_images_disc, labels)

In [None]:
plt.plot(discriminate_loss[0:])
plt.plot(generate_loss[0:])
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend(["Discriminator", "Generator"])
plt.show()

Lets use MNIST on VAE's.

In [28]:
class VAEMNIST(nn.Module):
    def __init__(self, latent_dim=64, first_channel=16):
        super(VAEMNIST, self).__init__()
        
        self.first_channel = first_channel
        self.latent_dim = latent_dim
        
        # Calculate sizes for proper dimensionality tracking
        # Input: 100x100
        # After 2 stride-2 convolutions: 100 -> 50 -> 25
        self.dim = 7

        # Encoder: convolutional layers for grid-like data
        self.encoder = nn.Sequential(
            nn.Conv2d(1, self.first_channel, kernel_size=3, stride=2, padding=1),  # 100x100 -> 50x50
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(self.first_channel, 2*self.first_channel, kernel_size=3, stride=2, padding=1),  # 50x50 -> 25x25
            nn.LeakyReLU(0.2, inplace=True),
        )
        
        self.flatten_size = 2*self.first_channel * self.dim**2
        
        self.fc1 = nn.Linear(self.flatten_size, self.latent_dim)  # Mean in latent space
        self.fc2 = nn.Linear(self.flatten_size, self.latent_dim)  # Log variance        
        self.fc3 = nn.Linear(self.latent_dim, 2 * self.first_channel * self.dim**2)
        
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(2*self.first_channel, self.first_channel, kernel_size=3, stride=2, padding=1, output_padding=1),  # 25x25 -> 50x50
            nn.LeakyReLU(0.2, inplace=True),
            nn.ConvTranspose2d(self.first_channel, 1, kernel_size=3, stride=2, padding=1, output_padding=1),  # 50x50 -> 100x100
            nn.Sigmoid()
        )
    
    def encode(self, x):
        x = self.encoder(x)
        x = x.view(x.size(0), -1)  # Flatten: 32 x 2*first_channel*25*25
        mu = self.fc1(x)           # Mean
        logvar = self.fc2(x)       # Log variance
        return mu, logvar
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)  # Standard deviation
        eps = torch.randn_like(std)  # Random noise
        return mu + eps * std  # Reparameterization trick
    
    def decode(self, z):
        x = self.fc3(z)  # (B, latent_dim) -> (B, 2*first_channel*25*25)
        # Reshape for the decoder: batch_size x channels x height x width
        x = x.view(-1, 2*self.first_channel, self.dim, self.dim)
        return self.decoder(x)
    
    def forward(self, x):
        mu, logvar = self.encode(x)  # Get mean and log variance
        z = self.reparameterize(mu, logvar)  # Sample from the latent space
        reconstructed_x = self.decode(z)  # Decode back to original space
        return reconstructed_x, mu, logvar
    
    def loss_function(self, recon_x, x, mu, logvar, beta=1.0):
        MSE = F.mse_loss(recon_x, x, reduction='sum')  
        KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        return MSE + beta * KLD

    def summary(self):
        pass

device = "cpu"
vae = VAEMNIST(latent_dim=64, first_channel=4).to(device) 

# For testing purposes
with torch.no_grad():
    tester = torch.randn(batch_size, 1, 28, 28)
    lol, wut = vae.encode(tester)
    yuh = vae.reparameterize(lol,wut)
    vae.decode(yuh).shape
    print(vae.encoder(torch.randn(batch_size, 1, 28, 28)).shape)
    lol.shape, wut.shape, vae.decode(yuh).shape
    test = vae.encoder(torch.randn(batch_size, 1, 28, 28))
    test.shape, test.view(tt.size(0), -1).shape

torch.Size([128, 8, 7, 7])


NameError: name 'tt' is not defined

In [None]:


batch_size = 128


train_data = datasets.MNIST('~/.pytorch/MNIST_data/', download=True, train=True)
test_data = datasets.MNIST('~/.pytorch/MNIST_data/', download=True, train=False)

# Getting a subset of the data for speed purposes
num_samples = 60000
indices = np.random.choice(len(train_data), num_samples, replace=False)
train_subset = Subset(train_data, indices)

train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)

In [None]:
import matplotlib.pyplot as plt
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import DataLoader, Subset, TensorDataset

from torchvision.transforms import v2
import torchvision.utils as vutils
from torchvision import datasets



train_dataset = TensorDataset(train_data)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

optimizer = optim.Adam(vae.parameters(), lr=1e-4)
 
num_epochs = 2000
print_every = 10

def lr_lambda(epoch):
    start_lr = 0.01
    end_lr = 0.0001
    total_steps = 100
    return (end_lr - start_lr) * (epoch / total_steps) + start_lr  # Linear decay

# Define LambdaLR scheduler
scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch=-1)

for epoch in range(num_epochs):
    vae.train()  # Training mode
    running_loss = 0.0
    for batch_idx, (data, ) in enumerate(train_loader):
        # Move data to the appropriate device
        data = data.to(device)
        data = data.unsqueeze(1)
        optimizer.zero_grad()
        recon_batch, mu, logvar = vae(data)
        loss = vae.loss_function(recon_batch, data, mu, logvar)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    if (epoch + 1) % print_every == 0:
        epoch_loss = running_loss / len(train_loader.dataset)
        print(f"Epoch [{epoch+1}/{num_epochs} *1e-6], Loss: {epoch_loss:.4f}")

AttributeError: 'MNIST' object has no attribute 'size'