# Import libs

In [12]:
import os

import sys
import torch
import h5py
import argparse
import numpy as np
import imageio
import matplotlib.pyplot as plt

from renderformer import RenderFormerRenderingPipeline
from renderformer.models.renderformer import RenderFormer
from renderformer.models.config import RenderFormerConfig

from simple_ocio import ToneMapper
import blenderproc as bproc
import numpy as np
import os
import random
from math import radians
import json

# Generate syntetic dataset

In [13]:
# from generate_dataset import SceneGenerator
# import asyncio


# CONFIG = {
#     "DATA_PATH": "/home/devel/.draft/renderformer/datasets",
#     "JSON_PATH": "/home/devel/.draft/renderformer/datasets/json",
#     "H5_PATH": "/home/devel/.draft/renderformer/datasets/h5",
#     "GT_PATH": "/home/devel/.draft/renderformer/datasets/gt",
#     "TEMP_MESH_PATH": "/home/devel/.draft/renderformer/datasets/temp",
#     "OBJ_PATH": "/home/devel/.draft/renderformer/examples/objects",
#     "BASE_DIR": "/home/devel/.draft/renderformer/examples",
#     "SCRIPT_NAME": "render_scene.py",
#     "NUM_RANDOM_SCENES": 5,
#     "MAX_CONCURRENT_TASKS": 4,
# }


# generator = SceneGenerator(CONFIG)

# # Asynchronous generation
# await generator.generate_dataset()

# Model setting

In [14]:
def load_single_h5_data(file_path):
    with h5py.File(file_path, 'r') as f:
        triangles = torch.from_numpy(np.array(f['triangles']).astype(np.float32))
        num_tris = triangles.shape[0]
        texture = torch.from_numpy(np.array(f['texture']).astype(np.float32))
        mask = torch.ones(num_tris, dtype=torch.bool)
        vn = torch.from_numpy(np.array(f['vn']).astype(np.float32))
        c2w = torch.from_numpy(np.array(f['c2w']).astype(np.float32))
        fov = torch.from_numpy(np.array(f['fov']).astype(np.float32))

        data = {
            'triangles': triangles,
            'texture': texture,
            'mask': mask,
            'c2w': c2w,
            'fov': fov,
            'vn': vn,
        }
    return data

In [15]:
# Создание конфигурации
config = RenderFormerConfig(
    latent_dim=768,
    num_layers=12,
    num_heads=6,
    dim_feedforward=768 * 4,
    num_register_tokens=16,
    dropout=0.0,
    activation='swiglu',
    norm_type='rms_norm',
    norm_first=True,
    pe_type='rope',
    rope_type='triangle',
    use_vn_encoder=True,
    texture_encode_patch_size=32,
    texture_channels=13,
    view_transformer_latent_dim=768,
    view_transformer_ffn_hidden_dim=768 * 4,
    view_transformer_n_heads=6,
    view_transformer_n_layers=6,
    patch_size=8,
    use_dpt_decoder=True
)

device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')

# Инициализация модели
model = RenderFormer(config)
model.to(device)

# Подготовка входных данных
batch_size = 1
max_num_tri = 1000
num_views = 1

tri_vpos_list = torch.randn(batch_size, max_num_tri, 9).cuda()
texture_patch_list = torch.randn(batch_size, max_num_tri, 13, 32, 32).cuda()
valid_mask = torch.ones(batch_size, max_num_tri, dtype=torch.bool).cuda()
vns = torch.randn(batch_size, max_num_tri, 3, 3).cuda()
rays_o = torch.randn(batch_size, num_views, 3).cuda()
rays_d = torch.randn(batch_size, num_views, 512, 512, 3).cuda()
tri_vpos_view_tf = torch.randn(batch_size, num_views, max_num_tri, 9).cuda()


# Load data and move to device
h5_file = '/home/devel/.draft/renderformer/tmp/random_scene_16_6k.h5'
data = load_single_h5_data(h5_file)

# Add batch dimension to all tensors
triangles = data['triangles'].unsqueeze(0).to(device)
texture = data['texture'].unsqueeze(0).to(device)
mask = data['mask'].unsqueeze(0).to(device)
vn = data['vn'].unsqueeze(0).to(device)
c2w = data['c2w'].unsqueeze(0).to(device)
fov = data['fov'].unsqueeze(0).unsqueeze(-1).to(device)
resolution = 512
torch_dtype = torch.float32

pipeline = RenderFormerRenderingPipeline(model)


rendered_imgs = pipeline(
    triangles=triangles,
    texture=texture,
    mask=mask,
    vn=vn,
    c2w=c2w,
    fov=fov,
    resolution=resolution,
    torch_dtype=torch_dtype,
)