In [None]:
import torch
from networks import networks
from utils import masking, show, renormalize, compositions, imutil
import matplotlib.pyplot as plt
import numpy as np
import cv2
from copy import deepcopy

def tensor2image(im_tensor):
    '''
    Convert an image batch (size=1) to an image with values being within [0, 1]
    im_tensor - shape: [1, 3, h, w]; value range: [-1, 1]
    '''
    im_array = im_tensor.data[0].detach().cpu().numpy()
    im_array = np.transpose(im_array, (1,2,0))
    im_array += 1
    im_array /=2

    return im_array

def show_image(im):
    plt.imshow(im)
    plt.show()
    
if __name__ == '__main__':
    # Read the fake image
    fake_img_path = '/home/uss00067/Datasets/FDC/video_1/angry/level_1/024/000/3/overlaid.jpg'
    # Read the gt image
    real_img_path = '/home/uss00067/Datasets/FDC/video_1/angry/level_1/024/000/3/015.jpg'
    fake_img_bgr = cv2.imread(fake_img_path)
    real_img_bgr = cv2.imread(real_img_path)

    print('fake img shape: ', fake_img_bgr.shape)
    print('real img shape: ', real_img_bgr.shape)
    
    # Convert the BGR image to RGB image
    fake_img_rgb = cv2.cvtColor(fake_img_bgr, cv2.COLOR_BGR2RGB)
    real_img_rgb = cv2.cvtColor(real_img_bgr, cv2.COLOR_BGR2RGB)
    
    original_fake_img_rgb = deepcopy(fake_img_rgb)
    
    gtype = 'stylegan'
    domain = 'ffhq'
    nets = networks.define_nets(gtype, domain)
    # proggan: celebahq, livingroom, church
    # stylegan: ffhq, church, car, horse
    
    compositer = compositions.get_compositer(domain)(nets)
    
    rng = np.random.RandomState(0)
    indices = rng.choice(compositer.total_samples, len(compositer.ordered_labels))
    
    if torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        raise RuntimeError('cuda is not available!')
    
    with torch.no_grad():
        # Composition over original data
        composite_data = compositer(indices)
        
        # with batch_size = 1
        input_data = composite_data.composite_image
        output_data = composite_data.inverted_RGBM
        
        # Convert tensor batch to individual image
        input_data_image = tensor2image(input_data)
        show_image(input_data_image)
        
        output_data_image = tensor2image(output_data)
        show_image(output_data_image)
        
        # Original mask value
        # 0.5: missing pixel region
        # 1: contents extraction region
        composite_mask = composite_data.composite_mask[0]
        composite_mask = composite_mask.unsqueeze(0)
        composite_mask_image = tensor2image(composite_mask)
        show_image(composite_mask_image)
        
        # Normalize the image from [0, 255] to [-1, 1]
        # All values have to be converted to float before computation
        normalize_factor = int(np.max(fake_img_rgb)/2.)
        fake_img_rgb = fake_img_rgb.astype(float)
        normalize_fake_img_rgb = (fake_img_rgb - normalize_factor)/normalize_factor
        fake_img_rgb_temp = np.transpose(normalize_fake_img_rgb, (2,0,1))
        
        # Convert numpy array to tensor on GPU
        fake_img_rgb_tensor = torch.tensor(fake_img_rgb_temp).unsqueeze(0).float()
        fake_img_rgb_tensor = fake_img_rgb_tensor.to(device)
        
        # Generate mask with all values being 1 (no 0.5: no missing pixels)
        fake_img_mask = np.ones((fake_img_rgb.shape[0], fake_img_rgb.shape[1]))
        fake_img_mask_tensor = torch.tensor(fake_img_mask).unsqueeze(0)
        fake_img_mask_tensor = fake_img_mask_tensor.unsqueeze(0).float()
        fake_img_mask_tensor = fake_img_mask_tensor.to(device)
        
        # Reconstruct the image based on the fake image tensor and fake mask tensor
        # batch size of rec_data: 1
        rec_data = compositer.nets_RGBM.invert(fake_img_rgb_tensor, fake_img_mask_tensor)
        rec_data_image = tensor2image(rec_data)
        show_image(original_fake_img_rgb) # Show the stitched image
        show_image(rec_data_image) # Show the reconstructed image
        show_image(real_img_rgb) # Show the gt image