In [None]:
# WGAN attempt --> changed the loss within the separate loss functions + added weight clipping

Import libraries

In [3]:
# Import all libraries regarding torch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.autograd as autograd

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from scipy.stats import gaussian_kde, lognorm
import os
from torch.utils.data import Dataset, DataLoader
import time
import shutil

from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

import torchvision
import torchvision.datasets as datasets

Read Input data 

In [4]:
# Reading from RSA_input.csv
df = pd.read_csv('RSA_input.csv') # add csv file to the correct location
grain_R = df["grain_R"]
grain_asp = df["aspect"]
print(len(grain_R)) # length 30218
print(len(grain_asp)) # length 30218

# Combined into numpy shape (30218, 2)
grainsData = np.column_stack((grain_R, grain_asp))
print(grainsData.shape)

device = "cuda:0" if torch.cuda.is_available() else "cpu"

30218
30218
(30218, 2)


Set Hyperparameters

In [30]:
# Hyperparameters
learning_rate = 5e-5
batch_size = 64
critic_iterations = 5
weight_clip = 0.01
num_epochs = 100

# Establish convention for real and fake labels during training
real_label = 1.
fake_label = -1.

# Initialize generator and discriminator
latent_Gaussian_dimension = 100  # Dimension of the input noise vector
number_of_grain_features = 2  # Dimension of the real data
#real_data_dim = 30218  # Dimension of the real data
number_of_reduced_grains = 1000  # Dimension of the generated data

Create NN models

In [25]:
# Generator
class Generator(nn.Module):
    def __init__(self, latent_Gaussian_dimension, number_of_grain_features):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            # change width and depth of the network here
            nn.Linear(latent_Gaussian_dimension, 75),
            nn.ReLU(),
            nn.Linear(75, 50),
            nn.ReLU(),
            nn.Linear(50, 25),
            nn.ReLU(),
            nn.Linear(25, number_of_grain_features)
        )

    def forward(self, x):
        return self.model(x)
    
# Discriminator
class Discriminator(nn.Module):
    def __init__(self, number_of_grain_features):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            # change width and depth of the network here
            nn.Linear(number_of_grain_features, 25),
            nn.ReLU(),
            nn.Linear(25, 50),
            nn.ReLU(),
            nn.Linear(50, 75),
            nn.ReLU(),
            nn.Linear(75, 100),
            nn.ReLU(),
            nn.Linear(100, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)
    
def generator_loss(discriminator, fake_grains): # This has been changed to the WGAN loss

    # Ensure that the target tensor has the same device as the input fake_grains
    target = torch.full((fake_grains.size(0),), real_label, device=fake_grains.device)

    # Forward pass through the discriminator with the fake grains
    output = discriminator(fake_grains).reshape(-1)

    # generator loss = -(avgerage critic score on fake images)
    loss = -torch.mean(output)

    return loss


def discriminator_loss(discriminator, real_grains, fake_grains):

    '''
    d_loss_real: The mean of the binary cross-entropy losses computed on the real_grains.
    D_real: Mean output of the discriminator for real_grains. This is useful for tracking convergence.
    d_loss_fake: The mean of the binary cross-entropy losses computed on the fake_grains.
    D_fake: Mean output of the discriminator for fake_grains. This is useful for tracking convergence.

    '''

    device = fake_grains.device  # Get the device of the input tensors
    # Transfer input tensors to the same device as the discriminator
    real_grains = real_grains.to(device)
    fake_grains = fake_grains.to(device)

    # Create the target labels for real and fake grains
    real_target = torch.full((real_grains.size(0),), real_label, device=device)
    fake_target = torch.full((fake_grains.size(0),), fake_label, device=device)

    # Compute the discriminator outputs for real and fake grains --> critic_real & critic_fake
    real_output = discriminator(real_grains).view(-1)
    fake_output = discriminator(fake_grains).view(-1)

    # Compute the W-losses
    # Discrumunator loss =  (avegrage critic score on real images) - (average critic score on fake images)
    loss_critic = -(torch.mean(real_output) - torch.mean(fake_output))
    d_loss_real = torch.mean(real_output)
    d_loss_fake = torch.mean(fake_output)

    # Compute the mean discriminator outputs for real and fake grains
    D_real = real_output.mean().item()
    D_fake = fake_output.mean().item()

    return d_loss_real, D_real, d_loss_fake, D_fake, loss_critic

Testing Shapes and Loss value

In [23]:
# Test generator loss
def test_generator_loss():
    # Create a generator
    netG = Generator(latent_Gaussian_dimension=2, number_of_grain_features=2)
    netD = Discriminator(number_of_grain_features=2)
    # Create fake grains
    noise = torch.randn(100, 2)
    fake_grains = netG(noise)

    # Compute the generator loss 
    loss = generator_loss(netD, fake_grains)
    print(loss)

test_generator_loss()

tensor(-0.5048, grad_fn=<NegBackward0>)


In [31]:
# Test discriminator loss
def test_discriminator_loss():
    # Create a generator
    netG = Generator(latent_Gaussian_dimension=2, number_of_grain_features=2)
    netD = Discriminator(number_of_grain_features=2)
    # Create real and fake grains
    real_grains = torch.randn(100, 2)
    noise = torch.randn(100, 2)
    fake_grains = netG(noise)
    # Compute the discriminator loss
    d_loss_real, D_real, d_loss_fake, D_fake, loss_critic = discriminator_loss(netD, real_grains, fake_grains)
    print(d_loss_real)
    print(D_real)
    print(d_loss_fake)
    print(D_fake)
    print(loss_critic)

test_discriminator_loss()

tensor(0.5169, grad_fn=<MeanBackward0>)
0.5168965458869934
tensor(0.5161, grad_fn=<MeanBackward0>)
0.5161435008049011
tensor(-0.0008, grad_fn=<NegBackward0>)


Train the Models

In [32]:
# initialize gen and discriminator_loss
generator = Generator(latent_Gaussian_dimension, number_of_grain_features).to(device)
discriminator = Discriminator(number_of_grain_features).to(device)

#initialize optimizer
generator_optimizer = optim.RMSprop(generator.parameters(), lr=0.00005)
discriminator_optimizer = optim.RMSprop(discriminator.parameters(), lr=0.00005)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Create DataLoader
dataloader = DataLoader(grainsData, batch_size=batch_size, shuffle=True)

# initialize tensorboard plotting

# train gen & critic in loop
for epoch in range(num_epochs + 1):
    for batch_index, real_grains in enumerate(dataloader):
        real_grains = real_grains.to(torch.float32)

        # Train discriminator = max E(discriminator(real) - E(dicriminator(fake)))
        for _ in range(critic_iterations):
            noise = torch.randn(batch_size, latent_Gaussian_dimension, device=device, dtype=torch.float32)
            fake_grains = generator(noise)
            d_loss_real, D_real, d_loss_fake, D_fake, d_loss = discriminator_loss(discriminator, real_grains, fake_grains)
            discriminator.zero_grad()

            # Update discriminator
            d_loss.backward(retain_graph=True)
            discriminator_optimizer.step()

            # clip critic weigts between -0.01 and 0.01
            for p in discriminator.parameters():
                p.data.clamp_(-weight_clip, weight_clip)
        
        # Train Generator: max E(discriminator(fake)) <--> min -E(discriminator(fake))
        # i.e. we want discriminator to think that the fake data is real
        
        # Generate fake grains
        noise = torch.randn(batch_size, latent_Gaussian_dimension, device=device)
        fake_grains = generator(noise)
        
        # Compute generator loss
        g_loss = generator_loss(discriminator, fake_grains)

        # Update generator
        generator_optimizer.zero_grad()
        g_loss.backward()
        generator_optimizer.step()

    ############################
    # Print training progress
    ###########################
    
    if epoch % 5 == 0:
        print(f"Epoch {epoch}")
        print(f"Discriminator Loss: D_real={D_real:.4f}, D_fake={D_fake:.4f}, d_loss = {d_loss:.4f}")
        print(f"Generator Loss: {g_loss:.4f}")
        



Epoch 0
Discriminator Loss: D_real=0.5002, D_fake=0.4998, d_loss = -0.0003
Generator Loss: -0.4998
Epoch 5
Discriminator Loss: D_real=0.5001, D_fake=0.5001, d_loss = -0.0000
Generator Loss: -0.5001
Epoch 10
Discriminator Loss: D_real=0.5002, D_fake=0.5002, d_loss = -0.0000
Generator Loss: -0.5001
Epoch 15
Discriminator Loss: D_real=0.5001, D_fake=0.5001, d_loss = 0.0000
Generator Loss: -0.5002
Epoch 20
Discriminator Loss: D_real=0.5004, D_fake=0.5004, d_loss = -0.0000
Generator Loss: -0.5004
Epoch 25
Discriminator Loss: D_real=0.5005, D_fake=0.5005, d_loss = -0.0000
Generator Loss: -0.5005
Epoch 30
Discriminator Loss: D_real=0.5005, D_fake=0.5005, d_loss = -0.0000
Generator Loss: -0.5005
Epoch 35
Discriminator Loss: D_real=0.5005, D_fake=0.5005, d_loss = -0.0000
Generator Loss: -0.5005
