In [1]:
import os.path
import time
import re

import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
# cuda_available = torch.cuda.is_available()
# dev = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

device = "cuda" if torch.cuda.is_available() else "cpu"
Tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.Tensor
print(f"Using {device} device")

# torch.cuda.empty_cache()

# Tensor = torch.Tensor
# device = torch.device("cpu")

Using cuda device


In [3]:
import torchvision
import tqdm

In [4]:
from torch.utils.data import Dataset
from PIL import Image
import numpy as np
import torchvision.transforms as transforms

FFHQ_DIR = "/home/andresokol/data/compressed"
STRUCTURE_ROOT = f"/home/andresokol/data/orientation"
RENDERED_ROOT = f"/home/andresokol/rendered/v1"
MASK_ROOT = f"/home/andresokol/data/masks"

In [5]:
class DownsampleBlock(nn.Module):
    def __init__(self):
        super().__init__()

        def _block(channels_in, channels_out):
            return [
                nn.Conv2d(channels_in, channels_out, kernel_size=3, padding=1, stride=2),
                nn.ReLU(),
            ]

        self.layers = nn.Sequential(
            *_block(3, 6),
            *_block(6, 12),
            *_block(12, 24),
            *_block(24, 48),
            *_block(48, 48),
        )
#         self.layers = nn.Sequential(
#             *_block(3, 3),
#             *_block(3, 3),
#             *_block(3, 3),
#             *_block(3, 3),
#             *_block(3, 3),
#         )

    def forward(self, x: Tensor) -> Tensor:
        return self.layers(x)

In [6]:
# PyTorch version of ResNet
# https://github.com/pytorch/vision/blob/cddad9ca3822011548e18342f52a3e9f4724c2dd/torchvision/models/resnet.py#L88


class ResBlk(nn.Module):
    def __init__(self, channels) -> None:
        super().__init__()
        self.conv_1 = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1)
        self.batch_norm_1 = nn.BatchNorm2d(channels)
        self.conv_2 = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1)
        self.batch_norm_2 = nn.BatchNorm2d(channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x: Tensor) -> Tensor:
        out = self.conv_1(x)
        out = self.batch_norm_1(out)
        out = self.relu(out)

        out = self.conv_2(out)
        out = self.batch_norm_2(out)

        out += x
        out = self.relu(out)

        return out

In [7]:
# E_f = nn.Sequential(
#     DownsampleBlock(),
#     *(ResBlk() for _ in range(4)),
# ).to(device)

# E_r = nn.Sequential(
#     DownsampleBlock(),
#     *(ResBlk() for _ in range(4)),
# ).to(device)

# E_ = nn.Sequential(
#     ResBlk(),
#     ResBlk(),
# ).to(device)

In [8]:
class GeneratorFake(nn.Module):
    def __init__(self):
        super().__init__()

        def _upscale_conv(channels_in, channels_out):
            return (
                nn.BatchNorm2d(channels_in),
                nn.ConvTranspose2d(channels_in, channels_out, kernel_size=3, stride=2, padding=1, output_padding=1),  # is this correct?
                nn.LeakyReLU(),
            )

        self.resblk1 = ResBlk(48)
        self.resblk2 = ResBlk(48)
        self.upscale_layers = nn.Sequential(
            *_upscale_conv(48, 48),
            *_upscale_conv(48, 24),
            *_upscale_conv(24, 12),
            *_upscale_conv(12, 6),
            *_upscale_conv(6, 3),
        )
#         self.upscale_layers = nn.Sequential(
#             *_upscale_conv(3, 3),
#             *_upscale_conv(3, 3),
#             *_upscale_conv(3, 3),
#             *_upscale_conv(3, 3),
#             *_upscale_conv(3, 3),
#         )

        self.out_conv = nn.Conv2d(3, 3, kernel_size=3, stride=1, padding=1)

    def forward(self, x: Tensor) -> Tensor:
        out = self.resblk1(x)
        out = self.resblk2(out)
        out = self.upscale_layers(out)
        # print(out.shape)

        # for i in range(5):
        #     out = self.upscale_layers[i](out)
        # print(out.shape)

        return torch.tanh(self.out_conv(out))


# G_f = GeneratorFake().to(device)
# G_f = nn.Sequential(
#     *(ResBlk() for _ in range(2)),
#     *(nn.ConvTranspose2d(3, 3, kernel_size=3, stride=2, padding=1) for _ in range(5)),
#     # nn.Conv2d(3, 3, kernel_size=3, stride=1, padding=1),
# )

In [9]:
# from https://github.com/NVlabs/SPADE

# Creates SPADE normalization layer based on the given configuration
# SPADE consists of two steps. First, it normalizes the activations using
# your favorite normalization method, such as Batch Norm or Instance Norm.
# Second, it applies scale and bias to the normalized output, conditioned on
# the segmentation map.
# The format of |config_text| is spade(norm)(ks), where
# (norm) specifies the type of parameter-free normalization.
#       (e.g. syncbatch, batch, instance)
# (ks) specifies the size of kernel in the SPADE module (e.g. 3x3)
# Example |config_text| will be spadesyncbatch3x3, or spadeinstance5x5.
# Also, the other arguments are
# |norm_nc|: the #channels of the normalized activations, hence the output dim of SPADE
# |label_nc|: the #channels of the input semantic map, hence the input dim of SPADE
class SPADE(nn.Module):
    def __init__(self, config_text, norm_nc, label_nc):
        super().__init__()

        assert config_text.startswith('spade')
        parsed = re.search('spade(\D+)(\d)x\d', config_text)
        param_free_norm_type = str(parsed.group(1))
        ks = int(parsed.group(2))

        if param_free_norm_type == 'instance':
            self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False)
        # elif param_free_norm_type == 'syncbatch':
        #     self.param_free_norm = SynchronizedBatchNorm2d(norm_nc, affine=False)
        elif param_free_norm_type == 'batch':
            self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=False)
        else:
            raise ValueError('%s is not a recognized param-free norm type in SPADE'
                             % param_free_norm_type)

        # The dimension of the intermediate embedding space. Yes, hardcoded.
        nhidden = 128

        pw = ks // 2
        self.mlp_shared = nn.Sequential(
            nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw),
            nn.ReLU()
        )
        self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)
        self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)

    def forward(self, x, segmap):

        # Part 1. generate parameter-free normalized activations
        normalized = self.param_free_norm(x)

        # Part 2. produce scaling and bias conditioned on semantic map
        segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest')
        actv = self.mlp_shared(segmap)
        gamma = self.mlp_gamma(actv)
        beta = self.mlp_beta(actv)

        # apply scale and bias
        out = normalized * (1 + gamma) + beta

        return out


# ResNet block that uses SPADE.
# It differs from the ResNet block of pix2pixHD in that
# it takes in the segmentation map as input, learns the skip connection if necessary,
# and applies normalization first and then convolution.
# This architecture seemed like a standard architecture for unconditional or
# class-conditional GAN architecture using residual block.
# The code was inspired from https://github.com/LMescheder/GAN_stability.
class SPADEResnetBlock(nn.Module):
    def __init__(self, fin, fout):#, opt):
        super().__init__()
        # Attributes
        self.learned_shortcut = (fin != fout)
        fmiddle = min(fin, fout)

        # create conv layers
        self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=1)
        self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1)
        if self.learned_shortcut:
            self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False)

        # apply spectral norm if specified
        # if 'spectral' in opt.norm_G:
        #     self.conv_0 = spectral_norm(self.conv_0)
        #     self.conv_1 = spectral_norm(self.conv_1)
        #     if self.learned_shortcut:
        #         self.conv_s = spectral_norm(self.conv_s)

        # define normalization layers
#         spade_config_str = opt.norm_G.replace('spectral', '')
        spade_config_str = "spadebatch3x3"
        self.norm_0 = SPADE(spade_config_str, fin, 3) #opt.semantic_nc)
        self.norm_1 = SPADE(spade_config_str, fmiddle, 3) #opt.semantic_nc)
        if self.learned_shortcut:
            self.norm_s = SPADE(spade_config_str, fin, 3) #opt.semantic_nc)

    # note the resnet block with SPADE also takes in |seg|,
    # the semantic segmentation map as input
    def forward(self, x, seg):
        x_s = self.shortcut(x, seg)

        dx = self.conv_0(self.actvn(self.norm_0(x, seg)))
        dx = self.conv_1(self.actvn(self.norm_1(dx, seg)))

        out = x_s + dx

        return out

    def shortcut(self, x, seg):
        if self.learned_shortcut:
            x_s = self.conv_s(self.norm_s(x, seg))
        else:
            x_s = x
        return x_s

    def actvn(self, x):
        return F.leaky_relu(x, 2e-1)

In [10]:
class GeneratorReal(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.resblk1 = ResBlk()
        self.resblk2 = ResBlk()

        self.spade1 = SPADEResnetBlock(3, 3)
        self.spade2 = SPADEResnetBlock(3, 3)
        self.spade3 = SPADEResnetBlock(3, 3)
        self.spade4 = SPADEResnetBlock(3, 3)
        self.spade5 = SPADEResnetBlock(3, 3)
        self.upsample = nn.Upsample(scale_factor=2)
    
    def forward(self, x: Tensor, seg: Tensor) -> Tensor:
        out = self.resblk1(x)
        out = self.resblk2(out)
        
        out = self.spade1(out, seg)
        out = self.upsample(out)
        
        out = self.spade2(out, seg)
        out = self.upsample(out)
        
        out = self.spade3(out, seg)
        out = self.upsample(out)
        
        out = self.spade4(out, seg)
        out = self.upsample(out)
        
        out = self.spade5(out, seg)
        out = self.upsample(out)
        
        return torch.tanh(out)

# G_r = GeneratorReal().to(device)

In [11]:
class EncoderMaterial(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.layers = nn.Sequential(
            DownsampleBlock(),
            nn.AvgPool2d(),
        )
    
    def forward(self, x: Tensor, mask: Tensor):
        out = self.layers(x)
#         print(out.shape)
#         print(mask.shape)
#         print(Tensor([0]).shape)

        return torch.where(mask, out, Tensor([0]))

# E_m = EncoderMaterial().to(device)

In [12]:
# emb = torch.rand((1, 3, 16, 16))
# x = torch.rand((1, 3, 512, 512))

# mask = torch.rand((1, 1, 512, 512)) > 0.5
# mask

# G_r(emb, E_m(x)).shape
# E_m(x, mask)
# x = torch.rand((1, 3, 1, 1))
# x = x.expand(-1, -1, 4, 4)
# x

In [13]:
# VGG architecter, used for the perceptual loss using a pretrained VGG network
class VGG19(torch.nn.Module):
    def __init__(self, requires_grad=False):
        super().__init__()
        vgg_pretrained_features = torchvision.models.vgg19(pretrained=True).features
        self.slice1 = torch.nn.Sequential()
        self.slice2 = torch.nn.Sequential()
        self.slice3 = torch.nn.Sequential()
        self.slice4 = torch.nn.Sequential()
        self.slice5 = torch.nn.Sequential()
        for x in range(2):
            self.slice1.add_module(str(x), vgg_pretrained_features[x])
        for x in range(2, 7):
            self.slice2.add_module(str(x), vgg_pretrained_features[x])
        for x in range(7, 12):
            self.slice3.add_module(str(x), vgg_pretrained_features[x])
        for x in range(12, 21):
            self.slice4.add_module(str(x), vgg_pretrained_features[x])
        for x in range(21, 30):
            self.slice5.add_module(str(x), vgg_pretrained_features[x])
        if not requires_grad:
            for param in self.parameters():
                param.requires_grad = False

    def forward(self, X):
        h_relu1 = self.slice1(X)
        h_relu2 = self.slice2(h_relu1)
        h_relu3 = self.slice3(h_relu2)
        h_relu4 = self.slice4(h_relu3)
        h_relu5 = self.slice5(h_relu4)
        out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
        return out

# Perceptual loss that uses a pretrained VGG network
class VGGLoss(nn.Module):
    transforms = transforms.Compose([
        transforms.Normalize(mean=[0., 0., 0.], std=[1/0.5, 1/0.5, 1/0.5]),
        transforms.Normalize(mean=[-0.5, -0.5, -0.5], std=[1, 1, 1]),
        transforms.Resize((224, 224)), # better without centercrop?
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    
    def __init__(self):
        super(VGGLoss, self).__init__()
        self.vgg = VGG19().to(device)
        self.criterion = nn.L1Loss()
        self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0]

    def forward(self, x, y):
        x_resized = self.transforms(x)
        y_resized = self.transforms(y)
        
        x_vgg, y_vgg = self.vgg(x_resized), self.vgg(y_resized)
        loss = 0
        for i in range(len(x_vgg)):
            loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())
        return loss

In [14]:
# https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/stargan/models.py

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()

        def discriminator_block(in_filters, out_filters, normalization=True):
            layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]
            if normalization:
                layers.append(nn.InstanceNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.layers = nn.Sequential(
            *discriminator_block(3, 64, normalization=False),
            *discriminator_block(64, 128),
            *discriminator_block(128, 256),
            *discriminator_block(256, 512),
            nn.ZeroPad2d((1, 0, 1, 0)),
            nn.Conv2d(512, 1, kernel_size=4, padding=1, bias=False),
        )
        self.layers2 = nn.Sequential(
            nn.Flatten(),
            #nn.Linear(1024, 1), # 512x512
            nn.Linear(256, 1),  # 256x256
        )

    def forward(self, img: Tensor) -> Tensor:
        # combined = torch.cat((img_a, img_b), dim=1)
        # out = img
        # # print("Discriminator start", out.shape)
        #
        # for layer in self.layers:
        #     out = layer(out)
        #     # print(layer, out.shape)

        out = self.layers(img)
        out = self.layers2(out)
        
        return out

In [15]:
# class EmbeddingDiscriminator(nn.Module):
#     def __init__(self):
#         super().__init__()
#
#         self.layers = nn.Sequential(
#             nn.Flatten(),
#             nn.Linear(3 * 16 * 16),
#             nn.Linear()
#         )

In [16]:
class Options:
    img_height = 256
    img_width = 256
    learning_rate = 0.0001  # from paper
    starting_epoch = 0
    total_epochs = 300
    batch_size = 8
    cpu_count = 2


opt = Options()

In [17]:
class ImageDataset(Dataset):
    transform = transforms.Compose([
        transforms.Resize((opt.img_height, opt.img_width), transforms.InterpolationMode.BICUBIC),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])

    def __init__(self, mode="train"):
        self.files = []

        for i in range(70):
            for j in range(1000):
                img_dir = f"{i:02}000"
                img_name = f"{i:02}{j:03}"
                if os.path.exists(f"{FFHQ_DIR}/{img_dir}/{img_name}.jpg"):
                    if os.path.exists(f"{RENDERED_ROOT}/{img_dir}_{img_name}_base.png"):
                        if os.path.exists(f"{RENDERED_ROOT}/{img_dir}_{img_name}_structure.png"):
                            self.files.append((img_dir, img_name))

        print(f"Read {len(self.files)} images")
        # self.files = sorted(glob.glob(os.path.join(root, mode) + "/*.*"))
        # if mode == "train":
        #     self.files.extend(sorted(glob.glob(os.path.join(root, "test") + "/*.*")))

    def __getitem__(self, index):
        img_dir, img_name = self.files[index % len(self.files)]
        
        photo = Image.open(f"{FFHQ_DIR}/{img_dir}/{img_name}.jpg")
        photo.putalpha(255)

        render_base = Image.open(f"{RENDERED_ROOT}/{img_dir}_{img_name}_base.png")
        render_struct = Image.open(f"{RENDERED_ROOT}/{img_dir}_{img_name}_structure.png")

        render_base = Image.alpha_composite(photo, render_base).convert("RGB")
        render_struct = Image.alpha_composite(photo, render_struct).convert("RGB")
        
        photo_structure = Image.open(f"{STRUCTURE_ROOT}/{img_dir}/{img_name}.png")
        
        mask_structure = Image.open(f"{MASK_ROOT}/{img_dir}/{img_name}.jpg")
        mask_structure = mask_structure.resize((opt.img_height, opt.img_width))
        mask_structure = np.array(mask_structure) > 128

        return {
            "photo_base": self.transform(photo.convert("RGB")),
            "photo_struct": self.transform(photo_structure),
            "mask": torch.from_numpy(mask_structure).permute((2, 0, 1)),
            "render_base": self.transform(render_base),
            "render_struct": self.transform(render_struct),
        }

    def __len__(self):
        return len(self.files)

In [18]:
from torch.utils.data import DataLoader

dataloader = DataLoader(
    ImageDataset(),
    batch_size=opt.batch_size,
    num_workers=opt.cpu_count,
    shuffle=True,
)

Read 10845 images


In [19]:
from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter()

In [20]:
def weights_init_normal(m):
    classname = m.__class__.__name__
    print(classname)
    if hasattr(m, 'weight') and m.weight is not None and classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)
    elif hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
        if hasattr(m, 'bias') and m.bias is not None:
            torch.nn.init.constant_(m.bias.data, 0.0)

In [22]:
def gan_loss(generated_scores: torch.Tensor,
             real_scores: torch.Tensor,
             for_generator: bool) -> torch.Tensor:
    if for_generator:
        return F.logsigmoid(-generated_scores).mean()
    
    return -F.logsigmoid(real_scores).mean() - F.logsigmoid(-generated_scores).mean()

In [23]:
def discriminator_accuracy(generated_scores: torch.Tensor,
                           real_scores: torch.Tensor) -> float:
    return torch.cat((generated_scores < 0, real_scores > 0)).float().mean()

In [24]:
def to_imgs(x: Tensor):
    # normalize: x -> (x - mu) / sigma
    # denormalize: x -> x * sigma + mu
    img = x.cpu().detach().permute((0, 2, 3, 1)).numpy()
    img = (img + 1) * 128
    return img.astype("uint8")

In [25]:
def image_grid_to_writer(writer, x, label, iteration):
    grid = torchvision.utils.make_grid(x, normalize=True)
    writer.add_image(label, grid, iteration)

In [26]:
class MySuperNetwork(nn.Module):
    lambda_vgg = 10
    
    def __init__(self):
        super().__init__()
        self.E_f =  nn.Sequential(
            DownsampleBlock(),
            *(ResBlk(48) for _ in range(4)),
        )
        self.E_shared = nn.Sequential(
            ResBlk(48),
            ResBlk(48),
        )
        self.G_f = GeneratorFake()
        self.D_f = Discriminator()
        
        self.vgg_loss = VGGLoss()

    def create_optimizers(self, opt):
        gan_params = list(self.E_f.parameters())
        gan_params += list(self.E_shared.parameters())
        gan_params += list(self.G_f.parameters())

        disc_params = list(self.D_f.parameters())
        
        gan_optimizer = torch.optim.Adam(gan_params, lr=opt.learning_rate)
        disc_optimizer = torch.optim.Adam(disc_params, lr=opt.learning_rate)

        return gan_optimizer, disc_optimizer

    def init_weights(self):
        self.E_f.apply(weights_init_normal)
        self.E_shared.apply(weights_init_normal)
        self.G_f.apply(weights_init_normal)
        self.D_f.apply(weights_init_normal)
    
    def _fake_branch(self, input_images):
        embedding = self.E_shared(self.E_f(input_images))
        generated = self.G_f(embedding)
        return generated
    
    def fake_branch_generator(self, input_images, ground_truth):
        generated = self._fake_branch(input_images)
        generated_scores = self.D_f(generated)

        losses = dict()
        losses["branch_F_generator_loss"] = gan_loss(generated_scores, None, for_generator=True)
        losses["VGG_F_loss"] = self.vgg_loss(generated, ground_truth) * self.lambda_vgg

        return generated, losses

    def fake_branch_discriminator(self, input_images, ground_truth):
        with torch.no_grad():
            generated = self._fake_branch(input_images)
            generated = generated.detach()
            generated.requires_grad_() # eh?

        generated_scores = self.D_f(generated)
        truth_scores = self.D_f(ground_truth)

        losses = {
            "D_F_loss": gan_loss(generated_scores, truth_scores, for_generator=False),
            "D_F_accuracy": discriminator_accuracy(generated_scores.detach(), truth_scores.detach()), # eh? detach?
        }
        return generated, losses

        
network = MySuperNetwork().to(device)

In [27]:
if opt.starting_epoch == 0:
    network.init_weights()

Conv2d
ReLU
Conv2d
ReLU
Conv2d
ReLU
Conv2d
ReLU
Conv2d
ReLU
Sequential
DownsampleBlock
Conv2d
BatchNorm2d
Conv2d
BatchNorm2d
ReLU
ResBlk
Conv2d
BatchNorm2d
Conv2d
BatchNorm2d
ReLU
ResBlk
Conv2d
BatchNorm2d
Conv2d
BatchNorm2d
ReLU
ResBlk
Conv2d
BatchNorm2d
Conv2d
BatchNorm2d
ReLU
ResBlk
Sequential
Conv2d
BatchNorm2d
Conv2d
BatchNorm2d
ReLU
ResBlk
Conv2d
BatchNorm2d
Conv2d
BatchNorm2d
ReLU
ResBlk
Sequential
Conv2d
BatchNorm2d
Conv2d
BatchNorm2d
ReLU
ResBlk
Conv2d
BatchNorm2d
Conv2d
BatchNorm2d
ReLU
ResBlk
BatchNorm2d
ConvTranspose2d
LeakyReLU
BatchNorm2d
ConvTranspose2d
LeakyReLU
BatchNorm2d
ConvTranspose2d
LeakyReLU
BatchNorm2d
ConvTranspose2d
LeakyReLU
BatchNorm2d
ConvTranspose2d
LeakyReLU
Sequential
Conv2d
GeneratorFake
Conv2d
LeakyReLU
Conv2d
InstanceNorm2d
LeakyReLU
Conv2d
InstanceNorm2d
LeakyReLU
Conv2d
InstanceNorm2d
LeakyReLU
ZeroPad2d
Conv2d
Sequential
Flatten
Linear
Sequential
Discriminator


In [28]:
gan_optimizer, disc_optimizer = network.create_optimizers(opt)
log_iter = 0

In [29]:
discriminator_freq = 16

In [30]:
# https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html
# https://github.com/hse-ds/iad-applied-ds/blob/master/2022/seminars/sem03/sem03-gan-task.ipynb
# https://colab.research.google.com/github/yandexdataschool/mlhep2019/blob/master/notebooks/day-6/06_GAN_faces_solution.ipynb#scrollTo=-Z7YAPGxp2Xl

losses_f = {}
losses_d_f = {}

for epoch in range(opt.starting_epoch, opt.total_epochs):
    with tqdm.auto.tqdm(total=len(dataloader), unit="batch", desc=f"Epoch {epoch}") as pbar:
        for i, batch in enumerate(dataloader):
            render_base = batch["render_base"].to(device)
            render_struct = batch["render_struct"].to(device)
            
            if i % discriminator_freq != 0:
                gan_optimizer.zero_grad()
                generated_f, losses_f = network.fake_branch_generator(render_struct, render_base)
                g_f_loss = sum(losses_f.values())
                g_f_loss.backward()
                gan_optimizer.step()
            else:
                disc_optimizer.zero_grad()
                generated_f, losses_d_f = network.fake_branch_discriminator(render_struct, render_base)
                losses_d_f["D_F_loss"].backward()
                disc_optimizer.step()
                
                if losses_d_f["D_F_accuracy"].item() > 0.9 and discriminator_freq < 32:
                    discriminator_freq *= 2
                elif losses_d_f["D_F_accuracy"].item() < 0.8 and discriminator_freq > 2:
                    discriminator_freq /= 2

            for d in [losses_f, losses_d_f]:
                for key, value in d.items():
                    writer.add_scalar(key, value.item(), log_iter)

            if i % 20 == 0:
                image_grid_to_writer(writer, generated_f, "generated fake", log_iter)
                image_grid_to_writer(writer, render_base, "true fake", log_iter)

            pbar.set_postfix(discriminator_freq=discriminator_freq)
            pbar.update(1)
            log_iter += 1

Epoch 0:   0%|          | 0/1356 [00:00<?, ?batch/s]

Epoch 1:   0%|          | 0/1356 [00:00<?, ?batch/s]

Epoch 2:   0%|          | 0/1356 [00:00<?, ?batch/s]

Epoch 3:   0%|          | 0/1356 [00:00<?, ?batch/s]

Epoch 4:   0%|          | 0/1356 [00:00<?, ?batch/s]

Epoch 5:   0%|          | 0/1356 [00:00<?, ?batch/s]

Epoch 6:   0%|          | 0/1356 [00:00<?, ?batch/s]

Epoch 7:   0%|          | 0/1356 [00:00<?, ?batch/s]

Epoch 8:   0%|          | 0/1356 [00:00<?, ?batch/s]

Epoch 9:   0%|          | 0/1356 [00:00<?, ?batch/s]

Epoch 10:   0%|          | 0/1356 [00:00<?, ?batch/s]

Epoch 11:   0%|          | 0/1356 [00:00<?, ?batch/s]

KeyboardInterrupt: 