In [1]:
%matplotlib inline
from skimage import transform
import numpy as np
from pathlib import Path
import re
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from matplotlib import colormaps
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from floortrans.models import get_model     
from floortrans.loaders import FloorplanSVG, DictToTensor, Compose, RotateNTurns
from floortrans.loaders.svg_loader import FloorplanSVGSample

from floortrans.plotting import segmentation_plot, polygons_to_image, draw_junction_from_dict, discrete_cmap
discrete_cmap()
from floortrans.post_prosessing import split_prediction, get_polygons, split_validation
from mpl_toolkits.axes_grid1 import AxesGrid
import panel as pn
from PIL import Image, ImageDraw

pn.extension()

rot = RotateNTurns()
room_classes = ["Background", "Outdoor", "Wall", "Kitchen", "Living Room" ,"Bed Room", "Bath", "Entry", "Railing", "Storage", "Garage", "Undefined"]
icon_classes = ["No Icon", "Window", "Door", "Closet", "Electrical Applience" ,"Toilet", "Sink", "Sauna Bench", "Fire Place", "Bathtub", "Chimney"]

In [2]:
# Copy a model_best_val_loss_var.pkl from a run, and optionally epoch-dependent checkpoints (e.g. model_epoch_41.pkl) here
# Also copy any png images here too (the glob further down will pick up on them
target_dir = Path('predict_cubi/')

In [35]:
data_folder = Path('data/cubicasa5k/')
n_classes = 44
split = [21, 12, 11]
n_rooms = 12
n_icons = 11
def setup_model(target_file = 'model_best_val_loss_var.pkl'):
    # Setup Model
    model = get_model('hg_furukawa_original', 51)
    
    model.conv4_ = torch.nn.Conv2d(256, n_classes, bias=True, kernel_size=1)
    model.upsample = torch.nn.ConvTranspose2d(n_classes, n_classes, kernel_size=4, stride=4)
    if torch.cuda.is_available():
        checkpoint = torch.load(target_dir / target_file)
    else:
        checkpoint = torch.load(
            target_dir / target_file, map_location=torch.device('cpu')
        )

    model.load_state_dict(checkpoint['model_state'])
    model.eval()
    if torch.cuda.is_available():
        model.cuda()
    print(f"Model loaded, {target_file=}, epoch={checkpoint['epoch']}")
    return model

In [25]:
def run_segmentation(model, image):
    with torch.no_grad():
        # Use the image dimensions, not those of the label!
        height = image.shape[2]
        width = image.shape[3]
        img_size = (height, width)
        
        rotations = [(0, 0), (1, -1), (2, 2), (-1, 1)]
        pred_count = len(rotations)
        prediction = torch.zeros([pred_count, n_classes, height, width])
        for i, r in enumerate(rotations):
            forward, back = r
            # We rotate first the image
            rot_image = rot(image, 'tensor', forward)
            pred = model(rot_image)
            # We rotate prediction back
            pred = rot(pred, 'tensor', back)
            # We fix heatmaps
            pred = rot(pred, 'points', back)
            # We make sure the size is correct
            pred = F.interpolate(pred, size=(height, width), mode='bilinear', align_corners=True)
            # We add the prediction to output
            prediction[i] = pred[0]

    prediction = torch.mean(prediction, 0, True)
    # rooms_label = label_np[0]
    # icons_label = label_np[1]

    rooms_pred = F.softmax(prediction[0, 21:21+12], 0).cpu().data.numpy()
    rooms_pred = np.argmax(rooms_pred, axis=0)

    icons_pred = F.softmax(prediction[0, 21+12:], 0).cpu().data.numpy()
    icons_pred = np.argmax(icons_pred, axis=0)
    fig = plt.figure(figsize=(12,12))
    ax = plt.subplot(1, 1, 1)
    ax.axis('off')
    rseg = ax.imshow(rooms_pred, cmap='rooms', vmin=0, vmax=n_rooms-0.1)
    cbar = plt.colorbar(rseg, ticks=np.arange(n_rooms) + 0.5, fraction=0.046, pad=0.01)
    cbar.ax.set_yticklabels(room_classes, fontsize=20)
    plt.close()
    rooms_pane = pn.pane.Matplotlib(fig, dpi=144, tight=True, sizing_mode="scale_both")

    fig = plt.figure(figsize=(12,12))
    ax = plt.subplot(1, 1, 1)
    ax.axis('off')
    iseg = ax.imshow(icons_pred, cmap='icons', vmin=0, vmax=n_icons-0.1)
    cbar = plt.colorbar(iseg, ticks=np.arange(n_icons) + 0.5, fraction=0.046, pad=0.01)
    cbar.ax.set_yticklabels(icon_classes, fontsize=20)
    plt.close()
    icons_pane = pn.pane.Matplotlib(fig, dpi=144, tight=True, sizing_mode="scale_both")
    return rooms_pane, icons_pane, prediction, img_size

In [38]:
def run_all(target_images: tuple[Path, ...], target_file = 'model_best_val_loss_var.pkl'):
    model = setup_model(target_file)
    layout = pn.GridBox(ncols=3)
    display(layout)
    for target_image in target_images:
        with Image.open(target_image) as im:                        
            image = (np.moveaxis(np.array(im.convert("RGB")), -1, 0) / 255).astype(np.float32)
            image = torch.from_numpy(image.reshape((1, *image.shape)))
        rooms_pane, icons_pane, prediction, img_size = run_segmentation(model, image)
        layout.extend([
            pn.pane.Image(target_image, sizing_mode="scale_both"),
            rooms_pane, icons_pane,
        ])

In [39]:
target_images = list(target_dir.glob('*.png'))

In [40]:
run_all(target_images)

Model loaded, target_file='model_best_val_loss_var.pkl', epoch=191
