In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import time

In [None]:
def train_discriminator(gen, disc, mapping_network, inputs, latent_dim, criterion, device, features=14):
  """ Trains the discrinator of the GAN model.

      Args:
          gen: The generator of the model currently being trained
          disc: The discriminator of the model currently being trained
          map: The mapping network of the model currently being trained
          inputs: The real images found in our data set
          latent_dim: The latent dimension of our noise vector
          criterion: The loss function used for our model
          device: The device used for cuda
    """

  batchSize = inputs.size(0)

  # Get real images from data set
  real_images = inputs#.to(device)
  realLabels = torch.ones(batchSize, device = device, dtype=torch.float16)

  # Get w from mapping network 
  w = None
  # Determines whether we should implement style mixing or not
  if torch.rand(()).item() < 0:
    cross = int(torch.rand(()).item() * features)
    # Randomly samples 2 latent vectors
    z1, z2 = torch.randn(batchSize, 512).to(device), torch.randn(batchSize, 512).to(device)
    w1 = mapping_network(z1)[None, :, :].expand(cross, -1, -1)
    w2 = mapping_network(z2)[None, :, :].expand(features - cross, -1, -1)

    # Mixes the styles generated 
    w = torch.cat((w1, w2), dim=0)
  else:
    z = torch.randn(batchSize, 512).to(device)

    # Does not mix styles
    w = mapping_network(z)
    w = w[None, :, :].expand(features, -1, -1)
    

  w = torch.transpose(w, 0, 1)

  # Create fake images with generator
  fake_images = gen.forward(w, None)

  fakeLabels = torch.zeros(batchSize, device = device)

  # Get discriminator loss on real images
  discPredReal = disc.forward(real_images)
  realLoss = criterion(discPredReal.squeeze(1), realLabels) 

  # Get discriminator loss on fake images
  discPredFake = disc.forward(fake_images)
  fakeLoss = criterion(discPredFake.squeeze(1), fakeLabels) 

  # Take average of losses
  discLoss = 0.5 * (realLoss + fakeLoss) 

  return realLabels, fakeLabels, discLoss, discPredReal

def train_generator(gen, disc, mapping_network, inputs, criterion, device, features=14):
  """ Trains the generator of the GAN model.

    Args:
        gen: The generator of the model currently being trained
        disc: The discriminator of the model currently being trained
        map: The mapping network of the model currently being trained
        inputs: The fake images generated by our generator in the discriminator training
        realLabels: A tensor of ones to label the images as real
        criterion: The loss function used for our model
        device: The device used for cuda
    """ 

  batchSize = inputs.size(0)
  realLabels = torch.ones(batchSize, device = device)

  # Get w from mapping network
  w = None

  # Determines whether we should implement style mixing or not
  if torch.rand(()).item() < 0:
    cross = int(torch.rand(()).item() * features)
    # Randomly samples 2 latent vectors
    z1, z2 = torch.randn(batchSize, 512).to(device), torch.randn(batchSize, 512).to(device)
    w1 = mapping_network(z1)[None, :, :].expand(cross, -1, -1)
    w2 = mapping_network(z2)[None, :, :].expand(features - cross, -1, -1)

    # Mixes the styles generated 
    w = torch.cat((w1, w2), dim=0)
  else:
    z = torch.randn(batchSize, 512).to(device)

    # Does not mix styles
    w = mapping_network(z)
    w = w[None, :, :].expand(features, -1, -1)


  w = torch.transpose(w, 0, 1)

  # Create fake images with generator
  fake_images = gen.forward(w, None)

  # Get discriminator loss on fake images with real labels instead of fake labels
  discPredFake = disc.forward(fake_images)

  genLoss = criterion(discPredFake.squeeze(1), realLabels)


  return genLoss, discPredFake, fake_images

  
def train_network(generator, discriminator, mapping, interval, device, trainLoader, num_epochs, lrGAN = 1e-5, lrMap = 1e-5, latent_dim = 512, cuda=True):
    """ Trains the model entirely (both the generator and discriminator) and stores results.

    Args:
        gen: The generator of the model currently being trained
        disc: The discriminator of the model currently being trained
        map: The mapping network of the model currently being trained
        latent_dim: The latent dimension of our noise vector
        device: The device being used for cuda
        trainLoader: The training data for our model to access during training
        num_epochs: The number of epochs to run our model for
        lrGAN: The learning rate of our generator and discriminator
        lrMap: The learning rate of our mapping network
        interval: The iteration interval at which we print our model's loss values
    """
    
    # Set manual seed for reproducible results
    torch.manual_seed(1)

    # Criterion and optimizers
    criterion = nn.BCEWithLogitsLoss() # necessary if we use our loss function and not theirs
    gen_optimizer =  optim.Adam(generator.parameters(), lr= lrGAN, weight_decay=1e-3)
    disc_optimizer =  optim.Adam(discriminator.parameters(), lr=lrGAN, weight_decay=1e-3)
    map_optimizer = optim.Adam(mapping.parameters(), lr = lrMap, weight_decay = 1e-3)

    # Training metrics
    train_gen_loss = []
    train_disc_loss = []
    train_disc_real_acc = []
    train_disc_fake_acc = []
    newImgs = []

    inputNoise = nn.Parameter(torch.randn([7, 2]), requires_grad=False)
    inputW = mapping(torch.randn(batchSize, 512).to(device))
    inputW = inputW[None, :, :].expand(features, -1, -1)
    inputW = torch.transpose(inputW, 0, 1)

    ########## SENDING TO CUDA ############
    if cuda:
        generator = generator.to('cuda:0')
        discriminator = discriminator.to('cuda:0')
        mapping = mapping.to('cuda:0')

    start_time = time.time()

    # Train loop
    for epoch in range(num_epochs):
        generator.train()
        discriminator.train()
        mapping.train()
        iteration = 0
    
        for imgs, _ in trainLoader:

            if cuda:
               imgs = imgs.cuda()

            # Train Discriminator
            disc_optimizer.zero_grad()
            realLabels, fakeLabels, discLoss, discPredReal = train_discriminator(generator, discriminator, mapping, imgs, latent_dim, criterion, device)

            discLoss.backward(retain_graph=True)            
            disc_optimizer.step()
        
            # Train Generator and Mapping Network
            gen_optimizer.zero_grad()
            map_optimizer.zero_grad()
            genLoss, discPredFake, fake_images = train_generator(generator, discriminator, mapping, imgs, criterion, device)
            genLoss.backward(retain_graph=True)
            gen_optimizer.step()
            map_optimizer.step()

            # Calculate accuracy of model
            predReal = torch.where(discPredReal.detach() > 0.0, 1.0, 0.0)
            predFake = torch.where(discPredFake.detach() > 0.0, 1.0, 0.0)
            accReal = (predReal == realLabels).float().mean() * 100
            accFake = (predFake == fakeLabels).float().mean() * 100

            # Store discriminator loss and accuracies, and generator loss
            train_disc_loss.append(discLoss.item())
            # Maybe add  torch.nn.utils.clip_grad_norm_(self.discriminator.parameters(), max_norm=1.0)
            train_gen_loss.append(genLoss.item())  
            # Maybe add  torch.nn.utils.clip_grad_norm_(self.generator.parameters(), max_norm=1.0)         
            train_disc_real_acc.append(accReal.item())
            train_disc_fake_acc.append(accFake.item())

            iteration += 1

            # If at iteration interval, print current results of model
            if iteration % interval == 0:
              print("Epoch: %03d/%03d | Batch %03d/%03d | Gen/Disc Loss: %.4f/%.4f"
                        % (epoch + 1, num_epochs, iteration, len(trainLoader),
                            genLoss.item(), discLoss.item()))
              
        # Store images created by generator every 100 epoch
        if epoch % 10 == 0:
          # Use constant input w vector and noise to see how our generator improves on same data
          tempImg = generator.forward(inputW, inputNoise)
          imshow(tempImg, title=None)
          
        print("Time elapsed: %.2f min" % ((time.time() - start_time) / 60))

    print('Finished Training')
    end_time = time.time()
    elapsed_time = end_time - start_time
    print("Total time elapsed: {:.2f} seconds".format(elapsed_time))

    return train_gen_loss, train_disc_loss, train_disc_real_acc, train_disc_fake_acc, newImgs

In [None]:
from importnb import Notebook

with Notebook():
    from flight_gan_model import Generator, Discriminator

gen_net = Generator()
disc_net = Discriminator()

train_network(gen_net.SynthesisNet, disc_net, gen_net.MappingNet, 100, 'cuda:0', train_loader, 30)

tensor([[-3.2102]], device='cuda:0', grad_fn=<AddmmBackward>)
tensor([[-3.1839]], device='cuda:0', grad_fn=<AddmmBackward>)
tensor([[-3.1544]], device='cuda:0', grad_fn=<AddmmBackward>)
tensor([[-3.1275]], device='cuda:0', grad_fn=<AddmmBackward>)
tensor([[-3.0992]], device='cuda:0', grad_fn=<AddmmBackward>)
tensor([[-3.0704]], device='cuda:0', grad_fn=<AddmmBackward>)
tensor([[-3.0421]], device='cuda:0', grad_fn=<AddmmBackward>)
tensor([[-3.0156]], device='cuda:0', grad_fn=<AddmmBackward>)
tensor([[-2.9879]], device='cuda:0', grad_fn=<AddmmBackward>)
tensor([[-2.9601]], device='cuda:0', grad_fn=<AddmmBackward>)
tensor([[-2.9334]], device='cuda:0', grad_fn=<AddmmBackward>)
tensor([[-2.9056]], device='cuda:0', grad_fn=<AddmmBackward>)
tensor([[-2.8773]], device='cuda:0', grad_fn=<AddmmBackward>)
tensor([[-2.8498]], device='cuda:0', grad_fn=<AddmmBackward>)
tensor([[-2.8245]], device='cuda:0', grad_fn=<AddmmBackward>)
tensor([[-2.7974]], device='cuda:0', grad_fn=<AddmmBackward>)
tensor([

KeyboardInterrupt: 