In [None]:
import torch
from torch.utils.data import DataLoader
from lcmr.utils.presentation import display_img
from lcmr_ext.dataset.dataset_7seg import Dataset7Seg
from lcmr_ext.renderer.renderer2d import PyTorch3DRenderer2D
from lcmr_ext.utils import collate_fn

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

In [None]:
from lcmr.renderer.renderer2d import OpenGLRenderer2D
from lcmr_ext.dataset.dataset_options import DatasetOptions
from lcmr_ext.dataset.dataset_random import RandomDataset


raster_size = (128, 128)
dataset = RandomDataset(DatasetOptions(data_len=1024, scenes=True, Renderer=OpenGLRenderer2D))
dataloader = DataLoader(dataset, batch_size=128, shuffle=True, drop_last=True, collate_fn=collate_fn)

In [None]:
from transformers.models.detr.modeling_detr import DetrConfig

from lcmr.reconstruction_model import ReconstructionModel
from lcmr_ext.encoder import ResNet50Encoder
from lcmr_ext.modeler import DETRModeler
from lcmr.utils.matcher import Matcher

renderer = PyTorch3DRenderer2D(raster_size, background_color=torch.tensor([0.0, 0.0, 0.0, 1.0]), device=device, with_alpha=False)
encoder = ResNet50Encoder().to(device)
modeler = DETRModeler(DetrConfig(num_queries=7)).to(device)

model = ReconstructionModel(encoder, modeler, renderer)
matcher = Matcher()

In [None]:
import numpy as np
from tqdm import tqdm
from scipy.optimize import linear_sum_assignment

epochs = 201
lr = 0.0001
show_step = 20

optimizer = torch.optim.Adam(list(modeler.parameters()), lr=lr)

def t_s_a_c(scene):
    t = scene.layer.object.transformation.translation
    s = scene.layer.object.transformation.scale
    a = scene.layer.object.transformation.angle
    c = scene.layer.object.appearance.color
    return t, s, a, c

batch_len = dataloader.batch_size
layer_len = 1

for epoch in (bar := tqdm(range(epochs))):
    
    for target_img, target_scene in dataloader:
        optimizer.zero_grad()
        
        target_img = target_img.to(device)
        target_scene = target_scene.to(device)
    
        pred_scene, pred_img = model(target_img)
        pred_img = pred_img[..., :3]
        target_t, target_s, target_a, target_c = t_s_a_c(target_scene)
        pred_t, pred_s, pred_a, pred_c = t_s_a_c(pred_scene)
        
            
        ind_a, ind_b = matcher.match(target_t, pred_t)
        target_t, target_s, target_a, target_c = matcher.gather(ind_a, (target_t, target_s, target_a, target_c))
        pred_t, pred_s, pred_a, pred_c = matcher.gather(ind_b, (pred_t, pred_s, pred_a, pred_c))


        loss = (target_t - pred_t).pow(2).mean()
        #loss = loss + 0.05 * (target_s - pred_s).pow(2).mean()
        # bezpośrednio na kącie nie bardzo jest sens, ale cos kąta pomiędzy ma sens
        #loss = loss + 0.1 * (1 - torch.cos((target_a - pred_a) * 4)).mean() # Dowolne dopasowanie osi
        
        #cos_axis_aligned = torch.cos((target_a - pred_a) * 2).detach()
        #loss = loss + 0.5 * ((1 - cos_axis_aligned) * (target_s - pred_s.flip(dims=[-1])).pow(2)).mean() # Dobry kąt, minimalizujemy złą skalę
        #loss = loss + 0.5 * ((cos_axis_aligned + 1) * (target_s - pred_s).pow(2)).mean() # Zły kąt, minimalizujemy dobrą skalę
        #loss = loss + (target_s - pred_s).pow(2).mean()
        
        loss = loss + 0.5 * (target_c - pred_c).pow(2).mean()
        loss = loss + 0.1 * (target_s.mean(dim=-1) - pred_s.mean(dim=-1)).pow(2).mean()

        loss = loss + (target_img - pred_img).pow(2).mean()

        loss.backward()
        optimizer.step()

        with torch.no_grad():
            bar.set_description(f"loss: {loss.detach().cpu().item():.4f}")
            # might need to something.clamp_(0.0, 1.0)

            if epoch % show_step == 0:
                #print((1 - torch.cos((target_a - pred_a) * 2)).mean())
                #print((target_s - pred_s.flip(dims=[-1])).pow(2).mean())
                #print((target_s - pred_s).pow(2).mean())
                #print(((1 - cos_axis_aligned) * (target_s - pred_s.flip(-1)).pow(2)).mean())
                #print(((cos_axis_aligned + 1) * (target_s - pred_s).pow(2)).mean())
                
                display_img(pred_img[0])
                display_img(target_img[0])

In [None]:
from lcmr_ext.dataset.dataset_7seg import random_image_7seg, font_file
from PIL import ImageFont

font_size = min(*raster_size) // 2
font = ImageFont.truetype(font_file, font_size)
            
test_batch = torch.cat([random_image_7seg((0, 0, 0), 1, raster_size, font)[None, ...] for _ in range(16)], dim=0).to(device)
#test_batch = torch.cat([dataset.data[i][0] for i in range(16)], dim=0).to(device)

with torch.inference_mode():
    pred = model(test_batch)[1]

display_img(torch.vstack([torch.hstack((t, p[..., :3])) for p, t in zip(pred, test_batch)]))