In [1]:
from dataset import ContentOrientedDataset
from loss import ContentOrientedLoss
from model import DummyModel
import torch
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from collections import namedtuple
import tqdm

def display_images_with_titles(images, titles):
    num_images = len(images)
    fig, axs = plt.subplots(1, num_images, figsize=(15, 5))
    for i in range(num_images):
        axs[i].imshow(images[i].permute(1, 2, 0).numpy())  
        axs[i].set_title(titles[i])  
        axs[i].axis('off')  
    plt.show()




The dataset loads the images and the raw face and structure masks from the disk. If the face coords and structure masks are not available then it automatically generates them when being initialised. With raw masks I mean the masks without the priority criteria mentionned in the paper. 

In [3]:
dataset = ContentOrientedDataset(root='dataset', crop_size=256, normalize=False)
[img, unproc_face_mask, unproc_structure_mask], bpp = dataset[0]
display_images_with_titles([img, unproc_face_mask, unproc_structure_mask], ["image", "unprocessed_face_mask", "unprocessed_structure_mask"])

  plt.show()


The masks are processed to conform with the priority criteria directly when the loss is computed and the texture mask is generated.

In [4]:
model = DummyModel()
args = namedtuple("LossArgs", ["alpha", "beta", "delta", "epsilon", "gamma", "normalize_input_image", "gan_loss_type"])
args.alpha, args.beta, args.delta, args.epsilon, args.gamma = 0.01, 1, 0.0005, 0.3, 0.2
args.normalize_input_image, args.gan_loss_type = False, "non_saturating"
loss = ContentOrientedLoss(args, discriminator=model.discriminator)
processed_face_mask, processed_structure_mask, processed_texture_mask = loss.process_masks(unproc_face_mask, unproc_structure_mask)
display_images_with_titles([img, processed_face_mask, processed_structure_mask, processed_texture_mask], ["image", "processed_face_mask", "processed_structure_mask", "processed_texture_mask"])

Setting up Perceptual loss...




Loading model from: /home/ge95jud/GLeaD_HiFiC/src/loss/perceptual_similarity/weights/v0.1/vgg.pth
...[net-lin [vgg]] initialized
...Done


  plt.show()


The training loop is split into two phases. In the first phase only the encoder/decoder are optimized. In the second stage only the discriminator is optimized.

In [5]:
num_epochs = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataloader = DataLoader(dataset, batch_size=4)
optimizer_encoder = torch.optim.Adam(model.encoder.parameters(), lr=1e-4)
optimizer_decoder = torch.optim.Adam(model.decoder.parameters(), lr=1e-4)
optimizer_discriminator = torch.optim.Adam(model.discriminator.parameters(), lr=1e-4)
model = model.to(device)
phase = "ED"
for epoch in tqdm.tqdm(range(num_epochs)):
    for [orig_imgs, face_masks, structure_masks], bpp in dataloader:
        orig_imgs, face_masks, structure_masks = orig_imgs.to(device), face_masks.to(device), structure_masks.to(device)
        recon_imgs = model(orig_imgs)
        compression_loss, D_loss = loss(orig_imgs, recon_imgs, face_masks, structure_masks)
        if phase=="ED":
            compression_loss.backward()
            optimizer_encoder.step()
            optimizer_decoder.step()
            optimizer_encoder.zero_grad()
            optimizer_decoder.zero_grad()
            phase = "D"
        else: 
            D_loss.backward()
            optimizer_discriminator.step()
            optimizer_discriminator.zero_grad()
            phase = "ED"


100%|██████████| 10/10 [00:11<00:00,  1.14s/it]
