From e3955ac28255c3b8076bd9f294833a01130682af Mon Sep 17 00:00:00 2001 From: zhanxiaohang Date: Mon, 29 Apr 2019 15:53:03 +0800 Subject: [PATCH] style --- dataset.py | 31 +++++-- demos/demo_annot.ipynb | 159 ++++++++++++++++++----------------- main.py | 2 +- models/cmp.py | 12 ++- models/modules/cmp.py | 11 ++- models/modules/decoder.py | 68 +++++++++++---- models/modules/others.py | 2 + models/modules/warp.py | 23 +++-- models/single_stage_model.py | 15 ++-- trainer.py | 98 ++++++++++++++------- 10 files changed, 266 insertions(+), 155 deletions(-) diff --git a/dataset.py b/dataset.py index d48df03..f7d2484 100644 --- a/dataset.py +++ b/dataset.py @@ -1,14 +1,16 @@ -import sys -import torch -from torch.utils.data import Dataset -import torchvision.transforms as transforms import numpy as np import io from PIL import Image + +import torch +import torchvision.transforms as transforms +from torch.utils.data import Dataset + from utils.flowlib import read_flo_file from utils import image_crop, image_resize, image_flow_crop, image_flow_resize, flow_sampler, image_flow_aug, flow_aug class ColorAugmentation(object): + def __init__(self, eig_vec=None, eig_val=None): if eig_vec == None: eig_vec = torch.Tensor([ @@ -28,6 +30,7 @@ def __call__(self, tensor): tensor = tensor + quatity.view(3, 1, 1) return tensor + def pil_loader(img_str, ch): buff = io.BytesIO(img_str) if ch == 1: @@ -37,6 +40,7 @@ def pil_loader(img_str, ch): img = img.convert('RGB') return img + def pil_loader_str(img_str, ch): if ch == 1: return Image.open(img_str) @@ -45,7 +49,9 @@ def pil_loader_str(img_str, ch): img = img.convert('RGB') return img + class ImageFlowDataset(Dataset): + def __init__(self, meta_file, config, phase): self.img_transform = transforms.Compose([ transforms.Normalize(config['data_mean'], config['data_div']) @@ -111,18 +117,22 @@ def __getitem__(self, idx): ## resize if self.short_size is not None or self.long_size is not None: - img1, img2, flow, ratio = image_flow_resize(img1, img2, flow, short_size=self.short_size, long_size=self.long_size) + img1, img2, flow, ratio = image_flow_resize( + img1, img2, flow, short_size=self.short_size, + long_size=self.long_size) ## crop if self.crop_size is not None: - img1, img2, flow, offset = image_flow_crop(img1, img2, flow, self.crop_size, self.phase) + img1, img2, flow, offset = image_flow_crop( + img1, img2, flow, self.crop_size, self.phase) ## augmentation if self.phase == 'train': # image flow aug img1, img2, flow = image_flow_aug(img1, img2, flow, flip_horizon=self.aug_flip) # flow aug - flow = flow_aug(flow, reverse=self.aug_reverse, scale=self.aug_scale, rotate=self.aug_rotate) + flow = flow_aug(flow, reverse=self.aug_reverse, + scale=self.aug_scale, rotate=self.aug_rotate) ## transform img1 = torch.from_numpy(np.array(img1).astype(np.float32).transpose((2,0,1))) @@ -131,14 +141,19 @@ def __getitem__(self, idx): img2 = self.img_transform(img2) ## sparse sampling - sparse_flow, mask = flow_sampler(flow, strategy=self.sample_strategy, bg_ratio=self.sample_bg_ratio, nms_ks=self.nms_ks, max_num_guide=self.max_num_guide) # (h,w,2), (h,w,2) + sparse_flow, mask = flow_sampler( + flow, strategy=self.sample_strategy, + bg_ratio=self.sample_bg_ratio, nms_ks=self.nms_ks, + max_num_guide=self.max_num_guide) # (h,w,2), (h,w,2) flow = torch.from_numpy(flow.transpose((2, 0, 1))) sparse_flow = torch.from_numpy(sparse_flow.transpose((2, 0, 1))) mask = torch.from_numpy(mask.transpose((2, 0, 1)).astype(np.float32)) return img1, sparse_flow, mask, flow, img2 + class ImageDataset(Dataset): + def __init__(self, meta_file, config): self.img_transform = transforms.Compose([ transforms.Normalize(config['data_mean'], config['data_div']) diff --git a/demos/demo_annot.ipynb b/demos/demo_annot.ipynb index 701e236..c3361bd 100644 --- a/demos/demo_annot.ipynb +++ b/demos/demo_annot.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": 61, "metadata": {}, "outputs": [], "source": [ @@ -28,7 +28,7 @@ "import importlib\n", "importlib.reload(utils)\n", "\n", - "exp = '../experiments/semiauto_annot/resnet50_vip+mpii_liteflow'\n", + "exp = '../experiments/semiauto_annot_verify/resnet50_vip+mpii_liteflow'\n", "\n", "load_iter = 42000\n", "config = \"{}/config.yaml\".format(exp)" @@ -36,86 +36,86 @@ }, { "cell_type": "code", - "execution_count": 33, + "execution_count": 62, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "=> loading checkpoint '../experiments/semiauto_annot/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar'\n", - "caution: missing keys from checkpoint ../experiments/semiauto_annot/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer1.1.bn2.num_batches_tracked\n", - "caution: missing keys from checkpoint ../experiments/semiauto_annot/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer1.1.bn3.num_batches_tracked\n", - "caution: missing keys from checkpoint ../experiments/semiauto_annot/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer3.0.bn2.num_batches_tracked\n", - "caution: missing keys from checkpoint ../experiments/semiauto_annot/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer3.3.bn1.num_batches_tracked\n", - "caution: missing keys from checkpoint ../experiments/semiauto_annot/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer4.1.bn3.num_batches_tracked\n", - "caution: missing keys from checkpoint ../experiments/semiauto_annot/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.flow_decoder.decoder2.8.num_batches_tracked\n", - "caution: missing keys from checkpoint ../experiments/semiauto_annot/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer3.0.bn3.num_batches_tracked\n", - "caution: missing keys from checkpoint ../experiments/semiauto_annot/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer3.1.bn2.num_batches_tracked\n", - "caution: missing keys from checkpoint ../experiments/semiauto_annot/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer3.1.bn3.num_batches_tracked\n", - "caution: missing keys from checkpoint ../experiments/semiauto_annot/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer4.0.bn3.num_batches_tracked\n", - "caution: missing keys from checkpoint ../experiments/semiauto_annot/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer2.1.bn1.num_batches_tracked\n", - "caution: missing keys from checkpoint ../experiments/semiauto_annot/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer3.3.bn2.num_batches_tracked\n", - "caution: missing keys from checkpoint ../experiments/semiauto_annot/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.flow_decoder.decoder8.5.num_batches_tracked\n", - "caution: missing keys from checkpoint ../experiments/semiauto_annot/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer2.0.bn1.num_batches_tracked\n", - "caution: missing keys from checkpoint ../experiments/semiauto_annot/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer3.0.bn1.num_batches_tracked\n", - "caution: missing keys from checkpoint ../experiments/semiauto_annot/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer2.3.bn2.num_batches_tracked\n", - "caution: missing keys from checkpoint ../experiments/semiauto_annot/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer1.0.downsample.1.num_batches_tracked\n", - "caution: missing keys from checkpoint ../experiments/semiauto_annot/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer1.2.bn3.num_batches_tracked\n", - "caution: missing keys from checkpoint ../experiments/semiauto_annot/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer2.2.bn3.num_batches_tracked\n", - "caution: missing keys from checkpoint ../experiments/semiauto_annot/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer1.0.bn1.num_batches_tracked\n", - "caution: missing keys from checkpoint ../experiments/semiauto_annot/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer3.5.bn2.num_batches_tracked\n", - "caution: missing keys from checkpoint ../experiments/semiauto_annot/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer3.2.bn1.num_batches_tracked\n", - "caution: missing keys from checkpoint ../experiments/semiauto_annot/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer3.0.downsample.1.num_batches_tracked\n", - "caution: missing keys from checkpoint ../experiments/semiauto_annot/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.flow_decoder.fusion4.1.num_batches_tracked\n", - "caution: missing keys from checkpoint ../experiments/semiauto_annot/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.flow_decoder.decoder8.8.num_batches_tracked\n", - "caution: missing keys from checkpoint ../experiments/semiauto_annot/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer2.0.bn2.num_batches_tracked\n", - "caution: missing keys from checkpoint ../experiments/semiauto_annot/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.flow_decoder.skipconv2.1.num_batches_tracked\n", - "caution: missing keys from checkpoint ../experiments/semiauto_annot/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer2.0.bn3.num_batches_tracked\n", - "caution: missing keys from checkpoint ../experiments/semiauto_annot/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer2.1.bn3.num_batches_tracked\n", - "caution: missing keys from checkpoint ../experiments/semiauto_annot/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer1.2.bn2.num_batches_tracked\n", - "caution: missing keys from checkpoint ../experiments/semiauto_annot/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer3.4.bn2.num_batches_tracked\n", - "caution: missing keys from checkpoint ../experiments/semiauto_annot/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.bn1.num_batches_tracked\n", - "caution: missing keys from checkpoint ../experiments/semiauto_annot/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer3.4.bn3.num_batches_tracked\n", - "caution: missing keys from checkpoint ../experiments/semiauto_annot/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer2.2.bn2.num_batches_tracked\n", - "caution: missing keys from checkpoint ../experiments/semiauto_annot/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer1.2.bn1.num_batches_tracked\n", - "caution: missing keys from checkpoint ../experiments/semiauto_annot/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer4.1.bn2.num_batches_tracked\n", - "caution: missing keys from checkpoint ../experiments/semiauto_annot/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer4.2.bn1.num_batches_tracked\n", - "caution: missing keys from checkpoint ../experiments/semiauto_annot/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.flow_decoder.decoder1.4.num_batches_tracked\n", - "caution: missing keys from checkpoint ../experiments/semiauto_annot/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer3.4.bn1.num_batches_tracked\n", - "caution: missing keys from checkpoint ../experiments/semiauto_annot/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.flow_decoder.fusion2.1.num_batches_tracked\n", - "caution: missing keys from checkpoint ../experiments/semiauto_annot/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer3.5.bn1.num_batches_tracked\n", - "caution: missing keys from checkpoint ../experiments/semiauto_annot/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.flow_decoder.decoder2.2.num_batches_tracked\n", - "caution: missing keys from checkpoint ../experiments/semiauto_annot/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.flow_decoder.skipconv4.1.num_batches_tracked\n", - "caution: missing keys from checkpoint ../experiments/semiauto_annot/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer2.2.bn1.num_batches_tracked\n", - "caution: missing keys from checkpoint ../experiments/semiauto_annot/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer3.2.bn2.num_batches_tracked\n", - "caution: missing keys from checkpoint ../experiments/semiauto_annot/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer1.0.bn2.num_batches_tracked\n", - "caution: missing keys from checkpoint ../experiments/semiauto_annot/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer4.1.bn1.num_batches_tracked\n", - "caution: missing keys from checkpoint ../experiments/semiauto_annot/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.flow_decoder.decoder4.5.num_batches_tracked\n", - "caution: missing keys from checkpoint ../experiments/semiauto_annot/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer2.3.bn3.num_batches_tracked\n", - "caution: missing keys from checkpoint ../experiments/semiauto_annot/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer4.0.bn1.num_batches_tracked\n", - "caution: missing keys from checkpoint ../experiments/semiauto_annot/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer1.1.bn1.num_batches_tracked\n", - "caution: missing keys from checkpoint ../experiments/semiauto_annot/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.flow_decoder.decoder1.7.num_batches_tracked\n", - "caution: missing keys from checkpoint ../experiments/semiauto_annot/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer4.2.bn3.num_batches_tracked\n", - "caution: missing keys from checkpoint ../experiments/semiauto_annot/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer2.1.bn2.num_batches_tracked\n", - "caution: missing keys from checkpoint ../experiments/semiauto_annot/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.flow_decoder.decoder8.2.num_batches_tracked\n", - "caution: missing keys from checkpoint ../experiments/semiauto_annot/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer3.1.bn1.num_batches_tracked\n", - "caution: missing keys from checkpoint ../experiments/semiauto_annot/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer2.3.bn1.num_batches_tracked\n", - "caution: missing keys from checkpoint ../experiments/semiauto_annot/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.flow_decoder.decoder4.2.num_batches_tracked\n", - "caution: missing keys from checkpoint ../experiments/semiauto_annot/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer3.5.bn3.num_batches_tracked\n", - "caution: missing keys from checkpoint ../experiments/semiauto_annot/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.flow_encoder.features.5.num_batches_tracked\n", - "caution: missing keys from checkpoint ../experiments/semiauto_annot/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer3.2.bn3.num_batches_tracked\n", - "caution: missing keys from checkpoint ../experiments/semiauto_annot/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer4.0.bn2.num_batches_tracked\n", - "caution: missing keys from checkpoint ../experiments/semiauto_annot/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer4.0.downsample.1.num_batches_tracked\n", - "caution: missing keys from checkpoint ../experiments/semiauto_annot/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer2.0.downsample.1.num_batches_tracked\n", - "caution: missing keys from checkpoint ../experiments/semiauto_annot/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.flow_decoder.decoder4.8.num_batches_tracked\n", - "caution: missing keys from checkpoint ../experiments/semiauto_annot/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.flow_decoder.decoder2.5.num_batches_tracked\n", - "caution: missing keys from checkpoint ../experiments/semiauto_annot/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer4.2.bn2.num_batches_tracked\n", - "caution: missing keys from checkpoint ../experiments/semiauto_annot/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer1.0.bn3.num_batches_tracked\n", - "caution: missing keys from checkpoint ../experiments/semiauto_annot/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.flow_encoder.features.1.num_batches_tracked\n", - "caution: missing keys from checkpoint ../experiments/semiauto_annot/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.flow_decoder.decoder1.1.num_batches_tracked\n", - "caution: missing keys from checkpoint ../experiments/semiauto_annot/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer3.3.bn3.num_batches_tracked\n", - "caution: missing keys from checkpoint ../experiments/semiauto_annot/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.flow_decoder.fusion8.1.num_batches_tracked\n" + "=> loading checkpoint '../experiments/semiauto_annot_verify/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar'\n", + "caution: missing keys from checkpoint ../experiments/semiauto_annot_verify/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer1.1.bn2.num_batches_tracked\n", + "caution: missing keys from checkpoint ../experiments/semiauto_annot_verify/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer1.1.bn3.num_batches_tracked\n", + "caution: missing keys from checkpoint ../experiments/semiauto_annot_verify/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer3.0.bn2.num_batches_tracked\n", + "caution: missing keys from checkpoint ../experiments/semiauto_annot_verify/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer3.3.bn1.num_batches_tracked\n", + "caution: missing keys from checkpoint ../experiments/semiauto_annot_verify/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer4.1.bn3.num_batches_tracked\n", + "caution: missing keys from checkpoint ../experiments/semiauto_annot_verify/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.flow_decoder.decoder2.8.num_batches_tracked\n", + "caution: missing keys from checkpoint ../experiments/semiauto_annot_verify/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer3.0.bn3.num_batches_tracked\n", + "caution: missing keys from checkpoint ../experiments/semiauto_annot_verify/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer3.1.bn2.num_batches_tracked\n", + "caution: missing keys from checkpoint ../experiments/semiauto_annot_verify/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer3.1.bn3.num_batches_tracked\n", + "caution: missing keys from checkpoint ../experiments/semiauto_annot_verify/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer4.0.bn3.num_batches_tracked\n", + "caution: missing keys from checkpoint ../experiments/semiauto_annot_verify/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer2.1.bn1.num_batches_tracked\n", + "caution: missing keys from checkpoint ../experiments/semiauto_annot_verify/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer3.3.bn2.num_batches_tracked\n", + "caution: missing keys from checkpoint ../experiments/semiauto_annot_verify/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.flow_decoder.decoder8.5.num_batches_tracked\n", + "caution: missing keys from checkpoint ../experiments/semiauto_annot_verify/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer2.0.bn1.num_batches_tracked\n", + "caution: missing keys from checkpoint ../experiments/semiauto_annot_verify/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer3.0.bn1.num_batches_tracked\n", + "caution: missing keys from checkpoint ../experiments/semiauto_annot_verify/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer2.3.bn2.num_batches_tracked\n", + "caution: missing keys from checkpoint ../experiments/semiauto_annot_verify/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer1.0.downsample.1.num_batches_tracked\n", + "caution: missing keys from checkpoint ../experiments/semiauto_annot_verify/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer1.2.bn3.num_batches_tracked\n", + "caution: missing keys from checkpoint ../experiments/semiauto_annot_verify/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer2.2.bn3.num_batches_tracked\n", + "caution: missing keys from checkpoint ../experiments/semiauto_annot_verify/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer1.0.bn1.num_batches_tracked\n", + "caution: missing keys from checkpoint ../experiments/semiauto_annot_verify/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer3.5.bn2.num_batches_tracked\n", + "caution: missing keys from checkpoint ../experiments/semiauto_annot_verify/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer3.2.bn1.num_batches_tracked\n", + "caution: missing keys from checkpoint ../experiments/semiauto_annot_verify/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer3.0.downsample.1.num_batches_tracked\n", + "caution: missing keys from checkpoint ../experiments/semiauto_annot_verify/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.flow_decoder.fusion4.1.num_batches_tracked\n", + "caution: missing keys from checkpoint ../experiments/semiauto_annot_verify/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.flow_decoder.decoder8.8.num_batches_tracked\n", + "caution: missing keys from checkpoint ../experiments/semiauto_annot_verify/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer2.0.bn2.num_batches_tracked\n", + "caution: missing keys from checkpoint ../experiments/semiauto_annot_verify/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.flow_decoder.skipconv2.1.num_batches_tracked\n", + "caution: missing keys from checkpoint ../experiments/semiauto_annot_verify/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer2.0.bn3.num_batches_tracked\n", + "caution: missing keys from checkpoint ../experiments/semiauto_annot_verify/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer2.1.bn3.num_batches_tracked\n", + "caution: missing keys from checkpoint ../experiments/semiauto_annot_verify/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer1.2.bn2.num_batches_tracked\n", + "caution: missing keys from checkpoint ../experiments/semiauto_annot_verify/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer3.4.bn2.num_batches_tracked\n", + "caution: missing keys from checkpoint ../experiments/semiauto_annot_verify/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.bn1.num_batches_tracked\n", + "caution: missing keys from checkpoint ../experiments/semiauto_annot_verify/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer3.4.bn3.num_batches_tracked\n", + "caution: missing keys from checkpoint ../experiments/semiauto_annot_verify/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer2.2.bn2.num_batches_tracked\n", + "caution: missing keys from checkpoint ../experiments/semiauto_annot_verify/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer1.2.bn1.num_batches_tracked\n", + "caution: missing keys from checkpoint ../experiments/semiauto_annot_verify/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer4.1.bn2.num_batches_tracked\n", + "caution: missing keys from checkpoint ../experiments/semiauto_annot_verify/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer4.2.bn1.num_batches_tracked\n", + "caution: missing keys from checkpoint ../experiments/semiauto_annot_verify/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.flow_decoder.decoder1.4.num_batches_tracked\n", + "caution: missing keys from checkpoint ../experiments/semiauto_annot_verify/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer3.4.bn1.num_batches_tracked\n", + "caution: missing keys from checkpoint ../experiments/semiauto_annot_verify/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.flow_decoder.fusion2.1.num_batches_tracked\n", + "caution: missing keys from checkpoint ../experiments/semiauto_annot_verify/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer3.5.bn1.num_batches_tracked\n", + "caution: missing keys from checkpoint ../experiments/semiauto_annot_verify/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.flow_decoder.decoder2.2.num_batches_tracked\n", + "caution: missing keys from checkpoint ../experiments/semiauto_annot_verify/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.flow_decoder.skipconv4.1.num_batches_tracked\n", + "caution: missing keys from checkpoint ../experiments/semiauto_annot_verify/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer2.2.bn1.num_batches_tracked\n", + "caution: missing keys from checkpoint ../experiments/semiauto_annot_verify/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer3.2.bn2.num_batches_tracked\n", + "caution: missing keys from checkpoint ../experiments/semiauto_annot_verify/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer1.0.bn2.num_batches_tracked\n", + "caution: missing keys from checkpoint ../experiments/semiauto_annot_verify/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer4.1.bn1.num_batches_tracked\n", + "caution: missing keys from checkpoint ../experiments/semiauto_annot_verify/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.flow_decoder.decoder4.5.num_batches_tracked\n", + "caution: missing keys from checkpoint ../experiments/semiauto_annot_verify/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer2.3.bn3.num_batches_tracked\n", + "caution: missing keys from checkpoint ../experiments/semiauto_annot_verify/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer4.0.bn1.num_batches_tracked\n", + "caution: missing keys from checkpoint ../experiments/semiauto_annot_verify/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer1.1.bn1.num_batches_tracked\n", + "caution: missing keys from checkpoint ../experiments/semiauto_annot_verify/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.flow_decoder.decoder1.7.num_batches_tracked\n", + "caution: missing keys from checkpoint ../experiments/semiauto_annot_verify/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer4.2.bn3.num_batches_tracked\n", + "caution: missing keys from checkpoint ../experiments/semiauto_annot_verify/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer2.1.bn2.num_batches_tracked\n", + "caution: missing keys from checkpoint ../experiments/semiauto_annot_verify/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.flow_decoder.decoder8.2.num_batches_tracked\n", + "caution: missing keys from checkpoint ../experiments/semiauto_annot_verify/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer3.1.bn1.num_batches_tracked\n", + "caution: missing keys from checkpoint ../experiments/semiauto_annot_verify/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer2.3.bn1.num_batches_tracked\n", + "caution: missing keys from checkpoint ../experiments/semiauto_annot_verify/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.flow_decoder.decoder4.2.num_batches_tracked\n", + "caution: missing keys from checkpoint ../experiments/semiauto_annot_verify/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer3.5.bn3.num_batches_tracked\n", + "caution: missing keys from checkpoint ../experiments/semiauto_annot_verify/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.flow_encoder.features.5.num_batches_tracked\n", + "caution: missing keys from checkpoint ../experiments/semiauto_annot_verify/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer3.2.bn3.num_batches_tracked\n", + "caution: missing keys from checkpoint ../experiments/semiauto_annot_verify/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer4.0.bn2.num_batches_tracked\n", + "caution: missing keys from checkpoint ../experiments/semiauto_annot_verify/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer4.0.downsample.1.num_batches_tracked\n", + "caution: missing keys from checkpoint ../experiments/semiauto_annot_verify/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer2.0.downsample.1.num_batches_tracked\n", + "caution: missing keys from checkpoint ../experiments/semiauto_annot_verify/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.flow_decoder.decoder4.8.num_batches_tracked\n", + "caution: missing keys from checkpoint ../experiments/semiauto_annot_verify/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.flow_decoder.decoder2.5.num_batches_tracked\n", + "caution: missing keys from checkpoint ../experiments/semiauto_annot_verify/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer4.2.bn2.num_batches_tracked\n", + "caution: missing keys from checkpoint ../experiments/semiauto_annot_verify/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer1.0.bn3.num_batches_tracked\n", + "caution: missing keys from checkpoint ../experiments/semiauto_annot_verify/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.flow_encoder.features.1.num_batches_tracked\n", + "caution: missing keys from checkpoint ../experiments/semiauto_annot_verify/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.flow_decoder.decoder1.1.num_batches_tracked\n", + "caution: missing keys from checkpoint ../experiments/semiauto_annot_verify/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.image_encoder.layer3.3.bn3.num_batches_tracked\n", + "caution: missing keys from checkpoint ../experiments/semiauto_annot_verify/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar: module.flow_decoder.fusion8.1.num_batches_tracked\n" ] } ], @@ -286,6 +286,7 @@ " self.mode_status.set_text(\"mode: add\")\n", "\n", " img = self.oriimg\n", + " self.img = img\n", " self.ax.imshow(img)\n", " self.ax.axis('off')\n", " if not self.large:\n", @@ -492,7 +493,7 @@ }, { "cell_type": "code", - "execution_count": 52, + "execution_count": 68, "metadata": { "scrolled": false }, @@ -1280,7 +1281,7 @@ { "data": { "text/html": [ - "" + "" ], "text/plain": [ "" @@ -1309,7 +1310,7 @@ "# silent: show information or not\n", "\n", "\n", - "test_image_idx = 4 # [0,1000)\n", + "test_image_idx = 12 # [0,1000)\n", "target = 'MPII'\n", "\n", "with open('/home/xhzhan/Share/transfer/{}/list.txt'.format(target), 'r') as f:\n", diff --git a/main.py b/main.py index 9cd4fe7..8c6cd95 100644 --- a/main.py +++ b/main.py @@ -1,7 +1,7 @@ +import multiprocessing as mp import argparse import os import yaml -import multiprocessing as mp from utils import dist_init from trainer import Trainer diff --git a/models/cmp.py b/models/cmp.py index 48b7965..3bc38fc 100644 --- a/models/cmp.py +++ b/models/cmp.py @@ -7,6 +7,7 @@ from . import SingleStageModel class CMP(SingleStageModel): + def __init__(self, params, dist_model=False): super(CMP, self).__init__(params, dist_model) model_params = params['module'] @@ -17,7 +18,8 @@ def __init__(self, params, dist_model=False): elif model_params['flow_criterion'] == 'L2': self.flow_criterion = nn.MSELoss() elif model_params['flow_criterion'] == 'DiscreteLoss': - self.flow_criterion = losses.DiscreteLoss(nbins=model_params['nbins'], fmax=model_params['fmax']) + self.flow_criterion = losses.DiscreteLoss( + nbins=model_params['nbins'], fmax=model_params['fmax']) else: raise Exception("No such flow loss: {}".format(model_params['flow_criterion'])) @@ -33,7 +35,9 @@ def eval(self, ret_loss=True): else: self.flow = cmp_output if self.flow.shape[2] != self.image_input.shape[2]: - self.flow = nn.functional.interpolate(self.flow, size=self.image_input.shape[2:4], mode="bilinear", align_corners=True) + self.flow = nn.functional.interpolate( + self.flow, size=self.image_input.shape[2:4], + mode="bilinear", align_corners=True) ret_tensors = { 'flow_tensors': [self.flow, self.flow_target], @@ -42,7 +46,9 @@ def eval(self, ret_loss=True): if ret_loss: if cmp_output.shape[2] != self.flow_target.shape[2]: - cmp_output = nn.functional.interpolate(cmp_output, size=self.flow_target.shape[2:4], mode="bilinear", align_corners=True) + cmp_output = nn.functional.interpolate( + cmp_output, size=self.flow_target.shape[2:4], + mode="bilinear", align_corners=True) loss_flow = self.flow_criterion(cmp_output, self.flow_target) / self.world_size return ret_tensors, {'loss_flow': loss_flow} else: diff --git a/models/modules/cmp.py b/models/modules/cmp.py index 76ba004..d155c1b 100644 --- a/models/modules/cmp.py +++ b/models/modules/cmp.py @@ -3,6 +3,7 @@ import models class CMP(nn.Module): + def __init__(self, params): super(CMP, self).__init__() img_enc_dim = params['img_enc_dim'] @@ -14,9 +15,13 @@ def __init__(self, params): if self.skip_layer: assert params['flow_decoder'] == "MotionDecoderSkipLayer" - self.image_encoder = models.backbone.__dict__[params['image_encoder']](img_enc_dim, pretrained) - self.flow_encoder = models.modules.__dict__[params['sparse_encoder']](sparse_enc_dim) - self.flow_decoder = models.modules.__dict__[params['flow_decoder']](input_dim=img_enc_dim+sparse_enc_dim, output_dim=output_dim, combo=decoder_combo) + self.image_encoder = models.backbone.__dict__[params['image_encoder']]( + img_enc_dim, pretrained) + self.flow_encoder = models.modules.__dict__[params['sparse_encoder']]( + sparse_enc_dim) + self.flow_decoder = models.modules.__dict__[params['flow_decoder']]( + input_dim=img_enc_dim+sparse_enc_dim, + output_dim=output_dim, combo=decoder_combo) def forward(self, image, sparse): sparse_enc = self.flow_encoder(sparse) diff --git a/models/modules/decoder.py b/models/modules/decoder.py index 19244ea..8f1c0e3 100644 --- a/models/modules/decoder.py +++ b/models/modules/decoder.py @@ -3,6 +3,7 @@ import math class MotionDecoderPlain(nn.Module): + def __init__(self, input_dim=512, output_dim=2, combo=[1,2,4]): super(MotionDecoderPlain, self).__init__() BN = nn.BatchNorm2d @@ -72,20 +73,28 @@ def forward(self, x): x1 = self.decoder1(x) cat_list.append(x1) if 2 in self.combo: - x2 = nn.functional.interpolate(self.decoder2(x), size=(x.size(2), x.size(3)), mode="bilinear", align_corners=True) + x2 = nn.functional.interpolate( + self.decoder2(x), size=(x.size(2), x.size(3)), + mode="bilinear", align_corners=True) cat_list.append(x2) if 4 in self.combo: - x4 = nn.functional.interpolate(self.decoder4(x), size=(x.size(2), x.size(3)), mode="bilinear", align_corners=True) + x4 = nn.functional.interpolate( + self.decoder4(x), size=(x.size(2), x.size(3)), + mode="bilinear", align_corners=True) cat_list.append(x4) if 8 in self.combo: - x8 = nn.functional.interpolate(self.decoder8(x), size=(x.size(2), x.size(3)), mode="bilinear", align_corners=True) + x8 = nn.functional.interpolate( + self.decoder8(x), size=(x.size(2), x.size(3)), + mode="bilinear", align_corners=True) cat_list.append(x8) cat = torch.cat(cat_list, dim=1) flow = self.head(cat) return flow + class MotionDecoderSkipLayer(nn.Module): + def __init__(self, input_dim=512, output_dim=2, combo=[1,2,4,8]): super(MotionDecoderSkipLayer, self).__init__() @@ -180,22 +189,34 @@ def forward(self, x, skip_feat): layer1, layer2, layer4 = skip_feat x1 = self.decoder1(x) - x2 = nn.functional.interpolate(self.decoder2(x), size=(x1.size(2), x1.size(3)), mode="bilinear", align_corners=True) - x4 = nn.functional.interpolate(self.decoder4(x), size=(x1.size(2), x1.size(3)), mode="bilinear", align_corners=True) - x8 = nn.functional.interpolate(self.decoder8(x), size=(x1.size(2), x1.size(3)), mode="bilinear", align_corners=True) + x2 = nn.functional.interpolate( + self.decoder2(x), size=(x1.size(2), x1.size(3)), + mode="bilinear", align_corners=True) + x4 = nn.functional.interpolate( + self.decoder4(x), size=(x1.size(2), x1.size(3)), + mode="bilinear", align_corners=True) + x8 = nn.functional.interpolate( + self.decoder8(x), size=(x1.size(2), x1.size(3)), + mode="bilinear", align_corners=True) cat = torch.cat([x1, x2, x4, x8], dim=1) f8 = self.fusion8(cat) - f8_up = nn.functional.interpolate(f8, size=(layer4.size(2), layer4.size(3)), mode="bilinear", align_corners=True) + f8_up = nn.functional.interpolate( + f8, size=(layer4.size(2), layer4.size(3)), + mode="bilinear", align_corners=True) f4 = self.fusion4(torch.cat([f8_up, self.skipconv4(layer4)], dim=1)) - f4_up = nn.functional.interpolate(f4, size=(layer2.size(2), layer2.size(3)), mode="bilinear", align_corners=True) + f4_up = nn.functional.interpolate( + f4, size=(layer2.size(2), layer2.size(3)), + mode="bilinear", align_corners=True) f2 = self.fusion2(torch.cat([f4_up, self.skipconv2(layer2)], dim=1)) flow = self.head(f2) return flow + class MotionDecoderFlowNet(nn.Module): + def __init__(self, input_dim=512, output_dim=2, combo=[1,2,4,8]): super(MotionDecoderFlowNet, self).__init__() global BN @@ -260,9 +281,12 @@ def __init__(self, input_dim=512, output_dim=2, combo=[1,2,4,8]): self.predict_flow2 = predict_flow(192 + output_dim, output_dim) self.predict_flow1 = predict_flow(67 + output_dim, output_dim) - self.upsampled_flow8_to_4 = nn.ConvTranspose2d(output_dim, output_dim, 4, 2, 1, bias=False) - self.upsampled_flow4_to_2 = nn.ConvTranspose2d(output_dim, output_dim, 4, 2, 1, bias=False) - self.upsampled_flow2_to_1 = nn.ConvTranspose2d(output_dim, output_dim, 4, 2, 1, bias=False) + self.upsampled_flow8_to_4 = nn.ConvTranspose2d( + output_dim, output_dim, 4, 2, 1, bias=False) + self.upsampled_flow4_to_2 = nn.ConvTranspose2d( + output_dim, output_dim, 4, 2, 1, bias=False) + self.upsampled_flow2_to_1 = nn.ConvTranspose2d( + output_dim, output_dim, 4, 2, 1, bias=False) self.deconv8 = deconv(256, 128) self.deconv4 = deconv(384 + output_dim, 128) @@ -286,9 +310,15 @@ def forward(self, x, skip_feat): # propagation nets x1 = self.decoder1(x) - x2 = nn.functional.interpolate(self.decoder2(x), size=(x1.size(2), x1.size(3)), mode="bilinear", align_corners=True) - x4 = nn.functional.interpolate(self.decoder4(x), size=(x1.size(2), x1.size(3)), mode="bilinear", align_corners=True) - x8 = nn.functional.interpolate(self.decoder8(x), size=(x1.size(2), x1.size(3)), mode="bilinear", align_corners=True) + x2 = nn.functional.interpolate( + self.decoder2(x), size=(x1.size(2), x1.size(3)), + mode="bilinear", align_corners=True) + x4 = nn.functional.interpolate( + self.decoder4(x), size=(x1.size(2), x1.size(3)), + mode="bilinear", align_corners=True) + x8 = nn.functional.interpolate( + self.decoder8(x), size=(x1.size(2), x1.size(3)), + mode="bilinear", align_corners=True) cat = torch.cat([x1, x2, x4, x8], dim=1) feat8 = self.fusion8(cat) # 256 @@ -312,13 +342,17 @@ def forward(self, x, skip_feat): return [flow1, flow2, flow4, flow8] + def predict_flow(in_planes, out_planes): - return nn.Conv2d(in_planes, out_planes, kernel_size=3,stride=1,padding=1,bias=True) + return nn.Conv2d(in_planes, out_planes, kernel_size=3, + stride=1, padding=1, bias=True) + def deconv(in_planes, out_planes): return nn.Sequential( - nn.ConvTranspose2d(in_planes, out_planes, kernel_size=4, stride=2, padding=1, bias=True), - nn.LeakyReLU(0.1,inplace=True) + nn.ConvTranspose2d(in_planes, out_planes, kernel_size=4, + stride=2, padding=1, bias=True), + nn.LeakyReLU(0.1, inplace=True) ) diff --git a/models/modules/others.py b/models/modules/others.py index 1289e81..591ce94 100644 --- a/models/modules/others.py +++ b/models/modules/others.py @@ -1,9 +1,11 @@ import torch.nn as nn class FixModule(nn.Module): + def __init__(self, m): super(FixModule, self).__init__() self.module = m + def forward(self, *args, **kwargs): return self.module(*args, **kwargs) diff --git a/models/modules/warp.py b/models/modules/warp.py index 7a2e851..d32dc5d 100644 --- a/models/modules/warp.py +++ b/models/modules/warp.py @@ -1,9 +1,8 @@ import torch import torch.nn as nn -import models - class WarpingLayerBWFlow(nn.Module): + def __init__(self): super(WarpingLayerBWFlow, self).__init__() @@ -12,24 +11,34 @@ def forward(self, image, flow): flow_for_grip[:,0,:,:] = flow[:,0,:,:] / ((flow.size(3) - 1.0) / 2.0) flow_for_grip[:,1,:,:] = flow[:,1,:,:] / ((flow.size(2) - 1.0) / 2.0) - torchHorizontal = torch.linspace(-1.0, 1.0, image.size(3)).view(1, 1, 1, image.size(3)).expand(image.size(0), 1, image.size(2), image.size(3)) - torchVertical = torch.linspace(-1.0, 1.0, image.size(2)).view(1, 1, image.size(2), 1).expand(image.size(0), 1, image.size(2), image.size(3)) + torchHorizontal = torch.linspace( + -1.0, 1.0, image.size(3)).view( + 1, 1, 1, image.size(3)).expand( + image.size(0), 1, image.size(2), image.size(3)) + torchVertical = torch.linspace( + -1.0, 1.0, image.size(2)).view( + 1, 1, image.size(2), 1).expand( + image.size(0), 1, image.size(2), image.size(3)) grid = torch.cat([torchHorizontal, torchVertical], 1).cuda() grid = (grid + flow_for_grip).permute(0, 2, 3, 1) return torch.nn.functional.grid_sample(image, grid) + class WarpingLayerFWFlow(nn.Module): + def __init__(self): super(WarpingLayerFWFlow, self).__init__() self.initialized = False - + def forward(self, image, flow, ret_mask = False): n, h, w = image.size(0), image.size(2), image.size(3) if not self.initialized or n != self.meshx.shape[0] or h * w != self.meshx.shape[1]: - self.meshx = torch.arange(w).view(1, 1, w).expand(n, h, w).contiguous().view(n, -1).cuda() - self.meshy = torch.arange(h).view(1, h, 1).expand(n, h, w).contiguous().view(n, -1).cuda() + self.meshx = torch.arange(w).view(1, 1, w).expand( + n, h, w).contiguous().view(n, -1).cuda() + self.meshy = torch.arange(h).view(1, h, 1).expand( + n, h, w).contiguous().view(n, -1).cuda() self.warped_image = torch.zeros((n, 3, h, w), dtype=torch.float32).cuda() if ret_mask: self.hole_mask = torch.ones((n, 1, h, w), dtype=torch.float32).cuda() diff --git a/models/single_stage_model.py b/models/single_stage_model.py index 115dfe4..96d963a 100644 --- a/models/single_stage_model.py +++ b/models/single_stage_model.py @@ -7,6 +7,7 @@ import utils class SingleStageModel(object): + def __init__(self, params, dist_model=False): model_params = params['module'] self.model = models.modules.__dict__[params['module']['arch']](model_params) @@ -20,11 +21,13 @@ def __init__(self, params, dist_model=False): self.world_size = 1 if params['optim'] == 'SGD': - self.optim = torch.optim.SGD(self.model.parameters(), lr=params['lr'], + self.optim = torch.optim.SGD( + self.model.parameters(), lr=params['lr'], momentum=0.9, weight_decay=0.0001) elif params['optim'] == 'Adam': - self.optim = torch.optim.Adam(self.model.parameters(), - lr=params['lr'], betas=(params['beta1'], 0.999)) + self.optim = torch.optim.Adam( + self.model.parameters(), lr=params['lr'], + betas=(params['beta1'], 0.999)) else: raise Exception("No such optimizer: {}".format(params['optim'])) @@ -57,9 +60,9 @@ def save_state(self, path, Iter): path = os.path.join(path, "ckpt_iter_{}.pth.tar".format(Iter)) torch.save({ - 'step': Iter, - 'state_dict': self.model.state_dict(), - 'optimizer': self.optim.state_dict()}, path) + 'step': Iter, + 'state_dict': self.model.state_dict(), + 'optimizer': self.optim.state_dict()}, path) def switch_to(self, phase): if phase == 'train': diff --git a/trainer.py b/trainer.py index 693e985..a13e2b1 100644 --- a/trainer.py +++ b/trainer.py @@ -34,16 +34,24 @@ def __init__(self, args): try: from tensorboardX import SummaryWriter except: - raise Exception("Please switch off \"tensorboard\" in your config file if you do not want to use it, otherwise install it.") + raise Exception("Please switch off \"tensorboard\" " + "in your config file if you do not " + "want to use it, otherwise install it.") self.tb_logger = SummaryWriter('{}/events'.format(args.exp_path)) else: self.tb_logger = None if args.validate: - self.logger = utils.create_logger('global_logger', '{}/logs/log_offline_val.txt'.format(args.exp_path)) + self.logger = utils.create_logger( + 'global_logger', + '{}/logs/log_offline_val.txt'.format(args.exp_path)) elif args.extract: - self.logger = utils.create_logger('global_logger', '{}/logs/log_extract.txt'.format(args.exp_path)) + self.logger = utils.create_logger( + 'global_logger', + '{}/logs/log_extract.txt'.format(args.exp_path)) else: - self.logger = utils.create_logger('global_logger', '{}/logs/log_train.txt'.format(args.exp_path)) + self.logger = utils.create_logger( + 'global_logger', + '{}/logs/log_train.txt'.format(args.exp_path)) # create model self.model = models.__dict__[args.model['arch']](args.model, dist_model=True) @@ -51,7 +59,8 @@ def __init__(self, args): # optionally resume from a checkpoint assert not (args.load_iter is not None and args.load_path is not None) if args.load_iter is not None: - self.model.load_state("{}/checkpoints".format(args.exp_path), args.load_iter, args.resume) + self.model.load_state("{}/checkpoints".format(args.exp_path), + args.load_iter, args.resume) self.start_iter = args.load_iter else: self.start_iter = 0 @@ -61,7 +70,8 @@ def __init__(self, args): # lr scheduler if not (args.validate or args.extract): # train - self.lr_scheduler = utils.StepLRScheduler(self.model.optim, args.model['lr_steps'], + self.lr_scheduler = utils.StepLRScheduler( + self.model.optim, args.model['lr_steps'], args.model['lr_mults'], args.model['lr'], args.model['warmup_lr'], args.model['warmup_steps'], last_iter=self.start_iter-1) @@ -78,8 +88,8 @@ def __init__(self, args): if not (args.validate or args.extract): # train train_dataset = imageflow_dataset(args.data['train_source'], args.data, 'train') train_sampler = utils.DistributedGivenIterationSampler( - train_dataset, args.model['total_iter'], - args.data['batch_size'], last_iter=self.start_iter-1) + train_dataset, args.model['total_iter'], + args.data['batch_size'], last_iter=self.start_iter-1) self.train_loader = DataLoader( train_dataset, batch_size=args.data['batch_size'], shuffle=False, num_workers=args.data['workers'], pin_memory=False, sampler=train_sampler) @@ -147,7 +157,8 @@ def train(self): flow_target = flow_target.cuda() rgb_target = rgb_target.cuda() - self.model.set_input(image, torch.cat([sparse, mask], dim=1), flow_target, rgb_target) + self.model.set_input(image, torch.cat([sparse, mask], dim=1), + flow_target, rgb_target) loss_dict = self.model.step() for k in loss_dict.keys(): recorder[k].update(utils.reduce_tensors(loss_dict[k]).item()) @@ -163,23 +174,32 @@ def train(self): self.tb_logger.add_scalar('lr', curr_lr, self.curr_step) for k in recorder.keys(): if self.tb_logger is not None: - self.tb_logger.add_scalar('train_{}'.format(k), recorder[k].avg, self.curr_step + 1) - loss_str += '{}: {loss.val:.4g} ({loss.avg:.4g})\t'.format(k, loss=recorder[k]) + self.tb_logger.add_scalar('train_{}'.format(k), recorder[k].avg, + self.curr_step + 1) + loss_str += '{}: {loss.val:.4g} ({loss.avg:.4g})\t'.format( + k, loss=recorder[k]) - self.logger.info('Iter: [{0}/{1}]\t'.format(self.curr_step, len(self.train_loader)) + - 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'.format(batch_time=btime_rec) + - 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'.format(data_time=dtime_rec) + + self.logger.info( + 'Iter: [{0}/{1}]\t'.format(self.curr_step, len(self.train_loader)) + + 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'.format( + batch_time=btime_rec) + + 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'.format( + data_time=dtime_rec) + loss_str + 'NPts {num_pts.val} ({num_pts.avg:.1f})\t'.format(num_pts=npts_rec) + 'lr {lr:.2g}'.format(lr=curr_lr)) # validate - if (self.curr_step + 1) % self.args.trainer['val_freq'] == 0 or (self.curr_step + 1) == self.args.model['total_iter']: + if ((self.curr_step + 1) % self.args.trainer['val_freq'] == 0 or + (self.curr_step + 1) == self.args.model['total_iter']): self.validate('on_val') # save - if self.rank == 0 and ((self.curr_step + 1) % self.args.trainer['save_freq'] == 0 or (self.curr_step + 1) == self.args.model['total_iter']): - self.model.save_state("{}/checkpoints".format(self.args.exp_path), self.curr_step + 1) + if (self.rank == 0 and + ((self.curr_step + 1) % self.args.trainer['save_freq'] == 0 or + (self.curr_step + 1) == self.args.model['total_iter'])): + self.model.save_state("{}/checkpoints".format(self.args.exp_path), + self.curr_step + 1) def validate(self, phase): @@ -195,7 +215,9 @@ def validate(self, phase): end = time.time() all_together = [] for i, (image, sparse, mask, flow_target, rgb_target) in enumerate(self.val_loader): - if 'val_iter' in self.args.trainer and self.args.trainer['val_iter'] != -1 and i == self.args.trainer['val_iter']: + if ('val_iter' in self.args.trainer and + self.args.trainer['val_iter'] != -1 and + i == self.args.trainer['val_iter']): break assert image.shape[0] > 0 @@ -209,7 +231,8 @@ def validate(self, phase): flow_target = flow_target.cuda() rgb_target = rgb_target.cuda() - self.model.set_input(image, torch.cat([sparse, mask], dim=1), flow_target, rgb_target) + self.model.set_input(image, torch.cat([sparse, mask], dim=1), + flow_target, rgb_target) tensor_dict, loss_dict = self.model.eval() for k in loss_dict.keys(): recorder[k].update(utils.reduce_tensors(loss_dict[k]).item()) @@ -218,11 +241,19 @@ def validate(self, phase): # tb visualize if self.rank == 0: - if i >= self.args.trainer['val_disp_start_iter'] and i < self.args.trainer['val_disp_end_iter']: - all_together.append(utils.visualize_tensor(image, mask, tensor_dict['flow_tensors'], tensor_dict['common_tensors'], tensor_dict['rgb_tensors'], self.args.data['data_mean'], self.args.data['data_div'])) - if i == self.args.trainer['val_disp_end_iter'] and self.args.trainer['val_disp_end_iter'] > self.args.trainer['val_disp_start_iter']: + if (i >= self.args.trainer['val_disp_start_iter'] and + i < self.args.trainer['val_disp_end_iter']): + all_together.append(utils.visualize_tensor( + image, mask, tensor_dict['flow_tensors'], + tensor_dict['common_tensors'], tensor_dict['rgb_tensors'], + self.args.data['data_mean'], self.args.data['data_div'])) + if (i == self.args.trainer['val_disp_end_iter'] and + self.args.trainer['val_disp_end_iter'] > + self.args.trainer['val_disp_start_iter']): all_together = torch.cat(all_together, dim=2) - grid = vutils.make_grid(all_together, nrow=1, normalize=True, range=(0, 255), scale_each=False) + grid = vutils.make_grid( + all_together, nrow=1, normalize=True, + range=(0, 255), scale_each=False) if self.tb_logger is not None: self.tb_logger.add_image('Image_' + phase, grid, self.curr_step + 1) @@ -231,14 +262,19 @@ def validate(self, phase): loss_str = "" for k in recorder.keys(): if self.tb_logger is not None: - self.tb_logger.add_scalar('val_{}'.format(k), recorder[k].avg, self.curr_step + 1) - loss_str += '{}: {loss.val:.4g} ({loss.avg:.4g})\t'.format(k, loss=recorder[k]) - - self.logger.info('Validation Iter: [{0}]\t'.format(self.curr_step) + - 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'.format(batch_time=btime_rec) + - 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'.format(data_time=dtime_rec) + - loss_str + - 'NPts {num_pts.val} ({num_pts.avg:.1f})\t'.format(num_pts=npts_rec)) + self.tb_logger.add_scalar('val_{}'.format(k), + recorder[k].avg, self.curr_step + 1) + loss_str += '{}: {loss.val:.4g} ({loss.avg:.4g})\t'.format( + k, loss=recorder[k]) + + self.logger.info( + 'Validation Iter: [{0}]\t'.format(self.curr_step) + + 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'.format( + batch_time=btime_rec) + + 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'.format( + data_time=dtime_rec) + + loss_str + + 'NPts {num_pts.val} ({num_pts.avg:.1f})\t'.format(num_pts=npts_rec)) self.model.switch_to("train")