In [None]:
import torch
from lcmr_ext.encoder.dino_v2_encoder import DinoV2Encoder
from lcmr_ext.loss import ImageMaeLoss, SceneLoss
from lcmr_ext.modeler import ConditionalDETRModeler, EfdModuleConfig, EfdModuleMode, ModelerConfig
from lcmr_ext.renderer.renderer2d import PyTorch3DRenderer2D
from lcmr_ext.utils.sample_efd import ellipse_efd, heart_efd, square_efd
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from transformers.models.conditional_detr import ConditionalDetrConfig

from lcmr.dataset import  DatasetOptions, EfdGeneratorOptions, RandomDataset
from lcmr.grammar.scene_data import SceneData
from lcmr.reconstruction_model import ReconstructionModel
from lcmr.utils.colors import colors
from lcmr.utils.presentation import display_img, make_img_grid

torch.set_float32_matmul_precision("high")

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
raster_size = (128, 128)
batch_size = 128
epochs = 100
show_step = 1

In [None]:
torch.manual_seed(1234)
    
choices = torch.cat([heart_efd()[None], square_efd()[None], ellipse_efd()[None]], dim=0)

train_options = DatasetOptions(
    raster_size=raster_size,
    n_samples=50_000,
    n_objects=4,
    Renderer=PyTorch3DRenderer2D,
    renderer_device=device,
    n_jobs=1,
    return_images=True,
    efd_options=EfdGeneratorOptions(choices=choices),
)

val_options = DatasetOptions(
    raster_size=raster_size,
    n_samples=5_000,
    n_objects=4,
    Renderer=PyTorch3DRenderer2D,
    renderer_device=device,
    n_jobs=1,
    return_images=True,
    efd_options=EfdGeneratorOptions(choices=choices),
)

test_options = DatasetOptions(
    raster_size=raster_size,
    n_samples=5_000,
    n_objects=4,
    Renderer=PyTorch3DRenderer2D,
    renderer_device=device,
    n_jobs=1,
    return_images=True,
    efd_options=EfdGeneratorOptions(choices=choices),
)

train_dataset = RandomDataset(train_options)
val_dataset = RandomDataset(val_options)
test_dataset = RandomDataset(test_options)
    

In [3]:
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, collate_fn=train_dataset.collate_fn)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, drop_last=False, collate_fn=val_dataset.collate_fn)

In [None]:
renderer = PyTorch3DRenderer2D(raster_size, background_color=colors.black, return_alpha=True, n_verts=64, device=device)

encoder = DinoV2Encoder(input_size=(126, 126)).to(device)
modeler = ConditionalDETRModeler(
    config=ModelerConfig(
        encoder_feature_dim=768,
        use_single_scale=True,
        use_confidence=True,
        efd_module_config=EfdModuleConfig(order=16, num_prototypes=3, mode=EfdModuleMode.PrototypeAttention),
    ),
    detr_config=ConditionalDetrConfig(num_queries=8, dropout=0),
).to(device)

model = ReconstructionModel(encoder, modeler, renderer)

lr = 0.0001
optimizer = torch.optim.Adam(list(modeler.parameters()), lr=lr)

In [None]:
model.modeler.to_efd.prototypes[:] = torch.load("prototypes.pt").to(device)
modeler.to_efd.prototypes.requires_grad_(False)

In [5]:
# compiled_model = torch.compile(model)
compiled_model = model

In [None]:
log_dir = "TP"
writer = SummaryWriter(log_dir)

loss_fn_image = ImageMaeLoss().to(device)
loss_fn_scene = SceneLoss().to(device)

val_loss_fn_image = ImageMaeLoss().to(device)
val_loss_fn_scene = SceneLoss().to(device)


for epoch in range(0, epochs):
    loss_fn_scene.reset()
    loss_fn_image.reset()

    modeler.train()
    for j, target in enumerate(epoch_bar := tqdm(train_dataloader, desc="Epoch")):
        target: SceneData = target.to(device)

        optimizer.zero_grad(set_to_none=True)

        pred: SceneData = compiled_model(target.image_rgb, render=False)
        loss = 0
        
        target = pred.clone().detach()
        target.scene.layer.object.appearance.color = target.scene.layer.object.appearance.color.detach()
        target.scene.layer.object.appearance.confidence = target.scene.layer.object.appearance.confidence.detach()
        target.scene.layer.object.efd = target.scene.layer.object.efd.detach()
        target.scene.layer.object.transformation.scale = target.scene.layer.object.transformation.scale.detach()
        target.scene.layer.object.transformation.translation = target.scene.layer.object.transformation.translation.detach()
        n_objects = 8
        target.scene.layer.object.transformation.rotation_vec = torch.nn.functional.normalize(torch.rand(batch_size * n_objects, 2) * 2 - 1, dim=-1).view(batch_size, 1, n_objects, 2)
        target.scene.layer.object.appearance.confidence = target.scene.layer.object.appearance.confidence.round()
        target.scene.layer.object.efd = model.modeler.to_efd.prototypes[torch.randint(0, 3, [128, 1, 8])].detach()
        with torch.no_grad():
            target: SceneData = renderer.render(target.scene)
            
        pred: SceneData = compiled_model(target.image_rgb, render=False)
        
        loss = loss + loss_fn_scene(target, pred)

        loss.backward()
        optimizer.step()

        with torch.no_grad():
            epoch_bar.set_postfix({"scene_image": str(loss_fn_image), "scene_loss": str(loss_fn_scene)})

    modeler.eval()
    with torch.inference_mode():

        pred: SceneData = renderer.render(compiled_model(target.image_rgb, render=False).scene)

        img_grid = make_img_grid((pred.image_rgb_top, target.image_rgb_top))
        writer.add_image("visualization", img_grid.permute(2, 0, 1), global_step=epoch)

        if epoch % show_step == 0:
            display_img(img_grid)

        print(f"Epoch {epoch}")
        loss_fn_image.show()
        loss_fn_scene.show()

        val_loss_fn_image.reset()
        val_loss_fn_scene.reset()

        for j, target in enumerate(epoch_bar := tqdm(val_dataloader, desc="Epoch")):
            target: SceneData = target.to(device)

            pred: SceneData = renderer.render(compiled_model(target.image_rgb, render=False).scene)
            val_loss_fn_image(target, pred)
            val_loss_fn_scene(target, pred)

            epoch_bar.set_postfix({"scene_image": str(val_loss_fn_image), "scene_loss": str(val_loss_fn_scene)})

        val_loss_fn_image.show()
        val_loss_fn_scene.show()

    writer.add_scalars(
        "image_loss",
        {"training": loss_fn_image.compute(), "validation": val_loss_fn_image.compute()},
        global_step=epoch,
    )
    writer.add_scalars(
        "scene_loss",
        {"training": loss_fn_image.compute(), "validation": val_loss_fn_image.compute()},
        global_step=epoch,
    )

    torch.save(modeler, f"{log_dir}/model_{epoch}.pt")

writer.close()