In [4]:
from openlrm.models import ModelLRM
from openlrm.runners.infer.lrm import LRMInferrer

In [5]:
class CFG:
    def __init__(self, image_input='./assets/sample_input/owl.png'):
        self.config = "./configs/infer-b.yaml"
        self.infer = {'lrm': None}
        self.model_name = 'zxhezexin/openlrm-mix-base-1.1'
        self.image_input = image_input
        self.export_video = True
        self.export_mesh = True
        self.source_size = 336
        self.render_size = 288
        self.source_cam_dist = 2.0
        self.video_dump = 'dumps/zxhezexin/openlrm-mix-base-1.1/videos'
        self.mesh_dump = 'dumps/zxhezexin/openlrm-mix-base-1.1/meshes'
        self.render_views = 60
        self.render_fps = 10
        self.mesh_size = 384
        self.mesh_thres = 3.0
        self.frame_size = 2
        self.logger = 'INFO'
        self.app_enabled = False

In [6]:
cfg = CFG('./assets/sample_input/pawn.jpg')

In [7]:
cfg

<__main__.CFG at 0x7f851afaad90>

In [8]:
cfg

<__main__.CFG at 0x7f851afaad90>

In [9]:
# {'source_size': 336, 'source_cam_dist': 2.0, 'render_size': 288, 'render_views': 160, 'render_fps': 40, 'frame_size': 2, 'mesh_size': 384, 'mesh_thres': 3.0, 'video_dump': 'dumps/zxhezexin/openlrm-mix-
# base-1.1/videos', 'mesh_dump': 'dumps/zxhezexin/openlrm-mix-base-1.1/meshes', 'infer': {'lrm': None}, 'model_name': 'zxhezexin/openlrm-mix-base-1.1', 'image_input': './assets/sample_input/owl.png', 'export_video': True, 'export_mesh': True, 'logger': 'INFO', 'app_enabled': False}


In [10]:
# {'experiment': {'type': 'lrm', 'seed': 42, 'parent': 'lrm-objaverse', 'child': 'small-dummyrun'}, 'model': {'camera_embed_dim': 1024, 'rendering_samples_per_ray': 96, 'transformer_dim': 512, 'transform
# er_layers': 12, 'transformer_heads': 8, 'triplane_low_res': 32, 'triplane_high_res': 64, 'triplane_dim': 32, 'encoder_type': 'dinov2', 'encoder_model_name': 'dinov2_vits14_reg', 'encoder_feat_dim': 384
# , 'encoder_freeze': False}, 'dataset': {'subsets': [{'name': 'objaverse', 'root_dirs': ['<REPLACE_WITH_RENDERING_ROOT>'], 'meta_path': {'train': '<TRAIN_UIDS_IN_JSON>', 'val': '<VAL_UIDS_IN_JSON>'}, 's
# ample_rate': 1.0}], 'sample_side_views': 3, 'source_image_res': 224, 'render_image': {'low': 64, 'high': 192, 'region': 64}, 'normalize_camera': True, 'normed_dist_to_center': 'auto', 'num_train_worker
# s': 4, 'num_val_workers': 2, 'pin_mem': True}, 'train': {'mixed_precision': 'bf16', 'find_unused_parameters': False, 'loss': {'pixel_weight': 1.0, 'perceptual_weight': 1.0, 'tv_weight': 0.0005}, 'optim
# ': {'lr': 0.0004, 'weight_decay': 0.05, 'beta1': 0.9, 'beta2': 0.95, 'clip_grad_norm': 1.0}, 'scheduler': {'type': 'cosine', 'warmup_real_iters': 3000}, 'batch_size': 16, 'accum_steps': 1, 'epochs': 60
# , 'debug_global_steps': None, 'lrm': None}, 'val': {'batch_size': 4, 'global_step_period': 1000, 'debug_batches': None}, 'saver': {'auto_resume': True, 'load_model': None, 'checkpoint_root': './exps/ch
# eckpoints', 'checkpoint_global_steps': 1000, 'checkpoint_keep_level': 5}, 'logger': {'stream_level': 'WARNING', 'log_level': 'INFO', 'log_root': './exps/logs', 'tracker_root': './exps/trackers', 'enabl
# e_profiler': False, 'trackers': ['tensorboard'], 'image_monitor': {'train_global_steps': 100, 'samples_per_log': 4}}, 'compile': {'suppress_errors': True, 'print_specializations': True, 'disable': True}}

In [11]:
import torch
import os
import argparse
import mcubes
import trimesh
import numpy as np
from PIL import Image
from omegaconf import OmegaConf
from tqdm.auto import tqdm
from accelerate.logging import get_logger

from openlrm.runners.infer.base_inferrer import Inferrer
from openlrm.datasets.cam_utils import build_camera_principle, build_camera_standard, surrounding_views_linspace, create_intrinsics
from openlrm.utils.logging import configure_logger
from openlrm.runners import REGISTRY_RUNNERS
from openlrm.utils.video import images_to_video
from openlrm.utils.hf_hub import wrap_model_hub


logger = get_logger(__name__)

In [12]:
np.__version__

'1.26.4'

In [13]:
import openlrm

openlrm.__file__

'/mnt/c/Users/Robinson/OneDrive/Desktop/Classes_Fall_2024/CAP6411/Project/testing/pythonProject/opnlrm_real/OpenLRM/openlrm/__init__.py'

In [14]:
class LRMInferrer(Inferrer):

    EXP_TYPE: str = 'lrm'

    def __init__(self):
        super().__init__()

        self.cfg = cfg
        configure_logger(
            stream_level=self.cfg.logger,
            log_level=self.cfg.logger,
        )



        self.model = self._build_model(self.cfg).to(self.device)

    def _build_model(self, cfg):
        from openlrm.models import model_dict
        hf_model_cls = wrap_model_hub(model_dict[self.EXP_TYPE])
        model = hf_model_cls.from_pretrained(cfg.model_name)
        return model

    def _default_source_camera(self, dist_to_center: float = 2.0, batch_size: int = 1, device: torch.device = torch.device('cpu')):
        # return: (N, D_cam_raw)
        canonical_camera_extrinsics = torch.tensor([[
            [1, 0, 0, 0],
            [0, 0, -1, -dist_to_center],
            [0, 1, 0, 0],
        ]], dtype=torch.float32, device=device)
        canonical_camera_intrinsics = create_intrinsics(
            f=0.75,
            c=0.5,
            device=device,
        ).unsqueeze(0)
        source_camera = build_camera_principle(canonical_camera_extrinsics, canonical_camera_intrinsics)
        return source_camera.repeat(batch_size, 1)

    def _default_render_cameras(self, n_views: int, batch_size: int = 1, device: torch.device = torch.device('cpu')):
        # return: (N, M, D_cam_render)
        render_camera_extrinsics = surrounding_views_linspace(n_views=n_views, device=device)
        render_camera_intrinsics = create_intrinsics(
            f=0.75,
            c=0.5,
            device=device,
        ).unsqueeze(0).repeat(render_camera_extrinsics.shape[0], 1, 1)
        render_cameras = build_camera_standard(render_camera_extrinsics, render_camera_intrinsics)
        return render_cameras.unsqueeze(0).repeat(batch_size, 1, 1)

    def infer_planes(self, image: torch.Tensor, source_cam_dist: float):
        N = image.shape[0]
        source_camera = self._default_source_camera(dist_to_center=source_cam_dist, batch_size=N, device=self.device)

        planes = self.model.forward_planes(image, source_camera)

        assert N == planes.shape[0]
        return planes

    def infer_video(self, planes: torch.Tensor, frame_size: int, render_size: int, render_views: int, render_fps: int, dump_video_path: str):
        N = planes.shape[0]
        render_cameras = self._default_render_cameras(n_views=render_views, batch_size=N, device=self.device)
        render_anchors = torch.zeros(N, render_cameras.shape[1], 2, device=self.device)
        render_resolutions = torch.ones(N, render_cameras.shape[1], 1, device=self.device) * render_size
        render_bg_colors = torch.ones(N, render_cameras.shape[1], 1, device=self.device, dtype=torch.float32) * 1.

        frames = []
        for i in range(0, render_cameras.shape[1], frame_size):
            frames.append(
                self.model.synthesizer(
                    planes=planes,
                    cameras=render_cameras[:, i:i+frame_size],
                    anchors=render_anchors[:, i:i+frame_size],
                    resolutions=render_resolutions[:, i:i+frame_size],
                    bg_colors=render_bg_colors[:, i:i+frame_size],
                    region_size=render_size,
                )
            )
        # merge frames
        frames = {
            k: torch.cat([r[k] for r in frames], dim=1)
            for k in frames[0].keys()
        }
        # dump
        os.makedirs(os.path.dirname(dump_video_path), exist_ok=True)
        for k, v in frames.items():
            if k == 'images_rgb':
                images_to_video(
                    images=v[0],
                    output_path=dump_video_path,
                    fps=render_fps,
                    gradio_codec=self.cfg.app_enabled,
                )

    def infer_mesh(self, planes: torch.Tensor, mesh_size: int, mesh_thres: float, dump_mesh_path: str):
        grid_out = self.model.synthesizer.forward_grid(
            planes=planes,
            grid_size=mesh_size,
        )
        
        vtx, faces = mcubes.marching_cubes(grid_out['sigma'].squeeze(0).squeeze(-1).cpu().numpy(), mesh_thres)
        vtx = vtx / (mesh_size - 1) * 2 - 1

        vtx_tensor = torch.tensor(vtx, dtype=torch.float32, device=self.device).unsqueeze(0)
        vtx_colors = self.model.synthesizer.forward_points(planes, vtx_tensor)['rgb'].squeeze(0).cpu().numpy()  # (0, 1)
        vtx_colors = (vtx_colors * 255).astype(np.uint8)
        
        mesh = trimesh.Trimesh(vertices=vtx, faces=faces, vertex_colors=vtx_colors)

        # dump
        os.makedirs(os.path.dirname(dump_mesh_path), exist_ok=True)
        mesh.export(dump_mesh_path)

    def infer_single(self, image_path: str, source_cam_dist: float, export_video: bool, export_mesh: bool, dump_video_path: str, dump_mesh_path: str):
        source_size = self.cfg.source_size
        render_size = self.cfg.render_size
        render_views = self.cfg.render_views
        render_fps = self.cfg.render_fps
        mesh_size = self.cfg.mesh_size
        mesh_thres = self.cfg.mesh_thres
        frame_size = self.cfg.frame_size
        source_cam_dist = self.cfg.source_cam_dist if source_cam_dist is None else source_cam_dist

        # prepare image: [1, C_img, H_img, W_img], 0-1 scale
        image = torch.from_numpy(np.array(Image.open(image_path))).to(self.device)
        image = image.permute(2, 0, 1).unsqueeze(0) / 255.0
        if image.shape[1] == 4:  # RGBA
            image = image[:, :3, ...] * image[:, 3:, ...] + (1 - image[:, 3:, ...])
        image = torch.nn.functional.interpolate(image, size=(source_size, source_size), mode='bicubic', align_corners=True)
        image = torch.clamp(image, 0, 1)

        with torch.no_grad():
            planes = self.infer_planes(image, source_cam_dist=source_cam_dist)

            results = {}
            if export_video:
                frames = self.infer_video(planes, frame_size=frame_size, render_size=render_size, render_views=render_views, render_fps=render_fps, dump_video_path=dump_video_path)
                results.update({
                    'frames': frames,
                })
            if export_mesh:
                mesh = self.infer_mesh(planes, mesh_size=mesh_size, mesh_thres=mesh_thres, dump_mesh_path=dump_mesh_path)
                results.update({
                    'mesh': mesh,
                })

    def infer(self):

        image_paths = []
        if os.path.isfile(self.cfg.image_input):
            omit_prefix = os.path.dirname(self.cfg.image_input)
            image_paths.append(self.cfg.image_input)
        else:
            omit_prefix = self.cfg.image_input
            for root, dirs, files in os.walk(self.cfg.image_input):
                for file in files:
                    if file.endswith('.png'):
                        image_paths.append(os.path.join(root, file))
            image_paths.sort()

        # alloc to each DDP worker
        image_paths = image_paths[self.accelerator.process_index::self.accelerator.num_processes]

        for image_path in tqdm(image_paths, disable=not self.accelerator.is_local_main_process):

            # prepare dump paths
            image_name = os.path.basename(image_path)
            uid = image_name.split('.')[0]
            subdir_path = os.path.dirname(image_path).replace(omit_prefix, '')
            subdir_path = subdir_path[1:] if subdir_path.startswith('/') else subdir_path
            dump_video_path = os.path.join(
                self.cfg.video_dump,
                subdir_path,
                f'{uid}.mov',
            )
            dump_mesh_path = os.path.join(
                self.cfg.mesh_dump,
                subdir_path,
                f'{uid}.ply',
            )

            self.infer_single(
                image_path,
                source_cam_dist=None,
                export_video=self.cfg.export_video,
                export_mesh=self.cfg.export_mesh,
                dump_video_path=dump_video_path,
                dump_mesh_path=dump_mesh_path,
            )

In [15]:
# class ModelLRMv2(ModelLRM):
#     def __init__(self, camera_embed_dim: int, rendering_samples_per_ray: int,
#                  transformer_dim: int, transformer_layers: int, transformer_heads: int,
#                  triplane_low_res: int, triplane_high_res: int, triplane_dim: int,
#                  encoder_freeze: bool = True, encoder_type: str = 'dino',
#                  encoder_model_name: str = 'facebook/dino-vitb16', encoder_feat_dim: int = 768):
#         super().__init__(camera_embed_dim, rendering_samples_per_ray,
#                  transformer_dim, transformer_layers, transformer_heads,
#                  triplane_low_res, triplane_high_res, triplane_dim,
#                  encoder_freeze, encoder_type, encoder_model_name, encoder_feat_dim)
#         self.model = self._build_model(self.cfg).to(self.device)

    
    

In [16]:
lrm_i = LRMInferrer()

[2024-09-22 20:43:42,894] openlrm.models.modeling_lrm: [INFO] Using DINOv2 as the encoder




In [17]:
from transformers import BertModel, BertTokenizer

In [18]:
bert = BertModel.from_pretrained('bert-base-uncased').to("cuda")
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')



In [19]:
text = "pawn"
inputs = tokenizer(text, return_tensors="pt").to("cuda")

In [20]:
inputs
bert(**inputs)

BaseModelOutputWithPoolingAndCrossAttentions(last_hidden_state=tensor([[[-0.5961,  0.0817, -0.0259,  ..., -0.1485,  0.0682,  1.1705],
         [-0.6663, -0.8367,  0.4010,  ...,  0.3661, -0.0528,  0.3260],
         [ 0.9318,  0.0092, -0.3196,  ...,  0.1259, -0.8594, -0.1575]]],
       device='cuda:0', grad_fn=<NativeLayerNormBackward0>), pooler_output=tensor([[-0.9324, -0.3545, -0.0097,  0.8321,  0.1828, -0.2639,  0.9565,  0.1587,
         -0.3545, -1.0000, -0.4712,  0.7049,  0.9891,  0.0572,  0.9486, -0.6808,
         -0.4614, -0.5864,  0.5341, -0.9010,  0.7179,  0.9990,  0.4551,  0.3323,
          0.5417,  0.9128, -0.6911,  0.9471,  0.9690,  0.7658, -0.7925,  0.2185,
         -0.9921, -0.3486, -0.4658, -0.9930,  0.3867, -0.8437, -0.1463, -0.0235,
         -0.9416,  0.4951,  0.9999,  0.4803,  0.4872, -0.4367, -1.0000,  0.3117,
         -0.9295,  0.4664,  0.3824,  0.1156,  0.2471,  0.5555,  0.6331,  0.0898,
          0.0425,  0.2780, -0.3225, -0.6382, -0.6747,  0.4071, -0.4357, -0.9256,

In [21]:
@torch.compile
def forward_planes(image, camera):
    # image: [N, C_img, H_img, W_img]
    # camera: [N, D_cam_raw]
    
    N = image.shape[0]
    
    # encode image
    image_feats = lrm_i.model.encoder(image)
    assert image_feats.shape[-1] == lrm_i.model.encoder_feat_dim, \
        f"Feature dimension mismatch: {image_feats.shape[-1]} vs {lrm_i.model.encoder_feat_dim}"

    # embed camera
    camera_embeddings = lrm_i.model.camera_embedder(camera)
    assert camera_embeddings.shape[-1] == lrm_i.model.camera_embed_dim, \
        f"Feature dimension mismatch: {camera_embeddings.shape[-1]} vs {lrm_i.model.camera_embed_dim}"
    
    text_feats = bert(**inputs).last_hidden_state
    # transformer generating planes
    feats = image_feats/2 + text_feats[:,0,:]/2
    

    tokens = lrm_i.model.forward_transformer(feats, camera_embeddings)
    planes = lrm_i.model.reshape_upsample(tokens)
    assert planes.shape[0] == N, "Batch size mismatch for planes"
    assert planes.shape[1] == 3, "Planes should have 3 channels"

    return planes

In [22]:
lrm_i.model.forward_planes = forward_planes

In [23]:
lrm_i.infer()

100%|██████████| 1/1 [04:11<00:00, 251.91s/it]
