# Uses SVG files to process and save them as a 256x512 file prepared for pix2pix

In [None]:
%matplotlib inline
from skimage import transform
from IPython.display import clear_output
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import torch
import torch.nn.functional as F
import io
import PIL.Image
from torch.utils.data import DataLoader
from floortrans.models import get_model
from floortrans.loaders import FloorplanSVG, DictToTensor, Compose, RotateNTurns
from floortrans.plotting import segmentation_plot, polygons_to_image, draw_junction_from_dict, discrete_cmap
discrete_cmap()
from floortrans.post_prosessing import split_prediction, get_polygons, split_validation
from mpl_toolkits.axes_grid1 import AxesGrid

rot = RotateNTurns()
room_classes = ["Background", "Outdoor", "Wall", "Kitchen", "Living Room" ,"Bed Room", "Bath", "Entry", "Railing", "Storage", "Garage", "Undefined"]
icon_classes = ["No Icon", "Window", "Door", "Closet", "Electrical Applience" ,"Toilet", "Sink", "Sauna Bench", "Fire Place", "Bathtub", "Chimney"]

data_folder = 'data/cubicasa5k/'
data_file = 'train.txt'
output_folder = 'data/cubicasa5k/output/'
filename = 0

normal_set = FloorplanSVG(data_folder, data_file, format='txt', original_size=True)
data_loader = DataLoader(normal_set, batch_size=1, num_workers=0)
data_iter = iter(data_loader)

# Setup Model
model = get_model('hg_furukawa_original', 51)

n_classes = 44
split = [21, 12, 11]
model.conv4_ = torch.nn.Conv2d(256, n_classes, bias=True, kernel_size=1)
model.upsample = torch.nn.ConvTranspose2d(n_classes, n_classes, kernel_size=4, stride=4)
checkpoint = torch.load('model_best_val_loss_var.pkl')

model.load_state_dict(checkpoint['model_state'])
model.eval()
model.cuda()
print("Model loaded!")

n_rooms = 12
n_icons = 11

# Looping through the whole folder of floorplans
for item in data_iter:
    val = item
    junctions = val['heatmaps']
    folder = val['folder'][0]
    image = val['image'].cuda()
    label = val['label']
    np_img = np.moveaxis(image[0].cpu().data.numpy(), 0, -1) / 2 + 0.5
    plt.figure(figsize=(10,10))
    plt.title('Source Image', fontsize=20)
    plt.axis('off')
    plt.imshow(np_img)
    plt.show()

    folder_true = folder[1:]

    label_np = label.data.numpy()[0]

    with torch.no_grad():
        height = label_np.shape[1]
        width = label_np.shape[2]
        img_size = (height, width)

        rotations = [(0, 0), (1, -1), (2, 2), (-1, 1)]
        pred_count = len(rotations)
        prediction = torch.zeros([pred_count, n_classes, height, width])
        for i, r in enumerate(rotations):
            forward, back = r
            # We rotate first the image
            rot_image = rot(image, 'tensor', forward)
            pred = model(rot_image)
            # We rotate prediction back
            pred = rot(pred, 'tensor', back)
            # We fix heatmaps
            pred = rot(pred, 'points', back)
            # We make sure the size is correct
            pred = F.interpolate(pred, size=(height, width), mode='bilinear', align_corners=True)
            # We add the prediction to output
            prediction[i] = pred[0]

    prediction = torch.mean(prediction, 0, True)
    rooms_label = label_np[0]
    icons_label = label_np[1]

    rooms_pred = F.softmax(prediction[0, 21:21+12], 0).cpu().data.numpy()
    rooms_pred = np.argmax(rooms_pred, axis=0)

    icons_pred = F.softmax(prediction[0, 21+12:], 0).cpu().data.numpy()
    icons_pred = np.argmax(icons_pred, axis=0)

    #Showcasing the progress, this is the rough segmentation that is then processed later
    plt.figure(figsize=(12,12))
    ax = plt.subplot(1, 1, 1)
    ax.axis('off')
    rseg = ax.imshow(rooms_pred, cmap='rooms', vmin=0, vmax=n_rooms-0.1)
    cbar = plt.colorbar(rseg, ticks=np.arange(n_rooms) + 0.5, fraction=0.046, pad=0.01)
    cbar.ax.set_yticklabels(room_classes, fontsize=20)
    
    #Rough segmentation of the icons
    plt.figure(figsize=(12,12))
    ax = plt.subplot(1, 1, 1)
    ax.axis('off')
    iseg = ax.imshow(icons_pred, cmap='icons', vmin=0, vmax=n_icons-0.1)
    cbar = plt.colorbar(iseg, ticks=np.arange(n_icons) + 0.5, fraction=0.046, pad=0.01)
    cbar.ax.set_yticklabels(icon_classes, fontsize=20)

    heatmaps, rooms, icons = split_prediction(prediction, img_size, split)
    polygons, types, room_polygons, room_types = get_polygons((heatmaps, rooms, icons), 0.2, [1, 2])
    
    # This is to save the figure in the buffer so we can combine and resize it more easily
    buf = io.BytesIO()
    buf2 = io.BytesIO()
    
    #Getting the relationship between the width and height so we get no opaque borders
    x_size = width/height*10
    y_size = 10

    pol_room_seg, pol_icon_seg = polygons_to_image(polygons, types, room_polygons, room_types, height, width)
    fig = plt.figure(frameon=False)
    fig.set_size_inches(x_size, y_size, forward=False)
    ax = plt.Axes(fig, [0., 0., 1., 1.])
    ax.axis('off')
    fig.add_axes(ax)
    rseg = ax.imshow(pol_room_seg, cmap='rooms', vmin=0, vmax=n_rooms-0.1)
    
    #Saving the plot to the buffer for combining with base floorplan
    plt.savefig(buf, format='png')
    buf.seek(0)
    im = PIL.Image.open(buf)

    #For Icons add support later, have to change the colors in the code
    fig = plt.figure(frameon=False)
    fig.set_size_inches(x_size, y_size, forward=False)
    ax = plt.Axes(fig, [0., 0., 1., 1.])
    ax.axis('off')
    fig.add_axes(ax)
    iseg = ax.imshow(pol_icon_seg, cmap='icons', vmin=0, vmax=n_icons-0.1)
    
    plt.savefig(buf2, format='png')
    buf2.seek(0)
    im_icon = PIL.Image.open(buf2)
    
    #Value to resize to
    resize = 256, 256

    # generating the thumbnail from given size
    im.thumbnail(resize, PIL.Image.ANTIALIAS)

    offset_x = max((resize[0] - im.size[0]) / 2, 0)
    offset_y = max((resize[1] - im.size[1]) / 2, 0)

    offset_tuple = (int(offset_x), int(offset_y)) #pack x and y into a tuple

    # create the image object to be the final product setting background to pink to match the generated image background
    im_room = PIL.Image.new(mode='RGBA',size=resize,color=(255, 0, 255, 255))

    # paste the thumbnail into the full sized image
    im_room.paste(im, offset_tuple)

    final_im = im_room

    # Adding the icons to the room map
    im_icon.thumbnail(resize, PIL.Image.ANTIALIAS)
    np_im_icon = np.array(im_icon)

    # Removing the black background for the icons so that we can paste it ontop of the room image.
    np_im_icon[np.where((np_im_icon==[0, 0, 0, 255]).all(axis=2))] = [255,255,255,0]

    # Making the np array to an image
    final_icon_im = PIL.Image.fromarray(np_im_icon, 'RGBA')

    # Pasting icons onto rooms
    final_im.paste(final_icon_im, offset_tuple, final_icon_im)

    # Getting the base image to paste next to the segmentated one for pix2pix
    base_image = PIL.Image.open(data_folder+folder_true+'F1_original.png')
    base_image.thumbnail(resize, PIL.Image.ANTIALIAS)

    offset_x = max((resize[0] - base_image.size[0]) / 2, 0)
    offset_y = max((resize[1] - base_image.size[1]) / 2, 0)

    offset_tuple = (int(offset_x), int(offset_y)) #pack x and y into a tuple

    final_base = PIL.Image.new(mode='RGBA',size=resize,color=(255, 255, 255, 255))
    
    # paste the thumbnail into the full sized image
    final_base.paste(base_image, offset_tuple)

    combined_im = PIL.Image.new('RGBA', (512, 256), color=(255, 255, 255, 255))

    combined_im.paste(final_base,(0, 0))
    combined_im.paste(final_im,(256, 0))

    # save (the PNG format will retain the alpha band unlike JPEG)
    combined_im.save(output_folder+str(filename)+'.png','PNG')
    buf.close()
    buf2.close()
    clear_output()
    print('saved '+ str(filename))
    
    # Adding a 1 to filename so we dont overwrite
    filename += 1

print('run completed')  