<a href="https://colab.research.google.com/github/ThomasDougherty/stylish-people/blob/main/people_style_transfer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Arguments
## Initialization of arguments for the transfer. Change if needed.

In [14]:
is_video = True
inverse = False
objects = ["person"]
style_pth = "saved_models/mosaic.pth"
save_name = "crook_person_only"

In [15]:
#@title Select video file
from google.colab import files
files_list = files.upload()
# choose the file on your computer to upload it then
for key in files_list:
  file_name = key

Saving t_vid_0.mov to t_vid_0.mov


In [16]:
#@title Imports and download Detectron2
# import some common 
import json
import os
import re
import random
import subprocess

import cv2
from google.colab.patches import cv2_imshow
import numpy as np
from tqdm import tqdm

!pip install pyyaml==5.1
# install detectron2: (Colab has CUDA 10.1 + torch 1.7)
# See https://detectron2.readthedocs.io/tutorials/install.html for instructions
import torch
assert torch.__version__.startswith("1.7")
import torchvision
from torchvision import transforms

!pip install detectron2 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cu101/torch1.7/index.html &> /dev/null
# Some basic setup:
# Setup detectron2 logger
import detectron2
from detectron2.utils.logger import setup_logger
setup_logger()

# import some common detectron2 utilities
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog, DatasetCatalog



In [17]:
#@title Download neural-style-transfer pretrained weights
import zipfile

def unzip(source_filename, dest_dir):
    with zipfile.ZipFile(source_filename) as zf:
        zf.extractall(path=dest_dir)

try:
    from torch.utils.model_zoo import _download_url_to_file
except ImportError:
    try:
        from torch.hub import download_url_to_file as _download_url_to_file
    except ImportError:
        from torch.hub import _download_url_to_file

_download_url_to_file('https://www.dropbox.com/s/lrvwfehqdcxoza8/saved_models.zip?dl=1', 'saved_models.zip', None, True)
unzip('saved_models.zip', '.')

HBox(children=(FloatProgress(value=0.0, max=25020748.0), HTML(value='')))




In [18]:
#@title Create directories
!mkdir temp_imgs
!mkdir temp_imgs/org_imgs/
!mkdir temp_imgs/style_imgs/
!mkdir temp_imgs/masks_imgs/
!mkdir temp_imgs/final_imgs/

temp_dir = "temp_imgs/"
org_dir = "temp_imgs/org_imgs/"
style_dir = "temp_imgs/style_imgs/"
masks_dir = "temp_imgs/masks_imgs/"
final_dir = "temp_imgs/final_imgs/"

In [19]:
#@title Read frames from video
vidcap = cv2.VideoCapture(file_name)
success,image = vidcap.read()
frame_num = 0
img_list = []
base_names = []
while success:
    cv2.imwrite(org_dir + "frame_{0:05}.png".format(frame_num), image)     # save frame as JPEG file 
    success,image = vidcap.read() 
    base_names.append("frame_{0:05}.png".format(frame_num))       
    frame_num += 1            

base_names.sort()

In [20]:
#@title fast_neural_style: transformer_net
class TransformerNet(torch.nn.Module):
    def __init__(self):
        super(TransformerNet, self).__init__()
        # Initial convolution layers
        self.conv1 = ConvLayer(3, 32, kernel_size=9, stride=1)
        self.in1 = torch.nn.InstanceNorm2d(32, affine=True)
        self.conv2 = ConvLayer(32, 64, kernel_size=3, stride=2)
        self.in2 = torch.nn.InstanceNorm2d(64, affine=True)
        self.conv3 = ConvLayer(64, 128, kernel_size=3, stride=2)
        self.in3 = torch.nn.InstanceNorm2d(128, affine=True)
        # Residual layers
        self.res1 = ResidualBlock(128)
        self.res2 = ResidualBlock(128)
        self.res3 = ResidualBlock(128)
        self.res4 = ResidualBlock(128)
        self.res5 = ResidualBlock(128)
        # Upsampling Layers
        self.deconv1 = UpsampleConvLayer(128, 64, kernel_size=3, stride=1, upsample=2)
        self.in4 = torch.nn.InstanceNorm2d(64, affine=True)
        self.deconv2 = UpsampleConvLayer(64, 32, kernel_size=3, stride=1, upsample=2)
        self.in5 = torch.nn.InstanceNorm2d(32, affine=True)
        self.deconv3 = ConvLayer(32, 3, kernel_size=9, stride=1)
        # Non-linearities
        self.relu = torch.nn.ReLU()

    def forward(self, X):
        y = self.relu(self.in1(self.conv1(X)))
        y = self.relu(self.in2(self.conv2(y)))
        y = self.relu(self.in3(self.conv3(y)))
        y = self.res1(y)
        y = self.res2(y)
        y = self.res3(y)
        y = self.res4(y)
        y = self.res5(y)
        y = self.relu(self.in4(self.deconv1(y)))
        y = self.relu(self.in5(self.deconv2(y)))
        y = self.deconv3(y)
        return y


class ConvLayer(torch.nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride):
        super(ConvLayer, self).__init__()
        reflection_padding = kernel_size // 2
        self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding)
        self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride)

    def forward(self, x):
        out = self.reflection_pad(x)
        out = self.conv2d(out)
        return out


class ResidualBlock(torch.nn.Module):
    """ResidualBlock
    introduced in: https://arxiv.org/abs/1512.03385
    recommended architecture: http://torch.ch/blog/2016/02/04/resnets.html
    """

    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = ConvLayer(channels, channels, kernel_size=3, stride=1)
        self.in1 = torch.nn.InstanceNorm2d(channels, affine=True)
        self.conv2 = ConvLayer(channels, channels, kernel_size=3, stride=1)
        self.in2 = torch.nn.InstanceNorm2d(channels, affine=True)
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        residual = x
        out = self.relu(self.in1(self.conv1(x)))
        out = self.in2(self.conv2(out))
        out = out + residual
        return out


class UpsampleConvLayer(torch.nn.Module):
    """UpsampleConvLayer
    Upsamples the input and then does a convolution. This method gives better results
    compared to ConvTranspose2d.
    ref: http://distill.pub/2016/deconv-checkerboard/
    """

    def __init__(self, in_channels, out_channels, kernel_size, stride, upsample=None):
        super(UpsampleConvLayer, self).__init__()
        self.upsample = upsample
        reflection_padding = kernel_size // 2
        self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding)
        self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride)

    def forward(self, x):
        x_in = x
        if self.upsample:
            x_in = torch.nn.functional.interpolate(x_in, mode='nearest', scale_factor=self.upsample)
        out = self.reflection_pad(x_in)
        out = self.conv2d(out)
        return out

In [None]:
#@title Engine
def create_mask(outputs, num_class_list, img):
  class_pred = outputs["instances"].pred_classes.detach().cpu().numpy()
  masks_pred = outputs["instances"].pred_masks.detach().cpu().numpy()
  bin_mask = np.zeros(img.shape[:2])
  if len(class_pred) != 0:
    for cnt, c in enumerate(class_pred):
      if c in num_class_list:
        bin_mask += masks_pred[cnt]*1
  bin_mask[bin_mask > 0] = 1
  return bin_mask

# stylization setup
content_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.mul(255))
])
with torch.no_grad():
    style_model = TransformerNet()
    state_dict = torch.load(style_pth)
    # remove saved deprecated running_* keys in InstanceNorm from the checkpoint
    for k in list(state_dict.keys()):
        if re.search(r'in\d+\.running_(mean|var)$', k):
            del state_dict[k]
    style_model.load_state_dict(state_dict)
    style_model.to(device='cuda:0')
style_imgs = []

# Detectron2 setup
cfg = get_cfg()
# add project-specific config (e.g., TensorMask) here if you're not running a model in detectron2's core library
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml"))
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5  # set threshold for this model
# Find a model from detectron2's model zoo. You can use the https://dl.fbaipublicfiles... url as well
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml")
predictor = DefaultPredictor(cfg)

# get class list 
img = cv2.imread(org_dir + base_names[0])
v = Visualizer(img[:, :, ::-1], MetadataCatalog.get(cfg.DATASETS.TRAIN[0]), scale=1.2)
class_list = v.metadata.thing_classes

# convert object classes to index
num_class_list = []
for ob in objects:
  num_class_list.append(class_list.index(ob))

for base in tqdm(base_names):
  img = cv2.imread(org_dir + base)
  outputs = predictor(img)
  binary_mask = create_mask(outputs, num_class_list, img)

  cv2.imwrite(masks_dir + base, binary_mask)
  content_image = img
  content_image = content_transform(content_image)
  content_image = content_image.unsqueeze(0).to(device='cuda:0')
  with torch.no_grad():
    style_img = style_model(content_image)
    style_img = torch.squeeze(style_img)
    style_img = style_img.permute(1, 2, 0).detach().cpu().numpy()
  cv2.imwrite(style_dir + base, style_img) 

  6%|▌         | 33/579 [00:28<07:44,  1.18it/s]

In [None]:
for base in tqdm(base_names):
    frame_img = cv2.imread(org_dir + base)
    style_img = cv2.imread(style_dir + base)
    mask_img = cv2.imread(masks_dir + base)
    mask_img = mask_img[:,:,0]
    if inverse:
      frame_img[mask_img==0] = 0
      style_img[mask_img==1] = 0
    else:
      frame_img[mask_img==1] = 0
      style_img[mask_img==0] = 0
    final_img = frame_img + style_img
    cv2.imwrite(final_dir + base, final_img)

In [None]:
height, width, layers = img.shape
size = (width,height)

fps = vidcap.get(cv2.CAP_PROP_FPS)
out = cv2.VideoWriter(temp_dir + 'vid_nosound.mp4',  #Provide a file to write the video to
                      cv2.VideoWriter_fourcc(*'DIVX'),
                      round(fps),                                        
                      size)
for base in tqdm(base_names):
  img = cv2.imread(final_dir + base)
  out.write(img)
out.release()

In [None]:
command = "ffmpeg -i temp_imgs/final_imgs/frame_%05d.png -c:v libx264 -vf fps=" + str(round(fps)) + " " + temp_dir + 'vid_nosound.mp4'
subprocess.call(command, shell=True)


In [None]:
command = "ffmpeg -i " + file_name + " -ab 160k -ac 2 -ar 44100 -vn " + temp_dir + "audio.wav"
subprocess.call(command, shell=True)

command = "ffmpeg -i " + temp_dir + 'vid_nosound.mp4' + " -i " + temp_dir + "audio.wav -c:v copy -c:a aac " + temp_dir + "vid_b_conv.mp4"
subprocess.call(command, shell=True)

if save_name is None:
  save_name = file_name + "style"
command = "ffmpeg -i " + temp_dir + "vid_b_conv.mp4 -vcodec libx264 -profile:v main -level 3.1 -preset medium -crf 23 -x264-params ref=4 -acodec copy -movflags +faststart  " + save_name + ".mp4"
subprocess.call(command, shell=True)



In [None]:
! rm -r sample_data
! rm -r saved_models
! rm -r temp_imgs
! rm -r saved_models.zip