In [1]:
# To infer at test time. 
# Take impulse as input from user 
# and predict object corresponding to it

In [2]:
import model_lib
import numpy as np
import warnings
warnings.filterwarnings('ignore', '.*output shape of zoom.*')
import pickle
import importlib
importlib.reload(model_lib)
import os
import time
import random

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [3]:
# config to train
# TODO: check Config is correct
class ProposalConfig():
    NAME = "InSegm"
    GPU_COUNT = 1
    # online training
    IMAGES_PER_GPU = 1
    STEPS_PER_EPOCH = 100
    NUM_WORKERS = 1
    PIN_MEMORY = True
    DATA_ORDER = "cw_ins"
    VALIDATION_STEPS = 20
    # including gt
    NUM_CLASSES = 81
    
    # only flips
    MEAN_PIXEL = np.array([0.485, 0.456, 0.406],dtype=np.float32).reshape(1,1,-1)
    STD_PIXEL = np.array([0.229, 0.224, 0.225],dtype=np.float32).reshape(1,1,-1)
    CLASS_NAMES = [
        'BG', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
        'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign',
        'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep',
        'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella',
        'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard',
        'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard',
        'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork',
        'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange',
        'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair',
        'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv',
        'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave',
        'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase',
        'scissors', 'teddy bear', 'hair drier', 'toothbrush'
    ]
    WIDTH = 224
    HEIGHT = 224
    CROP_SIZE = 224
    def __init__(self):
        self.BATCH_SIZE = self.IMAGES_PER_GPU * self.GPU_COUNT
        self.IMAGE_SHAPE = (self.WIDTH, self.HEIGHT,3)

    def display(self):
        """Display Configuration values."""
        print("\nConfigurations:")
        for a in dir(self):
            if not a.startswith("__") and not callable(getattr(self, a)):
                print("{:30} {}".format(a, getattr(self, a)))
        print("\n")

In [4]:
image_dir = "/media/Data1/interns/aravind/test2017/"
model_dir = "./models/"
config = ProposalConfig()


In [5]:
import torch
import torch.nn.functional as F
from PIL import Image
net = model_lib.MultiHGModel()

pretrained_dict = torch.load(model_dir+"multi_impulse.pt")
net_dict = net.state_dict()


pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in net_dict}
net_dict.update(pretrained_dict) 
net.load_state_dict(net_dict)

# net.vgg.load_state_dict(torch.load(model_dir+"model_vgg_class_only.pt").vgg)
# net.classifier.load_state_dict(torch.load(model_dir+"model_vgg_class_only.pt").classifier)
net = net.cuda()

In [42]:
clicks = []
def show_and_quit(np_images):
    mode = {2:'L',3:'RGB'}
    objs = []
    for img in np_images:
        f = Image.fromarray(img.astype(np.uint8),mode[img.ndim])
        objs.append(f)
        f.show()        
    time.sleep
    [f.close() for f in objs]

def detect_on_loc(event,img_obj):
    global clicks
    clicks.append([event.x,event.y])
    if len(clicks) != 3:
        return 0
    impulse = np.zeros((4,) + img_obj.size).astype(np.float32)
    for event in clicks:
        q, p = event[0], event[1]
        w,h = img_obj.size
        l = 8
        n = 4
        for i in range(4):
            L = (l // 2) * (3**i)
            impulse[i][max(p - L, 0):min(p + L, w), max(q - L, 0):min(q + L, h)] = 1
    clicks = []
    with torch.no_grad():
        image = np.array(img_obj).astype(np.float32)
        batch_images = image/256
        batch_images -= config.MEAN_PIXEL
        batch_images /= config.STD_PIXEL

        batch_images = torch.from_numpy(np.moveaxis(np.expand_dims(batch_images,0),-1,1)).cuda()
        batch_impulses = torch.from_numpy(np.expand_dims(impulse,0)).cuda()

        pred_class,pred_masks = net([batch_images,batch_impulses])

        pred_class = F.softmax(pred_class,dim=-1).squeeze()
        maxs, indices = torch.topk(pred_class,5,-1)

        pred_mask = pred_masks[1].squeeze()
        pred_mask = F.sigmoid(pred_mask)
#         pred_mask = F.threshold(pred_mask,0.5,0)
        pred_mask = pred_mask.squeeze().cpu().numpy()*255
#         Image.fromarray(pred_mask.astype(np.uint8),"L").convert("RGB").show()

        pred_mask = pred_masks[1].squeeze()
        pred_mask = F.sigmoid(pred_mask)
        pred_mask = F.threshold(pred_mask,0.5,0)
        pred_mask = pred_mask.squeeze().cpu().numpy()*255
#         Image.fromarray(pred_mask.astype(np.uint8),"L").convert("RGB").show()

        for i in range(5):
            print(maxs[i],indices[i],config.CLASS_NAMES[int(indices[i])])
        print("=====")
        show_and_quit([impulse[0]*255,pred_mask])


def get_image(image_path):
    thumbnail_shape = (224,224)
    z = Image.new("RGB", thumbnail_shape, "black")
    image_obj = Image.open(image_path).convert("RGB")
    image_obj.thumbnail(thumbnail_shape, Image.ANTIALIAS)
    (w, h) = image_obj.size
    z.paste(image_obj, ((thumbnail_shape[0] - w) // 2, (thumbnail_shape[1] - h) // 2))
    return z



In [63]:
import tkinter as tk
from PIL import ImageTk, Image
image_path = image_dir+random.choice(os.listdir(image_dir))
print(image_path)
#This creates the main window of an application
window = tk.Tk()
window.title("Join")
window.geometry("224x224")
window.configure(background='grey')

#Creates a Tkinter-compatible photo image, which can be used everywhere Tkinter expects an image object.
img = ImageTk.PhotoImage(get_image(image_path))

#The Label widget is a standard Tkinter widget used to display a text or image on the screen.
panel = tk.Label(window, image = img)

#The Pack geometry manager packs widgets in rows or columns.
panel.pack(side = "bottom", fill = "both", expand = "yes")

#Bind motion to motion function
window.bind("<Button-1>", lambda event, arg=get_image(image_path) : detect_on_loc(event, arg))

#Start the GUI
window.mainloop()


/media/Data1/interns/aravind/test2017/000000407258.jpg
tensor(0.8104, device='cuda:0') tensor(62, device='cuda:0') toilet
tensor(1.00000e-02 *
       5.9867, device='cuda:0') tensor(72, device='cuda:0') sink
tensor(1.00000e-02 *
       4.2428, device='cuda:0') tensor(73, device='cuda:0') refrigerator
tensor(1.00000e-02 *
       2.2653, device='cuda:0') tensor(60, device='cuda:0') bed
tensor(1.00000e-02 *
       2.1791, device='cuda:0') tensor(0, device='cuda:0') BG
=====
tensor(0.8537, device='cuda:0') tensor(62, device='cuda:0') toilet
tensor(1.00000e-02 *
       6.2543, device='cuda:0') tensor(72, device='cuda:0') sink
tensor(1.00000e-02 *
       2.0894, device='cuda:0') tensor(60, device='cuda:0') bed
tensor(1.00000e-02 *
       2.0626, device='cuda:0') tensor(73, device='cuda:0') refrigerator
tensor(1.00000e-02 *
       1.5027, device='cuda:0') tensor(0, device='cuda:0') BG
=====
