In [1]:
%matplotlib notebook
import os
import torch
import numpy as np
from torch import nn
from torch.nn import functional as F
from PIL import Image
import torchvision as tv
from matplotlib import pyplot as plt

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
# Residual block
class res_block(nn.Module):
    def __init__(self, C):
        super(res_block, self).__init__()
        self.conv1 = nn.Conv2d(C, C, 3, padding=1)
        self.conv2 = nn.Conv2d(C, C, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(C)
        self.bn2 = nn.BatchNorm2d(C)
        
    def forward(self, x):
        h = self.conv1(x)
        h = self.bn1(h)
        h = F.relu(h)
        h = self.conv2(h)
        h = self.bn2(h)
        h = h + x
        y = F.relu(h)
        
        return y
    
# Non-Residual block
class non_res_block(nn.Module):
    def __init__(self, C_in, C_out, k=3, mode='normal'):
        super(non_res_block, self).__init__()
        self.conv = nn.Conv2d(C_in, C_out, k, padding=np.int((k-1)/2))
        self.bn = nn.BatchNorm2d(C_out)
        self.mode = mode
        
    def forward(self, x):
        h = self.conv(x)
        h = self.bn(h)
        if self.mode == 'normal':
            y = F.relu(h)
        else:
            y = torch.tanh(h)
                
        return y
    
# Transformation network
class TransNet(nn.Module):
    def __init__(self, D=5):
        super(TransNet, self).__init__()
        self.D = D
        self.up1 = non_res_block(3, 32, k=9)
        self.up2 = non_res_block(32, 64)
        self.up3 = non_res_block(64, 128)
        self.res = nn.ModuleList()
        for ii in range(self.D):
            self.res.append(non_res_block(128, 128))
        self.dn1 = non_res_block(128, 64)
        self.dn2 = non_res_block(64, 32)
        self.dn3 = non_res_block(32, 3, k=9, mode='last')
        
        for param in self.parameters():
            param.requires_grad = True
        
    def forward(self, x): 
        h = self.up1(x)
        h = self.up2(h)
        h = self.up3(h)
        
        for ii in range(self.D):
            h = self.res[ii](h)
            
        h = self.dn1(h)
        h = self.dn2(h)
        y = self.dn3(h)
        
        return y
    
def image_loader(img_path):
    img = Image.open(img_path).convert('RGB')
    transform = tv.transforms.Compose([
        tv.transforms.Resize([512, 512]),
        tv.transforms.ToTensor(),
        tv.transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])
        ])
    x = transform(img).view([1, 3, 512, 512])
    return x

def imshow(image, ax=plt):
    image = image.to('cpu').numpy()
    image = np.moveaxis(image, [0, 1, 2], [2, 0, 1])
    image = (image + 1) / 2
    image[image < 0] = 0
    image[image > 1] = 1
    h = ax.imshow(image)
    ax.axis('off')
    return h

In [4]:
model_dir = os.getcwd() + '/Models/'
style = 'scream'
ratio = '4e3'
p_fn = model_dir+style+'.txt'
c_fn = model_dir+style+'.pth.tar'

if os.path.isfile(c_fn):
    checkpoint = torch.load(c_fn)
    transnet = TransNet().to(device)
    transnet.load_state_dict(checkpoint)
    del checkpoint
    print('Trained model loaded')
else:
    print('Model not found!')

Trained model loaded


In [7]:
import cv2
import time
import os
start=time.time()
vidcap = cv2.VideoCapture(os.getcwd()+'/video/funny.mp4')
if not os.path.exists('temp_frame'):
    os.makedirs('temp_frame')
success,image = vidcap.read()
count = 0
while success:    

    image=Image.fromarray(image)
    transform = tv.transforms.Compose([
        tv.transforms.Resize([512, 512]),
        tv.transforms.ToTensor(),
        tv.transforms.Normalize([0,0,0], [1,1,1])
        ])
    image = transform(image).view([1, 3, 512, 512])
    
    content_img = image.view([1, 3, 512, 512]).to(device)
    content_img[:,0,:,:],content_img[:,2,:,:]=content_img[:,2,:,:],content_img[:,0,:,:]
    content_img[:,2,:,:],content_img[:,1,:,:]=content_img[:,1,:,:],content_img[:,2,:,:]

#     content_img = image_loader('frog/frame'+str(count)+'.jpg').view([1, 3, 512, 512]).to(device)
    transfered_img = transnet(content_img)

    
    tv.utils.save_image(transfered_img.view([3, 512, 512]),'temp_frame/transfer'+str(count)+'.jpg',normalize=True)
    
    success,image = vidcap.read()
    print('Read a new frame: ', success)
    count += 1
    print(count)
end=time.time()
print(end-start)

Read a new frame:  True
1
Read a new frame:  True
2
Read a new frame:  True
3
Read a new frame:  True
4
Read a new frame:  True
5
Read a new frame:  True
6
Read a new frame:  True
7
Read a new frame:  True
8
Read a new frame:  True
9
Read a new frame:  True
10
Read a new frame:  True
11
Read a new frame:  True
12
Read a new frame:  True
13
Read a new frame:  True
14
Read a new frame:  True
15
Read a new frame:  True
16
Read a new frame:  True
17
Read a new frame:  True
18
Read a new frame:  True
19
Read a new frame:  True
20
Read a new frame:  True
21
Read a new frame:  True
22
Read a new frame:  True
23
Read a new frame:  True
24
Read a new frame:  True
25
Read a new frame:  True
26
Read a new frame:  True
27
Read a new frame:  True
28
Read a new frame:  True
29
Read a new frame:  True
30
Read a new frame:  True
31
Read a new frame:  True
32
Read a new frame:  True
33
Read a new frame:  True
34
Read a new frame:  True
35
Read a new frame:  True
36
Read a new frame:  True
37
Read a new

In [8]:
import cv2
import numpy as np
import glob
import natsort 

name = sorted(glob.glob('temp_frame/transfer*.jpg'))
# print(natsort.natsorted(name,reverse=False))
name_sorted=natsort.natsorted(name,reverse=False)

In [9]:
import cv2
import numpy as np
import glob

img_array = []
for filename in name_sorted:
    img = cv2.imread(filename)
    height, width, layers = img.shape
    size = (width,height)
    img_array.append(img)
    
# print(img_array[0].shape)


out = cv2.VideoWriter('project.mp4',cv2.VideoWriter_fourcc(*'DIVX'), 30, size)
for i in range(len(img_array)):
    out.write(img_array[i])
out.release()

In [10]:
videoooo = sorted(glob.glob('*.mp4'))

In [11]:
videoooo

['project.mp4']