**Config**

In [24]:
import torch
import os

# edit the config
device = torch.device('cuda:2')
torch.cuda.set_device(2)
dataset_name = 'lsa' # ['vox', 'taichi', 'ted', 'mgif']
# source_image_path = '/DataSet/lsa64_cut/test/057_001_005/057_001_005_000.png'
# source_image_path = '/DataSet/LSA64/256/test/057_001_005/057_001_005_000.png'
# source_image_path = '/DataSet/PHOENIX-2014-T/features/fullFrame-128x128px/test/11August_2009_Tuesday_tagesschau-4355/images0001.png'
# source_image_path = '/DataSet/WLASL2000_128x128/test/03089/0001.png'
# source_image_path = '/disk1/dataset/PHOENIX-2014-T-release-v3/PHOENIX-2014-T/features/fullFrame-128x128px/test/01May_2010_Saturday_tagesschau-7195/images0001.png'
# # csl
# source_image_path = '/disk1/dataset/CSL-Daily_128x128px/test/S000054_P0008_T00/000017.jpg'
# driving_folder_path = '/disk1/dataset/CSL-Daily_128x128px/test/S000153_P0000_T00'
# driving_folder_path = '/disk1/dataset/CSL-Daily_128x128px/test/S000020_P0004_T00'
# driving_folder_path = '/disk1/dataset/CSL-Daily_128x128px/test/S005971_P0006_T00'
# driving_folder_path = '/disk1/dataset/CSL-Daily_128x128px/test/S007121_P0007_T00'

# p14t
source_image_path = '/disk1/dataset/PHOENIX-2014-T-release-v3/PHOENIX-2014-T/features/fullFrame-128x128px/test/05May_2011_Thursday_heute-3747/images0008.png'
driving_folder_path = '/disk1/dataset/PHOENIX-2014-T-release-v3/PHOENIX-2014-T/features/fullFrame-128x128px/test/29April_2010_Thursday_heute-8626'
# driving_folder_path = '/disk1/dataset/PHOENIX-2014-T-release-v3/PHOENIX-2014-T/features/fullFrame-128x128px/test/01April_2010_Thursday_tagesschau-4330'
# driving_folder_path = '/disk1/dataset/PHOENIX-2014-T-release-v3/PHOENIX-2014-T/features/fullFrame-128x128px/test/01December_2011_Thursday_tagesschau-3473'
# driving_folder_path = '/disk1/dataset/PHOENIX-2014-T-release-v3/PHOENIX-2014-T/features/fullFrame-128x128px/test/02December_2010_Thursday_tagesschau-3631'

dir_name = os.path.basename(driving_folder_path)
output_video_path = './'+dir_name+'.mp4'
config_path = 'config/wlasl.yaml'
checkpoint_path = '/disk1/tongkai/mraa/log/phoenix.pth.tar'
predict_mode = 'standard' # ['standard', 'relative', 'avd']
find_best_frame = False # when use the relative mode to animate a face, use 'find_best_frame=True' can get better quality result

pixel = 128 # for vox, taichi and mgif, the resolution is 256*256
if(dataset_name == 'ted'): # for ted, the resolution is 384*384
    pixel = 384


**Make Driving Video**

In [25]:
# import os
# import imageio

# folder_path = driving_folder_path
# image_files = [os.path.join(folder_path, file) for file in sorted(os.listdir(folder_path)) if file.endswith('.png')]
# writer = imageio.get_writer(driving_video_path, fps=30)

# for image_file in image_files:
#     image = imageio.imread(image_file)
#     writer.append_data(image)

# writer.close()

**Read image and video**

In [26]:
import imageio
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from skimage.transform import resize
from IPython.display import HTML
import warnings
warnings.filterwarnings("ignore")

source_image = imageio.imread(source_image_path)

image_files = [os.path.join(driving_folder_path, file) for file in sorted(os.listdir(driving_folder_path)) if (file.endswith('.png') or file.endswith('.jpg'))]
fps = 30
driving_video = []

for image_file in image_files:
    image = imageio.imread(image_file)
    driving_video.append(image)

source_image = resize(source_image, (pixel, pixel))[..., :3]

driving_video = [resize(frame, (pixel, pixel))[..., :3] for frame in driving_video]

def display(source, driving, generated=None):
    fig = plt.figure(figsize=(8 + 4 * (generated is not None), 6))

    ims = []
    for i in range(len(driving)):
        cols = [source]
        cols.append(driving[i])
        if generated is not None:
            cols.append(generated[i])
        im = plt.imshow(np.concatenate(cols, axis=1), animated=True)
        plt.axis('off')
        ims.append([im])

    ani = animation.ArtistAnimation(fig, ims, interval=50, repeat_delay=1000)
    plt.close()
    return ani

**Create a model and load checkpoints**

In [27]:
from demo import load_checkpoints
inpainting, dense_motion_network, bg_predictor = load_checkpoints(config_path = config_path, checkpoint_path = checkpoint_path, device = device)

**Perform image animation**

In [28]:
from demo import make_animation
from skimage import img_as_ubyte

predictions = make_animation(source_image, driving_video, inpainting, dense_motion_network, bg_predictor, device = device, mode = predict_mode)

#save resulting video
imageio.mimsave(output_video_path, [img_as_ubyte(frame) for frame in predictions], fps=fps)

# HTML(display(source_image, driving_video, predictions).to_html5_video())

  0%|          | 0/62 [00:00<?, ?it/s]

100%|██████████| 62/62 [00:00<00:00, 83.52it/s]


In [29]:
# from demo import inference_img
# import imageio
# from skimage.transform import resize
# import numpy as np
# import torch


# torch.cuda.set_device(5)
# device = torch.device('cuda:5')

# source_image_path = '/disk1/dataset/WLASL2000_128x128/test/69494/0040.png'
# driving_image_path = '/disk1/dataset/WLASL2000_128x128/test/01248/0006.png'
# config_path = 'config/wlasl.yaml'
# checkpoint_path = 'log/wlasl.pth.tar'
# out_img_path = driving_image_path.split('/')[-2]+'_'+driving_image_path.split('/')[-1]


# source_image = imageio.imread(source_image_path)
# source_image = resize(source_image, (128, 128))[..., :3]

# driving_image = imageio.imread(driving_image_path)
# driving_image = resize(driving_image, (128, 128))[..., :3]

# from demo import load_checkpoints
# inpainting, dense_motion_network, bg_predictor = load_checkpoints(config_path = config_path, checkpoint_path = checkpoint_path, device = device)

# out = inference_img(source_image, driving_image, inpainting, dense_motion_network, bg_predictor, device)
# prediction = out['prediction'][0].cpu().numpy().transpose(1, 2, 0)

# imageio.imsave(out_img_path, (255 * prediction).astype(np.uint8))