In [None]:
import sys
import torch
from PIL import Image
import numpy as np

sys.path.append('../')
from detector.architecture import Architecture

# Load a trained model

In [None]:
model = Architecture(num_outputs=5 + 10)
model.eval()
model.load_state_dict(torch.load('../models/run00.pth', map_location=torch.device('cpu')))

# Get an image

In [None]:
image = Image.open('/home/dan/datasets/COCO/images/val2017/000000000885.jpg')
print(image.size)
image = image.resize((640, 448))
image

# Predict

In [None]:
image_tensor = torch.FloatTensor(np.array(image)/255.0)
image_tensor = image_tensor.unsqueeze(0).permute(0, 3, 1, 2)

with torch.no_grad():
    x, features = model(image_tensor)
    
heatmaps, offsets = torch.split(x, [5, 10], dim=1)
heatmaps = torch.sigmoid(heatmaps)[0]

# Show masks

In [None]:
def show_mask(image, mask):

    red = np.array([255, 0, 0], dtype='uint8')
    gray_mask = mask.numpy().astype('uint8')
    color_mask = red * np.expand_dims(gray_mask, 2)

    gray_mask = Image.fromarray(100 * gray_mask)
    color_mask = Image.fromarray(color_mask)
    color_mask.putalpha(gray_mask)
    
    image_copy = image.copy()
    image_copy.putalpha(255)
    width, height = image.size
    image_copy.alpha_composite(color_mask.resize((width, height)))
    return image_copy

In [None]:
mask = features['p2'][0, 0]
mask = (mask - mask.min())/(mask.max() - mask.min())

In [None]:
show_mask(image, mask > 0.5)

# Show heatmaps

In [None]:
image_copy = image.copy()
image_copy.putalpha(255)
width, height = image.size

colors = {
    0: [255, 0, 0],  # red - top
    1: [0, 0, 255],  # blue - bottom
    2: [255, 255, 0],  # yellow - left
    3: [255, 0, 255],  # pink - right
    4: [0, 255, 0]  # green - center
}

for i, color in colors.items():

    h = np.expand_dims(heatmaps[i].numpy() > 0.1, 2)
    gray_h = Image.fromarray(255*h[:, :, 0].astype('uint8'))
    color_h = Image.fromarray((color * h).astype('uint8'))
    
    color_h = color_h.resize((width, height))
    color_h.putalpha(gray_h.resize((width, height)))
    
    image_copy.alpha_composite(color_h)
    
image_copy