In [None]:
import torch
from lcmr_ext.renderer.renderer2d import PyTorch3DRenderer2D
from lcmr_ext.utils.sample_efd import ellipse_efd, heart_efd, square_efd
from torchmetrics.classification import BinaryJaccardIndex
from torchmetrics.image import StructuralSimilarityIndexMeasure
from torchvision.utils import save_image
from tqdm import tqdm

from lcmr.dataset import DatasetOptions, EfdGeneratorOptions, RandomDataset
from lcmr.grammar.scene_data import SceneData
from lcmr.utils.elliptic_fourier_descriptors import EfdGeneratorOptions

torch.set_float32_matmul_precision("high")

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
raster_size = (128, 128)

In [None]:
torch.manual_seed(1234)
    
choices = torch.cat([heart_efd()[None], square_efd()[None], ellipse_efd()[None]], dim=0)

train_options = DatasetOptions(
    raster_size=raster_size,
    n_samples=50_000,
    n_objects=4,
    Renderer=PyTorch3DRenderer2D,
    renderer_device=device,
    n_jobs=1,
    return_images=False,
    efd_options=EfdGeneratorOptions(choices=choices),
)

val_options = DatasetOptions(
    raster_size=raster_size,
    n_samples=5_000,
    n_objects=4,
    Renderer=PyTorch3DRenderer2D,
    renderer_device=device,
    n_jobs=1,
    return_images=True,
    efd_options=EfdGeneratorOptions(choices=choices),
)

test_options = DatasetOptions(
    raster_size=raster_size,
    n_samples=5_000,
    n_objects=4,
    Renderer=PyTorch3DRenderer2D,
    renderer_device=device,
    n_jobs=1,
    return_images=True,
    efd_options=EfdGeneratorOptions(choices=choices),
)

train_dataset = RandomDataset(train_options)
val_dataset = RandomDataset(val_options)
test_dataset = RandomDataset(test_options)
    

In [None]:
for i, scene_data in enumerate(tqdm(train_dataset)):
    scene_data: SceneData = scene_data
    save_image(scene_data.image_rgb[0].permute(2, 0, 1), f"../MDS-HR/train/{i}.png")
    torch.save(scene_data.scene.state_dict(), f"../MDS-HR/train/{i}.pt")
    
for i, scene_data in enumerate(tqdm(val_dataset)):
    scene_data: SceneData = scene_data
    save_image(scene_data.image_rgb[0].permute(2, 0, 1), f"../MDS-HR/val/{i}.png")
    torch.save(scene_data.scene.state_dict(), f"../MDS-HR/val/{i}.pt")
    
for i, scene_data in enumerate(tqdm(test_dataset)):
    scene_data: SceneData = scene_data
    save_image(scene_data.image_rgb[0].permute(2, 0, 1), f"../MDS-HR/test/{i}.png")
    torch.save(scene_data.scene.state_dict(), f"../MDS-HR/test/{i}.pt")