Skip to content

Commit

Permalink
faster, cleaner, transforms.
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoForte committed Sep 22, 2021
1 parent 5323fa0 commit 8736764
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 54 deletions.
25 changes: 13 additions & 12 deletions demo.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Our libs
from networks.transforms import trimap_transform, groupnorm_normalise_image
from networks.transforms import trimap_transform, normalise_image
from networks.models import build_model
from dataloader import PredDataset

Expand All @@ -11,10 +11,14 @@
import cv2
import numpy as np
import torch
import time

def np_to_torch(x, permute=True):
if permute:
return torch.from_numpy(x).permute(2, 0, 1)[None, :, :, :].float().cuda()
else:
return torch.from_numpy(x)[None, :, :, :].float().cuda()

def np_to_torch(x):
return torch.from_numpy(x).permute(2, 0, 1)[None, :, :, :].float().cuda()


def scale_input(x: np.ndarray, scale: float, scale_type) -> np.ndarray:
Expand All @@ -36,8 +40,9 @@ def predict_fba_folder(model, args):
image_np = item_dict['image']
trimap_np = item_dict['trimap']

st = time.time()
fg, bg, alpha = pred(image_np, trimap_np, model)

print("Time taken for prediction: ", time.time() - st)
cv2.imwrite(os.path.join(
save_dir, item_dict['name'][:-4] + '_fg.png'), fg[:, :, ::-1] * 255)
cv2.imwrite(os.path.join(
Expand Down Expand Up @@ -65,9 +70,9 @@ def pred(image_np: np.ndarray, trimap_np: np.ndarray, model) -> np.ndarray:
trimap_torch = np_to_torch(trimap_scale_np)

trimap_transformed_torch = np_to_torch(
trimap_transform(trimap_scale_np))
image_transformed_torch = groupnorm_normalise_image(
image_torch.clone(), format='nchw')
trimap_transform(trimap_scale_np), permute=False)
image_transformed_torch = normalise_image(
image_torch.clone())

output = model(
image_torch,
Expand Down Expand Up @@ -101,12 +106,8 @@ def pred(image_np: np.ndarray, trimap_np: np.ndarray, model) -> np.ndarray:
'--output_dir',
default='./examples/predictions',
help="")
parser.add_argument(
'--custom_groupnorm',
default=False,
help="Useful for conversion to TRTorch")

args = parser.parse_args()
model = build_model(args.weights, custom_groupnorm=args.custom_groupnorm)
model = build_model(args.weights)
model.eval().cuda()
predict_fba_folder(model, args)
53 changes: 11 additions & 42 deletions networks/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,51 +7,20 @@
def dt(a):
return cv2.distanceTransform((a * 255).astype(np.uint8), cv2.DIST_L2, 0)


def trimap_transform(trimap):
h, w = trimap.shape[0], trimap.shape[1]

clicks = np.zeros((h, w, 6))
def trimap_transform(trimap, L = 320):
clicks = []
for k in range(2):
if(np.count_nonzero(trimap[:, :, k]) > 0):
dt_mask = -dt(1 - trimap[:, :, k])**2
L = 320
clicks[:, :, 3*k] = np.exp(dt_mask / (2 * ((0.02 * L)**2)))
clicks[:, :, 3*k+1] = np.exp(dt_mask / (2 * ((0.08 * L)**2)))
clicks[:, :, 3*k+2] = np.exp(dt_mask / (2 * ((0.16 * L)**2)))

dt_mask = -dt(1 - trimap[:, :, k])**2
clicks.append(np.exp(dt_mask / (2 * ((0.02 * L)**2))))
clicks.append(np.exp(dt_mask / (2 * ((0.08 * L)**2))))
clicks.append(np.exp(dt_mask / (2 * ((0.16 * L)**2))))
clicks = np.array(clicks)
return clicks


# For RGB !
group_norm_std = [0.229, 0.224, 0.225]
group_norm_mean = [0.485, 0.456, 0.406]


def groupnorm_normalise_image(img, format='nhwc'):
'''
Accept rgb in range 0,1
'''
if(format == 'nhwc'):
for i in range(3):
img[..., i] = (img[..., i] - group_norm_mean[i]) / group_norm_std[i]
else:
for i in range(3):
img[..., i, :, :] = (img[..., i, :, :] - group_norm_mean[i]) / group_norm_std[i]

return img
imagenet_norm_std = torch.from_numpy(np.array([0.229, 0.224, 0.225])).float().cuda()[None, :, None, None]
imagenet_norm_mean = torch.from_numpy(np.array([0.485, 0.456, 0.406])).float().cuda()[None, :, None, None]


def groupnorm_denormalise_image(img, format='nhwc'):
'''
Accept rgb, normalised, return in range 0,1
'''
if(format == 'nhwc'):
for i in range(3):
img[:, :, :, i] = img[:, :, :, i] * group_norm_std[i] + group_norm_mean[i]
else:
img1 = torch.zeros_like(img).cuda()
for i in range(3):
img1[:, i, :, :] = img[:, i, :, :] * group_norm_std[i] + group_norm_mean[i]
return img1
return img
def normalise_image(image, mean=imagenet_norm_mean, std=imagenet_norm_std):
return (image - mean) / std

0 comments on commit 8736764

Please sign in to comment.