# Training Loop


In [None]:
for current_epoch in range(total_epochs):
    # Initialize epoch losses
    epoch_generator_loss = 0.0
    epoch_discriminator_loss = 0.0
    num_batches = 0

    # Iterate over batches in the training dataloader
    for real_images, labels, sketches, _ in tqdm(train_data_loader):
        # Move data to device
        real_images, labels, sketches = real_images.to(DEVICE), labels.to(DEVICE), sketches.to(DEVICE)

        # Forward pass through the generator
        generated_images = generator(sketches, labels.to(torch.long)).to(DEVICE)

        # Train the discriminator
        disc_real_output = discriminator(real_images, labels.to(torch.long)).reshape(-1).to(DEVICE)
        disc_fake_output = discriminator(generated_images.detach(), labels.to(torch.long)).reshape(-1)
        disc_real_loss = criterion(disc_real_output, torch.ones_like(disc_real_output))
        disc_fake_loss = criterion(disc_fake_output, torch.zeros_like(disc_fake_output))
        discriminator_loss = (disc_real_loss + disc_fake_loss) / 2

        # Backpropagation and optimization for discriminator
        discriminator.zero_grad()
        discriminator_loss.backward()
        discriminator_optimizer.step()

        # Train the generator
        gen_fake_output = discriminator(generated_images, labels.to(torch.long)).reshape(-1)
        gen_loss_adversarial = criterion(gen_fake_output, torch.ones_like(gen_fake_output))
        gen_loss_reconstruction = L1_LOSS(generated_images, real_images) * 100
        generator_loss = gen_loss_adversarial + gen_loss_reconstruction

        # Backpropagation and optimization for generator
        generator.zero_grad()
        generator_loss.backward()
        generator_optimizer.step()

        # Accumulate losses
        epoch_generator_loss += generator_loss.item()
        epoch_discriminator_loss += discriminator_loss.item()
        num_batches += 1

        # Logging and visualization
        if current_epoch % display_step == 0 and current_epoch > 0:
            # Calculate mean losses for the current step
            mean_generator_loss = epoch_generator_loss / num_batches
            mean_discriminator_loss = epoch_discriminator_loss / num_batches

            # Log losses
            wandb.log({"generator_loss_per_step": mean_generator_loss, "discriminator_loss_per_step": mean_discriminator_loss}, step=current_epoch)

            # Plot generated and real images
            plot_images_from_tensor(generated_images, name="fake_images")
            plot_images_from_tensor(real_images, name="real_images")

            # Compute and log Inception Score and FID Score
            inception_score_val = inception_score(generated_images.to("cpu"))
            fid_score_val = calculate_fid(real_images, generated_images, device='cpu')
            wandb.log({"inception_score": inception_score_val, "FID score": fid_score_val}, step=current_epoch)

            # Plot losses over time
            step_bins = 20
            x_axis = sorted([i * step_bins for i in range(len(generator_losses) // step_bins)] * step_bins)
            num_examples = (len(generator_losses) // step_bins) * step_bins
            plt.plot(
                range(num_examples // step_bins),
                torch.Tensor(generator_losses[:num_examples]).view(-1, step_bins).mean(1),
                label="Generator Loss",
            )
            plt.plot(
                range(num_examples // step_bins),
                torch.Tensor(discriminator_losses[:num_examples]).view(-1, step_bins).mean(1),
                label="Discriminator Loss",
            )
            plt.legend()
            plt.show()
        elif current_epoch == 0:
            print("Training has started, let it continue...")
        current_epoch += 1

    # Calculate average epoch losses
    avg_epoch_generator_loss = epoch_generator_loss / num_batches
    avg_epoch_discriminator_loss = epoch_discriminator_loss / num_batches

    # Log average epoch losses
    wandb.log({"generator_loss_per_epoch": avg_epoch_generator_loss, "discriminator_loss_per_epoch": avg_epoch_discriminator_loss}, step=current_epoch)

# Finish Weights & Biases run
wandb.finish()


In [None]:
torch.save(generator.state_dict(), 'generator.pth')