In [83]:
import torch
from torch import nn
from torchvision.models import resnet50
from torchvision.models.detection import faster_rcnn
import torch.nn.functional as F
import sys
sys.path.append('core')

import argparse
import os
import cv2
import glob
import numpy as np
import torch
from PIL import Image

from raft import RAFT
from utils import flow_viz
from utils.utils import InputPadder

In [68]:
DEVICE = 'cuda'

In [69]:
resnet = resnet50(pretrained = True).to(DEVICE)



In [77]:
for name, _ in resnet.named_children():
    print(name)

conv1
bn1
relu
maxpool
layer1
layer2
layer3
layer4
avgpool
fc


In [70]:
class FeatureExtractor(nn.Module):
    def __init__(self, model : nn.Module) -> None:
        super(FeatureExtractor, self).__init__()
        self.feature = nn.Sequential(*list(model.children())[:-2])

    def forward(self, x):
        x = self.feature(x)
        return x

In [17]:
def flow_field (img1, img2, model, iters = 20):
    _, flow_up = model(img1, img2, iters = 20, test_mode = True)
    # (1, 2, H, W)
    return flow_up

In [None]:
def load_image(imfile):
    img = np.array(Image.open(imfile)).astype(np.uint8)
    img = torch.from_numpy(img).permute(2, 0, 1).float()
    return img[None].to(DEVICE)

In [75]:
def feature_warp(f_k : torch.Tensor, flow : torch.Tensor):
    n, c, h, w = f_k.shape
    kernel_size = 2
    f_i = torch.zeros_like(f_k)
    flo = F.interpolate(flow, size=(h,w), mode='bilinear', align_corners=False)

    for px in range(w):
        for py in range(h):
            dpx = flo[:, 0:1, py, px]
            dpy = flo[:, 1:, py, px]
            i, j = torch.floor(py + dpy), torch.floor(px + dpx)
            di, dj = py + dpy - i, px + dpx - j
            G = torch.concat([di * dj, di * (1 - dj), (1 - di) * dj, (1 - di) * (1 - dj)], dim=1).reshape(n, 1, kernel_size, kernel_size)
            # n, c, kernel, kernel
            G = G.repeat(1, c, 1, 1).to(DEVICE)
            grid = torch.zeros(n, kernel_size, kernel_size, 2).to(DEVICE)
            for gy in range(kernel_size):
                for gx in range(kernel_size):
                    grid[:, gy, gx, 0:1] = 2 * (j + gx) / (w - 1) - 1
                    grid[:, gy, gx, 1:] = 2 * (i + gy) / (h - 1) - 1
            # n, c, kernel, kernel
            patch = F.grid_sample(f_k, grid,  mode='bilinear', padding_mode='zeros', align_corners=True)
            f_i[:,:, py, px] = torch.sum(G * patch, dim=(2, 3))

    return f_i

In [76]:
f_k= torch.randn(64, 2048, 7, 7).to(DEVICE)
flow = torch.randn(64, 2, 224, 224).to(DEVICE)
f_i = feature_warp(f_k, flow)

In [60]:
f_i.shape

torch.Size([64, 2048, 7, 7])

In [None]:
model = torch.nn.DataParallel(RAFT(args))
model.load_state_dict(torch.load(args.model))

model = model.module
model.to(DEVICE)
model.eval()

In [82]:
w = torch.randn(10, 2048, 1, 1)
f = torch.randn(10, 2048, 7, 7)
a = f*w

In [80]:
a.shape

torch.Size([10, 2048, 7, 7])