In [1]:
import os
os.chdir("/home/yhding/Repo/Imp_NeuAvatar/")
## the notebook requests ~20GiB GPU mem
os.environ['CUDA_VISIBLE_DEVICES']='13'

import torch
import numpy as np

from nha.data.real import RealDataModule, tracking_results_2_data_batch
from nha.models.nha_optimizer import NHAOptimizer
from torch.utils.data.dataloader import DataLoader
from tqdm import tqdm
from nha.util.general import *
from nha.util.render import create_intrinsics_matrix

import matplotlib.pyplot as plt
from pathlib import Path


In [2]:
## Generate a set of frames where the driver's exp/pose are transferred to the target

# both ckpt
driving_ckpt = 'ckpts_and_data/nha/person_0000.ckpt'
target_ckpt = 'ckpts_and_data/nha/person_0004.ckpt'
# both tracking results
driving_tracking_results = 'ckpts_and_data/tracking/person_0000.npz'
target_tracking_results = 'ckpts_and_data/tracking/person_0004.npz'
# alignment
neutral_driving_frame = 0
neutral_target_frame = 0
outpath = '.'




In [3]:
@torch.no_grad()
def reenact_avatar(target_model: NHAOptimizer, driving_model: NHAOptimizer, target_tracking_results: dict,
                   driving_tracking_results: dict, outpath: Path,
                   neutral_driving_frame=0, neutral_target_frame=0,
                   batch_size=3, plot=False):
    base_drive_sample = dict_2_device(tracking_results_2_data_batch(driving_tracking_results, [neutral_driving_frame]),
                                      driving_model.device)
    base_target_sample = dict_2_device(tracking_results_2_data_batch(target_tracking_results, [neutral_target_frame]),
                                       target_model.device)

    base_drive_params = driving_model._create_flame_param_batch(base_drive_sample)
    base_target_params = target_model._create_flame_param_batch(base_target_sample)

    tmp_dir_pred = Path("/tmp/scene_reenactment_pred")
    tmp_dir_drive = Path("/tmp/scene_reenactment_drive")
    os.makedirs(tmp_dir_drive, exist_ok=True)
    os.makedirs(tmp_dir_pred, exist_ok=True)
#     os.makedirs(outpath.parent, exist_ok=True)
    os.makedirs(outpath, exist_ok=True)
    os.system(f"rm -r {tmp_dir_drive}/*")
    os.system(f"rm -r {tmp_dir_pred}/*")
    frameid2imgname = lambda x: f"{x:04d}.png"

    for idcs in tqdm(torch.split(torch.from_numpy(driving_tracking_results["frame"]), batch_size)):
        batch = dict_2_device(tracking_results_2_data_batch(driving_tracking_results, idcs.tolist()), target_model.device)

        rgb_driving = driving_model.forward(batch, symmetric_rgb_range=False)[:, :3].clamp(0,1)

        # change camera parameters
        batch["cam_intrinsic"] = base_target_sample["cam_intrinsic"].expand_as(batch["cam_intrinsic"])
        batch["cam_extrinsic"] = base_target_sample["cam_extrinsic"].expand_as(batch["cam_extrinsic"])

        rgb_target = target_model.predict_reenaction(batch, driving_model=driving_model,
                                                     base_target_params=base_target_params,
                                                     base_driving_params=base_drive_params)

        for frame_idx, pred, drive in zip(batch["frame"], rgb_target, rgb_driving):
            save_torch_img(pred, tmp_dir_pred / frameid2imgname(frame_idx.cpu().item()))
            save_torch_img(drive, tmp_dir_drive / frameid2imgname(frame_idx.cpu().item()))

            if plot:
                fig, axes = plt.subplots(ncols=2)
                axes[0].imshow(drive.cpu().permute(1, 2, 0))
                axes[1].imshow(pred.cpu().permute(1, 2, 0))
                plt.show()
                plt.close()

In [4]:
driving_model = NHAOptimizer.load_from_checkpoint(driving_ckpt).cuda().eval()
target_model = NHAOptimizer.load_from_checkpoint(target_ckpt).cuda().eval()


  self._edges_packed = torch.stack([u // V, u % V], dim=1)


In [6]:
reenact_avatar(target_model, driving_model,
                   target_tracking_results=np.load(target_tracking_results),
                   driving_tracking_results=np.load(driving_tracking_results),
                   outpath=outpath,
                   neutral_driving_frame=neutral_driving_frame,
                   neutral_target_frame=neutral_target_frame)

100%|████████████████████████████████████████████████████████████████████████████████████████████████| 499/499 [05:35<00:00,  1.49it/s]
ffmpeg version 4.3 Copyright (c) 2000-2020 the FFmpeg developers
  built with gcc 7.3.0 (crosstool-NG 1.23.0.449-a04d0)
  configuration: --prefix=/home/yhding/miniconda3/envs/avatar --cc=/opt/conda/conda-bld/ffmpeg_1597178665428/_build_env/bin/x86_64-conda_cos6-linux-gnu-cc --disable-doc --disable-openssl --enable-avresample --enable-gnutls --enable-hardcoded-tables --enable-libfreetype --enable-libopenh264 --enable-pic --enable-pthreads --enable-shared --disable-static --enable-version3 --enable-zlib --enable-libmp3lame
  libavutil      56. 51.100 / 56. 51.100
  libavcodec     58. 91.100 / 58. 91.100
  libavformat    58. 45.100 / 58. 45.100
  libavdevice    58. 10.100 / 58. 10.100
  libavfilter     7. 85.100 /  7. 85.100
  libavresample   4.  0.  0 /  4.  0.  0
  libswscale      5.  7.100 /  5.  7.100
  libswresample   3.  7.100 /  3.  7.100
Unrecogni

In [8]:
tmp_dir_pred = Path("/tmp/scene_reenactment_pred")
tmp_dir_drive = Path("/tmp/scene_reenactment_drive")

os.system(f"ffmpeg -pattern_type glob -i {tmp_dir_pred}/'*.png' -c:v libx264 -profile:v high "
          f"-level:v 4.0 -pix_fmt yuv420p -codec:a aac {outpath}/Reenactment_pred.mp4 -y")

os.system(f"ffmpeg -pattern_type glob -i {tmp_dir_drive}/'*.png' -c:v libx264 -profile:v high "
          f"-level:v 4.0 -pix_fmt yuv420p -crf 22 -codec:a aac {outpath}/Reenactment_drive.mp4 -y")

os.system(f"ffmpeg  -i {outpath}/Reenactment_drive.mp4 -i {outpath}/Reenactment_pred.mp4 "
          f"-filter_complex hstack=inputs=2 {outpath}/Reenactment_combined.mp4 -y")

ffmpeg version 4.2.2 Copyright (c) 2000-2019 the FFmpeg developers
  built with gcc 7.3.0 (crosstool-NG 1.23.0.449-a04d0)
  configuration: --prefix=/home/yhding/miniconda3/envs/avatar --cc=/tmp/build/80754af9/ffmpeg_1587154242452/_build_env/bin/x86_64-conda_cos6-linux-gnu-cc --disable-doc --enable-avresample --enable-gmp --enable-hardcoded-tables --enable-libfreetype --enable-libvpx --enable-pthreads --enable-libopus --enable-postproc --enable-pic --enable-pthreads --enable-shared --enable-static --enable-version3 --enable-zlib --enable-libmp3lame --disable-nonfree --enable-gpl --enable-gnutls --disable-openssl --enable-libopenh264 --enable-libx264
  libavutil      56. 31.100 / 56. 31.100
  libavcodec     58. 54.100 / 58. 54.100
  libavformat    58. 29.100 / 58. 29.100
  libavdevice    58.  8.100 / 58.  8.100
  libavfilter     7. 57.100 /  7. 57.100
  libavresample   4.  0.  0 /  4.  0.  0
  libswscale      5.  5.100 /  5.  5.100
  libswresample   3.  5.100 /  3.  5.100
  libpostproc  

frame= 1496 fps=365 q=-1.0 Lsize=    3375kB time=00:00:59.72 bitrate= 463.0kbits/s speed=14.6x    
video:3357kB audio:0kB subtitle:0kB other streams:0kB global headers:0kB muxing overhead: 0.547366%
[libx264 @ 0x55aa72d89a00] frame I:6     Avg QP:17.04  size: 20554
[libx264 @ 0x55aa72d89a00] frame P:377   Avg QP:23.51  size:  4844
[libx264 @ 0x55aa72d89a00] frame B:1113  Avg QP:29.39  size:  1336
[libx264 @ 0x55aa72d89a00] consecutive B-frames:  0.7%  0.1%  0.2% 98.9%
[libx264 @ 0x55aa72d89a00] mb I  I16..4: 19.7% 69.9% 10.4%
[libx264 @ 0x55aa72d89a00] mb P  I16..4:  0.3%  1.1%  0.2%  P16..4: 23.5% 10.9%  4.6%  0.0%  0.0%    skip:59.4%
[libx264 @ 0x55aa72d89a00] mb B  I16..4:  0.0%  0.0%  0.0%  B16..8: 22.7%  3.9%  0.5%  direct: 0.3%  skip:72.6%  L0:42.5% L1:55.4% BI: 2.1%
[libx264 @ 0x55aa72d89a00] 8x8 transform intra:67.5% inter:66.6%
[libx264 @ 0x55aa72d89a00] coded y,uvDC,uvAC intra: 36.0% 42.7% 22.4% inter: 4.1% 4.2% 0.1%
[libx264 @ 0x55aa72d89a00] i16 v,h,dc,p: 66% 13% 15%  5%
[l

0

In [9]:
from IPython.display import HTML
from base64 import b64encode

def play(filename):
    html = ''
    video = open(filename,'rb').read()
    src = 'data:video/mp4;base64,' + b64encode(video).decode()
    html += '<video width=1000 controls autoplay loop><source src="%s" type="video/mp4"></video>' % src 
    return HTML(html)

play('Reenactment_combined.mp4')
