-
Notifications
You must be signed in to change notification settings - Fork 19
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
9a77e5f
commit 53e81c4
Showing
9 changed files
with
717 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,23 @@ | ||
# Learning-to-See-Moving-Objects-in-the-Dark | ||
# Learning-to-See-Moving-Objects-in-the-Dark | ||
|
||
## Demo Video | ||
https://youtu.be/GZu30-a8N0M | ||
|
||
## Demo Version | ||
|
||
### Download trained model | ||
```Shell | ||
python download_model.py | ||
``` | ||
|
||
### Download example data | ||
```Shell | ||
python download_dataset.py | ||
``` | ||
|
||
### Run demo | ||
```Shell | ||
python test.py | ||
``` | ||
|
||
### More details and complete dataset will be updated soon. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
#!/usr/bin/env python | ||
|
||
# ---------------------------------------------------------------- | ||
# Configurations for Training and Testing Process | ||
# Written by Haiyang Jiang | ||
# Mar 1st 2019 | ||
# ---------------------------------------------------------------- | ||
|
||
# file lists ================================================================ | ||
FILE_LIST = 'file_list' | ||
VALID_LIST = 'valid_list' | ||
TEST_LIST = 'test_list' | ||
CUSOMIZED_LIST = 'customized_list' | ||
|
||
# network.py ================================================================ | ||
DEBUG = False | ||
|
||
|
||
# train.py ================================================================ | ||
EXP_NAME = '16_bit_HE_to_HE_gt' | ||
CHECKPOINT_DIR = './1_checkpoint/' + EXP_NAME + '/' | ||
RESULT_DIR = './2_result/' + EXP_NAME + '/' | ||
LOGS_DIR = RESULT_DIR | ||
TRAIN_LOG_DIR = 'train' | ||
VAL_LOG_DIR = 'val' | ||
# training settings | ||
ALL_FRAME = 200 | ||
SAVE_FRAMES = list(range(0, ALL_FRAME, 32)) | ||
CROP_FRAME = 16 | ||
CROP_HEIGHT = 256 | ||
CROP_WIDTH = 256 | ||
|
||
SAVE_FREQ = 5 | ||
MAX_EPOCH = 50 | ||
|
||
FRAME_FREQ = 1 | ||
GROUP_NUM = 4 | ||
|
||
INIT_LR = 1e-4 | ||
DECAY_LR = 1e-5 | ||
DECAY_EPOCH = 30 | ||
|
||
# test.py ================================================================ | ||
TEST_CROP_FRAME = 32 | ||
TEST_CROP_HEIGHT = 512 | ||
TEST_CROP_WIDTH= 512 | ||
|
||
MAX_FRAME = 800 | ||
|
||
OVERLAP = 0.01 | ||
OUT_MAX = 255.0 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
Cam1_gain10 ./0_data/Cam1/gain10.npy ./0_data/Cam1/gain10.npy | ||
Cam2_gain15 ./0_data/Cam2/gain15.npy ./0_data/Cam2/gain15.npy | ||
Outdoor ./0_data/Outdoor/gain10_15fps_2.npy ./0_data/Outdoor/gain10_15fps_2.npy |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
import requests | ||
import os | ||
|
||
def download_file_from_google_drive(id, destination): | ||
URL = "https://docs.google.com/uc?export=download" | ||
|
||
session = requests.Session() | ||
|
||
response = session.get(URL, params = { 'id' : id }, stream = True) | ||
token = get_confirm_token(response) | ||
|
||
if token: | ||
params = { 'id' : id, 'confirm' : token } | ||
response = session.get(URL, params = params, stream = True) | ||
|
||
save_response_content(response, destination) | ||
|
||
def get_confirm_token(response): | ||
for key, value in response.cookies.items(): | ||
if key.startswith('download_warning'): | ||
return value | ||
|
||
return None | ||
|
||
def save_response_content(response, destination): | ||
CHUNK_SIZE = 32768 | ||
|
||
with open(destination, "wb") as f: | ||
for chunk in response.iter_content(CHUNK_SIZE): | ||
if chunk: # filter out keep-alive new chunks | ||
f.write(chunk) | ||
|
||
|
||
if not os.path.isdir('0_data'): | ||
os.mkdir('0_data') | ||
|
||
print('Dowloading Camera 1 Example data... ') | ||
download_file_from_google_drive('1qES6teQUprs-cgL-nACzFighjm_Dho0L', '0_data/Cam1.zip') | ||
|
||
print('Dowloading Camera 2 Example data... ') | ||
download_file_from_google_drive('1wiz5XPricNj-sFwqfj7tVI1IomlyO6xM', '0_data/Cam2.zip') | ||
|
||
print('Dowloading Outdoor Example data... ') | ||
download_file_from_google_drive('1UFah1QWltNDzUlsQ4HCMnuiyAAUh-9EE', '0_data/Outdoor.zip') | ||
|
||
os.system('unzip 0_data/Cam1.zip -d 0_data') | ||
os.system('unzip 0_data/Cam2.zip -d 0_data') | ||
os.system('unzip 0_data/Outdoor.zip -d 0_data') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import os | ||
import requests | ||
|
||
from config import CHECKPOINT_DIR | ||
|
||
def download_file_from_google_drive(id, destination): | ||
URL = "https://docs.google.com/uc?export=download" | ||
|
||
session = requests.Session() | ||
|
||
response = session.get(URL, params = { 'id' : id }, stream = True) | ||
token = get_confirm_token(response) | ||
|
||
if token: | ||
params = { 'id' : id, 'confirm' : token } | ||
response = session.get(URL, params = params, stream = True) | ||
|
||
save_response_content(response, destination) | ||
|
||
def get_confirm_token(response): | ||
for key, value in response.cookies.items(): | ||
if key.startswith('download_warning'): | ||
return value | ||
|
||
return None | ||
|
||
def save_response_content(response, destination): | ||
CHUNK_SIZE = 32768 | ||
|
||
with open(destination, "wb") as f: | ||
for chunk in response.iter_content(CHUNK_SIZE): | ||
if chunk: # filter out keep-alive new chunks | ||
f.write(chunk) | ||
|
||
|
||
if not os.path.isdir(CHECKPOINT_DIR): | ||
os.makedirs(CHECKPOINT_DIR) | ||
|
||
print('Dowloading Trained Model (63Mb)...') | ||
download_file_from_google_drive('1yXeEh2zbP4NQ9ogOO-r7GO9pdrV7yXzr', CHECKPOINT_DIR + '/checkpoint') | ||
download_file_from_google_drive('1yl3mMkvXBZf19XoM38lmDyUuJkYt-Rgb', CHECKPOINT_DIR + '/model.ckpt.index') | ||
download_file_from_google_drive('1YQP0zzbkGH-EaqU3eMIX6MpWB7dD3l6l', CHECKPOINT_DIR + '/model.ckpt.meta') | ||
download_file_from_google_drive('1YbiBNm2iIRuSm4Jb3xSVJ5UrJspIE9cw', CHECKPOINT_DIR + '/model.ckpt.data-00000-of-00001') | ||
print('Done.') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
#!/usr/bin/env python | ||
|
||
# ---------------------------------------------------------------- | ||
# 3D-Convolution and 2D-Pooling UNet | ||
# Written by Haiyang Jiang | ||
# Mar 1st 2019 | ||
# ---------------------------------------------------------------- | ||
|
||
|
||
import tensorflow as tf | ||
import tensorflow.contrib.slim as slim | ||
|
||
from config import DEBUG | ||
|
||
|
||
# leaky ReLU | ||
def lrelu(x): | ||
return tf.maximum(x * 0.2, x) | ||
|
||
|
||
def upsample_and_concat(x1, x2, output_channels, in_channels): | ||
pool_size = 2 | ||
deconv_filter = tf.Variable(tf.truncated_normal([pool_size, pool_size, output_channels, in_channels], stddev=0.02)) | ||
deconv = tf.nn.conv2d_transpose(x1[0], deconv_filter, tf.shape(x2[0]), strides=[1, pool_size, pool_size, 1]) | ||
|
||
deconv_output = tf.concat([deconv, x2[0]], -1) | ||
deconv_output.set_shape([None, None, None, output_channels * 2]) | ||
|
||
return tf.expand_dims(deconv_output, axis=0) | ||
|
||
|
||
# 3D-Conv-2D-Pool UNet | ||
def network(input, depth=3, channel=32, prefix=''): | ||
depth = min(max(depth, 2), 4) | ||
|
||
conv1 = slim.conv3d(input, channel, [3, 3, 3], rate=1, activation_fn=lrelu, scope=prefix + 'g_conv1_1') | ||
conv1 = slim.conv3d(conv1, channel, [3, 3, 3], rate=1, activation_fn=lrelu, scope=prefix + 'g_conv1_2') | ||
pool1 = tf.expand_dims(slim.max_pool2d(conv1[0], [2, 2], padding='SAME'), axis=0) | ||
|
||
conv2 = slim.conv3d(pool1, channel * 2, [3, 3, 3], rate=1, activation_fn=lrelu, scope=prefix + 'g_conv2_1') | ||
conv2 = slim.conv3d(conv2, channel * 2, [3, 3, 3], rate=1, activation_fn=lrelu, scope=prefix + 'g_conv2_2') | ||
pool2 = tf.expand_dims(slim.max_pool2d(conv2[0], [2, 2], padding='SAME'), axis=0) | ||
|
||
conv3 = slim.conv3d(pool2, channel * 4, [3, 3, 3], rate=1, activation_fn=lrelu, scope=prefix + 'g_conv3_1') | ||
conv3 = slim.conv3d(conv3, channel * 4, [3, 3, 3], rate=1, activation_fn=lrelu, scope=prefix + 'g_conv3_2') | ||
if depth == 2: | ||
up8 = upsample_and_concat(conv3, conv2, channel * 2, channel * 4) | ||
else: | ||
pool3 = tf.expand_dims(slim.max_pool2d(conv3[0], [2, 2], padding='SAME'), axis=0) | ||
|
||
conv4 = slim.conv3d(pool3, channel * 8, [3, 3, 3], rate=1, activation_fn=lrelu, scope=prefix + 'g_conv4_1') | ||
conv4 = slim.conv3d(conv4, channel * 8, [3, 3, 3], rate=1, activation_fn=lrelu, scope=prefix + 'g_conv4_2') | ||
if depth == 3: | ||
up7 = upsample_and_concat(conv4, conv3, channel * 4, channel * 8) | ||
else: | ||
pool4 = tf.expand_dims(slim.max_pool2d(conv4[0], [2, 2], padding='SAME'), axis=0) | ||
|
||
conv5 = slim.conv3d(pool4, channel * 16, [3, 3, 3], rate=1, activation_fn=lrelu, scope=prefix + 'g_conv5_1') | ||
conv5 = slim.conv3d(conv5, channel * 16, [3, 3, 3], rate=1, activation_fn=lrelu, scope=prefix + 'g_conv5_2') | ||
|
||
up6 = upsample_and_concat(conv5, conv4, channel * 8, channel * 16) | ||
conv6 = slim.conv3d(up6, channel * 8, [3, 3, 3], rate=1, activation_fn=lrelu, scope=prefix + 'g_conv6_1') | ||
conv6 = slim.conv3d(conv6, channel * 8, [3, 3, 3], rate=1, activation_fn=lrelu, scope=prefix + 'g_conv6_2') | ||
|
||
up7 = upsample_and_concat(conv6, conv3, channel * 4, channel * 8) | ||
conv7 = slim.conv3d(up7, channel * 4, [3, 3, 3], rate=1, activation_fn=lrelu, scope=prefix + 'g_conv7_1') | ||
conv7 = slim.conv3d(conv7, channel * 4, [3, 3, 3], rate=1, activation_fn=lrelu, scope=prefix + 'g_conv7_2') | ||
|
||
up8 = upsample_and_concat(conv7, conv2, channel * 2, channel * 4) | ||
conv8 = slim.conv3d(up8, channel * 2, [3, 3, 3], rate=1, activation_fn=lrelu, scope=prefix + 'g_conv8_1') | ||
conv8 = slim.conv3d(conv8, channel * 2, [3, 3, 3], rate=1, activation_fn=lrelu, scope=prefix + 'g_conv8_2') | ||
|
||
up9 = upsample_and_concat(conv8, conv1, channel, channel * 2) | ||
conv9 = slim.conv3d(up9, channel, [3, 3, 3], rate=1, activation_fn=lrelu, scope=prefix + 'g_conv9_1') | ||
conv9 = slim.conv3d(conv9, channel, [3, 3, 3], rate=1, activation_fn=lrelu, scope=prefix + 'g_conv9_2') | ||
|
||
conv10 = slim.conv3d(conv9, 12, [1, 1, 1], rate=1, activation_fn=None, scope=prefix + 'g_conv10') | ||
|
||
out = tf.concat([tf.expand_dims(tf.depth_to_space(conv10[:, i, :, :, :], 2), axis=1) for i in range(conv10.shape[1])], axis=1) | ||
if DEBUG: | ||
print '[DEBUG] (network.py) conv10.shape, out.shape:', conv10.shape, out.shape | ||
|
||
return out | ||
|
||
|
||
# test function for network | ||
def main(): | ||
sess = tf.Session() | ||
in_image = tf.placeholder(tf.float32, [None, 16, None, None, 4]) | ||
gt_image = tf.placeholder(tf.float32, [None, 16, None, None, 3]) | ||
out_image = [] | ||
out_image += network(in_image), | ||
out_image += network(in_image, 3, prefix='1'), | ||
out_image += network(in_image, 2, prefix='2'), | ||
out_image += network(in_image, 3, 16, prefix='3'), | ||
from skvideo.io import vread, vwrite | ||
import numpy as np | ||
vid = np.load('./0_data/raw/test_data/gt_input/001_00_0001.npy') | ||
vid = np.expand_dims(np.float32(np.minimum((vid[:16, :256, :256, :] / vid.mean() / 5), 1.0)), axis=0) | ||
sess.run(tf.global_variables_initializer()) | ||
for i, out in enumerate(out_image): | ||
print out.shape | ||
output = sess.run(out, feed_dict={in_image: vid}) | ||
output = (np.minimum(output, 1.0) * 255).astype('uint8') | ||
print output[0].shape | ||
vwrite(str(i) + '.mp4', output[0]) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
ffmpeg==1.4 | ||
numpy==1.15.2 | ||
opencv-python==3.4.3.18 | ||
requests==2.19.1 | ||
scikit-learn==0.19.2 | ||
scikit-video==1.1.11 | ||
scipy==1.1.0 | ||
tensorboard==1.9.0 | ||
tensorflow==1.9.0 |
Oops, something went wrong.