# Colab-AnimeInterp

Original-repo: [lisiyao21/AnimeInterp](https://github.com/lisiyao21/AnimeInterp)

My fork: [styler00dollar/Colab-AnimeInterp](https://github.com/styler00dollar/Colab-AnimeInterp)

Don't use cpu, unless you are fine with like 75 seconds per image.

In [None]:
!nvidia-smi

In [None]:
#@title Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')
print('Google Drive connected.')

In [None]:
#@title install
%cd /content/
!git clone https://github.com/lisiyao21/AnimeInterp
!wget "https://www.dropbox.com/s/oc8juclx1775qib/anime_interp_full.ckpt?dl=1" -O /content/anime_interp_full.ckpt
!curl https://colab.chainer.org/install | sh -

In [None]:
#@title get video / copy it into the cainapp folder
%cd /content
# either copy video from drive 
#!cp /path/ /path/

#or get one with wget / youtube-dl
# wget
#!wget URL

# youtube-dl
!sudo rm -rf test.mp4
!wget -O - https://yt-dl.org/latest/youtube-dl | sudo tee /usr/local/bin/youtube-dl > /dev/null
!sudo chmod a+x /usr/local/bin/youtube-dl
video_path = "/content/test.mp4"
!youtube-dl "https://www.youtube.com/watch?v=dQw4w9WgXcQ" --output {video_path}

In [None]:
# extract data
# adjust rescale value if needed, or remove it
!mkdir /content/data
input_path = "/content/test.mkv" #@param
%shell ffmpeg -i {input_path} -vf scale=848:480:flags=lanczos "/content/data/%05d.png"

``cupy`` is much faster, I don't recommend without ``cupy``.

In [None]:
#@title init.py (adding imports)
%%writefile /content/AnimeInterp/models/__init__.py
#from .AnimeInterp import AnimeInterp
from .AnimeInterp_no_cupy import AnimeInterpNoCupy
from .AnimeInterp import AnimeInterp

__all__ = [ 'AnimeInterpNoCupy', 'AnimeInterp' ]

In [None]:
#@title utils.py (F.grid_sample fix)
%%writefile /content/AnimeInterp/models/rfr_model/utils.py
import torch
import torch.nn.functional as F
import numpy as np
from scipy import interpolate


class InputPadder:
    """ Pads images such that dimensions are divisible by 8 """
    def __init__(self, dims):
        self.ht, self.wd = dims[-2:]
        pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8
        pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8
        self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht]

    def pad(self, *inputs):
        return [F.pad(x, self._pad, mode='replicate') for x in inputs]

    def unpad(self,x):
        ht, wd = x.shape[-2:]
        c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]]
        return x[..., c[0]:c[1], c[2]:c[3]]

def forward_interpolate(flow):
    flow = flow.detach().cpu().numpy()
    dx, dy = flow[0], flow[1]

    ht, wd = dx.shape
    x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht))

    x1 = x0 + dx
    y1 = y0 + dy
    
    x1 = x1.reshape(-1)
    y1 = y1.reshape(-1)
    dx = dx.reshape(-1)
    dy = dy.reshape(-1)

    valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht)
    x1 = x1[valid]
    y1 = y1[valid]
    dx = dx[valid]
    dy = dy[valid]

    flow_x = interpolate.griddata(
        (x1, y1), dx, (x0, y0), method='cubic', fill_value=0)

    flow_y = interpolate.griddata(
        (x1, y1), dy, (x0, y0), method='cubic', fill_value=0)

    flow = np.stack([flow_x, flow_y], axis=0)
    return torch.from_numpy(flow).float()


def bilinear_sampler(img, coords, mode='bilinear', mask=False):
    """ Wrapper for grid_sample, uses pixel coordinates """
    H, W = img.shape[-2:]
    xgrid, ygrid = coords.split([1,1], dim=-1)
    xgrid = 2*xgrid/(W-1) - 1
    ygrid = 2*ygrid/(H-1) - 1

    grid = torch.cat([xgrid, ygrid], dim=-1)
    img = F.grid_sample(img, grid, align_corners=True, mode="bilinear")
    if mask==True:
      mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
      return img, mask.float()
    return img



def coords_grid(batch, ht, wd):
    coords = torch.meshgrid(torch.arange(ht), torch.arange(wd))
    coords = torch.stack(coords[::-1], dim=0).float()
    return coords[None].repeat(batch, 1, 1, 1)


def upflow8(flow, mode='bilinear'):
    new_size = (8 * flow.shape[2], 8 * flow.shape[3])
    return  8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)


In [None]:
#@title rfr_new.py (F.grid_sample fix)
%%writefile /content/AnimeInterp/models/rfr_model/rfr_new.py
##################################################
#  RFR is implemented based on RAFT optical flow #
##################################################

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from .update import BasicUpdateBlock, SmallUpdateBlock
from .extractor import BasicEncoder, SmallEncoder
from .corr import CorrBlock
from .utils import bilinear_sampler, coords_grid, upflow8

try:
    autocast = torch.cuda.amp.autocast
except:
    # dummy autocast for PyTorch < 1.6
    class autocast:
        def __init__(self, enabled):
            pass
        def __enter__(self):
            pass
        def __exit__(self, *args):
            pass

def backwarp(img, flow):
    _, _, H, W = img.size()

    u = flow[:, 0, :, :]
    v = flow[:, 1, :, :]

    gridX, gridY = np.meshgrid(np.arange(W), np.arange(H))

    gridX = torch.tensor(gridX, requires_grad=False,).cuda()
    gridY = torch.tensor(gridY, requires_grad=False,).cuda()
    x = gridX.unsqueeze(0).expand_as(u).float() + u
    y = gridY.unsqueeze(0).expand_as(v).float() + v
    # range -1 to 1
    x = 2*(x/(W-1) - 0.5)
    y = 2*(y/(H-1) - 0.5)
    # stacking X and Y
    grid = torch.stack((x,y), dim=3)
    # Sample pixels using bilinear interpolation.
    imgOut = torch.nn.functional.grid_sample(img, grid, align_corners=True, mode="bilinear")

    return imgOut
class ErrorAttention(nn.Module):
    """A three-layer network for predicting mask"""
    def __init__(self, input, output):
        super(ErrorAttention, self).__init__()
        self.conv1 = nn.Conv2d(input, 32, 5, padding=2)
        self.conv2 = nn.Conv2d(32, 32, 3, padding=1)
        self.conv3 = nn.Conv2d(38, output, 3, padding=1)
        self.prelu1 = nn.PReLU()
        self.prelu2 = nn.PReLU()

    def forward(self, x1):
        x = self.prelu1(self.conv1(x1))
        x = self.prelu2(torch.cat([self.conv2(x), x1], dim=1)) 
        x = self.conv3(x)
        return x

class RFR(nn.Module):
    def __init__(self, args):
        super(RFR, self).__init__()
        self.attention2 = ErrorAttention(6, 1)
        self.hidden_dim = hdim = 128
        self.context_dim = cdim = 128
        args.corr_levels = 4
        args.corr_radius = 4
        args.dropout = 0
        self.args = args

        # feature network, context network, and update block
        self.fnet = BasicEncoder(output_dim=256, norm_fn='none', dropout=args.dropout)        
        # self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='none', dropout=args.dropout)
        self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim)
        
        

    def freeze_bn(self):
        for m in self.modules():
            if isinstance(m, nn.BatchNorm2d):
                m.eval()

    def initialize_flow(self, img):
        """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0"""
        N, C, H, W = img.shape
        coords0 = coords_grid(N, H//8, W//8).to(img.device)
        coords1 = coords_grid(N, H//8, W//8).to(img.device)

        # optical flow computed as difference: flow = coords1 - coords0
        return coords0, coords1

    def upsample_flow(self, flow, mask):
        """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """
        N, _, H, W = flow.shape
        mask = mask.view(N, 1, 9, 8, 8, H, W)
        mask = torch.softmax(mask, dim=2)

        up_flow = F.unfold(8 * flow, [3,3], padding=1)
        up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)

        up_flow = torch.sum(mask * up_flow, dim=2)
        up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
        return up_flow.reshape(N, 2, 8*H, 8*W)

    def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False):
        H, W = image1.size()[2:4]
        H8 = H // 8 * 8
        W8 = W // 8 * 8

        if flow_init is not None:
            flow_init_resize = F.interpolate(flow_init, size=(H8//8, W8//8), mode='nearest')

            flow_init_resize[:, :1] = flow_init_resize[:, :1].clone() * (W8 // 8 *1.0) / flow_init.size()[3]
            flow_init_resize[:, 1:] = flow_init_resize[:, 1:].clone() * (H8 // 8*1.0) / flow_init.size()[2]
            
            if not hasattr(self.args, 'not_use_rfr_mask') or ( hasattr(self.args, 'not_use_rfr_mask') and (not self.args.not_use_rfr_mask)):
                im18 = F.interpolate(image1, size=(H8//8, W8//8), mode='bilinear')
                im28 = F.interpolate(image2, size=(H8//8, W8//8), mode='bilinear')
                
                warp21 = backwarp(im28, flow_init_resize)
                error21 = torch.sum(torch.abs(warp21 - im18), dim=1, keepdim=True)
                # print('errormin', error21.min(), error21.max())
                f12init = torch.exp(- self.attention2(torch.cat([im18, error21, flow_init_resize], dim=1)) ** 2) * flow_init_resize
        else:
            flow_init_resize = None
            flow_init = torch.zeros(image1.size()[0], 2, image1.size()[2]//8, image1.size()[3]//8).cuda()
            error21 = torch.zeros(image1.size()[0], 1, image1.size()[2]//8, image1.size()[3]//8).cuda()

            f12_init = flow_init
            # print('None inital flow!')
        
        image1 = F.interpolate(image1, size=(H8, W8), mode='bilinear')
        image2 = F.interpolate(image2, size=(H8, W8), mode='bilinear')

        f12s, f12, f12_init = self.forward_pred(image1, image2, iters, flow_init_resize, upsample, test_mode)
        
 
        if (hasattr(self.args, 'requires_sq_flow') and self.args.requires_sq_flow):
            for ii in range(len(f12s)):
                f12s[ii] = F.interpolate(f12s[ii], size=(H, W), mode='bilinear')
                f12s[ii][:, :1] = f12s[ii][:, :1].clone() / (1.0*W8) * W
                f12s[ii][:, 1:] = f12s[ii][:, 1:].clone() / (1.0*H8) * H
            if self.training:
                return f12s
            else:
                return [f12s[-1]], f12_init
        else:
            f12[:, :1] = f12[:, :1].clone() / (1.0*W8) * W
            f12[:, 1:] = f12[:, 1:].clone() / (1.0*H8) * H

            f12 = F.interpolate(f12, size=(H, W), mode='bilinear')
            # print('wo!!')
            return f12, f12_init, error21, 
            
    def forward_pred(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False):
        """ Estimate optical flow between pair of frames """


        image1 = image1.contiguous()
        image2 = image2.contiguous()

        hdim = self.hidden_dim
        cdim = self.context_dim

        # run the feature network
        with autocast(enabled=self.args.mixed_precision):
            fmap1, fmap2 = self.fnet([image1, image2])
        fmap1 = fmap1.float()
        fmap2 = fmap2.float()
        corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)

        # run the context network
        with autocast(enabled=self.args.mixed_precision):
            cnet = self.fnet(image1)
            net, inp = torch.split(cnet, [hdim, cdim], dim=1)
            net = torch.tanh(net)
            inp = torch.relu(inp)

        coords0, coords1 = self.initialize_flow(image1)

        if flow_init is not None:
            coords1 = coords1 + flow_init

        flow_predictions = []
        for itr in range(iters):
            coords1 = coords1.detach()
            if itr == 0 and flow_init is not None:
                coords1 = coords1 + flow_init
            corr = corr_fn(coords1) # index correlation volume

            flow = coords1 - coords0
            with autocast(enabled=self.args.mixed_precision):
                net, up_mask, delta_flow = self.update_block(net, inp, corr, flow)

            # F(t+1) = F(t) + \Delta(t)
            coords1 = coords1 + delta_flow

            # upsample predictions
            if up_mask is None:
                flow_up = upflow8(coords1 - coords0)
            else:
                flow_up = self.upsample_flow(coords1 - coords0, up_mask)

            flow_predictions.append(flow_up)


        return flow_predictions, flow_up, flow_init


In [None]:
#@title softsplat.py (fixing cupy)
%%writefile /content/AnimeInterp/models/softsplat.py
#!/usr/bin/env python
########################
# copy from soft splat #
########################
import torch

import cupy
import re

kernel_Softsplat_updateOutput = '''
	extern "C" __global__ void kernel_Softsplat_updateOutput(
		const int n,
		const float* input,
		const float* flow,
		float* output
	) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
		const int intN = ( intIndex / SIZE_3(output) / SIZE_2(output) / SIZE_1(output) ) % SIZE_0(output);
		const int intC = ( intIndex / SIZE_3(output) / SIZE_2(output)                  ) % SIZE_1(output);
		const int intY = ( intIndex / SIZE_3(output)                                   ) % SIZE_2(output);
		const int intX = ( intIndex                                                    ) % SIZE_3(output);

		float fltOutputX = (float) (intX) + VALUE_4(flow, intN, 0, intY, intX);
		float fltOutputY = (float) (intY) + VALUE_4(flow, intN, 1, intY, intX);

		int intNorthwestX = (int) (floor(fltOutputX));
		int intNorthwestY = (int) (floor(fltOutputY));
		int intNortheastX = intNorthwestX + 1;
		int intNortheastY = intNorthwestY;
		int intSouthwestX = intNorthwestX;
		int intSouthwestY = intNorthwestY + 1;
		int intSoutheastX = intNorthwestX + 1;
		int intSoutheastY = intNorthwestY + 1;

		float fltNorthwest = ((float) (intSoutheastX) - fltOutputX   ) * ((float) (intSoutheastY) - fltOutputY   );
		float fltNortheast = (fltOutputX    - (float) (intSouthwestX)) * ((float) (intSouthwestY) - fltOutputY   );
		float fltSouthwest = ((float) (intNortheastX) - fltOutputX   ) * (fltOutputY    - (float) (intNortheastY));
		float fltSoutheast = (fltOutputX    - (float) (intNorthwestX)) * (fltOutputY    - (float) (intNorthwestY));

		if ((intNorthwestX >= 0) & (intNorthwestX < SIZE_3(output)) & (intNorthwestY >= 0) & (intNorthwestY < SIZE_2(output))) {
			atomicAdd(&output[OFFSET_4(output, intN, intC, intNorthwestY, intNorthwestX)], VALUE_4(input, intN, intC, intY, intX) * fltNorthwest);
		}

		if ((intNortheastX >= 0) & (intNortheastX < SIZE_3(output)) & (intNortheastY >= 0) & (intNortheastY < SIZE_2(output))) {
			atomicAdd(&output[OFFSET_4(output, intN, intC, intNortheastY, intNortheastX)], VALUE_4(input, intN, intC, intY, intX) * fltNortheast);
		}

		if ((intSouthwestX >= 0) & (intSouthwestX < SIZE_3(output)) & (intSouthwestY >= 0) & (intSouthwestY < SIZE_2(output))) {
			atomicAdd(&output[OFFSET_4(output, intN, intC, intSouthwestY, intSouthwestX)], VALUE_4(input, intN, intC, intY, intX) * fltSouthwest);
		}

		if ((intSoutheastX >= 0) & (intSoutheastX < SIZE_3(output)) & (intSoutheastY >= 0) & (intSoutheastY < SIZE_2(output))) {
			atomicAdd(&output[OFFSET_4(output, intN, intC, intSoutheastY, intSoutheastX)], VALUE_4(input, intN, intC, intY, intX) * fltSoutheast);
		}
	} }
'''

kernel_Softsplat_updateGradInput = '''
	extern "C" __global__ void kernel_Softsplat_updateGradInput(
		const int n,
		const float* input,
		const float* flow,
		const float* gradOutput,
		float* gradInput,
		float* gradFlow
	) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
		const int intN = ( intIndex / SIZE_3(gradInput) / SIZE_2(gradInput) / SIZE_1(gradInput) ) % SIZE_0(gradInput);
		const int intC = ( intIndex / SIZE_3(gradInput) / SIZE_2(gradInput)                     ) % SIZE_1(gradInput);
		const int intY = ( intIndex / SIZE_3(gradInput)                                         ) % SIZE_2(gradInput);
		const int intX = ( intIndex                                                             ) % SIZE_3(gradInput);

		float fltGradInput = 0.0;

		float fltOutputX = (float) (intX) + VALUE_4(flow, intN, 0, intY, intX);
		float fltOutputY = (float) (intY) + VALUE_4(flow, intN, 1, intY, intX);

		int intNorthwestX = (int) (floor(fltOutputX));
		int intNorthwestY = (int) (floor(fltOutputY));
		int intNortheastX = intNorthwestX + 1;
		int intNortheastY = intNorthwestY;
		int intSouthwestX = intNorthwestX;
		int intSouthwestY = intNorthwestY + 1;
		int intSoutheastX = intNorthwestX + 1;
		int intSoutheastY = intNorthwestY + 1;

		float fltNorthwest = ((float) (intSoutheastX) - fltOutputX   ) * ((float) (intSoutheastY) - fltOutputY   );
		float fltNortheast = (fltOutputX    - (float) (intSouthwestX)) * ((float) (intSouthwestY) - fltOutputY   );
		float fltSouthwest = ((float) (intNortheastX) - fltOutputX   ) * (fltOutputY    - (float) (intNortheastY));
		float fltSoutheast = (fltOutputX    - (float) (intNorthwestX)) * (fltOutputY    - (float) (intNorthwestY));

		if ((intNorthwestX >= 0) & (intNorthwestX < SIZE_3(gradOutput)) & (intNorthwestY >= 0) & (intNorthwestY < SIZE_2(gradOutput))) {
			fltGradInput += VALUE_4(gradOutput, intN, intC, intNorthwestY, intNorthwestX) * fltNorthwest;
		}

		if ((intNortheastX >= 0) & (intNortheastX < SIZE_3(gradOutput)) & (intNortheastY >= 0) & (intNortheastY < SIZE_2(gradOutput))) {
			fltGradInput += VALUE_4(gradOutput, intN, intC, intNortheastY, intNortheastX) * fltNortheast;
		}

		if ((intSouthwestX >= 0) & (intSouthwestX < SIZE_3(gradOutput)) & (intSouthwestY >= 0) & (intSouthwestY < SIZE_2(gradOutput))) {
			fltGradInput += VALUE_4(gradOutput, intN, intC, intSouthwestY, intSouthwestX) * fltSouthwest;
		}

		if ((intSoutheastX >= 0) & (intSoutheastX < SIZE_3(gradOutput)) & (intSoutheastY >= 0) & (intSoutheastY < SIZE_2(gradOutput))) {
			fltGradInput += VALUE_4(gradOutput, intN, intC, intSoutheastY, intSoutheastX) * fltSoutheast;
		}

		gradInput[intIndex] = fltGradInput;
	} }
'''

kernel_Softsplat_updateGradFlow = '''
	extern "C" __global__ void kernel_Softsplat_updateGradFlow(
		const int n,
		const float* input,
		const float* flow,
		const float* gradOutput,
		float* gradInput,
		float* gradFlow
	) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
		float fltGradFlow = 0.0;

		const int intN = ( intIndex / SIZE_3(gradFlow) / SIZE_2(gradFlow) / SIZE_1(gradFlow) ) % SIZE_0(gradFlow);
		const int intC = ( intIndex / SIZE_3(gradFlow) / SIZE_2(gradFlow)                    ) % SIZE_1(gradFlow);
		const int intY = ( intIndex / SIZE_3(gradFlow)                                       ) % SIZE_2(gradFlow);
		const int intX = ( intIndex                                                          ) % SIZE_3(gradFlow);

		float fltOutputX = (float) (intX) + VALUE_4(flow, intN, 0, intY, intX);
		float fltOutputY = (float) (intY) + VALUE_4(flow, intN, 1, intY, intX);

		int intNorthwestX = (int) (floor(fltOutputX));
		int intNorthwestY = (int) (floor(fltOutputY));
		int intNortheastX = intNorthwestX + 1;
		int intNortheastY = intNorthwestY;
		int intSouthwestX = intNorthwestX;
		int intSouthwestY = intNorthwestY + 1;
		int intSoutheastX = intNorthwestX + 1;
		int intSoutheastY = intNorthwestY + 1;

		float fltNorthwest = 0.0;
		float fltNortheast = 0.0;
		float fltSouthwest = 0.0;
		float fltSoutheast = 0.0;

		if (intC == 0) {
			fltNorthwest = ((float) (-1.0)) * ((float) (intSoutheastY) - fltOutputY   );
			fltNortheast = ((float) (+1.0)) * ((float) (intSouthwestY) - fltOutputY   );
			fltSouthwest = ((float) (-1.0)) * (fltOutputY    - (float) (intNortheastY));
			fltSoutheast = ((float) (+1.0)) * (fltOutputY    - (float) (intNorthwestY));

		} else if (intC == 1) {
			fltNorthwest = ((float) (intSoutheastX) - fltOutputX   ) * ((float) (-1.0));
			fltNortheast = (fltOutputX    - (float) (intSouthwestX)) * ((float) (-1.0));
			fltSouthwest = ((float) (intNortheastX) - fltOutputX   ) * ((float) (+1.0));
			fltSoutheast = (fltOutputX    - (float) (intNorthwestX)) * ((float) (+1.0));

		}

		for (int intChannel = 0; intChannel < SIZE_1(gradOutput); intChannel += 1) {
			float fltInput = VALUE_4(input, intN, intChannel, intY, intX);

			if ((intNorthwestX >= 0) & (intNorthwestX < SIZE_3(gradOutput)) & (intNorthwestY >= 0) & (intNorthwestY < SIZE_2(gradOutput))) {
				fltGradFlow += fltInput * VALUE_4(gradOutput, intN, intChannel, intNorthwestY, intNorthwestX) * fltNorthwest;
			}

			if ((intNortheastX >= 0) & (intNortheastX < SIZE_3(gradOutput)) & (intNortheastY >= 0) & (intNortheastY < SIZE_2(gradOutput))) {
				fltGradFlow += fltInput * VALUE_4(gradOutput, intN, intChannel, intNortheastY, intNortheastX) * fltNortheast;
			}

			if ((intSouthwestX >= 0) & (intSouthwestX < SIZE_3(gradOutput)) & (intSouthwestY >= 0) & (intSouthwestY < SIZE_2(gradOutput))) {
				fltGradFlow += fltInput * VALUE_4(gradOutput, intN, intChannel, intSouthwestY, intSouthwestX) * fltSouthwest;
			}

			if ((intSoutheastX >= 0) & (intSoutheastX < SIZE_3(gradOutput)) & (intSoutheastY >= 0) & (intSoutheastY < SIZE_2(gradOutput))) {
				fltGradFlow += fltInput * VALUE_4(gradOutput, intN, intChannel, intSoutheastY, intSoutheastX) * fltSoutheast;
			}
		}

		gradFlow[intIndex] = fltGradFlow;
	} }
'''

def cupy_kernel(strFunction, objVariables):
	strKernel = globals()[strFunction]

	while True:
		objMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel)

		if objMatch is None:
			break
		# end

		intArg = int(objMatch.group(2))

		strTensor = objMatch.group(4)
		intSizes = objVariables[strTensor].size()

		strKernel = strKernel.replace(objMatch.group(), str(intSizes[intArg]))
	# end

	while True:
		objMatch = re.search('(OFFSET_)([0-4])(\()([^\)]+)(\))', strKernel)

		if objMatch is None:
			break
		# end

		intArgs = int(objMatch.group(2))
		strArgs = objMatch.group(4).split(',')

		strTensor = strArgs[0]
		intStrides = objVariables[strTensor].stride()
		strIndex = [ '((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg]) + ')' for intArg in range(intArgs) ]

		strKernel = strKernel.replace(objMatch.group(0), '(' + str.join('+', strIndex) + ')')
	# end

	while True:
		objMatch = re.search('(VALUE_)([0-4])(\()([^\)]+)(\))', strKernel)

		if objMatch is None:
			break
		# end

		intArgs = int(objMatch.group(2))
		strArgs = objMatch.group(4).split(',')

		strTensor = strArgs[0]
		intStrides = objVariables[strTensor].stride()
		strIndex = [ '((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg]) + ')' for intArg in range(intArgs) ]

		strKernel = strKernel.replace(objMatch.group(0), strTensor + '[' + str.join('+', strIndex) + ']')
	# end

	return strKernel
# end

@cupy.memoize(for_each_device=True)
def cupy_launch(strFunction, strKernel):
	return cupy.cuda.compile_with_cache(strKernel).get_function(strFunction)
# end

class _FunctionSoftsplat(torch.autograd.Function):
	@staticmethod
	def forward(self, input, flow):
		self.save_for_backward(input, flow)

		intSamples = input.shape[0]
		intInputDepth, intInputHeight, intInputWidth = input.shape[1], input.shape[2], input.shape[3]
		intFlowDepth, intFlowHeight, intFlowWidth = flow.shape[1], flow.shape[2], flow.shape[3]

		assert(intFlowDepth == 2)
		assert(intInputHeight == intFlowHeight)
		assert(intInputWidth == intFlowWidth)

		assert(input.is_contiguous() == True)
		assert(flow.is_contiguous() == True)

		output = input.new_zeros([ intSamples, intInputDepth, intInputHeight, intInputWidth ])

		if input.is_cuda == True:
			n = output.nelement()
			cupy_launch('kernel_Softsplat_updateOutput', cupy_kernel('kernel_Softsplat_updateOutput', {
				'input': input,
				'flow': flow,
				'output': output
			}))(
				grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]),
				block=tuple([ 512, 1, 1 ]),
				args=[ n, input.data_ptr(), flow.data_ptr(), output.data_ptr() ]
			)

		elif input.is_cuda == False:
			raise NotImplementedError()

		# end

		return output
	# end

	@staticmethod
	def backward(self, gradOutput):
		input, flow = self.saved_tensors

		intSamples = input.shape[0]
		intInputDepth, intInputHeight, intInputWidth = input.shape[1], input.shape[2], input.shape[3]
		intFlowDepth, intFlowHeight, intFlowWidth = flow.shape[1], flow.shape[2], flow.shape[3]

		assert(intFlowDepth == 2)
		assert(intInputHeight == intFlowHeight)
		assert(intInputWidth == intFlowWidth)

		# assert(gradOutput.is_contiguous() == True)

		gradInput = input.new_zeros([ intSamples, intInputDepth, intInputHeight, intInputWidth ]) if self.needs_input_grad[0] == True else None
		gradFlow = input.new_zeros([ intSamples, intFlowDepth, intFlowHeight, intFlowWidth ]) if self.needs_input_grad[1] == True else None

		if input.is_cuda == True:
			if gradInput is not None:
				n = gradInput.nelement()
				cupy_launch('kernel_Softsplat_updateGradInput', cupy_kernel('kernel_Softsplat_updateGradInput', {
					'input': input,
					'flow': flow,
					'gradOutput': gradOutput,
					'gradInput': gradInput,
					'gradFlow': gradFlow
				}))(
					grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]),
					block=tuple([ 512, 1, 1 ]),
					args=[ n, input.data_ptr(), flow.data_ptr(), gradOutput.data_ptr(), gradInput.data_ptr(), None ]
				)
			# end

			if gradFlow is not None:
				n = gradFlow.nelement()
				cupy_launch('kernel_Softsplat_updateGradFlow', cupy_kernel('kernel_Softsplat_updateGradFlow', {
					'input': input,
					'flow': flow,
					'gradOutput': gradOutput,
					'gradInput': gradInput,
					'gradFlow': gradFlow
				}))(
					grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]),
					block=tuple([ 512, 1, 1 ]),
					args=[ n, input.data_ptr(), flow.data_ptr(), gradOutput.data_ptr(), None, gradFlow.data_ptr() ]
				)
			# end

		elif input.is_cuda == False:
			raise NotImplementedError()

		# end

		return gradInput, gradFlow
	# end
# end

def FunctionSoftsplat(tenInput, tenFlow, tenMetric, strType):
	assert(tenMetric is None or tenMetric.shape[1] == 1)
	assert(strType in ['summation', 'average', 'linear', 'softmax'])

	if strType == 'average':
		tenInput = torch.cat([ tenInput, tenInput.new_ones(tenInput.shape[0], 1, tenInput.shape[2], tenInput.shape[3]) ], 1)

	elif strType == 'linear':
		tenInput = torch.cat([ tenInput * tenMetric, tenMetric ], 1)

	elif strType == 'softmax':
		tenInput = torch.cat([ tenInput * tenMetric.exp(), tenMetric.exp() ], 1)

	# end

	tenOutput = _FunctionSoftsplat.apply(tenInput, tenFlow)

	if strType == 'seperate':
		return tenOutput[:, :-1, :, :], tenOutput[:, -1:, :, :] + 0.0000001
	elif strType != 'summation':
		tenOutput = tenOutput[:, :-1, :, :] / (tenOutput[:, -1:, :, :] + 0.0000001)
	
	# end

	return tenOutput
# end

class ModuleSoftsplat(torch.nn.Module):
	def __init__(self, strType):
		super(ModuleSoftsplat, self).__init__()

		self.strType = strType
	# end

	def forward(self, tenInput, tenFlow, tenMetric=None):
		return FunctionSoftsplat(tenInput, tenFlow, tenMetric, self.strType)
	# end
# end

In [None]:
#@title AnimeInterp.py (removing flow init)
%%writefile /content/AnimeInterp/models/AnimeInterp.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import sys
import argparse

from .rfr_model.rfr_new import RFR as RFR
from .softsplat import ModuleSoftsplat as ForwardWarp
from .GridNet import GridNet




class FeatureExtractor(nn.Module):
    """The quadratic model"""
    def __init__(self, path='./network-default.pytorch'):
        super(FeatureExtractor, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.prelu1 = nn.PReLU()
        self.conv2 = nn.Conv2d(32, 32, 3, padding=1)
        self.prelu2 = nn.PReLU()
        self.conv3 = nn.Conv2d(32, 64, 3, stride=2, padding=1)
        self.prelu3 = nn.PReLU()
        self.conv4 = nn.Conv2d(64, 64, 3, padding=1)
        self.prelu4 = nn.PReLU()
        self.conv5 = nn.Conv2d(64, 96, 3, stride=2, padding=1)
        self.prelu5 = nn.PReLU()
        self.conv6 = nn.Conv2d(96, 96, 3, padding=1)
        self.prelu6 = nn.PReLU()

    def forward(self, x):
        x = self.prelu1(self.conv1(x))
        x1 = self.prelu2(self.conv2(x))
        x = self.prelu3(self.conv3(x1))
        x2 = self.prelu4(self.conv4(x))
        x = self.prelu5(self.conv5(x2))
        x3 = self.prelu6(self.conv6(x))

        return x1, x2, x3


class AnimeInterp(nn.Module):
    """The quadratic model"""
    def __init__(self, path='models/raft_model/models/rfr_sintel_latest.pth-no-zip', args=None):
        super(AnimeInterp, self).__init__()

        args = argparse.Namespace()
        args.small = False
        args.mixed_precision = False
        # args.requires_sq_flow = False

        self.flownet = RFR(args)
        self.feat_ext = FeatureExtractor()
        self.fwarp = ForwardWarp('summation')
        self.synnet = GridNet(6, 64, 128, 96*2, 3)


        if path is not None:
            dict1 = torch.load(path)
            dict2 = dict()
            for key in dict1:
                dict2[key[7:]] = dict1[key]
            self.flownet.load_state_dict(dict2, strict=False)

    def dflow(self, flo, target):
        tmp = F.interpolate(flo, target.size()[2:4])
        tmp[:, :1] = tmp[:, :1].clone() * tmp.size()[3] / flo.size()[3]
        tmp[:, 1:] = tmp[:, 1:].clone() * tmp.size()[2] / flo.size()[2]

        return tmp
    def forward(self, I1, I2, t):
        r = 0.6

        # I1 = I1[:, [2, 1, 0]]
        # I2 = I2[:, [2, 1, 0]]


        # extract features
        I1o = (I1 - 0.5) / 0.5
        I2o = (I2 - 0.5) / 0.5

        feat11, feat12, feat13 = self.feat_ext(I1o)
        feat21, feat22, feat23 = self.feat_ext(I2o)

        # calculate motion 

        # with torch.no_grad():
        # self.flownet.eval()
        F12, F12in, err12, = self.flownet(I1o, I2o, iters=12, test_mode=False, flow_init=None)
        F21, F21in, err12, = self.flownet(I2o, I1o, iters=12, test_mode=False, flow_init=None)

        F1t = t * F12
        F2t = (1-t) * F21

        F1td = self.dflow(F1t, feat11)
        F2td = self.dflow(F2t, feat21)

        F1tdd = self.dflow(F1t, feat12)
        F2tdd = self.dflow(F2t, feat22)

        F1tddd = self.dflow(F1t, feat13)
        F2tddd = self.dflow(F2t, feat23)

        # warping 
        one0 = torch.ones(I1.size(), requires_grad=True).cuda()
        one1 = torch.ones(feat11.size(), requires_grad=True).cuda()
        one2 = torch.ones(feat12.size(), requires_grad=True).cuda()
        one3 = torch.ones(feat13.size(), requires_grad=True).cuda()

        I1t = self.fwarp(I1, F1t)
        feat1t1 = self.fwarp(feat11, F1td)
        feat1t2 = self.fwarp(feat12, F1tdd)
        feat1t3 = self.fwarp(feat13, F1tddd)

        I2t = self.fwarp(I2, F2t)
        feat2t1 = self.fwarp(feat21, F2td)
        feat2t2 = self.fwarp(feat22, F2tdd)
        feat2t3 = self.fwarp(feat23, F2tddd)

        norm1 = self.fwarp(one0, F1t.clone())
        norm1t1 = self.fwarp(one1, F1td.clone())
        norm1t2 = self.fwarp(one2, F1tdd.clone())
        norm1t3 = self.fwarp(one3, F1tddd.clone())

        norm2 = self.fwarp(one0, F2t.clone())
        norm2t1 = self.fwarp(one1, F2td.clone())
        norm2t2 = self.fwarp(one2, F2tdd.clone())
        norm2t3 = self.fwarp(one3, F2tddd.clone())

        # normalize
        # Note: normalize in this way benefit training than the original "linear"
        I1t[norm1 > 0] = I1t.clone()[norm1 > 0] / norm1[norm1 > 0]
        I2t[norm2 > 0] = I2t.clone()[norm2 > 0] / norm2[norm2 > 0]
        
        feat1t1[norm1t1 > 0] = feat1t1.clone()[norm1t1 > 0] / norm1t1[norm1t1 > 0]
        feat2t1[norm2t1 > 0] = feat2t1.clone()[norm2t1 > 0] / norm2t1[norm2t1 > 0]
        
        feat1t2[norm1t2 > 0] = feat1t2.clone()[norm1t2 > 0] / norm1t2[norm1t2 > 0]
        feat2t2[norm2t2 > 0] = feat2t2.clone()[norm2t2 > 0] / norm2t2[norm2t2 > 0]
        
        feat1t3[norm1t3 > 0] = feat1t3.clone()[norm1t3 > 0] / norm1t3[norm1t3 > 0]
        feat2t3[norm2t3 > 0] = feat2t3.clone()[norm2t3 > 0] / norm2t3[norm2t3 > 0]


        # synthesis
        It_warp = self.synnet(torch.cat([I1t, I2t], dim=1), torch.cat([feat1t1, feat2t1], dim=1), torch.cat([feat1t2, feat2t2], dim=1), torch.cat([feat1t3, feat2t3], dim=1))
        return It_warp, F12, F21, F12in, F21in

In [None]:
#@title gpu inference
%cd /content/AnimeInterp
from types import FrameType
from PIL import Image
import models
import argparse
import torch
import torchvision.transforms as TF
import torch.nn as nn
import os
import numpy as np
import cv2
import warnings
import numpy
from tqdm import tqdm
import glob
warnings.filterwarnings("ignore")

frames_dir = "/content/data" #@param
files = sorted(glob.glob(frames_dir + '/**/*.png', recursive=True))
del files[-1]

# https://github.com/lisiyao21/AnimeInterp/blob/49b1ea2ee0d6637292adbb157f0ba6b0e8cadb0d/datas/AniTriplet.py#L34
def _pil_loader(path, cropArea=None, resizeDim=None, frameFlip=0):
  with open(path, 'rb') as f:
    img = Image.open(f)
    resized_img = img.resize(resizeDim, Image.ANTIALIAS) if (resizeDim != None) else img
    cropped_img = resized_img.crop(cropArea) if cropArea != None else resized_img
    flipped_img = cropped_img.transpose(Image.FLIP_LEFT_RIGHT) if frameFlip else cropped_img
    return flipped_img.convert('RGB')

# https://github.com/lisiyao21/AnimeInterp/issues/8
normalize1 = TF.Normalize([0., 0., 0.], [1.0, 1.0, 1.0])
normalize2 = TF.Normalize([0, 0, 0], [1, 1, 1])
trans = TF.Compose([TF.ToTensor(), normalize1, normalize2, ])
revmean = [-x for x in [0., 0., 0.]]
revstd = [1.0 / x for x in [1, 1, 1]]
revnormalize1 = TF.Normalize([0.0, 0.0, 0.0], revstd)
revnormalize2 = TF.Normalize(revmean, [1.0, 1.0, 1.0])
revNormalize = TF.Compose([revnormalize1, revnormalize2])
revtrans = TF.Compose([revnormalize1, revnormalize2, TF.ToPILImage()])
to_img = TF.ToPILImage()

#model = getattr(models, 'AnimeInterpNoCupy')(None).cuda()
model = getattr(models, 'AnimeInterp')(None).cuda()
model = nn.DataParallel(model)
dict1 = torch.load("/content/anime_interp_full.ckpt")
model.load_state_dict(dict1['model_state_dict'], strict=False)
model.eval()

input_frame = 1
for f in tqdm(files):
  with torch.no_grad():
    filename_frame_1 = f
    filename_frame_2 = os.path.join(frames_dir, f'{input_frame+1:0>5d}.png')
    output_frame_file_path = os.path.join(frames_dir, f"{input_frame:0>5d}_0.5.png")
    frame1 = _pil_loader(filename_frame_1)
    frame2 = _pil_loader(filename_frame_2)
    transform1 = TF.Compose([TF.ToTensor()])
    frame1 = transform1(frame1).unsqueeze(0)
    frame2 = transform1(frame2).unsqueeze(0)
    outputs = model(frame1.cuda(), frame2.cuda(), 0.5)
    It_warp = outputs[0]
    to_img(revNormalize(It_warp.cpu()[0]).clamp(0.0, 1.0)).save(output_frame_file_path)
    input_frame += 1

In [None]:
# img -> video with ffmpeg
# customize the ffmpeg command if needed
# this is a very simple ffmpeg command, currently only creating video without sound
%cd /content/data
import cv2
video = cv2.VideoCapture("/content/test.mkv");
fps = 2*video.get(cv2.CAP_PROP_FPS)
%shell ffmpeg -y -r {fps} -f image2 -pattern_type glob -i '*.png' -crf 18 "/content/output.mp4"

In [None]:
# copy video back
!cp /content/output.mp4 /content/drive/MyDrive/output.mp4

In [None]:
# delete output if needed
%cd /content/
!sudo rm -rf /content/data
!sudo rm -rf /content/output.mp4
!mkdir /content/data