In [None]:
# bs = 32
# d_latent = 512
# imgsz = 32
# lr = 1e-3
# map lr = 1e-5
# training steps = 150k

def train_discriminator(gen, disc, inputs, latent_dim, criterion):
  batchSize = inputs.size(0)

  real_images = inputs.to(device)
  realLabels = torch.ones(batchSize, device = device)

  noise = torch.randn(batchSize, latent_dim, 1, 1, device = device)
  w = ___
  fake_images = gen.forward(w, noise)
  fakeLabels = torch.zeros(batchSize, device = device)

  discPredReal = disc.forward(real_images).view(-1)
  realLoss = criterion(discPredReal, realLabels) #they use different loss function (wasserstein)

  discPredFake = disc.forward(fake_images).view(-1)
  fakeLoss = criterion(discPredFake, fakeLabels) #they use different loss function (wasserstein)

  discLoss = 0.5 * (realLoss + fakeLoss)
  return realLabels, fakeLabels, fake_images, discLoss, discPredReal

def train_generator(gen, disc, inputs, realLabels, criterion):
  batchSize = inputs.size(0)
  realLabels = torch.ones(batchSize, device = device)

  noise = torch.randn(batchSize, latent_dim, 1, 1, device = device)
  w = ___
  fake_images = gen.forward(w, noise)
  discPredFake = model.discForward(fake_images).view(-1)
  genLoss = criterion(discPredFake, realLabels) # they use different loss function (mean)

  return genLoss, discPredFake

def train_network(model, latent_dim, device, trainLoader, num_epochs, lrGAN, lrMap):
    gen_optimizer =  optim.Adam(generator.parameters(), lr= lrGAN, weight_decay=1e-3)
    disc_optimizer =  optim.Adam(discriminator.parameters(), lr=lrGAN, weight_decay=1e-3)
    map_optimizer = optim.Adam(mapping.parameters(), lr = lrMap, weight_decay = 1e-3)

  # Training metrics
    train_gen_loss = []
    train_disc_loss = []
    train_disc_real_acc = []
    train_disc_fake_acc = []
    newImgs = []

    best_train_err = 1000
    best_val_err = 1000

    ########## SENDING TO CUDA ############
    if cuda:
        generator = generator.to('cuda:0')
        discriminator = discriminator.to('cuda:0')
        mapping = mapping.to('cuda:0')

    start_time = time.time()

    # Train loop
    for epoch in range(num_epochs):
        model.train()
        iteration = 0

        disc_optimizer.zero_grad()
        for imgs, _ in trainLoader:


            realLabels, fakeLabels, fake_images, discLoss, discPredReal = train_discriminator(generator, discriminator, imgs, latent_dim, criterion)

            discLoss.backward()

        train_disc_loss.append(discLoss.item())
        disc_optimizer.step()

        gen_optimizer.zero_grad()
        map_optimizer.zero_grad()

        for imgs, _ in trainLoader
            genLoss, discPredFake = train_generator(generator, discriminator, imgs, realLabels, criterion)
            genLoss.backward()
            train_gen_loss.append(genLoss.item())

        gen_optimizer.step()
        map_optimizer.step()

        predReal = torch.where(discPredReal.detach() > 0.0, 1.0, 0.0)
        predFake = torch.where(discPredFake.detach() > 0.0, 1.0, 0.0)
        accReal = (predReal == realLabels).float().mean() * 100
        accFake = (predFake == fakeLabels).float().mean() * 100



        train_disc_real_acc.append(accReal.item())
        train_disc_fake_acc.append(accFake.item())

        iteration += 1

        if(iteration % interval == 0):
            print("Epoch: %03d/%03d | Batch %03d/%03d | Gen/Disc Loss: %.4f/%.4f"
                        % (epoch + 1, num_epochs, iteration, len(trainLoader),
                            genLoss.item(), discLoss.item()))
        newImgs.append(torchvision.utils.make_grid(fake_images, padding = 2, normalzie = True))
        print("Time elapsed: %.2f min" % ((time.time() - start_time) / 60))

    print('Finished Training')
    end_time = time.time()
    elapsed_time = end_time - start_time
    print("Total time elapsed: {:.2f} seconds".format(elapsed_time))

    return train_gen_loss, train_disc_loss, train_disc_real_acc, train_disc_fake_acc, newImgs