In [None]:
import numpy as np
import torch
from kornia.geometry.transform import rotate
from lcmr_ext.encoder.dino_v2_encoder import DinoV2Encoder
from lcmr_ext.loss import CombinedLoss, EfdRegularizerLoss, ImageMaeLoss
from lcmr_ext.modeler import ConditionalDETRModeler, EfdModuleConfig, EfdModuleMode, ModelerConfig
from lcmr_ext.modeler.efd_module import plot_prototypes
from lcmr_ext.renderer.renderer2d import PyTorch3DRenderer2D
from lcmr_ext.utils import optimize_params
from lcmr_ext.utils.sample_efd import ellipse_efd, heart_efd, square_efd
from skimage.draw import polygon2mask
from sklearn.metrics import silhouette_score
from sklearn_extra.cluster import KMedoids
from torch.utils.data import DataLoader
from transformers.models.conditional_detr import ConditionalDetrConfig

from lcmr.dataset import DatasetOptions, EfdGeneratorOptions, RandomDataset
from lcmr.grammar.scene_data import SceneData
from lcmr.reconstruction_model import ReconstructionModel
from lcmr.utils.colors import colors
from lcmr.utils.elliptic_fourier_descriptors import reconstruct_contour
from lcmr.utils.presentation import display_img, make_img_grid

torch.set_float32_matmul_precision("high")

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
raster_size = (128, 128)
batch_size = 128
epochs = 100
show_step = 1

In [None]:
torch.manual_seed(1234)
    
choices = torch.cat([heart_efd()[None], square_efd()[None], ellipse_efd()[None]], dim=0)

train_options = DatasetOptions(
    raster_size=raster_size,
    n_samples=50_000,
    n_objects=4,
    Renderer=PyTorch3DRenderer2D,
    renderer_device=device,
    n_jobs=1,
    return_images=True,
    efd_options=EfdGeneratorOptions(choices=choices),
)

val_options = DatasetOptions(
    raster_size=raster_size,
    n_samples=5_000,
    n_objects=4,
    Renderer=PyTorch3DRenderer2D,
    renderer_device=device,
    n_jobs=1,
    return_images=True,
    efd_options=EfdGeneratorOptions(choices=choices),
)

test_options = DatasetOptions(
    raster_size=raster_size,
    n_samples=5_000,
    n_objects=4,
    Renderer=PyTorch3DRenderer2D,
    renderer_device=device,
    n_jobs=1,
    return_images=True,
    efd_options=EfdGeneratorOptions(choices=choices),
)

train_dataset = RandomDataset(train_options)
val_dataset = RandomDataset(val_options)
test_dataset = RandomDataset(test_options)
    

In [3]:
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, collate_fn=train_dataset.collate_fn)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, drop_last=False, collate_fn=val_dataset.collate_fn)

In [None]:
renderer = PyTorch3DRenderer2D(raster_size, background_color=colors.black, return_alpha=True, n_verts=64, device=device)

encoder = DinoV2Encoder(input_size=(126, 126)).to(device)
modeler = ConditionalDETRModeler(
    config=ModelerConfig(
        encoder_feature_dim=768,
        use_single_scale=True,
        use_confidence=True,
        efd_module_config=EfdModuleConfig(order=16, num_prototypes=3, mode=EfdModuleMode.PrototypeAttention),
    ),
    detr_config=ConditionalDetrConfig(num_queries=8, dropout=0),
).to(device)

model = ReconstructionModel(encoder, modeler, renderer)

lr = 0.0001
optimizer = torch.optim.Adam(list(modeler.parameters()), lr=lr)

In [None]:
state_dict = torch.load("./GP/model_4.pt", weights_only=False).state_dict()  
model.modeler.load_state_dict(state_dict, strict=False)

In [6]:
compiled_model = model

In [8]:
target = next(iter(train_dataloader)).to(device)
with torch.no_grad():
    pred: SceneData = model(target.image_rgb, render=False)
    pred.scene.layer.object.efd[:] = ellipse_efd(modeler.to_efd.config.order)[None, None, None, ...]
    pred.scene.layer.object.appearance.confidence[:] = pred.scene.layer.object.appearance.confidence.round()
    pred = renderer.render(pred.scene)
    img_grid = make_img_grid((pred.image_rgb_top, target.image_rgb_top))
    pred.scene = pred.scene.clone()

In [9]:
def draw_contours(contours, resolution: int = 64):
    device = contours.device
    contours = (contours * resolution / 2 + resolution / 2).detach().cpu().numpy()
    masks = torch.from_numpy(np.array([polygon2mask((resolution, resolution), contour) for contour in contours])).to(device)
    return masks

In [None]:
loss_func = CombinedLoss((0.01, EfdRegularizerLoss()), (1.0, ImageMaeLoss())).to(device)

for i in range(10):
    optimize_params(pred, target, renderer, loss_func=loss_func, params=pred.scene.fields["tsrcb"], show_progress=False, lr=0.05, epochs=100)
    optimize_params(pred, target, renderer, loss_func=loss_func, params=pred.scene.fields["tsrcbe"], show_progress=False, lr=0.0005, epochs=100)

    with torch.no_grad():
        pred = renderer.render(pred.scene)
        img_grid = make_img_grid((pred.image_rgb_top, target.image_rgb_top))
        display_img(img_grid)
    
    prototypes_mask = pred.scene.layer.object.appearance.confidence.flatten() > 0.5
    prototypes = pred.scene.layer.object.efd.flatten(0, 2)[prototypes_mask]

    contours = draw_contours(reconstruct_contour(prototypes), resolution=128).to(torch.float32)
    precomputed = torch.cdist(contours.flatten(-2, -1), contours.flatten(-2, -1), p=2)
    for a in range(0, 360, 1):
        rotated = rotate(contours[..., None].permute(0, 3, 1, 2), angle=torch.tensor([a], device=device, dtype=torch.float32)).permute(0, 2, 3, 1)[..., 0]
        precomputed = torch.minimum(precomputed, torch.cdist(contours.flatten(-2, -1), rotated.flatten(-2, -1), p=2))
    precomputed.fill_diagonal_(0)

    silhouette_scores = []
    clustering_temp = []
    k_range = range(2, 9)
    for k in k_range:
        print(f"Running KMedoids for k={k}...")
        clustering = KMedoids(metric="precomputed", init="k-medoids++", n_clusters=k, method="pam", max_iter=1200).fit(precomputed.cpu())
        score_s = silhouette_score(precomputed.cpu(), clustering.labels_, metric="precomputed")
        silhouette_scores.append(score_s)
        clustering_temp.append(clustering)
    
    idx = np.argmax(silhouette_scores)
    if i < 5:
        idx += 1
    idx = min(idx, len(k_range) - 1)
    print(f"Optimal k based on Silhouette Score (peak): {k_range[idx]}")
    
    clustering = clustering_temp[idx]

    indices = torch.from_numpy(clustering.medoid_indices_).to(device)
    labels = torch.from_numpy(clustering.labels_).to(device)

    plot_prototypes(prototypes[indices])
    
    pred.scene.layer.object.efd.flatten(0, 2)[prototypes_mask] = prototypes[indices][labels]
    
    with torch.no_grad():
        pred = renderer.render(pred.scene)
        img_grid = make_img_grid((pred.image_rgb_top, target.image_rgb_top))
        display_img(img_grid)

In [None]:
torch.save(prototypes[indices], "prototypes.pt")

In [None]:
prototypes = torch.load("prototypes.pt")
plot_prototypes(prototypes)