# Install

In [None]:
!git clone https://github.com/Tismoney/gan-compression.git

In [None]:
import os
os.chdir('gan-compression')
print(os.getcwd())

In [None]:
!pip install -r requirements.txt
!pip install --upgrade git+https://github.com/mit-han-lab/torchprofile.git

In [None]:
import pickle
import time
import tqdm

import numpy as np
import torch

# Pretrained Models
Download the original model and our compressed of edges2shoes dataset.

In [None]:
!python3 scripts/download_fomm.py 
print('Download the pretrained models successfully!!!')

# Models

In [None]:
from models import create_model

Get the options for the loading models

In [None]:
!python get_test_opt.py --dataroot database/fomm_100k \
--results_dir results-pretrained/2pix2pix/fomm \
--ngf 96 --netG mobile_resnet_9blocks \
--restore_G_path pretrained/fomm/full/latest_net_G.pth \
--real_stat_path real_stat/fomm_B.npz \
--input_nc 6 --output_nc 3 \
--use_motion \
--need_profile --num_test 0

with open('opts/opt_full.pkl', 'rb') as f:
    opt = pickle.load(f)
full_model = create_model(opt, verbose=False)
full_model.setup(opt, verbose=False)

In [None]:
!python get_test_opt.py --dataroot database/fomm_100k \
--results_dir results-pretrained/2pix2pix/fomm \
--ngf 48 --netG mobile_resnet_9blocks \
--restore_G_path pretrained/fomm/full/latest_net_G_distilled.pth \
--real_stat_path real_stat/fomm_B.npz \
--input_nc 6 --output_nc 3 \
--use_motion \
--need_profile --num_test 0

with open('opts/opt_full.pkl', 'rb') as f:
    opt = pickle.load(f)
distill_model = create_model(opt, verbose=False)
distill_model.setup(opt, verbose=False)

In [None]:
from IPython.display import display
from PIL import Image
import torchvision.transforms as transforms

from utils.util import save_image, tensor2im

import imageio
from skimage.transform import resize
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from tqdm import tqdm

In [None]:
transform_list = [
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
]
transform = transforms.Compose(transform_list)

Chooose image and video

In [None]:
source_image = Image.open('./content/imgs/got-03.png')
source_image = Image.fromarray(np.array(source_image)[..., :3])
driving_video = imageio.mimread('./content/vids/10.mp4')
driving_video = [Image.fromarray(np.array(frame)[..., :3]) for frame in driving_video]

In [None]:
@torch.no_grad()
def animate(source_image, driving_video, full_model, compressed):
    source = source_image
    source = transform(source)
    source_image = tensor2im(source)
    
    result = []
    
    for i in tqdm(range(len(driving_video))):
        drive = driving_video[i]
        drive = transform(drive)
        
        stacked_input = torch.cat([source, drive], dim=0).unsqueeze(0)
        
        output_full_model = full_model(stacked_input).squeeze().cpu()
        output_compressed = compressed(stacked_input).squeeze().cpu()
        
        stacked_output = np.hstack([
            source_image,
            tensor2im(drive),
            tensor2im(output_full_model),
            tensor2im(output_compressed)
        ])
        result.append(stacked_output)
    return result

In [None]:
result = animate(source_image, driving_video, full_model.netG, distill_model.netG)

In [None]:
#Save new video
writer = imageio.get_writer('test.mp4', fps=20)
for frame in result:
    writer.append_data(frame)
writer.close()

In [None]:
Image.fromarray(result[0])