In [None]:
import torch
from torch.utils.data import DataLoader
from lcmr.dataset import DatasetOptions, RandomDataset
from lcmr.utils.presentation import display_img, make_img_grid
from lcmr.utils.colors import colors

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
raster_size = (128, 128)

options = DatasetOptions(
    raster_size=raster_size, n_objects=1, n_samples=1024, return_scenes=False, background_color=colors.black, n_jobs=4, renderer_device=device
)
dataset = RandomDataset(options)
dataloader = DataLoader(dataset, batch_size=128, shuffle=True, drop_last=True, collate_fn=RandomDataset.collate_fn)

In [None]:
from lcmr.reconstruction_model import ReconstructionModel
from lcmr.encoder import ResNet50Encoder
from lcmr.modeler import DummyModeler
from lcmr.renderer.renderer2d import PyTorch3DRenderer2D

encoder = ResNet50Encoder().to(device)
renderer = PyTorch3DRenderer2D(raster_size, background_color=colors.black, device=device, n_verts=32, faces_per_pixel=4)
modeler = DummyModeler(encoder_feature_dim=2048, hidden_dim=128).to(device)

model = ReconstructionModel(encoder, modeler, renderer)

In [None]:
import numpy as np
from tqdm.notebook import tqdm
from torch.utils.tensorboard import SummaryWriter

logdir = "./runs/experiment_1"
writer = SummaryWriter(logdir)

epochs = 201
show_step = 10
lr = 0.00002

optimizer = torch.optim.Adam(list(modeler.parameters()), lr=lr)

batch_len = dataloader.batch_size

for epoch in (bar := tqdm(range(epochs))):
    losses = []
    for j, target_img in enumerate(dataloader):
        optimizer.zero_grad()

        target_img = target_img.to(device)

        pred_scene, pred_img = model(target_img)
        pred_img = pred_img[..., :3]

        loss = (target_img - pred_img).pow(2).mean()

        losses.append(loss.detach().cpu().item())

        loss.backward()
        optimizer.step()

        with torch.no_grad():
            bar.set_description(f"loss: {np.mean(losses):.4f}")
            if epoch % show_step == 0 and j == 0:
                img_grid = make_img_grid((pred_img[:8], target_img[:8]))
                writer.add_image("visualization", img_grid.permute(2, 0, 1), epoch)
                display_img(img_grid)

    dataset.regenerate()

    writer.add_scalar("training loss", np.mean(losses), epoch)

torch.save(modeler, f"{logdir}/model.pt")
writer.close()