In [1]:
import os
import math
from tqdm.auto import tqdm
import torch
import torch.nn as nn
from torchvision.utils import make_grid
from accelerate.logging import get_logger

#from .base_trainer import Trainer
from openlrm.runners.train.base_trainer import Trainer
from openlrm.utils.profiler import DummyProfiler
from openlrm.runners import REGISTRY_RUNNERS
import os
import time
import math
import argparse
import shutil
import torch
import safetensors
from omegaconf import OmegaConf
from abc import abstractmethod
from contextlib import contextmanager
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed

from openlrm.utils.logging import configure_logger
from openlrm.utils.compile import configure_dynamo
from openlrm.runners.abstract import Runner
import logging
from transformers import BertModel, BertTokenizer


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class AttributeDict(dict):
    __getattr__ = dict.__getitem__
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__

In [3]:
logger = get_logger(__name__)
class CFG:
    def __init__(self):
        self.experiment = AttributeDict({
            'type': 'lrm',
            'seed': 42,
            'parent': 'lrm-objaverse',
            'child': 'small-dummyrun'
        })
        self.model = AttributeDict({
            'camera_embed_dim': 1024,
            'rendering_samples_per_ray': 96,
            'transformer_dim': 512,
            'transformer_layers': 12,
            'transformer_heads': 8,
            'triplane_low_res': 32,
            'triplane_high_res': 64,
            'triplane_dim': 32,
            'encoder_type': 'dinov2',
            'encoder_model_name': 'dinov2_vits14_reg',
            'encoder_feat_dim': 384,
            'encoder_freeze': False
        })
        self.dataset = AttributeDict({
            'subsets': [
                AttributeDict({
                    'name': 'objaverse',
                    'root_dirs': ['/mnt/c/Users/Robinson/OneDrive/Desktop/Classes_Fall_2024/CAP6411/Project/testing/pythonProject/OpenLRM/scripts/data/objaverse/views'],
                    'meta_path': AttributeDict({
                        'train': '/mnt/c/Users/Robinson/OneDrive/Desktop/Classes_Fall_2024/CAP6411/Project/testing/pythonProject/OpenLRM/scripts/data/lol.json',
                        'val': '/mnt/c/Users/Robinson/OneDrive/Desktop/Classes_Fall_2024/CAP6411/Project/testing/pythonProject/OpenLRM/scripts/data/lol.json'
                    }),
                    'sample_rate': 1.0
                })
            ],
            'sample_side_views': 3,
            'source_image_res': 224,
            'render_image': AttributeDict({
                'low': 64,
                'high': 192,
                'region': 64
            }),
            'normalize_camera': True,
            'normed_dist_to_center': 'auto',
            'num_train_workers': 4,
            'num_val_workers': 2,
            'pin_mem': True
        })
        self.train = AttributeDict({
            'mixed_precision': 'no',
            'find_unused_parameters': False,
            'loss': AttributeDict({
                'pixel_weight': 1.0,
                'perceptual_weight': 1.0,
                'tv_weight': 0.0005
            }),
            'optim': AttributeDict({
                'lr': 0.0004,
                'weight_decay': 0.05,
                'beta1': 0.9,
                'beta2': 0.95,
                'clip_grad_norm': 1.0
            }),
            'scheduler': AttributeDict({
                'type': 'cosine',
                'warmup_real_iters': 3000
            }),
            'batch_size': 1,
            'accum_steps': 1,
            'epochs': 60,
            'debug_global_steps': None,
            'lrm': None
        })
        self.val = AttributeDict({
            'batch_size': 4,
            'global_step_period': 1000,
            'debug_batches': None
        })
        self.saver = AttributeDict({
            'auto_resume': True,
            'load_model': None,
            'checkpoint_root': './exps/checkpoints',
            'checkpoint_global_steps': 1000,
            'checkpoint_keep_level': 5
        })
        self.logger = AttributeDict({
            'stream_level': 'WARNING',
            'log_level': 'INFO',
            'log_root': './exps/logs',
            'tracker_root': './exps/trackers',
            'enable_profiler': False,
            'trackers': ['tensorboard'],
            'image_monitor': AttributeDict({
                'train_global_steps': 100,
                'samples_per_log': 4
            })
        })
        self.compile = AttributeDict({
            'suppress_errors': True,
            'print_specializations': True,
            'disable': True
        })
        

In [4]:
cfg = CFG()

In [5]:
actual_logger = logger.logger

In [6]:
if not actual_logger.handlers:
    handler = logging.StreamHandler()  # Default to console output
    handler.setLevel(logging.INFO)
    actual_logger.addHandler(handler)

In [7]:
for handler in logger.logger.handlers:
    print(f"Handler: {handler}")
    print(f"Handler Type: {type(handler)}")
    if isinstance(handler, logging.FileHandler):
        print(f"Logging to file: {handler.baseFilename}")
    elif isinstance(handler, logging.StreamHandler):
        print("Logging to console/stream")

Handler: <StreamHandler stderr (INFO)>
Handler Type: <class 'logging.StreamHandler'>
Logging to console/stream


In [8]:
# Copyright (c) 2023-2024, Zexin He
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.




logger = get_logger(__name__)


# def parse_configs():
#     # Define argparse arguments
#     parser = argparse.ArgumentParser()
#     parser.add_argument('--config', type=str, default='./assets/config.yaml')
#     args, unknown = parser.parse_known_args()
# 
#     # Load configuration file
#     cfg = OmegaConf.load(args.config)
# 
#     # Override with command-line arguments
#     cli_cfg = OmegaConf.from_cli(unknown)
#     cfg = OmegaConf.merge(cfg, cli_cfg)
# 
#     return cfg


class Trainer(Runner):

    def __init__(self):
        super().__init__()

        self.cfg = cfg
        self.timestamp = time.strftime("%Y%m%d-%H%M%S")
        
        self.accelerator = Accelerator(
            mixed_precision=self.cfg.train.mixed_precision,
            gradient_accumulation_steps=self.cfg.train.accum_steps,
            log_with=tuple(self.cfg.logger.trackers),
            project_config=ProjectConfiguration(
                logging_dir=self.cfg.logger.tracker_root,
            ),
            use_seedable_sampler=True,
            kwargs_handlers=[
                DistributedDataParallelKwargs(
                    find_unused_parameters=self.cfg.train.find_unused_parameters,
                ),
            ],
        )
        set_seed(self.cfg.experiment.seed, device_specific=True)
        with self.accelerator.main_process_first():
            configure_logger(
                stream_level=self.cfg.logger.stream_level,
                log_level=self.cfg.logger.log_level,
                file_path=os.path.join(
                    self.cfg.logger.log_root,
                    self.cfg.experiment.parent, self.cfg.experiment.child,
                    f"{self.timestamp}.log",
                ) if self.accelerator.is_main_process else None,
            )
        logger.info(self.accelerator.state, main_process_only=False, in_order=True)
        configure_dynamo(dict(self.cfg.compile))

        # attributes with defaults
        self.model : torch.nn.Module = None
        self.optimizer: torch.optim.Optimizer = None
        self.scheduler: torch.optim.lr_scheduler.LRScheduler = None
        self.train_loader: torch.utils.data.DataLoader = None
        self.val_loader: torch.utils.data.DataLoader = None
        self.N_max_global_steps: int = None
        self.N_global_steps_per_epoch: int = None
        self.global_step: int = 0
        self.current_epoch: int = 0

    def __enter__(self):
        self.accelerator.init_trackers(
            project_name=f"{self.cfg.experiment.parent}/{self.cfg.experiment.child}",
        )
        self.prepare_everything()
        self.log_inital_info()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.accelerator.end_training()

    @staticmethod
    def control(option: str = None, synchronized: bool = False):
        def decorator(func):
            def wrapper(self, *args, **kwargs):
                if option is None or hasattr(self.accelerator, option):
                    accelerated_func = getattr(self.accelerator, option)(func) if option is not None else func
                    result = accelerated_func(self, *args, **kwargs)
                    if synchronized:
                        self.accelerator.wait_for_everyone()
                    return result
                else:
                    raise AttributeError(f"Accelerator has no attribute {option}")
            return wrapper
        return decorator

    @contextmanager
    def exec_in_order(self):
        for rank in range(self.accelerator.num_processes):
            try:
                if self.accelerator.process_index == rank:
                    yield
            finally:
                self.accelerator.wait_for_everyone()

    @property
    def device(self):
        return self.accelerator.device

    @property
    def is_distributed(self) -> bool:
        return self.accelerator.num_processes > 1

    def prepare_everything(self, is_dist_validation: bool = True):
        # prepare with accelerator
        if is_dist_validation:
            self.model, self.optimizer, self.train_loader, self.val_loader = \
                self.accelerator.prepare(
                    self.model, self.optimizer, self.train_loader, self.val_loader,
                )
        else:
            self.model, self.optimizer, self.train_loader = \
                self.accelerator.prepare(
                    self.model, self.optimizer, self.train_loader,
                )
        self.accelerator.register_for_checkpointing(self.scheduler)
        # prepare stats
        N_total_batch_size = self.cfg.train.batch_size * self.accelerator.num_processes * self.cfg.train.accum_steps
        self.N_global_steps_per_epoch = math.ceil(len(self.train_loader) / self.cfg.train.accum_steps)
        self.N_max_global_steps = self.N_global_steps_per_epoch * self.cfg.train.epochs
        if self.cfg.train.debug_global_steps is not None:
            logger.warning(f"Overriding max global steps from {self.N_max_global_steps} to {self.cfg.train.debug_global_steps}")
            self.N_max_global_steps = self.cfg.train.debug_global_steps
        logger.info(f"======== Statistics ========")
        logger.info(f"** N_max_global_steps: {self.N_max_global_steps}")
        logger.info(f"** N_total_batch_size: {N_total_batch_size}")
        logger.info(f"** N_epochs: {self.cfg.train.epochs}")
        logger.info(f"** N_global_steps_per_epoch: {self.N_global_steps_per_epoch}")
        logger.debug(f"** Prepared loader length: {len(self.train_loader)}")
        logger.info(f"** Distributed validation: {is_dist_validation}")
        logger.info(f"============================")
        logger.info(f"======== Trainable parameters ========")
        logger.info(f"** Total: {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}")
        for sub_name, sub_module in self.accelerator.unwrap_model(self.model).named_children():
            logger.info(f"** {sub_name}: {sum(p.numel() for p in sub_module.parameters() if p.requires_grad)}")
        logger.info(f"=====================================")
        self.accelerator.wait_for_everyone()
        # load checkpoint or model
        self.load_ckpt_or_auto_resume_(self.cfg)
        # register hooks
        self.register_hooks()

    @abstractmethod
    def register_hooks(self):
        pass

    def auto_resume_(self, cfg) -> bool:
        ckpt_root = os.path.join(
            cfg.saver.checkpoint_root,
            cfg.experiment.parent, cfg.experiment.child,
        )
        if not os.path.exists(ckpt_root):
            return False
        ckpt_dirs = os.listdir(ckpt_root)
        if len(ckpt_dirs) == 0:
            return False
        ckpt_dirs.sort()
        latest_ckpt = ckpt_dirs[-1]
        latest_ckpt_dir = os.path.join(ckpt_root, latest_ckpt)
        logger.info(f"======== Auto-resume from {latest_ckpt_dir} ========")
        self.accelerator.load_state(latest_ckpt_dir)
        self.global_step = int(latest_ckpt)
        self.current_epoch = self.global_step // self.N_global_steps_per_epoch
        return True

    def load_model_(self, cfg):
        logger.info(f"======== Loading model from {cfg.saver.load_model} ========")
        safetensors.torch.load_model(
            self.accelerator.unwrap_model(self.model),
            cfg.saver.load_model,
            strict=True,
        )
        logger.info(f"======== Model loaded ========")

    @control(synchronized=True)
    def load_ckpt_or_auto_resume_(self, cfg):
        # auto resume has higher priority, load model from path if auto resume is not available
        # cfg.saver.auto_resume and cfg.saver.load_model
        if cfg.saver.auto_resume:
            successful_resume = self.auto_resume_(cfg)
            if successful_resume:
                return
        if cfg.saver.load_model:
            successful_load = self.load_model_(cfg)
            if successful_load:
                return
        logger.debug(f"======== No checkpoint or model is loaded ========")

    @control('on_main_process', synchronized=True)
    def save_checkpoint(self):
        ckpt_dir = os.path.join(
            self.cfg.saver.checkpoint_root,
            self.cfg.experiment.parent, self.cfg.experiment.child,
            f"{self.global_step:06d}",
        )
        self.accelerator.save_state(output_dir=ckpt_dir, safe_serialization=True)
        logger.info(f"======== Saved checkpoint at global step {self.global_step} ========")
        # manage stratified checkpoints
        ckpt_dirs = os.listdir(os.path.dirname(ckpt_dir))
        ckpt_dirs.sort()
        max_ckpt = int(ckpt_dirs[-1])
        ckpt_base = int(self.cfg.saver.checkpoint_keep_level)
        ckpt_period = self.cfg.saver.checkpoint_global_steps
        logger.debug(f"Checkpoint base: {ckpt_base}")
        logger.debug(f"Checkpoint period: {ckpt_period}")
        print(ckpt_base, ckpt_period, max_ckpt)
         
        e = 0
        if max_ckpt // ckpt_period == 0:
            e = 1e-5
        cur_order = ckpt_base ** math.floor(math.log((max_ckpt // ckpt_period) + e, ckpt_base))
        cur_idx = 0
        while cur_order > 0:
            cur_digit = max_ckpt // ckpt_period // cur_order % ckpt_base
            while cur_idx < len(ckpt_dirs) and int(ckpt_dirs[cur_idx]) // ckpt_period // cur_order % ckpt_base < cur_digit:
                if int(ckpt_dirs[cur_idx]) // ckpt_period % cur_order != 0:
                    shutil.rmtree(os.path.join(os.path.dirname(ckpt_dir), ckpt_dirs[cur_idx]))
                    logger.info(f"Removed checkpoint {ckpt_dirs[cur_idx]}")
                cur_idx += 1
            cur_order //= ckpt_base

    @property
    def global_step_in_epoch(self):
        print(self.global_step, self.N_global_steps_per_epoch)
        if(self.N_global_steps_per_epoch == 0):
            self.N_global_steps_per_epoch = 1
        #input("remember to change global_step_in_epoch")
        return self.global_step % self.N_global_steps_per_epoch

    @abstractmethod
    def _build_model(self):
        pass

    @abstractmethod
    def _build_optimizer(self):
        pass

    @abstractmethod
    def _build_scheduler(self):
        pass

    @abstractmethod
    def _build_dataloader(self):
        pass

    @abstractmethod
    def _build_loss_fn(self):
        pass

    @abstractmethod
    def train(self):
        pass

    @abstractmethod
    def evaluate(self):
        pass

    @staticmethod
    def _get_str_progress(epoch: int = None, step: int = None):
        if epoch is not None:
            log_type = 'epoch'
            log_progress = epoch
        elif step is not None:
            log_type = 'step'
            log_progress = step
        else:
            raise ValueError('Either epoch or step must be provided')
        return log_type, log_progress

    @control('on_main_process')
    def log_scalar_kwargs(self, epoch: int = None, step: int = None, split: str = None, **scalar_kwargs):
        log_type, log_progress = self._get_str_progress(epoch, step)
        split = f'/{split}' if split else ''
        for key, value in scalar_kwargs.items():
            self.accelerator.log({f'{key}{split}/{log_type}': value}, log_progress)

    @control('on_main_process')
    def log_images(self, values: dict, step: int | None = None, log_kwargs: dict | None = {}):
        for tracker in self.accelerator.trackers:
            if hasattr(tracker, 'log_images'):
                tracker.log_images(values, step=step, **log_kwargs.get(tracker.name, {}))

    @control('on_main_process')
    def log_optimizer(self, epoch: int = None, step: int = None, attrs: list[str] = [], group_ids: list[int] = []):
        log_type, log_progress = self._get_str_progress(epoch, step)
        assert self.optimizer is not None, 'Optimizer is not initialized'
        if not attrs:
            logger.warning('No optimizer attributes are provided, nothing will be logged')
        if not group_ids:
            logger.warning('No optimizer group ids are provided, nothing will be logged')
        for attr in attrs:
            assert attr in ['lr', 'momentum', 'weight_decay'], f'Invalid optimizer attribute {attr}'
            for group_id in group_ids:
                self.accelerator.log({f'opt/{attr}/{group_id}': self.optimizer.param_groups[group_id][attr]}, log_progress)

    @control('on_main_process')
    def log_inital_info(self):
        assert self.model is not None, 'Model is not initialized'
        assert self.optimizer is not None, 'Optimizer is not initialized'
        assert self.scheduler is not None, 'Scheduler is not initialized'
        self.accelerator.log({'Config': "```\n" + OmegaConf.to_yaml(self.cfg) + "\n```"})
        self.accelerator.log({'Model': "```\n" + str(self.model) + "\n```"})
        self.accelerator.log({'Optimizer': "```\n" + str(self.optimizer) + "\n```"})
        self.accelerator.log({'Scheduler': "```\n" + str(self.scheduler) + "\n```"})

    def run(self):
        self.train()


In [9]:
#5 ** math.floor(math.log(60 // 1000, 5))

In [10]:
e = 0
if 60 // 1000 == 0:
    e = 1e-5
5 ** math.floor(math.log((60 // 1000) + e, 5))

2.56e-06

In [11]:



# @REGISTRY_RUNNERS.register('train.lrm')
class LRMTrainer(Trainer):
    def __init__(self):
        super().__init__()

        self.model = self._build_model(self.cfg)
        self.optimizer = self._build_optimizer(self.model, self.cfg)
        self.train_loader, self.val_loader = self._build_dataloader(self.cfg)
        self.scheduler = self._build_scheduler(self.optimizer, self.cfg)
        self.pixel_loss_fn, self.perceptual_loss_fn, self.tv_loss_fn = self._build_loss_fn(self.cfg)

    def _build_model(self, cfg):
        assert cfg.experiment.type == 'lrm', \
            f"Config type {cfg.experiment.type} does not match with runner {self.__class__.__name__}"
        from openlrm.models import ModelLRM
        model = ModelLRM(**cfg.model)
        return model

    def _build_optimizer(self, model: nn.Module, cfg):
        decay_params, no_decay_params = [], []

        # add all bias and LayerNorm params to no_decay_params
        for name, module in model.named_modules():
            if isinstance(module, nn.LayerNorm):
                no_decay_params.extend([p for p in module.parameters()])
            elif hasattr(module, 'bias') and module.bias is not None:
                no_decay_params.append(module.bias)

        # add remaining parameters to decay_params
        _no_decay_ids = set(map(id, no_decay_params))
        decay_params = [p for p in model.parameters() if id(p) not in _no_decay_ids]

        # filter out parameters with no grad
        decay_params = list(filter(lambda p: p.requires_grad, decay_params))
        no_decay_params = list(filter(lambda p: p.requires_grad, no_decay_params))

        # monitor this to make sure we don't miss any parameters
        logger.info("======== Weight Decay Parameters ========")
        logger.info(f"Total: {len(decay_params)}")
        logger.info("======== No Weight Decay Parameters ========")
        logger.info(f"Total: {len(no_decay_params)}")

        # Optimizer
        opt_groups = [
            {'params': decay_params, 'weight_decay': cfg.train.optim.weight_decay},
            {'params': no_decay_params, 'weight_decay': 0.0},
        ]
        optimizer = torch.optim.AdamW(
            opt_groups,
            lr=cfg.train.optim.lr,
            betas=(cfg.train.optim.beta1, cfg.train.optim.beta2),
        )

        return optimizer

    def _build_scheduler(self, optimizer, cfg):
        local_batches_per_epoch = math.floor(len(self.train_loader) / self.accelerator.num_processes)
        total_global_batches = cfg.train.epochs * math.ceil(local_batches_per_epoch / self.cfg.train.accum_steps)
        effective_warmup_iters = cfg.train.scheduler.warmup_real_iters
        logger.debug(f"======== Scheduler effective max iters: {total_global_batches} ========")
        logger.debug(f"======== Scheduler effective warmup iters: {effective_warmup_iters} ========")
        if cfg.train.scheduler.type == 'cosine':
            from openlrm.utils.scheduler import CosineWarmupScheduler
            scheduler = CosineWarmupScheduler(
                optimizer=optimizer,
                warmup_iters=effective_warmup_iters,
                max_iters=total_global_batches,
            )
        else:
            raise NotImplementedError(f"Scheduler type {cfg.train.scheduler.type} not implemented")
        return scheduler

    def _build_dataloader(self, cfg):
        # dataset class
        from openlrm.datasets import MixerDataset

        # build dataset
        train_dataset = MixerDataset(
            split="train",
            subsets=cfg.dataset.subsets,
            sample_side_views=cfg.dataset.sample_side_views,
            render_image_res_low=cfg.dataset.render_image.low,
            render_image_res_high=cfg.dataset.render_image.high,
            render_region_size=cfg.dataset.render_image.region,
            source_image_res=cfg.dataset.source_image_res,
            normalize_camera=cfg.dataset.normalize_camera,
            normed_dist_to_center=cfg.dataset.normed_dist_to_center,
        )
        val_dataset = MixerDataset(
            split="val",
            subsets=cfg.dataset.subsets,
            sample_side_views=cfg.dataset.sample_side_views,
            render_image_res_low=cfg.dataset.render_image.low,
            render_image_res_high=cfg.dataset.render_image.high,
            render_region_size=cfg.dataset.render_image.region,
            source_image_res=cfg.dataset.source_image_res,
            normalize_camera=cfg.dataset.normalize_camera,
            normed_dist_to_center=cfg.dataset.normed_dist_to_center,
        )

        # build data loader
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=cfg.train.batch_size,
            shuffle=True,
            drop_last=True,
            num_workers=cfg.dataset.num_train_workers,
            pin_memory=cfg.dataset.pin_mem,
            persistent_workers=True,
        )
 
       
        val_loader = torch.utils.data.DataLoader(
            val_dataset,
            batch_size=cfg.val.batch_size,
            shuffle=False,
            drop_last=False,
            num_workers=cfg.dataset.num_val_workers,
            pin_memory=cfg.dataset.pin_mem,
            persistent_workers=False,
        )

        return train_loader, val_loader

    def _build_loss_fn(self, cfg):
        from openlrm.losses import PixelLoss, LPIPSLoss, TVLoss
        pixel_loss_fn = PixelLoss()
        with self.accelerator.main_process_first():
            perceptual_loss_fn = LPIPSLoss(device=self.device, prefech=True)
        tv_loss_fn = TVLoss()
        return pixel_loss_fn, perceptual_loss_fn, tv_loss_fn

    def register_hooks(self):
        pass

    def forward_loss_local_step(self, data):

        source_camera = data['source_camera']
        render_camera = data['render_camera']
        source_image = data['source_image']
        render_image = data['render_image']
        render_anchors = data['render_anchors']
        render_full_resolutions = data['render_full_resolutions']
        render_bg_colors = data['render_bg_colors']

    
        N, M, C, H, W = render_image.shape
        
        # forward
        # for params in self.model.parameters():
        #     print(params.device)
        outputs = self.model(
            image=source_image,
            source_camera=source_camera,
            render_cameras=render_camera,
            render_anchors=render_anchors,
            render_resolutions=render_full_resolutions,
            render_bg_colors=render_bg_colors,
            render_region_size=self.cfg.dataset.render_image.region,
        )

        # loss calculation
        loss = 0.
        loss_pixel = None
        loss_perceptual = None
        loss_tv = None

        if self.cfg.train.loss.pixel_weight > 0.:
            loss_pixel = self.pixel_loss_fn(outputs['images_rgb'], render_image)
            loss += loss_pixel * self.cfg.train.loss.pixel_weight
        if self.cfg.train.loss.perceptual_weight > 0.:
            loss_perceptual = self.perceptual_loss_fn(outputs['images_rgb'], render_image)
            loss += loss_perceptual * self.cfg.train.loss.perceptual_weight
        if self.cfg.train.loss.tv_weight > 0.: 
            loss_tv = self.tv_loss_fn(outputs['planes'])
            loss += loss_tv * self.cfg.train.loss.tv_weight

        return outputs, loss, loss_pixel, loss_perceptual, loss_tv
    
    

    def train_epoch(self, pbar: tqdm, loader: torch.utils.data.DataLoader, profiler: torch.profiler.profile):
        self.model.train()

        local_step_losses = []
        global_step_losses = []

        logger.debug(f"======== Starting epoch {self.current_epoch} ========")
        
        
        for data in loader:
            
            logger.debug(f"======== Starting global step {self.global_step} ========")
            with self.accelerator.accumulate(self.model):

                # forward to loss
                outs, loss, loss_pixel, loss_perceptual, loss_tv = self.forward_loss_local_step(data)
                
                # backward
                self.accelerator.backward(loss)
                if self.accelerator.sync_gradients and self.cfg.train.optim.clip_grad_norm > 0.:
                    self.accelerator.clip_grad_norm_(self.model.parameters(), self.cfg.train.optim.clip_grad_norm)
                self.optimizer.step()
                self.optimizer.zero_grad()

                # track local losses
                local_step_losses.append(torch.stack([
                    _loss.detach() if _loss is not None else torch.tensor(float('nan'), device=self.device)
                    for _loss in [loss, loss_pixel, loss_perceptual, loss_tv]
                ]))

            # track global step
            if self.accelerator.sync_gradients:
                profiler.step()
                self.scheduler.step()
                logger.debug(f"======== Scheduler step ========")
                self.global_step += 1
                global_step_loss = self.accelerator.gather(torch.stack(local_step_losses)).mean(dim=0).cpu()
                loss, loss_pixel, loss_perceptual, loss_tv = global_step_loss.unbind()
                loss_kwargs = {
                    'loss': loss.item(),
                    'loss_pixel': loss_pixel.item(),
                    'loss_perceptual': loss_perceptual.item(),
                    'loss_tv': loss_tv.item(),
                }
                self.log_scalar_kwargs(
                    step=self.global_step, split='train',
                    **loss_kwargs
                )
                self.log_optimizer(step=self.global_step, attrs=['lr'], group_ids=[0, 1])
                local_step_losses = []
                global_step_losses.append(global_step_loss)

                # manage display
                pbar.update(1)
                description = {
                    **loss_kwargs,
                    'lr': self.optimizer.param_groups[0]['lr'],
                }
                description = '[TRAIN STEP]' + \
                    ', '.join(f'{k}={tqdm.format_num(v)}' for k, v in description.items() if not math.isnan(v))
                pbar.set_description(description)

                # periodic actions
                if self.global_step % self.cfg.saver.checkpoint_global_steps == 0:
                    self.save_checkpoint()
                if self.global_step % self.cfg.val.global_step_period == 0:
                    self.evaluate()
                    self.model.train()
                if self.global_step % self.cfg.logger.image_monitor.train_global_steps == 0:
                    self.log_image_monitor(
                        step=self.global_step, split='train',
                        renders=outs['images_rgb'].detach()[:self.cfg.logger.image_monitor.samples_per_log].cpu(),
                        gts=data['render_image'][:self.cfg.logger.image_monitor.samples_per_log].cpu(),
                    )

                # progress control
                if self.global_step >= self.N_max_global_steps:
                    self.accelerator.set_trigger()
                    break

        # track epoch
        self.current_epoch += 1
        epoch_losses = torch.stack(global_step_losses).mean(dim=0)
        epoch_loss, epoch_loss_pixel, epoch_loss_perceptual, epoch_loss_tv = epoch_losses.unbind()
        epoch_loss_dict = {
            'loss': epoch_loss.item(),
            'loss_pixel': epoch_loss_pixel.item(),
            'loss_perceptual': epoch_loss_perceptual.item(),
            'loss_tv': epoch_loss_tv.item(),
        }
        self.log_scalar_kwargs(
            epoch=self.current_epoch, split='train',
            **epoch_loss_dict,
        )
        logger.info(
            f'[TRAIN EPOCH] {self.current_epoch}/{self.cfg.train.epochs}: ' + \
                ', '.join(f'{k}={tqdm.format_num(v)}' for k, v in epoch_loss_dict.items() if not math.isnan(v))
        )

    def train(self):
        
        starting_local_step_in_epoch = self.global_step_in_epoch * self.cfg.train.accum_steps
        skipped_loader = self.accelerator.skip_first_batches(self.train_loader, starting_local_step_in_epoch)
        logger.info(f"======== Skipped {starting_local_step_in_epoch} local batches ========")

        with tqdm(
            range(0, self.N_max_global_steps),
            initial=self.global_step,
            disable=(not self.accelerator.is_main_process),
        ) as pbar:

            profiler = torch.profiler.profile(
                activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
                schedule=torch.profiler.schedule(
                    wait=10, warmup=10, active=100,
                ),
                on_trace_ready=torch.profiler.tensorboard_trace_handler(os.path.join(
                    self.cfg.logger.tracker_root,
                    self.cfg.experiment.parent, self.cfg.experiment.child,
                )),
                record_shapes=True,
                profile_memory=True,
                with_stack=True,
            ) if self.cfg.logger.enable_profiler else DummyProfiler()
            
            with profiler:

                self.optimizer.zero_grad()
                for _ in range(self.current_epoch, self.cfg.train.epochs):

                    loader = skipped_loader or self.train_loader
                    #loader = self.train_loader
                    #input("remember to get rid of line above")
                    skipped_loader = None
                    self.train_epoch(pbar=pbar, loader=loader, profiler=profiler)
                    if self.accelerator.check_trigger():
                        break

            logger.info(f"======== Training finished at global step {self.global_step} ========")

            # final checkpoint and evaluation
            self.save_checkpoint()
            self.evaluate()

    @torch.no_grad()
    @torch.compiler.disable
    def evaluate(self, epoch: int = None):
        self.model.eval()

        max_val_batches = self.cfg.val.debug_batches or len(self.val_loader)
        running_losses = []
        sample_data, sample_outs = None, None

        for data in tqdm(self.val_loader, disable=(not self.accelerator.is_main_process), total=max_val_batches):

            if len(running_losses) >= max_val_batches:
                logger.info(f"======== Early stop validation at {len(running_losses)} batches ========")
                break

            outs, loss, loss_pixel, loss_perceptual, loss_tv = self.forward_loss_local_step(data)
            sample_data, sample_outs = data, outs

            running_losses.append(torch.stack([
                _loss if _loss is not None else torch.tensor(float('nan'), device=self.device)
                for _loss in [loss, loss_pixel, loss_perceptual, loss_tv]
            ]))

        total_losses = self.accelerator.gather(torch.stack(running_losses)).mean(dim=0).cpu()
        total_loss, total_loss_pixel, total_loss_perceptual, total_loss_tv = total_losses.unbind()
        total_loss_dict = {
            'loss': total_loss.item(),
            'loss_pixel': total_loss_pixel.item(),
            'loss_perceptual': total_loss_perceptual.item(),
            'loss_tv': total_loss_tv.item(),
        }

        if epoch is not None:
            self.log_scalar_kwargs(
                epoch=epoch, split='val',
                **total_loss_dict,
            )
            logger.info(
                f'[VAL EPOCH] {epoch}/{self.cfg.train.epochs}: ' + \
                    ', '.join(f'{k}={tqdm.format_num(v)}' for k, v in total_loss_dict.items() if not math.isnan(v))
            )
            self.log_image_monitor(
                epoch=epoch, split='val',
                renders=sample_outs['images_rgb'][:self.cfg.logger.image_monitor.samples_per_log].cpu(),
                gts=sample_data['render_image'][:self.cfg.logger.image_monitor.samples_per_log].cpu(),
            )
        else:
            self.log_scalar_kwargs(
                step=self.global_step, split='val',
                **total_loss_dict,
            )
            logger.info(
                f'[VAL STEP] {self.global_step}/{self.N_max_global_steps}: ' + \
                    ', '.join(f'{k}={tqdm.format_num(v)}' for k, v in total_loss_dict.items() if not math.isnan(v))
            )
            self.log_image_monitor(
                step=self.global_step, split='val',
                renders=sample_outs['images_rgb'][:self.cfg.logger.image_monitor.samples_per_log].cpu(),
                gts=sample_data['render_image'][:self.cfg.logger.image_monitor.samples_per_log].cpu(),
            )

    @Trainer.control('on_main_process')
    def log_image_monitor(
        self, epoch: int = None, step: int = None, split: str = None,
        renders: torch.Tensor = None, gts: torch.Tensor = None,
        ):
        M = renders.shape[1]
        merged = torch.stack([renders, gts], dim=1)[0].view(-1, *renders.shape[2:])
        renders, gts = renders.view(-1, *renders.shape[2:]), gts.view(-1, *gts.shape[2:])
        renders, gts, merged = make_grid(renders, nrow=M), make_grid(gts, nrow=M), make_grid(merged, nrow=M)
        log_type, log_progress = self._get_str_progress(epoch, step)
        split = f'/{split}' if split else ''
        self.log_images({
            f'Images_split{split}/rendered': renders.unsqueeze(0),
            f'Images_split{split}/gt': gts.unsqueeze(0),
            f'Images_merged{split}': merged.unsqueeze(0),
        }, log_progress)


In [12]:
cfg.experiment

{'type': 'lrm',
 'seed': 42,
 'parent': 'lrm-objaverse',
 'child': 'small-dummyrun'}

In [13]:
# {'experiment': {'type': 'lrm', 'seed': 42, 'parent': 'lrm-objaverse', 'child': 'small-dummyrun'}, 'model': {'camera_embed_dim': 1024, 'rendering_samples_per_ray': 96, 'transformer_dim': 512, 'transform
# er_layers': 12, 'transformer_heads': 8, 'triplane_low_res': 32, 'triplane_high_res': 64, 'triplane_dim': 32, 'encoder_type': 'dinov2', 'encoder_model_name': 'dinov2_vits14_reg', 'encoder_feat_dim': 384
# , 'encoder_freeze': False}, 'dataset': {'subsets': [{'name': 'objaverse', 'root_dirs': ['<REPLACE_WITH_RENDERING_ROOT>'], 'meta_path': {'train': '<TRAIN_UIDS_IN_JSON>', 'val': '<VAL_UIDS_IN_JSON>'}, 's
# ample_rate': 1.0}], 'sample_side_views': 3, 'source_image_res': 224, 'render_image': {'low': 64, 'high': 192, 'region': 64}, 'normalize_camera': True, 'normed_dist_to_center': 'auto', 'num_train_worker
# s': 4, 'num_val_workers': 2, 'pin_mem': True}, 'train': {'mixed_precision': 'bf16', 'find_unused_parameters': False, 'loss': {'pixel_weight': 1.0, 'perceptual_weight': 1.0, 'tv_weight': 0.0005}, 'optim
# ': {'lr': 0.0004, 'weight_decay': 0.05, 'beta1': 0.9, 'beta2': 0.95, 'clip_grad_norm': 1.0}, 'scheduler': {'type': 'cosine', 'warmup_real_iters': 3000}, 'batch_size': 16, 'accum_steps': 1, 'epochs': 60
# , 'debug_global_steps': None, 'lrm': None}, 'val': {'batch_size': 4, 'global_step_period': 1000, 'debug_batches': None}, 'saver': {'auto_resume': True, 'load_model': None, 'checkpoint_root': './exps/ch
# eckpoints', 'checkpoint_global_steps': 1000, 'checkpoint_keep_level': 5}, 'logger': {'stream_level': 'WARNING', 'log_level': 'INFO', 'log_root': './exps/logs', 'tracker_root': './exps/trackers', 'enabl
# e_profiler': False, 'trackers': ['tensorboard'], 'image_monitor': {'train_global_steps': 100, 'samples_per_log': 4}}, 'compile': {'suppress_errors': True, 'print_specializations': True, 'disable': True}}

In [14]:
lrm_trainer = LRMTrainer()



dataloader_config = DataLoaderConfiguration(use_seedable_sampler=True)


In [15]:
lrm_trainer.prepare_everything()

In [16]:
bert = BertModel.from_pretrained('bert-base-uncased')
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

inputs = tokenizer("Man standing with gray shirt and brown pants, standing in a rocky surface", return_tensors="pt")



In [17]:
@torch.compile
def forward_planes(image, camera):
    # image: [N, C_img, H_img, W_img]
    # camera: [N, D_cam_raw]
    
    N = image.shape[0]
    
    # encode image
    image_feats = lrm_trainer.model.encoder(image)
    assert image_feats.shape[-1] == lrm_trainer.model.encoder_feat_dim, \
        f"Feature dimension mismatch: {image_feats.shape[-1]} vs {lrm_trainer.model.encoder_feat_dim}"

    # embed camera
    camera_embeddings = lrm_trainer.model.camera_embedder(camera)
    assert camera_embeddings.shape[-1] == lrm_trainer.model.camera_embed_dim, \
        f"Feature dimension mismatch: {camera_embeddings.shape[-1]} vs {lrm_trainer.model.camera_embed_dim}"
    
    text_feats = bert(**inputs).last_hidden_state
    # transformer generating planes
    feats = image_feats/2 + text_feats[:,0,:]/2
    

    tokens = lrm_trainer.model.forward_transformer(feats, camera_embeddings)
    planes = lrm_trainer.model.reshape_upsample(tokens)
    assert planes.shape[0] == N, "Batch size mismatch for planes"
    assert planes.shape[1] == 3, "Planes should have 3 channels"

    return planes

In [18]:
lrm_trainer.model.forward_planes = forward_planes

In [19]:
lrm_trainer.train()

60 1


100%|██████████| 60/60 [00:00<?, ?it/s]

5 1000 60



  0%|          | 0/1 [00:01<?, ?it/s][A
100%|██████████| 60/60 [00:08<?, ?it/s]


RuntimeError: The size of tensor a (384) must match the size of tensor b (768) at non-singleton dimension 2