In [None]:
# bs = 32
# d_latent = 512
# imgsz = 32
# lr = 1e-3
# map lr = 1e-5
# training steps = 150k

def train_discriminator(gen, disc, map, inputs, latent_dim, criterion, device):
  """ Trains the discrinator of the GAN model.

    Args:
        gen: The generator of the model currently being trained
        disc: The discriminator of the model currently being trained
        map: The mapping network of the model currently being trained
        inputs: The real images found in our data set
        latent_dim: The latent dimension of our noise vector
        criterion: The loss function used for our model
        device: The device used for cuda
    """

  batchSize = inputs.size(0)

  # Get real images from data set
  real_images = inputs.to(device)
  realLabels = torch.ones(batchSize, device = device)

  #noise = torch.randn(batchSize, latent_dim, 1, 1, device = device)
  # Create random noise vector
  inputNoise = []
  res = 4
  for x in range(7):
    if x == 0:
      noiseOne = None
    else:
      noiseOne = torch.randn(batchSize, 1, res, res, device = device)
    noiseTwo = torch.randn(batchSize, 1, res, res, device = device)
    inputNoise.append((noiseOne, noiseTwo))
    res = res * 2

  # Get w from mapping network 
  w = ___

  # Create fake images with generator
  fake_images = gen.forward(w, inputNoise)
  fakeLabels = torch.zeros(batchSize, device = device)

  # Get discriminator loss on real images
  discPredReal = disc.forward(real_images)#.view(-1)
  realLoss = criterion(discPredReal, realLabels) #they use different loss function (wasserstein)

  # Get discriminator loss on fake images
  discPredFake = disc.forward(fake_images)#.view(-1)
  fakeLoss = criterion(discPredFake, fakeLabels) #they use different loss function (wasserstein)

  # Different Loss
  # realLoss = F.relu(1 - discPredReal).mean()
  # fakeLoss = F.relu(1 + discPredFake).mean()
  # discLoss = realLoss + fakeLoss

  # Take average of losses
  discLoss = 0.5 * (realLoss + fakeLoss)

  return realLabels, fakeLabels, discLoss, discPredReal

def train_generator(gen, disc, inputs, criterion, device):
  """ Trains the generator of the GAN model.

    Args:
        gen: The generator of the model currently being trained
        disc: The discriminator of the model currently being trained
        map: The mapping network of the model currently being trained
        inputs: The fake images generated by our generator in the discriminator training
        realLabels: A tensor of ones to label the images as real
        criterion: The loss function used for our model
        device: The device used for cuda
    """ 

  batchSize = inputs.size(0)
  realLabels = torch.ones(batchSize, device = device)

  # Create random noise vector
  #noise = torch.randn(batchSize, latent_dim, 1, 1, device = device)
  inputNoise = []
  res = 4
  for x in range(7):
    if x == 0:
      noiseOne = Null
    else:
      noiseOne = torch.randn(batchSize, 1, res, res, device = device)
    noiseTwo = torch.randn(batchSize, 1, res, res, device = device)
    inputNoise.append((noiseOne, noiseTwo))
    res = res * 2
  
  # Get w from mapping network
  w = ___

  # Create fake images with generator
  fake_images = gen.forward(w, inputNoise)

  # Get discriminator loss on fake images with real labels instead of fake labels
  discPredFake = disc.forward(fake_images).view(-1)
  genLoss = criterion(discPredFake, realLabels) # they use different loss function (mean)

  # Differenet Loss
  # genLoss = -discPredFake.mean()

  return genLoss, discPredFake, fake_images

  
def train_network(generator, discriminator, mapping, latent_dim = 512, device, trainLoader, num_epochs, lrGAN = 1e-3, lrMap = 1e-5, interval):
""" Trains the model entirely (both the generator and discriminator) and stores results.

    Args:
        gen: The generator of the model currently being trained
        disc: The discriminator of the model currently being trained
        map: The mapping network of the model currently being trained
        latent_dim: The latent dimension of our noise vector
        device: The device being used for cuda
        trainLoader: The training data for our model to access during training
        num_epochs: The number of epochs to run our model for
        lrGAN: The learning rate of our generator and discriminator
        lrMap: The learning rate of our mapping network
        interval: The iteration interval at which we print our model's loss values
    """
    
    # Set manual seed for reproducible results
    torch.manual_seed(1)

    # Criterion and optimizers
    criterion = nn.BCEWithLogitsLoss() # necessary if we use our loss function and not theirs
    gen_optimizer =  optim.Adam(generator.parameters(), lr= lrGAN, weight_decay=1e-3)
    disc_optimizer =  optim.Adam(discriminator.parameters(), lr=lrGAN, weight_decay=1e-3)
    map_optimizer = optim.Adam(mapping.parameters(), lr = lrMap, weight_decay = 1e-3)

    # Training metrics
    train_gen_loss = []
    train_disc_loss = []
    train_disc_real_acc = []
    train_disc_fake_acc = []
    newImgs = []

    ########## SENDING TO CUDA ############
    if cuda:
        generator = generator.to('cuda:0')
        discriminator = discriminator.to('cuda:0')
        mapping = mapping.to('cuda:0')

    start_time = time.time()

    # Train loop
    for epoch in range(num_epochs):
        generator.train()
        discriminator.train()
        mapping.train()
        iteration = 0
    
        for imgs, _ in trainLoader:

            # Train Discriminator
            disc_optimizer.zero_grad()
            realLabels, fakeLabels, discLoss, discPredReal = train_discriminator(generator, discriminator, mapping, imgs, latent_dim, criterion, device)
            discLoss.backward()            
            disc_optimizer.step()
        
            # Train Generator and Mapping Network
            gen_optimizer.zero_grad()
            map_optimizer.zero_grad()
            genLoss, discPredFake, fake_images = train_generator(generator, discriminator, mapping, imgs, criterion, device)
            genLoss.backward()
            gen_optimizer.step()
            map_optimizer.step()

            # Calculate accuracy of model
            predReal = torch.where(discPredReal.detach() > 0.0, 1.0, 0.0)
            predFake = torch.where(discPredFake.detach() > 0.0, 1.0, 0.0)
            accReal = (predReal == realLabels).float().mean() * 100
            accFake = (predFake == fakeLabels).float().mean() * 100

            # Store discriminator loss and accuracies, and generator loss
            train_disc_loss.append(discLoss.item())
            # Maybe add  torch.nn.utils.clip_grad_norm_(self.discriminator.parameters(), max_norm=1.0)
            train_gen_loss.append(genLoss.item())  
            # Maybe add  torch.nn.utils.clip_grad_norm_(self.generator.parameters(), max_norm=1.0)         
            train_disc_real_acc.append(accReal.item())
            train_disc_fake_acc.append(accFake.item())

            iteration += 1

            # If at iteration interval, print current results of model
            if iteration % interval == 0:
              print("Epoch: %03d/%03d | Batch %03d/%03d | Gen/Disc Loss: %.4f/%.4f"
                        % (epoch + 1, num_epochs, iteration, len(trainLoader),
                            genLoss.item(), discLoss.item()))
              
        # Store images created by generator every 100 epoch
        if epoch % 100 == 0:
          newImgs.append(torchvision.utils.make_grid(fake_images, padding = 2, normalzie = True))
        print("Time elapsed: %.2f min" % ((time.time() - start_time) / 60))

    print('Finished Training')
    end_time = time.time()
    elapsed_time = end_time - start_time
    print("Total time elapsed: {:.2f} seconds".format(elapsed_time))

    return train_gen_loss, train_disc_loss, train_disc_real_acc, train_disc_fake_acc, newImgs