In [1]:
from __future__ import absolute_import, division, print_function

import click
import cv2
import matplotlib
import matplotlib.cm as cm
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import yaml
from addict import Dict
import os
import time

from libs.models import *
from libs.utils import DenseCRF


def get_classtable(CONFIG):
    with open(CONFIG.DATASET.LABELS) as f:
        classes = {}
        for label in f:
            label = label.rstrip().split("\t")
            classes[int(label[0])] = label[1].split(",")[0]
    return classes


def get_device(cuda):
    cuda = cuda and torch.cuda.is_available()
    device = torch.device("cuda" if cuda else "cpu")
    if cuda:
        current_device = torch.cuda.current_device()
        print("Device:", torch.cuda.get_device_name(current_device))
    else:
        print("Device: CPU")
    return device


def setup_postprocessor(CONFIG):
    # CRF post-processor
    postprocessor = DenseCRF(
        iter_max=CONFIG.CRF.ITER_MAX,
        pos_xy_std=CONFIG.CRF.POS_XY_STD,
        pos_w=CONFIG.CRF.POS_W,
        bi_xy_std=CONFIG.CRF.BI_XY_STD,
        bi_rgb_std=CONFIG.CRF.BI_RGB_STD,
        bi_w=CONFIG.CRF.BI_W,
    )
    return postprocessor


def preprocessing(image, device, CONFIG):
    # Resize
    scale = CONFIG.IMAGE.SIZE.TEST / max(image.shape[:2])
    image = cv2.resize(image, dsize=None, fx=scale, fy=scale)
    raw_image = image.astype(np.uint8)

    # Subtract mean values
    image = image.astype(np.float32)
    image -= np.array(
        [
            float(CONFIG.IMAGE.MEAN.B),
            float(CONFIG.IMAGE.MEAN.G),
            float(CONFIG.IMAGE.MEAN.R),
        ]
    )

    # Convert to torch.Tensor and add "batch" axis
    image = torch.from_numpy(image.transpose(2, 0, 1)).float().unsqueeze(0)
    image = image.to(device)

    return image, raw_image


def inference(model, image, raw_image=None, postprocessor=None):
    _, _, H, W = image.shape

    # Image -> Probability map
    logits = model(image)
    logits = F.interpolate(logits, size=(H, W), mode="bilinear", align_corners=False)
    probs = F.softmax(logits, dim=1)[0]
    probs = probs.cpu().numpy()

    # Refine the prob map with CRF
    if postprocessor and raw_image is not None:
        probs = postprocessor(raw_image, probs)

    labelmap = np.argmax(probs, axis=0)

    return labelmap


In [2]:
config_path = "C:\\Users\\oranl\\Desktop\\proj\\deeplab-pytorch-master\\configs\\cocostuff10k.yaml"
image_path = "C:\\Users\\oranl\\Desktop\\tttt.jpg"
model_path = "C:\\Users\\oranl\\Desktop\\proj\\deeplab-pytorch-master\\deeplabv2_resnet101_msc-cocostuff10k-20000.pth"

folder_path = "C:\\Users\\oranl\\Desktop\\proj\\landscapes_small\\field"
cuda = True
crf = None

CONFIG = Dict(yaml.load(open(config_path)))
device = get_device(cuda)
torch.set_grad_enabled(False)

classes = get_classtable(CONFIG)
postprocessor = setup_postprocessor(CONFIG) if crf else None

model = eval(CONFIG.MODEL.NAME)(n_classes=CONFIG.DATASET.N_CLASSES)
state_dict = torch.load(model_path, map_location=lambda storage, loc: storage)
model.load_state_dict(state_dict)
model.eval()
model.to(device)
print("Model:", CONFIG.MODEL.NAME)

# Inference
# image = cv2.imread(image_path, cv2.IMREAD_COLOR)
# image, raw_image = preprocessing(image, device, CONFIG)
# labelmap = inference(model, image, raw_image, postprocessor)
# labels = np.unique(labelmap)

print("done")
total = 20
start = time.time()
count = 0
image_addresses = {}
labelmaps = []
labellist = []
for dr in os.listdir(folder_path):
    image_path = os.path.join(folder_path, dr)
    image = cv2.imread(image_path, cv2.IMREAD_COLOR)
                       
    if image.shape[1] < image.shape[0]:
        continue
        
    image, raw_image = preprocessing(image, device, CONFIG)
    labelmap = inference(model, image, raw_image, postprocessor)
    labels = np.unique(labelmap)
    
    labellist.append(labels)
    labelmaps.append(labelmap)
    image_addresses[count] = image_path

    count += 1
    
    if count % 5 == 0:
        print(count)
    if count >= total:
        break
        
end = time.time()
print((end - start)/total)
print(end-start)

# # Show result for each class
# rows = np.floor(np.sqrt(len(labels) + 1))
# cols = np.ceil((len(labels) + 1) / rows)

# plt.figure(figsize=(10, 10))
# ax = plt.subplot(rows, cols, 1)
# ax.set_title("Input image")
# ax.imshow(raw_image[:, :, ::-1])
# ax.axis("off")

# for i, label in enumerate(labels):
#     mask = labelmap == label
# #     print(mask.shape)
#     ax = plt.subplot(rows, cols, i + 2)
#     ax.set_title(classes[label])
# #     ax.imshow(raw_image[..., ::-1])
# #     print(mask.astype(np.float32))
#     ax.imshow(mask.astype(np.float32), alpha=1)
#     ax.axis("off")

# plt.tight_layout()
# plt.show()

  if __name__ == '__main__':


Device: CPU
Model: DeepLabV2_ResNet101_MSC
done
5
10
15
20
13.734725069999694
274.6945013999939


In [3]:
from PIL import Image

# print(labels)
to_one_hot = {105 : 0, # clouds
              
              123: 1, # grass
              96: 1, # bush
              141: 1, # plant-other
              144 :1, # playingfield
              124: 1, # gravel
              93 : 1, 
              125 : 1, 
              128 : 1, 
              133 : 1, 
              153 : 1, 
              181 : 1, 
              
              156: 2, # sky-other
              
              168: 3, # tree
              
              126: 4, # hill
              110: 4, # dirt
              134: 4,
              149: 4,
              
              154: 5, # sea-other
              147: 5
              
             }

label_to_color = {0: np.array([255, 255, 255]),
                 1: np.array([0, 255, 0]),
                 2: np.array([0, 197, 229]),
                 3: np.array([0, 154, 78]),
                 4: np.array([143, 74, 0]),
                 5: np.array([0, 64, 143])}


one_hot_arrs = np.zeros((len(labellist), 347, 513, 6))
arrs = []
start = time.time()
for i in range(len(labellist)):
    labels = labellist[i]
    labelmap = labelmaps[i]
    for i, label in enumerate(labels):
        if label not in to_one_hot:
            to_one_hot[label] = 0

    pic = np.vectorize(to_one_hot.__getitem__)(labelmap)
    arr = (np.arange(6) == pic[...,None]).astype('uint8')
    
    arrs.append(arr)

one_hot_arrs = np.stack(arrs, axis=0)

end = time.time()
print((end-start)/len(labelmap))

def to_pic(one_hot_pic):
    x, y, _ = one_hot_pic.shape
    image = np.zeros(shape = (x, y, 3))
    
    
    
    for i in range(x):
        for j in range(y):
            image[i, j] = label_to_color.get(np.dot(one_hot_pic[i, j], np.arange(6)), label_to_color[0])
    
    
    print(image.shape)
    pic = Image.fromarray(image.astype('uint8'), mode="RGB")
    
    return pic

0.0011753396960431628


In [4]:
np.save("landscape", one_hot_arrs)

In [5]:
temp = np.load("landscape.npy")
print(temp.shape)

(20, 347, 513, 6)


In [7]:
from matplotlib.pyplot import imshow
from IPython.display import display

for i in range(one_hot_arrs.shape[0]):
    im1 = to_pic(one_hot_arrs[i])
    im1.save("test.png")
    print("A")
    im = Image.open(image_addresses[i])
    break
#     display(im)
#     display(im1)
#     a = input()
#     if a == "q":
#         break

(347, 513, 3)
A
