In [None]:
import copy
from collections import defaultdict
from contextlib import nullcontext
from glob import glob
from io import BytesIO

import cairosvg
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from bs4 import BeautifulSoup
from kornia.filters import canny, gaussian_blur2d
from lcmr_ext.encoder.dino_v2_encoder import DinoV2Encoder
from lcmr_ext.loss import CombinedLoss, EfdRegularizerLoss, ImageMaeLoss, ImageMseLoss, SceneLoss
from lcmr_ext.modeler import ConditionalDETRModeler, EfdModuleConfig, EfdModuleMode, ModelerConfig
from lcmr_ext.renderer.renderer2d import PyTorch3DRenderer2D
from lcmr_ext.utils import optimize_params
from lcmr_ext.utils.sample_efd import ellipse_efd, heart_efd, square_efd
from object_centric_library.evaluation.metrics.ari import ari
from omegaconf import OmegaConf
from PIL import Image
from skimage import measure
from torch.nn.functional import l1_loss, mse_loss
from torch.utils.data import DataLoader
from torchmetrics.aggregation import MeanMetric
from torchmetrics.classification import BinaryJaccardIndex
from torchmetrics.image import StructuralSimilarityIndexMeasure
from torchvision.transforms.functional import pil_to_tensor
from tqdm import tqdm
from tqdm.auto import tqdm
from transformers.models.conditional_detr import ConditionalDetrConfig

from lcmr.dataset import DatasetOptions, EfdGeneratorOptions, RandomDataset
from lcmr.grammar import Scene
from lcmr.grammar.scene_data import SceneData
from lcmr.grammar.shapes import Shape2D
from lcmr.reconstruction_model import ReconstructionModel
from lcmr.renderer.renderer2d import OpenGLRenderer2D
from lcmr.utils.colors import colors
from lcmr.utils.elliptic_fourier_descriptors import EfdGeneratorOptions
from lcmr.utils.presentation import display_img, make_img_grid

torch.set_float32_matmul_precision("high")

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
raster_size = (128, 128)
batch_size = 128

In [None]:
torch.manual_seed(1234)
    
choices = torch.cat([heart_efd()[None], square_efd()[None], ellipse_efd()[None]], dim=0)

train_options = DatasetOptions(
    raster_size=raster_size,
    n_samples=50_000,
    n_objects=4,
    Renderer=PyTorch3DRenderer2D,
    renderer_device=device,
    n_jobs=1,
    return_images=False,
    efd_options=EfdGeneratorOptions(choices=choices),
)

val_options = DatasetOptions(
    raster_size=raster_size,
    n_samples=5_000,
    n_objects=4,
    Renderer=PyTorch3DRenderer2D,
    renderer_device=device,
    n_jobs=1,
    return_images=True,
    efd_options=EfdGeneratorOptions(choices=choices),
)

test_options = DatasetOptions(
    raster_size=raster_size,
    n_samples=5_000,
    n_objects=4,
    Renderer=PyTorch3DRenderer2D,
    renderer_device=device,
    n_jobs=1,
    return_images=True,
    efd_options=EfdGeneratorOptions(choices=choices),
)

train_dataset = RandomDataset(train_options)
val_dataset = RandomDataset(val_options)
test_dataset = RandomDataset(test_options)
    

In [None]:
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, collate_fn=train_dataset.collate_fn)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, drop_last=True, collate_fn=val_dataset.collate_fn)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=True, collate_fn=test_dataset.collate_fn)

In [None]:
renderer = PyTorch3DRenderer2D(raster_size, background_color=colors.black, return_alpha=True, n_verts=64, device=device)
encoder = DinoV2Encoder(input_size=(126, 126)).to(device)
modeler = ConditionalDETRModeler(
    config=ModelerConfig(
        encoder_feature_dim=768,
        use_single_scale=True,
        use_confidence=True,
        efd_module_config=EfdModuleConfig(order=16, num_prototypes=3, mode=EfdModuleMode.PrototypeAttention),
    ),
    detr_config=ConditionalDetrConfig(num_queries=8, dropout=0),
).to(device)

model = ReconstructionModel(encoder, modeler, renderer)

In [None]:
def load_state_dict(f):
    state_dict = torch.load(f, weights_only=False).state_dict()
    model.modeler.load_state_dict(state_dict, strict=False)

In [None]:
def mask_from_scene(scene_data: SceneData):
    mask = (scene_data.mask * scene_data.scene.layer.object.appearance.confidence[:, None, ..., 0]).round()
    return mask

def scene_to_index(scene_data: SceneData):
    mask = mask_from_scene(scene_data)
    mask = torch.logical_and(mask, (mask.sum(-1, keepdim=True) == 1))
    indices = torch.argsort(mask.flatten(1, 2).sum(1)[:, None, None].expand_as(mask), dim=-1, descending=True)
    sorted_tensor = torch.gather(mask, -1, indices)
    mask_max_values, mask_max_indices = torch.max(sorted_tensor, dim=-1)
    mask_max_indices[mask_max_values == 0] = -1
    mask_max_indices += 1
    return mask_max_indices

# Ablations

In [None]:
loss_func_image = CombinedLoss((0.01, EfdRegularizerLoss()), (1.0, ImageMaeLoss())).to(device)

test_mae = ImageMaeLoss().to(device)
test_mse = ImageMseLoss().to(device)
test_loss_fn_scene = SceneLoss().to(device)
test_ssim = StructuralSimilarityIndexMeasure(data_range=(0.0, 1.0)).to(device)
test_iou = BinaryJaccardIndex().to(device)
test_ari = MeanMetric()

for j, (random, target) in enumerate(epoch_bar := tqdm(zip(val_dataloader, test_dataloader), desc="Epoch")):
    pred = random.clone().to(device)
    target = target.clone().to(device)

    on = j % 4 + 1
    f = torch.tensor([1.0] * on + [0.0] * (4 - on), device=device)[None, None, :, None]
    pred.scene.layer.object.appearance.confidence[:] = f
    target.scene.layer.object.appearance.confidence[:] = f
    pred = renderer.render(pred.scene)
    target = renderer.render(target.scene)

    optimize_params(lambda: pred, target, renderer=renderer, params=pred.scene.fields["tsrcb"], epochs=250, loss_func=loss_func_image, lr=0.01)
    optimize_params(lambda: pred, target, renderer=renderer, params=pred.scene.fields["tsrcbe"], epochs=250, loss_func=loss_func_image, lr=0.001)
    pred = renderer.render(pred.scene)

    img_grid = make_img_grid([target.image_rgb_top, pred.image_rgb_top], padding=2, pad_value=1)
    display_img(img_grid)

    test_mae(target, pred)
    test_mse(target, pred)
    test_loss_fn_scene(target, pred)
    test_ssim(target.image_rgb.permute(0, 3, 1, 2), pred.image_rgb.permute(0, 3, 1, 2))
    test_iou.update(mask_from_scene(target).any(dim=-1, keepdim=True), mask_from_scene(pred).any(dim=-1, keepdim=True))
    test_ari.update(ari(scene_to_index(target).cpu(), scene_to_index(pred).cpu(), num_ignored_objects=0).mean())

test_mae.show()
test_mse.show()
test_loss_fn_scene.show()
print("SSIM:", test_ssim.compute().item())
print("IoU:", test_iou.compute().item())
print("ARI:", test_ari.compute().item())

In [None]:
load_state_dict("./TI-TP/model_X.pt")

In [None]:
loss_func_image = ImageMaeLoss().to(device)

test_mae = ImageMaeLoss().to(device)
test_mse = ImageMseLoss().to(device)
test_loss_fn_scene = SceneLoss().to(device)
test_ssim = StructuralSimilarityIndexMeasure(data_range=(0.0, 1.0)).to(device)
test_iou = BinaryJaccardIndex().to(device)
test_ari = MeanMetric()

for j, (random, target) in enumerate(epoch_bar := tqdm(zip(val_dataloader, test_dataloader), desc="Epoch")):
    pred = random.clone().to(device)
    target = target.clone().to(device)

    on = j % 4 + 1
    f = torch.tensor([1.0] * on + [0.0] * (4 - on), device=device)[None, None, :, None]
    pred.scene.layer.object.appearance.confidence[:] = f
    target.scene.layer.object.appearance.confidence[:] = f
    pred = renderer.render(pred.scene)
    target = renderer.render(target.scene)
    
    latent = encoder(pred.image_rgb)
    z = modeler(latent, return_z=True)
    z = z.detach()
    pred = renderer.render(modeler(latent, custom_z=z))
    
    optimize_params(lambda: renderer.render(modeler(latent, custom_z=z)), target, renderer=renderer, params=[z], epochs=250, loss_func=loss_func_image, lr=0.01, show_progress=False)
    optimize_params(lambda: renderer.render(modeler(latent, custom_z=z)), target, renderer=renderer, params=[z], epochs=250, loss_func=loss_func_image, lr=0.001, show_progress=False)
    pred = renderer.render(modeler(latent, custom_z=z))
    
    test_mae(target, pred)
    test_mse(target, pred)
    test_loss_fn_scene(target, pred)
    test_ssim(target.image_rgb.permute(0, 3, 1, 2), pred.image_rgb.permute(0, 3, 1, 2))
    test_iou.update(mask_from_scene(target).any(dim=-1, keepdim=True), mask_from_scene(pred).any(dim=-1, keepdim=True))
    test_ari.update(ari(scene_to_index(target).cpu(), scene_to_index(pred).cpu(), num_ignored_objects=0).mean())

test_mae.show()
test_mse.show()
test_loss_fn_scene.show()
print("SSIM:", test_ssim.compute().item())
print("IoU:", test_iou.compute().item())
print("ARI:", test_ari.compute().item())

# Main results

### DVP+GP, DVP+TI, DVP+TP, DVP+GI, DVP+TI+, DVP+TI-TP

In [None]:
test_mae = ImageMaeLoss().to(device)
test_mse = ImageMseLoss().to(device)
test_loss_fn_scene = SceneLoss().to(device)
test_ssim = StructuralSimilarityIndexMeasure(data_range=(0.0, 1.0)).to(device)
test_iou = BinaryJaccardIndex().to(device)
test_ari = MeanMetric()


names = ["./GP/model_X.pt", "./TI/model_X.pt", "./TP/model_X.pt", "./GI/model_19.pt", "./TI+_2/model_X.pt", "./TI-TP/model_X.pt"]
for name in names:
    print("\n==========================")
    print(">", name)
    print("==========================")
    load_state_dict(name)
    modeler.eval()
    with torch.inference_mode():
        test_mae.reset()
        test_mse.reset()
        test_loss_fn_scene.reset()
        test_ssim.reset()
        test_iou.reset()
        test_ari.reset()
        for j, target in enumerate(epoch_bar := tqdm(test_dataloader, desc="Epoch")):
            target: SceneData = target.to(device)
            target = renderer.render(target.scene)
            pred: SceneData = renderer.render(model(target.image_rgb, render=False).scene)
            
            test_mae(target, pred)
            test_mse(target, pred)
            test_loss_fn_scene(target, pred)
            test_ssim(target.image_rgb.permute(0, 3, 1, 2), pred.image_rgb.permute(0, 3, 1, 2))
            test_iou.update(mask_from_scene(target).any(dim=-1, keepdim=True), mask_from_scene(pred).any(dim=-1, keepdim=True))
            test_ari.update(ari(scene_to_index(target).cpu(), scene_to_index(pred).cpu(), num_ignored_objects=0).mean())
            
            epoch_bar.set_postfix({"mae": str(test_mae), "scene": str(test_loss_fn_scene)})
        test_mae.show()
        test_mse.show()
        test_loss_fn_scene.show()
        print("SSIM:", test_ssim.compute().item())
        print("IoU:", test_iou.compute().item())
        print("ARI:", test_ari.compute().item())

### DVP+TI-TP-OptP

In [None]:
test_mae = ImageMaeLoss().to(device)
test_mse = ImageMseLoss().to(device)
test_loss_fn_scene = SceneLoss().to(device)
test_ssim = StructuralSimilarityIndexMeasure(data_range=(0.0, 1.0)).to(device)
test_iou = BinaryJaccardIndex().to(device)
test_ari = MeanMetric()


names = ["./TI-TP/model_X.pt"]
for name in names:
    print("\n==========================")
    print(">", name)
    print("==========================")
    load_state_dict(name)
    modeler.eval()
    with nullcontext():
        test_mae.reset()
        test_mse.reset()
        test_loss_fn_scene.reset()
        test_ssim.reset()
        test_iou.reset()
        test_ari.reset()
        for j, target in enumerate(epoch_bar := tqdm(test_dataloader, desc="Epoch")):
            target: SceneData = target.to(device)
            with torch.no_grad():
                target = renderer.render(target.scene)
                pred: SceneData = renderer.render(model(target.image_rgb, render=False).scene)
            
            loss_func_image = ImageMaeLoss().to(device)
            pred = pred.clone()
            optimize_params(pred, target, renderer=renderer, params=pred.scene.fields["tsrcb"], epochs=100, loss_func=loss_func_image, lr=0.01)
            pred = renderer.render(pred.scene)
            
            test_mae(target, pred)
            test_mse(target, pred)
            test_loss_fn_scene(target, pred)
            test_ssim(target.image_rgb.permute(0, 3, 1, 2), pred.image_rgb.permute(0, 3, 1, 2))
            test_iou.update(mask_from_scene(target).any(dim=-1, keepdim=True), mask_from_scene(pred).any(dim=-1, keepdim=True))
            test_ari.update(ari(scene_to_index(target).cpu(), scene_to_index(pred).cpu(), num_ignored_objects=0).mean())
            
            epoch_bar.set_postfix({"mae": str(test_mae), "scene": str(test_loss_fn_scene)})
        test_mae.show()
        test_mse.show()
        test_loss_fn_scene.show()
        print("SSIM:", test_ssim.compute().item())
        print("IoU:", test_iou.compute().item())
        print("ARI:", test_ari.compute().item())

### DVP+TI-TP-OptZ

In [None]:
test_mae = ImageMaeLoss().to(device)
test_mse = ImageMseLoss().to(device)
test_loss_fn_scene = SceneLoss().to(device)
test_ssim = StructuralSimilarityIndexMeasure(data_range=(0.0, 1.0)).to(device)
test_iou = BinaryJaccardIndex().to(device)
test_ari = MeanMetric()

loss_func_image = ImageMaeLoss().to(device)

names = ["./TP_TI_2/model_98.pt"]
for name in names:
    print("\n==========================")
    print(">", name)
    print("==========================")
    load_state_dict(name)
    modeler.eval()
    with nullcontext():
        test_mae.reset()
        test_mse.reset()
        test_loss_fn_scene.reset()
        test_ssim.reset()
        test_iou.reset()
        test_ari.reset()
        for j, target in enumerate(epoch_bar := tqdm(test_dataloader, desc="Epoch")):
            target: SceneData = target.to(device)
            with torch.no_grad():
                target = renderer.render(target.scene)
                latent = encoder(target.image_rgb)
                z = modeler(latent, return_z=True)
                pred = renderer.render(modeler(latent, custom_z=z))
            
            pred = pred.clone() 
            optimize_params(lambda: renderer.render(modeler(latent, custom_z=z)), target, renderer=renderer, params=[z], epochs=100, loss_func=loss_func_image, lr=0.01, show_progress=False)
            pred = renderer.render(modeler(latent, custom_z=z))
            
            test_mae(target, pred)
            test_mse(target, pred)
            test_loss_fn_scene(target, pred)
            test_ssim(target.image_rgb.permute(0, 3, 1, 2), pred.image_rgb.permute(0, 3, 1, 2))
            test_iou.update(mask_from_scene(target).any(dim=-1, keepdim=True), mask_from_scene(pred).any(dim=-1, keepdim=True))
            test_ari.update(ari(scene_to_index(target).cpu(), scene_to_index(pred).cpu(), num_ignored_objects=0).mean())
            
            epoch_bar.set_postfix({"mae": str(test_mae), "scene": str(test_loss_fn_scene)})
        test_mae.show()
        test_mse.show()
        test_loss_fn_scene.show()
        print("SSIM:", test_ssim.compute().item())
        print("IoU:", test_iou.compute().item())
        print("ARI:", test_ari.compute().item())

### MONet

In [None]:
import sys
sys.path.append("./object_centric_library")

In [None]:
from object_centric_library.models.monet.model import Monet
from object_centric_library.utils.viz import make_recon_img


width, height = 128, 128
monet_config = OmegaConf.load("./object_centric_library/config/model/monet.yaml").model

monet_config.num_slots = 5
monet_config.decoder_params.w_broadcast = width + 8
monet_config.decoder_params.h_broadcast = height + 8
monet_config.num_blocks_unet = 6
del monet_config["_target_"]

model_monet = Monet(**monet_config, width=width, height=height).to(device)
state_dict = torch.load("./MONET/model_X.pt", weights_only=False).state_dict()
model_monet.load_state_dict(state_dict, strict=False)

In [None]:
test_mae = ImageMaeLoss().to(device)
test_mse = ImageMseLoss().to(device)
test_loss_fn_scene = SceneLoss().to(device)
test_ssim = StructuralSimilarityIndexMeasure(data_range=(0.0, 1.0)).to(device)
test_iou = BinaryJaccardIndex().to(device)
test_ari = MeanMetric()


names = ["MONet"]
for name in names:
    print("\n==========================")
    print(">", name)
    print("==========================")

    with torch.inference_mode():
        test_mae.reset()
        test_mse.reset()
        test_loss_fn_scene.reset()
        test_ssim.reset()
        test_iou.reset()
        test_ari.reset()
        for j, target in enumerate(epoch_bar := tqdm(test_dataloader, desc="Epoch")):
            target: SceneData = target.to(device)
            target = renderer.render(target.scene)
            
            pred = model_monet(target.image_rgb.permute(0, 3, 1, 2))
            pred_img = make_recon_img(pred["slot"], pred["mask"]).permute(0, 2, 3, 1)
            pred_mask = pred["mask"] 
            pred = SceneData(image=pred_img, batch_size=[batch_size])

            test_mae(target, pred)
            test_mse(target, pred)

            test_ssim(target.image_rgb.permute(0, 3, 1, 2), pred.image_rgb.permute(0, 3, 1, 2))
            
            mask = pred_mask.permute(0, 3, 4, 2, 1).flatten(-2, -1)[..., 1:].round()
            test_iou.update(mask_from_scene(target).any(dim=-1, keepdim=True), mask.any(dim=-1, keepdim=True))
            
            mask = torch.logical_and(mask, (mask.sum(-1, keepdim=True) == 1))
            indices = torch.argsort(mask.flatten(1, 2).sum(1)[:, None, None].expand_as(mask), dim=-1, descending=True)
            sorted_tensor = torch.gather(mask, -1, indices)
            mask_max_values, mask_max_indices = torch.max(sorted_tensor, dim=-1)
            mask_max_indices[mask_max_values == 0] = -1
            mask_max_indices += 1
            test_ari.update(ari(scene_to_index(target).cpu(), mask_max_indices.cpu(), num_ignored_objects=0).mean())
            
            epoch_bar.set_postfix({"mae": str(test_mae)})
            
        test_mae.show()
        test_mse.show()
        test_loss_fn_scene.show()
        print("SSIM:", test_ssim.compute().item())
        print("IoU:", test_iou.compute().item())
        print("ARI:", test_ari.compute().item())

### Opt-Iter

In [None]:
def OptIter_object_init(pred, target):
    device = pred.device
    
    no_diff_thresh = 0.01
    quantile_interval = 100
    map = l1_loss(pred, target, reduction="none").mean(dim=-1, keepdim=True)
    map[map < no_diff_thresh] = 0
    bin_label = (map * quantile_interval).to(torch.uint8)
    
    def best_label(bin_labels):
        blobs_labels = measure.label(bin_labels, background=0)
        count = np.bincount(blobs_labels.flatten(), weights=bin_labels.flatten())
        mask = blobs_labels == np.argmax(count)

        props = measure.regionprops(mask.squeeze(-1).astype(np.uint8))
        angle = props[0].orientation
        return mask, np.array([np.cos(angle), np.sin(angle)])
        
    
    blobs_mask, rotation_vec = (torch.from_numpy(x).to(device) for x in np.vectorize(best_label, signature='(w,h,c)->(w,h,c),(n)')(bin_label.cpu()))
        
    xs = torch.linspace(0, 1, steps=raster_size[0], device=device)
    ys = torch.linspace(0, 1, steps=raster_size[1], device=device)
    xy = torch.roll(torch.cartesian_prod(ys, xs).view(1, *raster_size, 2), 1, dims=-1)
    xy_masked = (xy * (blobs_mask / blobs_mask)).flatten(1, 2)
    centroid = xy_masked.nanmean(dim=1, keepdim=True)
    idx = torch.pow(xy_masked - centroid, 2).sum(dim=-1, keepdim=True).nan_to_num(torch.inf).argmin(dim=1, keepdim=True)
    translation = torch.gather(xy_masked, 1, idx.expand((-1, -1, 2)))[:, 0]
    color = (target * (blobs_mask / blobs_mask)).nanmean(dim=(1, 2))
    return translation, color, rotation_vec

def OptIter(target, object_size=4):
    device = target.image_rgb.device
    batch_size = target.shape[0]
    layer_size = 1
    efd_size = 16
    translation = torch.zeros((batch_size, layer_size, object_size, 2), dtype=torch.float32)
    scale = torch.ones((batch_size, layer_size, object_size, 2), dtype=torch.float32)
    color = torch.zeros((batch_size, layer_size, object_size, 3), dtype=torch.float32)
    confidence = torch.zeros((batch_size, layer_size, object_size, 1), dtype=torch.float32)
    rotation_vec = torch.ones((batch_size, layer_size, object_size, 2), dtype=torch.float32)
    objectShape = torch.ones((batch_size, layer_size, object_size, 1), dtype=torch.uint8) * Shape2D.EFD_SHAPE.value
    efd = ellipse_efd(efd_size, -0.75)[None, None, None].to(device).expand(batch_size, layer_size, object_size, efd_size, 4)
    backgroundColor = target.image_rgb.flatten(1, 2).median(dim=1)[0]

    pred = SceneData(scene=Scene.from_tensors_sparse(
        translation=translation,
        scale=scale,
        color=color,
        confidence=confidence,
        rotation_vec=rotation_vec,
        objectShape=objectShape,
        efd=efd,
        backgroundColor=backgroundColor,
    ), batch_size=[batch_size]).to(device)

    loss_func_image = CombinedLoss((0.01, EfdRegularizerLoss()), (1.0, ImageMaeLoss())).to(device)
    pred = renderer.render(pred.scene)

    for i in range(object_size):
        pred.scene.layer.object.appearance.confidence[:, 0, i] = 1
        translation, color, rotation_vec = OptIter_object_init(pred.image_rgb, target.image_rgb)
        pred.scene.layer.object.appearance.color[:, 0, i] = color
        pred.scene.layer.object.transformation.translation[:, 0, i] = translation
        pred.scene.layer.object.transformation.rotation_vec[:, 0, i] = rotation_vec
        pred.scene.layer.object.transformation.scale[:, 0, i] = 0.1
        
        optimize_params(pred, target, renderer=renderer, params=pred.scene.fields["tsrcb"], epochs=100, loss_func=loss_func_image, lr=0.01)
        optimize_params(pred, target, renderer=renderer, params=pred.scene.fields["tsrcbe"], epochs=400, loss_func=loss_func_image, lr=0.0005)
        pred = renderer.render(pred.scene)
        
    optimize_params(pred, target, renderer=renderer, params=pred.scene.fields["tsrcbe"], epochs=100, loss_func=loss_func_image, lr=0.0005)
    pred = renderer.render(pred.scene)
    return pred

In [None]:
test_mae = ImageMaeLoss().to(device)
test_mse = ImageMseLoss().to(device)
test_loss_fn_scene = SceneLoss().to(device)
test_ssim = StructuralSimilarityIndexMeasure(data_range=(0.0, 1.0)).to(device)
test_iou = BinaryJaccardIndex().to(device)
test_ari = MeanMetric()

names = ["Opt-Iter"]
for name in names:
    print("\n==========================")
    print(">", name)
    print("==========================")
    test_mae.reset()
    test_mse.reset()
    test_loss_fn_scene.reset()
    test_ssim.reset()
    test_iou.reset()
    test_ari.reset()
    for j, target in enumerate(epoch_bar := tqdm(test_dataloader, desc="Epoch")):
        target: SceneData = target.to(device)
        target = renderer.render(target.scene)
        pred: SceneData = OptIter(target)
        
        test_mae(target, pred)
        test_mse(target, pred)
        test_loss_fn_scene(target, pred)
        test_ssim(target.image_rgb.permute(0, 3, 1, 2), pred.image_rgb.permute(0, 3, 1, 2))
        test_iou.update(mask_from_scene(target).any(dim=-1, keepdim=True), mask_from_scene(pred).any(dim=-1, keepdim=True))
        test_ari.update(ari(scene_to_index(target).cpu(), scene_to_index(pred).cpu(), num_ignored_objects=0).mean())
        
        epoch_bar.set_postfix({"mae": str(test_mae), "scene": str(test_loss_fn_scene)})
        
    test_mae.show()
    test_mse.show()
    test_loss_fn_scene.show()
    print("SSIM:", test_ssim.compute().item())
    print("IoU:", test_iou.compute().item())
    print("ARI:", test_ari.compute().item())

In [None]:
test_mae = ImageMaeLoss().to(device)
test_mse = ImageMseLoss().to(device)
test_loss_fn_scene = SceneLoss().to(device)
test_ssim = StructuralSimilarityIndexMeasure(data_range=(0.0, 1.0)).to(device)
test_iou = BinaryJaccardIndex().to(device)
test_ari = MeanMetric()

names = ["LIVE"]
for name in names:
    print("\n==========================")
    print(">", name)
    print("==========================")
    test_mae.reset()
    test_mse.reset()
    test_loss_fn_scene.reset()
    test_ssim.reset()
    test_iou.reset()
    test_ari.reset()
    for k, target in zip(range(64), val_dataset):
        
        path = glob(f"./LIVE-Layerwise-Image-Vectorization/LIVE/val_dataset_log/*_{k}")[0]
        with open(f"{path}/output-svg/1-1-1-1-1.svg") as fp:
            soup = BeautifulSoup(fp, 'xml')
            
        pred = Image.open(BytesIO(cairosvg.svg2png(bytestring=soup.prettify(), background_color='white', output_width=raster_size[0], output_height=raster_size[1])))
        pred = pil_to_tensor(pred).permute(1, 2, 0)[None] / 255
        pred = SceneData(image=pred, batch_size=[1])
        
        test_mae(target, pred)
        test_mse(target, pred)
        test_ssim(target.image_rgb.permute(0, 3, 1, 2), pred.image_rgb.permute(0, 3, 1, 2))
        
        masks = []
        
        for i in range(5):
            soup_copy = copy.copy(soup)
            for j, path in enumerate(soup_copy.find_all("path")):
                if i != j:
                    path.decompose()
                else:
                    path["fill"] = "rgb(255, 255, 255)"

            svg_raster_bytes = cairosvg.svg2png(bytestring=soup_copy.prettify(), background_color='black', output_width=raster_size[0], output_height=raster_size[1]) 
            svg_raster = Image.open(BytesIO(svg_raster_bytes))
            
            mask = pil_to_tensor(svg_raster).permute(1, 2, 0)[None, ..., 0, None] / 255
            if (mask.sum() + 0.5).round() > 128*128 * 0.5:
                mask = 1 - mask
            masks.append(mask)
        
        background_idx = min(enumerate([mask.round().sum() for mask in masks]), key=lambda x: x[1])[0]
        
        del masks[background_idx]
        
        target: SceneData = target.to(device)
        target = renderer.render(target.scene).cpu()
        
        test_iou.update(mask_from_scene(target).any(dim=-1, keepdim=True).cpu(), torch.cat(masks, dim=-1).round().any(dim=-1, keepdim=True).cpu())
        
        
        def scene_to_index2():
            mask = torch.cat(masks, dim=-1).round()
            mask = torch.logical_and(mask, (mask.sum(-1, keepdim=True) == 1))
            indices = torch.argsort(mask.flatten(1, 2).sum(1)[:, None, None].expand_as(mask), dim=-1, descending=True)
            sorted_tensor = torch.gather(mask, -1, indices)
            mask_max_values, mask_max_indices = torch.max(sorted_tensor, dim=-1)
            mask_max_indices[mask_max_values == 0] = -1
            mask_max_indices += 1
            return mask_max_indices
        
        test_ari.update(ari(scene_to_index(target).cpu(), scene_to_index2().cpu(), num_ignored_objects=0).mean())
        
    test_mae.show()
    test_mse.show()
    test_loss_fn_scene.show()
    print("SSIM:", test_ssim.compute().item())
    print("IoU:", test_iou.compute().item())
    print("ARI:", test_ari.compute().item())

# Gradient

In [None]:
def iterate_pairs(iterable):
    it = iter(iterable)
    while True:
        try:
            yield next(it), next(it)
        except StopIteration:
            return

In [None]:
loss_func_scene = SceneLoss().to(device)

# INSTRUCTION
# Run this and next 3 cells, to collect results for MAE
# Then uncomment disabled MSE lines, comment lines corresponding to MAE, and collect results for MSE

loss_func_image = ImageMaeLoss().to(device)
# loss_func_image = ImageMseLoss().to(device)

ws = [0.05] + np.linspace(0.1, 1.0, 10).tolist()

global_metric_dict = dict()
for w in ws:
    metric_dict = defaultdict(MeanMetric)

    for i, (a, b) in enumerate(iterate_pairs(tqdm(train_dataloader))):
        if i == 16:
            break
        
        a = a.to(device)
        b = b.to(device)
        
        a = renderer.render(a.scene)

        ab: Scene = Scene.interpolate(a.scene, b.scene, weight=w).to(device)
        ab.fields["f"][:] = torch.round(ab.fields["f"][:])

        names = "tsrceb"
        fields = ab.fields[names]
        for f in fields:
            f.requires_grad = True


        ab = renderer.render(ab)

        grad_scene = torch.autograd.grad(loss_func_scene(a, ab), fields)
        grad_mse = torch.autograd.grad(loss_func_image(a, ab), fields)

        for n, g1, g2 in zip(names, grad_scene, grad_mse):
            if n == "e":
                g1 = g1.flatten(-2, -1)
                g2 = g2.flatten(-2, -1)
            if n == "s":
                g1 = g1[..., 0, None]
                g2 = g2[..., 0, None]

            sim = torch.nn.functional.cosine_similarity(g1, g2, dim=-1)

            metric_dict[n].update(sim.mean().detach().cpu())

    for n in metric_dict.keys():
        metric_dict[n] = metric_dict[n].compute().item()
    global_metric_dict[w] = metric_dict

In [None]:
df = pd.DataFrame(global_metric_dict).transpose()
df = pd.melt(df, var_name="property", ignore_index=False)
df = df.reset_index(names=["t"])

In [None]:
mae_df = df
mae_df["loss"] = "MAE"
# mse_df = df
# mse_df["loss"] = "MSE"

In [None]:
df = pd.concat([mse_df, mae_df])
df.to_csv("gradient.csv", sep='\t', encoding='utf-8', index=False, header=True)

In [None]:
sns.set_theme(style="whitegrid")
sns.set_context("notebook")
pd.set_option("display.max_colwidth", 100)
matplotlib.rcParams["font.family"] = "Times New Roman"
matplotlib.rcParams["axes.formatter.use_mathtext"] = True
matplotlib.rcParams.update({"font.size": 36})
matplotlib.rcParams.update({'text.usetex': True, })

In [None]:
SMALL_SIZE = 15
MEDIUM_SIZE = 15
BIGGER_SIZE = 15

plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=SMALL_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title

In [None]:
df = pd.read_csv("./gradient.csv", sep="\t")
df

In [None]:
import matplotlib.pyplot as plt
from matplotlib import font_manager

font_path = "Times New Roman.ttf"
font_manager.fontManager.addfont(font_path)
prop = font_manager.FontProperties(fname=font_path)

In [None]:
plt.gcf().set_size_inches(5, 3.5)
sns.lineplot(data=df, x="Noise strength", y="Cosine simmilarity", hue="Property", style="Loss")

plt.gca().set_xlabel('$\\alpha$', fontsize=16)
plt.legend(ncol=1, loc="center right", bbox_to_anchor=(1.5, 0.46))
plt.savefig("gradient.pdf", bbox_inches="tight", pad_inches=0.01)

In [None]:
def iterate_pairs(iterable):
    it = iter(iterable)
    while True:
        try:
            yield next(it), next(it)
        except StopIteration:
            return

In [None]:
loss_func_scene = SceneLoss().to(device)
loss_func_image = ImageMaeLoss().to(device)

ws = [0.05] + np.linspace(0.1, 1.0, 10).tolist()

global_list = list()

for w in ws:
    unoptimized = MeanMetric().to(device)
    optimized = MeanMetric().to(device)
    
    for i, (a, b) in enumerate(iterate_pairs(tqdm(train_dataloader))):
        if i == 16:
            break
        
        a: SceneData = renderer.render(a.to(device).scene)
        b: SceneData = b.to(device)

        ab: Scene = Scene.interpolate(a.scene, b.scene, weight=w).to(device)
        ab.fields["f"][:] = torch.round(ab.fields["f"][:])
        ab: SceneData = renderer.render(ab)

        optimize_params(ab, a, renderer=renderer, params=ab.scene.fields["tsrcb"], epochs=25, loss_func=loss_func_image, lr=0.01)
        optimize_params(ab, a, renderer=renderer, params=[ab.scene.fields["e"]], epochs=25, loss_func=loss_func_image, lr=0.001)
        optimize_params(ab, a, renderer=renderer, params=ab.scene.fields["tsrcbe"], epochs=50, loss_func=loss_func_image, lr=0.001)
        ab_optimized = renderer.render(ab.scene)

        mse1 = mse_loss(a.image_rgb, ab.image_rgb).reshape(1)
        mse2 = mse_loss(a.image_rgb, ab_optimized.image_rgb).reshape(1)
        
        unoptimized.update(mse1)
        optimized.update(mse2)
        
    global_list.append(dict(t=w, unoptimized=unoptimized.compute().item(), optimized=optimized.compute().item()))

In [None]:
df = pd.DataFrame(global_list)
df.to_csv("optimized.csv", sep='\t', encoding='utf-8', index=False, header=True)

In [None]:
df = pd.read_csv("./optimized.csv", sep="\t")
df = df.set_index("Noise strength")
df = pd.melt(df, var_name=" ", value_name="MAE", ignore_index=False)
df = df.reset_index()
df

In [None]:
plt.gcf().set_size_inches(5, 3.5)
sns.lineplot(df, x="Noise strength", y="MAE", hue=" ")

T = plt.legend().get_texts()
T[0].set_text('before optimization')
T[1].set_text('after optimization')
plt.gca().set_xlabel('$\\alpha$', fontsize=16)
plt.savefig("optimized.pdf", bbox_inches="tight")

# Overlapping objects

In [None]:
opengl_renderer = OpenGLRenderer2D(raster_size, contours_only=True, background_color=colors.white)

In [None]:
load_state_dict("./TI-TP/model_X.pt")

img_grids = []
for d in [0.15, 0.1, 0.05]:
    target = next(iter(val_dataloader))[4:6].clone().to(device)

    target.scene.fields["t"][:] = 0.5
    target.scene.fields["s"][:] = 0.1
    target.scene.fields["r"][..., 0] = -1
    target.scene.fields["r"][..., 1] = 0
    target.scene.fields["r"][..., 1, 0] = 1
    target.scene.fields["f"][..., :2, :] = 1
    target.scene.fields["c"][..., 0, :] = target.scene.fields["c"][..., 1, :]
    target.scene.fields["e"][..., 1, :, :] = target.scene.fields["e"][..., 0, :, :]
    target.scene.fields["t"][..., 0, 0] = 0.5 + d
    target.scene.fields["t"][..., 1, 0] = 0.5 - d

    target = renderer.render(target.scene)

    with torch.no_grad():
        pred_dvp: SceneData = model(target.image_rgb, render=False)
        pred_dvp.scene.fields["f"][:] = pred_dvp.scene.fields["f"].round()
        pred_dvp.scene.fields["c"][..., 2, :] = colors.red[:3]
        pred_dvp.scene.fields["c"][..., 3, :] = colors.blue[:3]
        pred_dvp.scene.fields["c"][..., 4, :] = colors.red[:3]
        pred_dvp = opengl_renderer.render(pred_dvp.scene)

    pred_live = OptIter(target, object_size=2)
    pred_live.scene.fields["c"][..., 0, :] = colors.red[:3]
    pred_live.scene.fields["c"][..., 1, :] = colors.blue[:3]
    pred_live = opengl_renderer.render(pred_live.scene)

    mask = model_monet(target.image_rgb.permute(0, 3, 1, 2))["mask"].detach().permute(0, 3, 4, 2, 1).flatten(-2, -1)[..., 1:].round().cpu()
    A = canny(mask[..., 0, None].permute(0, 3, 1, 2), kernel_size=(9, 9))[1].permute(0, 2, 3, 1).to(torch.int32)
    B = canny(mask[..., 1, None].permute(0, 3, 1, 2), kernel_size=(9, 9))[1].permute(0, 2, 3, 1).to(torch.int32)
    A = torch.tensor([[0, 0, 0], [1, 0, 0]], dtype=torch.float32)[A][..., 0, :]
    B = torch.tensor([[0, 0, 0], [0, 0, 1]], dtype=torch.float32)[B][..., 0, :]
    AB = A + B
    AB[(AB == torch.zeros(3)[None, None, None]).all(dim=-1, keepdim=True).repeat(1, 1, 1, 3)] = 1
    AB = gaussian_blur2d(AB.permute(0, 3, 1, 2), (3, 3), (0.5, 0.5)).permute(0, 2, 3, 1)

    target.scene.fields["c"][..., 0, :] = colors.red[:3]
    target.scene.fields["c"][..., 1, :] = colors.blue[:3]
    contours = opengl_renderer.render(target.scene)
    img_grid = make_img_grid([target.image_rgb_top.cpu(), contours.image_rgb_top, pred_dvp.image_rgb, pred_live.image_rgb, AB], padding=2, pad_value=1, nrow=10)
    img_grids.append(img_grid)
    display_img(img_grid)

In [None]:
display_img(torch.cat([x[2:] for x in img_grids], dim=0))