In [1]:
import sys
sys.path.append('C:/Users/Josep/OneDrive/Desktop/Co-segmentation')
# Import necessary libraries
import imageio
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from skimage.transform import resize
from IPython.display import HTML
import warnings
from part_swap import load_checkpoints, make_video
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
import torch.nn.functional as F
import matplotlib.patches as mpatches

warnings.filterwarnings("ignore")

# Load source image and target video
source_image_path = 'C:/Users/Josep/OneDrive/Desktop/Co-segmentation/eye-lip0.jpg'
target_video_path = 'C:/Users/Josep/OneDrive/Desktop/Co-segmentation/filth.mp4'

source_image = imageio.imread(source_image_path)
target_video = imageio.mimread(target_video_path, memtest = False)

# Resize image and video to 256x256
source_image = resize(source_image, (256, 256))[..., :3]
target_video = [resize(frame, (256, 256))[..., :3] for frame in target_video]

# Function to display videos
def display(source, target, generated=None):
    fig = plt.figure(figsize=(8 + 4 * (generated is not None), 6))

    ims = []
    for i in range(len(target)):
        cols = [source]
        cols.append(target[i])
        if generated is not None:
            cols.append(generated[i])
        im = plt.imshow(np.concatenate(cols, axis= 1), animated=True)
        plt.axis('off')
        ims.append([im])

    ani = animation.ArtistAnimation(fig, ims, interval=50, repeat_delay=1000)
    plt.close()
    return ani
    
HTML(display(source_image, target_video).to_html5_video())

# Load checkpoints
config_path = 'C:/Users/Josep/OneDrive/Desktop/Co-segmentation/config/vox-256-sem-10segments.yaml'
checkpoint_path = 'C:/Users/Josep/OneDrive/Desktop/Co-segmentation/checkpoints/vox-10segments.pth.tar'

reconstruction_module, segmentation_module = load_checkpoints(config=config_path, checkpoint=checkpoint_path, blend_scale=1, cpu=False)

# Function to visualize segmentation
def visualize_segmentation(image, network, supervised=False, hard=True, colormap='gist_rainbow'):
    with torch.no_grad():
        inp = torch.tensor(image[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2).to(device)
        if supervised:
            inp = F.interpolate(inp, size=(512, 512))
            inp = (inp - network.mean.to(device)) / network.std.to(device)
            mask = torch.softmax(network(inp)[0], dim=1)
            mask = F.interpolate(mask, size=image.shape[:2])
        else:
            mask = network(inp)['segmentation']
            mask = F.interpolate(mask, size=image.shape[:2], mode='bilinear')
    
    if hard:
        mask = (torch.max(mask, dim=1, keepdim=True)[0] == mask).float()
    
    colormap = plt.get_cmap(colormap)
    num_segments = mask.shape[1]
    mask = mask.squeeze(0).permute(1, 2, 0).cpu().numpy()
    color_mask = 0
    patches = []
    for i in range(num_segments):
        if i != 0:
            color = np.array(colormap((i - 1) / (num_segments - 1)))[:3]
        else:
            color = np.array((0, 0, 0))
        patches.append(mpatches.Patch(color=color, label=str(i)))
        color_mask += mask[..., i:(i+1)] * color.reshape(1, 1, 3)
    
    fig, ax = plt.subplots(1, 2, figsize=(12,6))

    ax[0].imshow(color_mask)
    ax[1].imshow(0.3 * image + 0.7 * color_mask)
    ax[1].legend(handles=patches)
    ax[0].axis('off')
    ax[1].axis('off')

visualize_segmentation(source_image, segmentation_module, hard=True)
plt.show()

# Make the video with swapped parts
predictions = make_video(swap_index=[2], source_image=source_image, target_video=target_video,
                             segmentation_module=segmentation_module, reconstruction_module=reconstruction_module)
HTML(display(source_image, target_video, predictions).to_html5_video())

100%|████████████████████████████████████████████████████████████████████████████████| 389/389 [00:17<00:00, 22.45it/s]


In [2]:
from part_swap import load_checkpoints
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
reconstruction_module, segmentation_module = load_checkpoints(config='C:/Users/Josep/OneDrive/Desktop/Co-segmentation/config/vox-256-sem-10segments.yaml', 
                                               checkpoint='C:/Users/Josep/OneDrive/Desktop/Co-segmentation/checkpoints/vox-10segments.pth.tar',
                                               blend_scale=1)

In [3]:
# Changing eyes

source_image = imageio.imread('C:/Users/Josep/OneDrive/Desktop/Co-segmentation/eye8.jpg')
target_video = imageio.mimread('C:/Users/Josep/OneDrive/Desktop/Co-segmentation/ramsey.mp4',memtest =False)
source_image = resize(source_image, (256, 256))[..., :3]
target_video = [resize(frame, (256, 256))[..., :3] for frame in target_video]


predictions = make_video(swap_index=[7,9], source_image = source_image, target_video = target_video,
                             segmentation_module=segmentation_module, reconstruction_module=reconstruction_module)
HTML(display(source_image, target_video, predictions).to_html5_video())

100%|████████████████████████████████████████████████████████████████████████████████| 164/164 [00:06<00:00, 23.87it/s]


In [4]:
from part_swap import load_checkpoints

reconstruction_module, segmentation_module = load_checkpoints(config='C:/Users/Josep/OneDrive/Desktop/Co-segmentation/config/vox-256-sem-5segments.yaml', 
                                               checkpoint='C:/Users/Josep/OneDrive/Desktop/Co-segmentation/checkpoints/vox-5segments.pth.tar',
                                               blend_scale=1)

In [5]:
source_image = imageio.imread('C:/Users/Josep/OneDrive/Desktop/Co-segmentation/eye1.jpg')
source_image = resize(source_image, (256, 256))[..., :3]
visualize_segmentation(source_image, segmentation_module, hard=True)
plt.show()

In [23]:
# wo-man

from part_swap import make_video

source_image = imageio.imread('C:/Users/Josep/OneDrive/Desktop/Co-segmentation/eye-hair5.jpg')
target_video = imageio.mimread('C:/Users/Josep/OneDrive/Desktop/Co-segmentation/hair.mp4',memtest =False)
source_image = resize(source_image, (256, 256))[..., :3]
target_video = [resize(frame, (256, 256))[..., :3] for frame in target_video]

predictions = make_video(swap_index=[3, 4,5], source_image = source_image, target_video = target_video, use_source_segmentation=True,
                             segmentation_module=segmentation_module, reconstruction_module=reconstruction_module)
HTML(display(source_image, target_video, predictions).to_html5_video())

100%|████████████████████████████████████████████████████████████████████████████████| 137/137 [00:06<00:00, 21.85it/s]


In [8]:
from part_swap import load_checkpoints

reconstruction_module, segmentation_module = load_checkpoints(config='C:/Users/Josep/OneDrive/Desktop/Co-segmentation/config/vox-256-sem-5segments.yaml', 
                                               checkpoint='C:/Users/Josep/OneDrive/Desktop/Co-segmentation/checkpoints/vox-5segments.pth.tar',
                                               blend_scale=1)
# Beard
source_image = imageio.imread('C:/Users/Josep/OneDrive/Desktop/Co-segmentation/beard0.jpg')
target_video = imageio.mimread('C:/Users/Josep/OneDrive/Desktop/Co-segmentation/Elon.mp4',memtest =False)
source_image = resize(source_image, (256, 256))[..., :3]
target_video = [resize(frame, (256, 256))[..., :3] for frame in target_video]

predictions = make_video(swap_index=[1], source_image = source_image, target_video = target_video,
                             segmentation_module=segmentation_module, reconstruction_module=reconstruction_module)
HTML(display(source_image, target_video, predictions).to_html5_video())

100%|████████████████████████████████████████████████████████████████████████████████| 251/251 [00:10<00:00, 24.79it/s]
