Skip to content

Commit

Permalink
Merge pull request #398 from axinc-ai/pytorch-inpainting-with-partial…
Browse files Browse the repository at this point in the history
…-conv

Implement pytorch-inpainting-with-partial-conv
  • Loading branch information
kyakuno committed Mar 15, 2021
2 parents 210cc0a + 63a1e7e commit 685d966
Show file tree
Hide file tree
Showing 26 changed files with 269 additions and 0 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ The collection of pre-trained, state-of-the-art models.
| [<img src="image_manipulation/u2net_portrait/your_portrait_results/GalGadot.jpg" width=128px>](image_manipulation/u2net_portrait/) | [u2net_portrait](/image_manipulation/u2net_portrait/) | [U^2-Net: Going Deeper with Nested U-Structure for Salient Object Detection](https://github.com/NathanUA/U-2-Net) | Pytorch | 1.2.2 and later |
| [<img src="image_manipulation/style2paints/output.png" width=128px>](image_manipulation/style2paints/) | [style2paints](/image_manipulation/style2paints/) | [Style2Paints](https://github.com/lllyasviel/style2paints) | TensorFlow | 1.2.6 and later |
| [<img src="image_manipulation/deep_white_balance/output_AWB.png" width=128px>](image_manipulation/deep_white_balance/) | [deep_white_balance](/image_manipulation/deep_white_balance/) | [Deep White-Balance Editing, CVPR 2020 (Oral)](https://github.com/mahmoudnafifi/Deep_White_Balance) | PyTorch | 1.2.6 and later |
| [<img src="image_manipulation/pytorch-inpainting-with-partial-conv/result.png" width=128px>](image_manipulation/pytorch-inpainting-with-partial-conv/) | [inpainting-with-partial-conv](/image_manipulation/pytorch-inpainting-with-partial-conv/) | [pytorch-inpainting-with-partial-conv](https://github.com/naoto0804/pytorch-inpainting-with-partial-conv) | PyTorch | 1.2.6 and later |


## Image segmentation
Expand Down
21 changes: 21 additions & 0 deletions image_manipulation/pytorch-inpainting-with-partial-conv/LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
MIT License

Copyright (c) 2018 Naoto Inoue

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
61 changes: 61 additions & 0 deletions image_manipulation/pytorch-inpainting-with-partial-conv/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# pytorch-inpainting-with-partial-conv

## Input

![Input](input.png)

(Image from Places2 dataset http://places2.csail.mit.edu/download.html)

Shape : (n, 3, 256, 256)

## Output

![Output](result.png)

Left to right: input, mask, image generated by the network, ground truth

Shape : (n, 3, 256, 256)

## Usage
Automatically downloads the onnx and prototxt files on the first run.
It is necessary to be connected to the Internet while downloading.

For the sample image,
```bash
$ python3 pytorch-inpainting-with-partial-conv.py
```

If you want to specify the input image, put the image path after the `--input` option.
You can use `--savepath` option to change the name of the output file to save.
```bash
$ python3 pytorch-inpainting-with-partial-conv.py --input IMAGE_PATH --savepath SAVE_IMAGE_PATH
```

By adding the `--mask-index` option, you can specify index of mask file.
(default is randomly selected)

```bash
$ python3 pytorch-inpainting-with-partial-conv.py --mask-index 12
```

If you want to re-generate mask file, run generate_masks.py.
You can specify the number of files to generate by `--N` option.
```bash
$ python3 generate_masks.py --N 16
```

## Reference

[pytorch-inpainting-with-partial-conv](https://github.com/naoto0804/pytorch-inpainting-with-partial-conv)

## Framework

Pytorch

## Model Format

ONNX opset=11

## Netron

[partialconv.onnx.prototxt](https://netron.app/?url=https://storage.googleapis.com/ailia-models/pytorch-inpainting-with-partial-conv/partialconv.onnx.prototxt)
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import argparse
import numpy as np
import random
from PIL import Image

action_list = [[0, 1], [0, -1], [1, 0], [-1, 0]]


def random_walk(canvas, ini_x, ini_y, length):
x = ini_x
y = ini_y
img_size = canvas.shape[-1]
x_list = []
y_list = []
for i in range(length):
r = random.randint(0, len(action_list) - 1)
x = np.clip(x + action_list[r][0], a_min=0, a_max=img_size - 1)
y = np.clip(y + action_list[r][1], a_min=0, a_max=img_size - 1)
x_list.append(x)
y_list.append(y)
canvas[np.array(x_list), np.array(y_list)] = 0
return canvas


if __name__ == '__main__':
import os

parser = argparse.ArgumentParser()
parser.add_argument('--image_size', type=int, default=512)
parser.add_argument('--N', type=int, default=10000)
parser.add_argument('--save_dir', type=str, default='masks')
args = parser.parse_args()

if not os.path.exists(args.save_dir):
os.makedirs(args.save_dir)

for i in range(args.N):
canvas = np.ones((args.image_size, args.image_size)).astype("i")
ini_x = random.randint(0, args.image_size - 1)
ini_y = random.randint(0, args.image_size - 1)
mask = random_walk(canvas, ini_x, ini_y, args.image_size ** 2)
print("save:", i, np.sum(mask))

img = Image.fromarray(mask * 255).convert('1')
img.save('{:s}/{:06d}.jpg'.format(args.save_dir, i))
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import os
import sys
import time
import glob
import random

import numpy as np
import cv2
from PIL import Image

import ailia

# import original modules
sys.path.append('../../util')
from utils import get_base_parser, update_parser, get_savepath # noqa: E402
from model_utils import check_and_download_models # noqa: E402
from image_utils import normalize_image # noqa: E402
from detector_utils import load_image # noqa: E402

# logger
from logging import getLogger # noqa: E402

logger = getLogger(__name__)

# ======================
# Parameters
# ======================
WEIGHT_PATH = 'partialconv.onnx'
MODEL_PATH = 'partialconv.onnx.prototxt'
REMOTE_PATH = 'https://storage.googleapis.com/ailia-models/pytorch-inpainting-with-partial-conv/'

IMAGE_PATH = 'Places365_test_00000146.jpg'
SAVE_IMAGE_PATH = 'result.png'
IMAGE_HEIGHT = 256
IMAGE_WIDTH = 256

NORM_MEAN = [0.485, 0.456, 0.406]
NORM_STD = [0.229, 0.224, 0.225]

# ======================
# Arguemnt Parser Config
# ======================
parser = get_base_parser('pytorch-inpainting-with-partial-conv model', IMAGE_PATH, SAVE_IMAGE_PATH)
parser.add_argument(
'-mi', '--mask-index', type=int, metavar='INDEX', default=None,
help='Mask index. If not specified, it will be randomly selected.'
)
args = update_parser(parser)


# ======================
# Main functions
# ======================

def postprocess(x):
mean = np.array(NORM_MEAN)
std = np.array(NORM_STD)
x = x.transpose(1, 2, 0) # CHW -> HWC
x = x * std + mean
x = x * 255
x = x[:, :, ::-1] # RGB -> BGR
return x


def recognize_from_image(net):
mask_paths = glob.glob('masks/*.jpg')
N_mask = len(mask_paths)

# input image loop
for image_path in args.input:
logger.info(image_path)

# prepare grand truth
gt_img = load_image(image_path)
gt_img = cv2.cvtColor(gt_img, cv2.COLOR_BGRA2RGB)
gt_img = np.array(Image.fromarray(gt_img).resize((IMAGE_WIDTH, IMAGE_HEIGHT), Image.BILINEAR))
gt_img = normalize_image(gt_img, 'ImageNet')
gt_img = gt_img.transpose((2, 0, 1)) # channel first

# prepare mask
if args.mask_index is not None:
mask_path = mask_paths[args.mask_index % N_mask]
else:
mask_path = mask_paths[random.randint(0, N_mask - 1)]
mask = load_image(mask_path)
mask = cv2.cvtColor(mask, cv2.COLOR_BGRA2RGB)
mask = np.array(Image.fromarray(mask).resize((IMAGE_WIDTH, IMAGE_HEIGHT), Image.BILINEAR))
mask = mask.transpose((2, 0, 1)) / 255 # channel first

# prepare input data
img = gt_img * mask
img = np.expand_dims(img, axis=0)
mask = np.expand_dims(mask, axis=0)
gt_img = np.expand_dims(gt_img, axis=0)

logger.debug(f'input image shape: {img.shape}')

# inference
logger.info('Start inference...')
if args.benchmark:
logger.info('BENCHMARK mode')
total_time = 0
for i in range(args.benchmark_count):
start = int(round(time.time() * 1000))
output = net.predict({'image': img, 'mask': mask})
end = int(round(time.time() * 1000))
logger.info(f'\tailia processing time {end - start} ms')
if i != 0:
total_time = total_time + (end - start)
logger.info(f'\taverage time {total_time / (args.benchmark_count - 1)} ms')
else:
output = net.predict({'image': img, 'mask': mask})

output, _ = output

img = postprocess(img[0])
mask = mask[0].transpose(1, 2, 0) * 255
output = postprocess(output[0])
gt_img = postprocess(gt_img[0])
res_img = np.hstack((img, mask, output, gt_img))

savepath = get_savepath(args.savepath, image_path, ext='.png')
logger.info(f'saved at : {savepath}')
cv2.imwrite(savepath, res_img)

logger.info('Script finished successfully.')


def main():
# model files check and download
check_and_download_models(WEIGHT_PATH, MODEL_PATH, REMOTE_PATH)

# net initialize
net = ailia.Net(MODEL_PATH, WEIGHT_PATH, env_id=args.env_id)

recognize_from_image(net)


if __name__ == '__main__':
main()
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions scripts/download_all_models.sh
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ cd ../../image_classification/vgg16; python3 vgg16.py ${OPTION}
cd ../../image_manipulation/dewarpnet; python3 dewarpnet.py ${OPTION}
cd ../../image_manipulation/illnet; python3 illnet.py ${OPTION}
cd ../../image_manipulation/noise2noise; python3 noise2noise.py ${OPTION}
cd ../../image_manipulation/pytorch-inpainting-with-partial-conv; python3 pytorch-inpainting-with-partial-conv.py ${OPTION}
cd ../../image_manipulation/colorization; python3 colorization.py ${OPTION}
cd ../../image_manipulation/u2net_portrait; python3 u2net_portrait.py ${OPTION}
cd ../../image_manipulation/style2paints; python3 style2paints.py ${OPTION}
Expand Down

0 comments on commit 685d966

Please sign in to comment.