Skip to content

Commit

Permalink
style
Browse files Browse the repository at this point in the history
  • Loading branch information
zhanxiaohang committed Apr 29, 2019
1 parent d9389b5 commit e3955ac
Show file tree
Hide file tree
Showing 10 changed files with 266 additions and 155 deletions.
31 changes: 23 additions & 8 deletions dataset.py
@@ -1,14 +1,16 @@
import sys
import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms
import numpy as np
import io
from PIL import Image

import torch
import torchvision.transforms as transforms
from torch.utils.data import Dataset

from utils.flowlib import read_flo_file
from utils import image_crop, image_resize, image_flow_crop, image_flow_resize, flow_sampler, image_flow_aug, flow_aug

class ColorAugmentation(object):

def __init__(self, eig_vec=None, eig_val=None):
if eig_vec == None:
eig_vec = torch.Tensor([
Expand All @@ -28,6 +30,7 @@ def __call__(self, tensor):
tensor = tensor + quatity.view(3, 1, 1)
return tensor


def pil_loader(img_str, ch):
buff = io.BytesIO(img_str)
if ch == 1:
Expand All @@ -37,6 +40,7 @@ def pil_loader(img_str, ch):
img = img.convert('RGB')
return img


def pil_loader_str(img_str, ch):
if ch == 1:
return Image.open(img_str)
Expand All @@ -45,7 +49,9 @@ def pil_loader_str(img_str, ch):
img = img.convert('RGB')
return img


class ImageFlowDataset(Dataset):

def __init__(self, meta_file, config, phase):
self.img_transform = transforms.Compose([
transforms.Normalize(config['data_mean'], config['data_div'])
Expand Down Expand Up @@ -111,18 +117,22 @@ def __getitem__(self, idx):

## resize
if self.short_size is not None or self.long_size is not None:
img1, img2, flow, ratio = image_flow_resize(img1, img2, flow, short_size=self.short_size, long_size=self.long_size)
img1, img2, flow, ratio = image_flow_resize(
img1, img2, flow, short_size=self.short_size,
long_size=self.long_size)

## crop
if self.crop_size is not None:
img1, img2, flow, offset = image_flow_crop(img1, img2, flow, self.crop_size, self.phase)
img1, img2, flow, offset = image_flow_crop(
img1, img2, flow, self.crop_size, self.phase)

## augmentation
if self.phase == 'train':
# image flow aug
img1, img2, flow = image_flow_aug(img1, img2, flow, flip_horizon=self.aug_flip)
# flow aug
flow = flow_aug(flow, reverse=self.aug_reverse, scale=self.aug_scale, rotate=self.aug_rotate)
flow = flow_aug(flow, reverse=self.aug_reverse,
scale=self.aug_scale, rotate=self.aug_rotate)

## transform
img1 = torch.from_numpy(np.array(img1).astype(np.float32).transpose((2,0,1)))
Expand All @@ -131,14 +141,19 @@ def __getitem__(self, idx):
img2 = self.img_transform(img2)

## sparse sampling
sparse_flow, mask = flow_sampler(flow, strategy=self.sample_strategy, bg_ratio=self.sample_bg_ratio, nms_ks=self.nms_ks, max_num_guide=self.max_num_guide) # (h,w,2), (h,w,2)
sparse_flow, mask = flow_sampler(
flow, strategy=self.sample_strategy,
bg_ratio=self.sample_bg_ratio, nms_ks=self.nms_ks,
max_num_guide=self.max_num_guide) # (h,w,2), (h,w,2)

flow = torch.from_numpy(flow.transpose((2, 0, 1)))
sparse_flow = torch.from_numpy(sparse_flow.transpose((2, 0, 1)))
mask = torch.from_numpy(mask.transpose((2, 0, 1)).astype(np.float32))
return img1, sparse_flow, mask, flow, img2


class ImageDataset(Dataset):

def __init__(self, meta_file, config):
self.img_transform = transforms.Compose([
transforms.Normalize(config['data_mean'], config['data_div'])
Expand Down
159 changes: 80 additions & 79 deletions demos/demo_annot.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion main.py
@@ -1,7 +1,7 @@
import multiprocessing as mp
import argparse
import os
import yaml
import multiprocessing as mp

from utils import dist_init
from trainer import Trainer
Expand Down
12 changes: 9 additions & 3 deletions models/cmp.py
Expand Up @@ -7,6 +7,7 @@
from . import SingleStageModel

class CMP(SingleStageModel):

def __init__(self, params, dist_model=False):
super(CMP, self).__init__(params, dist_model)
model_params = params['module']
Expand All @@ -17,7 +18,8 @@ def __init__(self, params, dist_model=False):
elif model_params['flow_criterion'] == 'L2':
self.flow_criterion = nn.MSELoss()
elif model_params['flow_criterion'] == 'DiscreteLoss':
self.flow_criterion = losses.DiscreteLoss(nbins=model_params['nbins'], fmax=model_params['fmax'])
self.flow_criterion = losses.DiscreteLoss(
nbins=model_params['nbins'], fmax=model_params['fmax'])
else:
raise Exception("No such flow loss: {}".format(model_params['flow_criterion']))

Expand All @@ -33,7 +35,9 @@ def eval(self, ret_loss=True):
else:
self.flow = cmp_output
if self.flow.shape[2] != self.image_input.shape[2]:
self.flow = nn.functional.interpolate(self.flow, size=self.image_input.shape[2:4], mode="bilinear", align_corners=True)
self.flow = nn.functional.interpolate(
self.flow, size=self.image_input.shape[2:4],
mode="bilinear", align_corners=True)

ret_tensors = {
'flow_tensors': [self.flow, self.flow_target],
Expand All @@ -42,7 +46,9 @@ def eval(self, ret_loss=True):

if ret_loss:
if cmp_output.shape[2] != self.flow_target.shape[2]:
cmp_output = nn.functional.interpolate(cmp_output, size=self.flow_target.shape[2:4], mode="bilinear", align_corners=True)
cmp_output = nn.functional.interpolate(
cmp_output, size=self.flow_target.shape[2:4],
mode="bilinear", align_corners=True)
loss_flow = self.flow_criterion(cmp_output, self.flow_target) / self.world_size
return ret_tensors, {'loss_flow': loss_flow}
else:
Expand Down
11 changes: 8 additions & 3 deletions models/modules/cmp.py
Expand Up @@ -3,6 +3,7 @@
import models

class CMP(nn.Module):

def __init__(self, params):
super(CMP, self).__init__()
img_enc_dim = params['img_enc_dim']
Expand All @@ -14,9 +15,13 @@ def __init__(self, params):
if self.skip_layer:
assert params['flow_decoder'] == "MotionDecoderSkipLayer"

self.image_encoder = models.backbone.__dict__[params['image_encoder']](img_enc_dim, pretrained)
self.flow_encoder = models.modules.__dict__[params['sparse_encoder']](sparse_enc_dim)
self.flow_decoder = models.modules.__dict__[params['flow_decoder']](input_dim=img_enc_dim+sparse_enc_dim, output_dim=output_dim, combo=decoder_combo)
self.image_encoder = models.backbone.__dict__[params['image_encoder']](
img_enc_dim, pretrained)
self.flow_encoder = models.modules.__dict__[params['sparse_encoder']](
sparse_enc_dim)
self.flow_decoder = models.modules.__dict__[params['flow_decoder']](
input_dim=img_enc_dim+sparse_enc_dim,
output_dim=output_dim, combo=decoder_combo)

def forward(self, image, sparse):
sparse_enc = self.flow_encoder(sparse)
Expand Down
68 changes: 51 additions & 17 deletions models/modules/decoder.py
Expand Up @@ -3,6 +3,7 @@
import math

class MotionDecoderPlain(nn.Module):

def __init__(self, input_dim=512, output_dim=2, combo=[1,2,4]):
super(MotionDecoderPlain, self).__init__()
BN = nn.BatchNorm2d
Expand Down Expand Up @@ -72,20 +73,28 @@ def forward(self, x):
x1 = self.decoder1(x)
cat_list.append(x1)
if 2 in self.combo:
x2 = nn.functional.interpolate(self.decoder2(x), size=(x.size(2), x.size(3)), mode="bilinear", align_corners=True)
x2 = nn.functional.interpolate(
self.decoder2(x), size=(x.size(2), x.size(3)),
mode="bilinear", align_corners=True)
cat_list.append(x2)
if 4 in self.combo:
x4 = nn.functional.interpolate(self.decoder4(x), size=(x.size(2), x.size(3)), mode="bilinear", align_corners=True)
x4 = nn.functional.interpolate(
self.decoder4(x), size=(x.size(2), x.size(3)),
mode="bilinear", align_corners=True)
cat_list.append(x4)
if 8 in self.combo:
x8 = nn.functional.interpolate(self.decoder8(x), size=(x.size(2), x.size(3)), mode="bilinear", align_corners=True)
x8 = nn.functional.interpolate(
self.decoder8(x), size=(x.size(2), x.size(3)),
mode="bilinear", align_corners=True)
cat_list.append(x8)

cat = torch.cat(cat_list, dim=1)
flow = self.head(cat)
return flow


class MotionDecoderSkipLayer(nn.Module):

def __init__(self, input_dim=512, output_dim=2, combo=[1,2,4,8]):
super(MotionDecoderSkipLayer, self).__init__()

Expand Down Expand Up @@ -180,22 +189,34 @@ def forward(self, x, skip_feat):
layer1, layer2, layer4 = skip_feat

x1 = self.decoder1(x)
x2 = nn.functional.interpolate(self.decoder2(x), size=(x1.size(2), x1.size(3)), mode="bilinear", align_corners=True)
x4 = nn.functional.interpolate(self.decoder4(x), size=(x1.size(2), x1.size(3)), mode="bilinear", align_corners=True)
x8 = nn.functional.interpolate(self.decoder8(x), size=(x1.size(2), x1.size(3)), mode="bilinear", align_corners=True)
x2 = nn.functional.interpolate(
self.decoder2(x), size=(x1.size(2), x1.size(3)),
mode="bilinear", align_corners=True)
x4 = nn.functional.interpolate(
self.decoder4(x), size=(x1.size(2), x1.size(3)),
mode="bilinear", align_corners=True)
x8 = nn.functional.interpolate(
self.decoder8(x), size=(x1.size(2), x1.size(3)),
mode="bilinear", align_corners=True)
cat = torch.cat([x1, x2, x4, x8], dim=1)
f8 = self.fusion8(cat)

f8_up = nn.functional.interpolate(f8, size=(layer4.size(2), layer4.size(3)), mode="bilinear", align_corners=True)
f8_up = nn.functional.interpolate(
f8, size=(layer4.size(2), layer4.size(3)),
mode="bilinear", align_corners=True)
f4 = self.fusion4(torch.cat([f8_up, self.skipconv4(layer4)], dim=1))

f4_up = nn.functional.interpolate(f4, size=(layer2.size(2), layer2.size(3)), mode="bilinear", align_corners=True)
f4_up = nn.functional.interpolate(
f4, size=(layer2.size(2), layer2.size(3)),
mode="bilinear", align_corners=True)
f2 = self.fusion2(torch.cat([f4_up, self.skipconv2(layer2)], dim=1))

flow = self.head(f2)
return flow


class MotionDecoderFlowNet(nn.Module):

def __init__(self, input_dim=512, output_dim=2, combo=[1,2,4,8]):
super(MotionDecoderFlowNet, self).__init__()
global BN
Expand Down Expand Up @@ -260,9 +281,12 @@ def __init__(self, input_dim=512, output_dim=2, combo=[1,2,4,8]):
self.predict_flow2 = predict_flow(192 + output_dim, output_dim)
self.predict_flow1 = predict_flow(67 + output_dim, output_dim)

self.upsampled_flow8_to_4 = nn.ConvTranspose2d(output_dim, output_dim, 4, 2, 1, bias=False)
self.upsampled_flow4_to_2 = nn.ConvTranspose2d(output_dim, output_dim, 4, 2, 1, bias=False)
self.upsampled_flow2_to_1 = nn.ConvTranspose2d(output_dim, output_dim, 4, 2, 1, bias=False)
self.upsampled_flow8_to_4 = nn.ConvTranspose2d(
output_dim, output_dim, 4, 2, 1, bias=False)
self.upsampled_flow4_to_2 = nn.ConvTranspose2d(
output_dim, output_dim, 4, 2, 1, bias=False)
self.upsampled_flow2_to_1 = nn.ConvTranspose2d(
output_dim, output_dim, 4, 2, 1, bias=False)

self.deconv8 = deconv(256, 128)
self.deconv4 = deconv(384 + output_dim, 128)
Expand All @@ -286,9 +310,15 @@ def forward(self, x, skip_feat):

# propagation nets
x1 = self.decoder1(x)
x2 = nn.functional.interpolate(self.decoder2(x), size=(x1.size(2), x1.size(3)), mode="bilinear", align_corners=True)
x4 = nn.functional.interpolate(self.decoder4(x), size=(x1.size(2), x1.size(3)), mode="bilinear", align_corners=True)
x8 = nn.functional.interpolate(self.decoder8(x), size=(x1.size(2), x1.size(3)), mode="bilinear", align_corners=True)
x2 = nn.functional.interpolate(
self.decoder2(x), size=(x1.size(2), x1.size(3)),
mode="bilinear", align_corners=True)
x4 = nn.functional.interpolate(
self.decoder4(x), size=(x1.size(2), x1.size(3)),
mode="bilinear", align_corners=True)
x8 = nn.functional.interpolate(
self.decoder8(x), size=(x1.size(2), x1.size(3)),
mode="bilinear", align_corners=True)
cat = torch.cat([x1, x2, x4, x8], dim=1)
feat8 = self.fusion8(cat) # 256

Expand All @@ -312,13 +342,17 @@ def forward(self, x, skip_feat):

return [flow1, flow2, flow4, flow8]


def predict_flow(in_planes, out_planes):
return nn.Conv2d(in_planes, out_planes, kernel_size=3,stride=1,padding=1,bias=True)
return nn.Conv2d(in_planes, out_planes, kernel_size=3,
stride=1, padding=1, bias=True)


def deconv(in_planes, out_planes):
return nn.Sequential(
nn.ConvTranspose2d(in_planes, out_planes, kernel_size=4, stride=2, padding=1, bias=True),
nn.LeakyReLU(0.1,inplace=True)
nn.ConvTranspose2d(in_planes, out_planes, kernel_size=4,
stride=2, padding=1, bias=True),
nn.LeakyReLU(0.1, inplace=True)
)


2 changes: 2 additions & 0 deletions models/modules/others.py
@@ -1,9 +1,11 @@
import torch.nn as nn

class FixModule(nn.Module):

def __init__(self, m):
super(FixModule, self).__init__()
self.module = m

def forward(self, *args, **kwargs):
return self.module(*args, **kwargs)

23 changes: 16 additions & 7 deletions models/modules/warp.py
@@ -1,9 +1,8 @@
import torch
import torch.nn as nn

import models

class WarpingLayerBWFlow(nn.Module):

def __init__(self):
super(WarpingLayerBWFlow, self).__init__()

Expand All @@ -12,24 +11,34 @@ def forward(self, image, flow):
flow_for_grip[:,0,:,:] = flow[:,0,:,:] / ((flow.size(3) - 1.0) / 2.0)
flow_for_grip[:,1,:,:] = flow[:,1,:,:] / ((flow.size(2) - 1.0) / 2.0)

torchHorizontal = torch.linspace(-1.0, 1.0, image.size(3)).view(1, 1, 1, image.size(3)).expand(image.size(0), 1, image.size(2), image.size(3))
torchVertical = torch.linspace(-1.0, 1.0, image.size(2)).view(1, 1, image.size(2), 1).expand(image.size(0), 1, image.size(2), image.size(3))
torchHorizontal = torch.linspace(
-1.0, 1.0, image.size(3)).view(
1, 1, 1, image.size(3)).expand(
image.size(0), 1, image.size(2), image.size(3))
torchVertical = torch.linspace(
-1.0, 1.0, image.size(2)).view(
1, 1, image.size(2), 1).expand(
image.size(0), 1, image.size(2), image.size(3))
grid = torch.cat([torchHorizontal, torchVertical], 1).cuda()

grid = (grid + flow_for_grip).permute(0, 2, 3, 1)
return torch.nn.functional.grid_sample(image, grid)


class WarpingLayerFWFlow(nn.Module):

def __init__(self):
super(WarpingLayerFWFlow, self).__init__()
self.initialized = False

def forward(self, image, flow, ret_mask = False):
n, h, w = image.size(0), image.size(2), image.size(3)

if not self.initialized or n != self.meshx.shape[0] or h * w != self.meshx.shape[1]:
self.meshx = torch.arange(w).view(1, 1, w).expand(n, h, w).contiguous().view(n, -1).cuda()
self.meshy = torch.arange(h).view(1, h, 1).expand(n, h, w).contiguous().view(n, -1).cuda()
self.meshx = torch.arange(w).view(1, 1, w).expand(
n, h, w).contiguous().view(n, -1).cuda()
self.meshy = torch.arange(h).view(1, h, 1).expand(
n, h, w).contiguous().view(n, -1).cuda()
self.warped_image = torch.zeros((n, 3, h, w), dtype=torch.float32).cuda()
if ret_mask:
self.hole_mask = torch.ones((n, 1, h, w), dtype=torch.float32).cuda()
Expand Down

0 comments on commit e3955ac

Please sign in to comment.