In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
from PIL import Image
import os
import torch
import pytorch_lightning as pl

from dataloaders.mask_generator import MaskGenerator
from dataloaders.images_dataset import ImagesDataset
from torch.utils.data import DataLoader

from models.pconv_unet import PConvUNet
from models.vgg16_extractor import VGG16Extractor

from loss.loss_compute import LossCompute

from utils.preprocessing import Preprocessor

from argparse import ArgumentParser

In [3]:
TRAIN_DIR = "../../Repos/image-inpainting/dataset/train_0"
VALID_DIR = "../../Repos/image-inpainting/dataset/test"
MASK_DIR = "../../../Repos/image-inpainting/dataset/irregular_mask/irregular_mask/disocclusion_img_mask/"

HEIGHT, WIDTH = 256,256
INVERT_MASK = False
NUM_WORKERS = 0
BS = 2
LR  = 0.0002

In [4]:
LOSS_FACTORS = {
    "loss_hole": 6.0, 
    "loss_valid": 1.0,
    "loss_perceptual":  1.0, # 0.05,
    "loss_style_out": 120.0,
    "loss_style_comp": 120.0,
    "loss_tv": 10.0  #0.1
}

In [5]:
class HParams(object):
    def __init__(self):
        self.train_dir = TRAIN_DIR
        self.valid_dir = VALID_DIR
        self.mask_dir = MASK_DIR
        
        self.height = HEIGHT
        self.width = WIDTH
        self.invert_mask = INVERT_MASK
        
        self.num_workers = NUM_WORKERS
        self.batch_size = BS
        
        self.learning_rate = LR
        self.loss_factors = LOSS_FACTORS
        
hparams = HParams()

In [6]:
class ImageInpaintingSystem(pl.LightningModule):

    def __init__(self, hparams):
        super(ImageInpaintingSystem, self).__init__()
        self.hparams = hparams
        self.pConvUNet = PConvUNet()
        
        self.vgg16extractor = VGG16Extractor().to("cuda")
        for param in self.vgg16extractor.parameters():
            param.requires_grad = False
        self.lossCompute = LossCompute(self.vgg16extractor, device="cuda")
        
        self.preprocess = Preprocessor("cuda")

    def forward(self, masked_img_tensor, mask_tensor):
        return self.pConvUNet(masked_img_tensor, mask_tensor)

    def training_step(self, batch, batch_nb):
        masked_img, mask, image  = batch
        
        img_tensor = self.preprocess.normalize(image.type(torch.float))
        mask_tensor = mask.type(torch.float).transpose(1, 3)
        masked_img_tensor = self.preprocess.normalize(masked_img.type(torch.float))
        
        ls_fn = self.lossCompute.loss_total(mask_tensor, self.hparams.loss_factors)
        output = self.forward(masked_img_tensor, mask_tensor)
        loss, dict_losses = ls_fn(img_tensor, output)

        dict_losses_train = {}
        for key, value in dict_losses.items():
            dict_losses_train[key] = value.item()

        self.logger.experiment.add_scalars('loss/train',dict_losses_train, self.global_step)
        self.logger.experiment.add_scalars('loss/overview',{'train_loss': loss}, self.global_step)
        
        return {'loss': loss,'progress_bar': {'train_loss': loss}} #,  'log': {'train_loss': loss}}

    def validation_step(self, batch, batch_nb):
        masked_img, mask, image = batch
        
        img_tensor = self.preprocess.normalize(image.type(torch.float))
        mask_tensor = mask.type(torch.float).transpose(1, 3)
        masked_img_tensor = self.preprocess.normalize(masked_img.type(torch.float))
        
        ls_fn = self.lossCompute.loss_total(mask_tensor, self.hparams.loss_factors)
        output = self.forward(masked_img_tensor, mask_tensor)
        loss, dict_losses = ls_fn(img_tensor, output)
        
        psnr = self.lossCompute.PSNR(img_tensor, output)
        if batch_nb == 0:
            res = np.clip(self.preprocess.unnormalize(output).detach().cpu().numpy(),0,1)
            original_img = np.clip(self.preprocess.unnormalize(masked_img_tensor).detach().cpu().numpy(),0,1)
            combined_imgs = []
            for i in range(image.shape[0]):
                combined_img = np.concatenate((original_img[i], res[i], image[i].detach().cpu().numpy()), axis=1)
                combined_imgs.append(combined_img)
            combined_imgs = np.concatenate(combined_imgs)
            self.logger.experiment.add_image('images', combined_imgs, dataformats='HWC') 
        dict_valid = {'val_loss': loss.mean(), 'psnr': psnr.mean(), **dict_losses}
        
        return dict_valid
    
    def validation_end(self, outputs):
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        avg_psnr = torch.stack([x['psnr'] for x in outputs]).mean()
        
        avg_loss_hole = torch.stack([x['loss_hole'] for x in outputs]).mean()
        avg_loss_valid = torch.stack([x['loss_valid'] for x in outputs]).mean()
        avg_loss_perceptual = torch.stack([x['loss_perceptual'] for x in outputs]).mean()
        avg_loss_style_out = torch.stack([x['loss_style_out'] for x in outputs]).mean()
        avg_loss_style_comp = torch.stack([x['loss_style_comp'] for x in outputs]).mean()
        avg_loss_tv = torch.stack([x['loss_tv'] for x in outputs]).mean()
        valid_dict = {
            "loss_hole": avg_loss_hole, 
            "loss_valid": avg_loss_valid,
            "loss_perceptual": avg_loss_perceptual,
            "loss_style_out": avg_loss_style_out,
            "loss_style_comp": avg_loss_style_comp,
            "loss_tv": avg_loss_tv
        }

        self.logger.experiment.add_scalars('loss/valid',valid_dict, self.global_step)
        self.logger.experiment.add_scalars('loss/overview',{'valid_loss': avg_loss}, self.global_step)

        tqdm_dict = {'valid_psnr': avg_psnr, 'valid_loss': avg_loss}
        return {'progress_bar': tqdm_dict} #, 'log': tqdm_dict}
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)

    @pl.data_loader
    def train_dataloader(self):
        mask_generator = MaskGenerator(self.hparams.mask_dir, self.hparams.height, self.hparams.width, invert_mask=self.hparams.invert_mask) 
        dataset = ImagesDataset(self.hparams.train_dir, self.hparams.height, self.hparams.width, mask_generator)
        dataloader = DataLoader(dataset, batch_size=self.hparams.batch_size, shuffle=True, num_workers=self.hparams.num_workers)
        return dataloader
    
    @pl.data_loader
    def val_dataloader(self):
        mask_generator = MaskGenerator(self.hparams.mask_dir, self.hparams.height, self.hparams.width, invert_mask=self.hparams.invert_mask) 
        dataset = ImagesDataset(self.hparams.valid_dir, self.hparams.height, self.hparams.width, mask_generator)
        dataloader = DataLoader(dataset, batch_size=self.hparams.batch_size, shuffle=False, num_workers=self.hparams.num_workers)
        return dataloader

In [7]:
from pytorch_lightning import Trainer

In [8]:
model = ImageInpaintingSystem(hparams)

In [9]:
# ?Trainer

In [None]:
trainer = Trainer(
        gpus=1,
        train_percent_check=0.1, 
        val_check_interval=0.05,
        use_amp=False,
        default_save_path='test_logs2'
    )

trainer.fit(model)

gpu available: True, used: True
VISIBLE GPUS: 0
55116 masks found: ../../../Repos/image-inpainting/dataset/irregular_mask/irregular_mask/disocclusion_img_mask/
55116 masks found: ../../../Repos/image-inpainting/dataset/irregular_mask/irregular_mask/disocclusion_img_mask/
                               Name           Type Params
0                         pConvUNet      PConvUNet   32 M
1                pConvUNet.encoder1   PConvEncoder    9 K
2          pConvUNet.encoder1.pconv  PartialConv2d    9 K
3      pConvUNet.encoder1.batchnorm    BatchNorm2d  128  
4     pConvUNet.encoder1.activation           ReLU    0  
..                              ...            ...    ...
99   vgg16extractor.max_pooling3.12         Conv2d  590 K
100  vgg16extractor.max_pooling3.13           ReLU    0  
101  vgg16extractor.max_pooling3.14         Conv2d  590 K
102  vgg16extractor.max_pooling3.15           ReLU    0  
103  vgg16extractor.max_pooling3.16      MaxPool2d    0  

[104 rows x 3 columns]


  dilated_mask = torch.tensor(dilated_mask> 0, dtype=torch.float, requires_grad=False).to(self.device)
  8%|▊         | 657/7828 [06:32<1:12:42,  1.64it/s, batch_nb=655, epoch=0, gpu=0, loss=10.194, train_loss=13.8, v_nb=0, valid_loss=12.3, valid_psnr=7.64]

In [None]:
# trainer = Trainer(gpus=1, fast_dev_run=True)
# trainer.fit(model)