In [None]:
import torch
from lcmr_ext.renderer.renderer2d import PyTorch3DRenderer2D
from lcmr_ext.utils.sample_efd import ellipse_efd, heart_efd, square_efd
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

from lcmr.dataset import  DatasetOptions, EfdGeneratorOptions, RandomDataset
from lcmr.grammar.scene_data import SceneData
from lcmr.reconstruction_model import ReconstructionModel
from lcmr.utils.colors import colors
from lcmr.utils.presentation import display_img, make_img_grid

from lcmr.grammar import Scene
from torch.nn.functional import l1_loss

from torchmetrics.aggregation import MeanMetric
from omegaconf import OmegaConf

torch.set_float32_matmul_precision("high")

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
raster_size = (128, 128)
batch_size = 64
epochs = 100
show_step = 1

In [None]:
torch.manual_seed(1234)
    
choices = torch.cat([heart_efd()[None], square_efd()[None], ellipse_efd()[None]], dim=0)

train_options = DatasetOptions(
    raster_size=raster_size,
    n_samples=50_000,
    n_objects=4,
    Renderer=PyTorch3DRenderer2D,
    renderer_device=device,
    n_jobs=1,
    return_images=True,
    efd_options=EfdGeneratorOptions(choices=choices),
)

val_options = DatasetOptions(
    raster_size=raster_size,
    n_samples=5_000,
    n_objects=4,
    Renderer=PyTorch3DRenderer2D,
    renderer_device=device,
    n_jobs=1,
    return_images=True,
    efd_options=EfdGeneratorOptions(choices=choices),
)

test_options = DatasetOptions(
    raster_size=raster_size,
    n_samples=5_000,
    n_objects=4,
    Renderer=PyTorch3DRenderer2D,
    renderer_device=device,
    n_jobs=1,
    return_images=True,
    efd_options=EfdGeneratorOptions(choices=choices),
)

train_dataset = RandomDataset(train_options)
val_dataset = RandomDataset(val_options)
test_dataset = RandomDataset(test_options)
    

In [3]:
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, collate_fn=train_dataset.collate_fn)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, drop_last=False, collate_fn=val_dataset.collate_fn)

In [4]:
import sys
sys.path.append("./object_centric_library")

In [None]:
from object_centric_library.models.monet.model import Monet
from object_centric_library.utils.viz import make_recon_img


width, height = 128, 128
monet_config = OmegaConf.load("./object_centric_library/config/model/monet.yaml").model

monet_config.num_slots = 5
monet_config.decoder_params.w_broadcast = width + 8
monet_config.decoder_params.h_broadcast = height + 8
monet_config.num_blocks_unet = 6
del monet_config["_target_"]

model = Monet(**monet_config, width=width, height=height).to(device)

In [6]:
lr = 0.0001
optimizer = torch.optim.Adam(list(model.parameters()), lr=lr)

In [None]:
log_dir = "MONet"
writer = SummaryWriter(log_dir)

train_loss = MeanMetric()
train_mae = MeanMetric()
val_mae = MeanMetric()

for epoch in (bar := range(0, epochs)):
    train_loss.reset()
    train_mae.reset()

    model.train()
    for j, target in enumerate(tqdm(train_dataloader)):
        batch_size = len(target)
        dataset_size = len(train_dataloader)
        total_examples = batch_size * (epoch * dataset_size + j)

        optimizer.zero_grad(set_to_none=True)

        target = target.to(device)


        pred = model(target.image_rgb.permute(0, 3, 1, 2))
        pred_img = make_recon_img(pred["slot"], pred["mask"]).permute(0, 2, 3, 1)
        pred_alpha = pred["mask"].permute(0, 1, 3, 4, 2)

        with torch.no_grad():
            mae = l1_loss(pred_img, target.image_rgb)
            train_mae.update(mae.mean().cpu().item())

        loss = pred["loss"]
        
        loss.backward()
        optimizer.step()

    val_mae.reset()
    model.eval()

    with torch.no_grad():
        img_grid = make_img_grid((pred_img[:8], target.image_rgb_top))
        writer.add_image("visualization", img_grid.permute(2, 0, 1), global_step=epoch)

        if epoch % show_step == 0:
            display_img(img_grid)

        val_mse = []
        for target_scene in val_dataloader:
            target_scene: Scene = target_scene.to(device)
            target = target.to(device)

            pred = model(target.image_rgb.permute(0, 3, 1, 2))
            pred_img = make_recon_img(pred["slot"], pred["mask"]).permute(0, 2, 3, 1)

            mae = l1_loss(pred_img, target.image_rgb)
            val_mae.update(mae.mean().cpu().item())

        print(f"train loss: {train_mae.compute():.4f}\tval loss: {val_mae.compute():.4f}")
    
    writer.add_scalars(
        "image_loss",
        {"training": train_mae.compute(), "validation": val_mae.compute()},
        global_step=epoch,
    )
    
    torch.save(model, f"{log_dir}/model_{epoch}.pt")

writer.close()