# Install stable-dreamfusion
- Manually set '-std=c++17' in all setup.py after downloading the repo!

In [None]:
# Assume that pwd is "/content"
%%bash
git clone https://github.com/ashawkey/stable-dreamfusion.git
cd /content/stable-dreamfusion

pip install -r requirements.txt

# No need for this reproduction.
# cd /content/stable-dreamfusion/pretrained/zero123
# wget https://zero123.cs.columbia.edu/assets/zero123-xl.ckpt

# cd /content/stable-dreamfusion
# mkdir pretrained/omnidata
# cd pretrained/omnidata
# assume gdown is installed
# gdown '1Jrh-bRnJEjyMCS7f-WsaFlccfPjJPPHI&confirm=t' # omnidata_dpt_depth_v2.ckpt
# gdown '1wNxVO4vVbDEMEpnAi_jwQObf2MFodcBR&confirm=t' # omnidata_dpt_normal_v2.ckpt

# Training and testing

## Ablation: R Precision for optimization options
1. Start from a simplified DreamFusion
2. Add a large range of viewpoints (ViewAug) -- Omitted
3. View-dependent prompts (ViewDep)
4. Illuminated renders (Lighting)
5. Textureless Shaded geometry images (Textureless)

In [None]:
## 1. Remove textureless rendering
%cd /content/notebooks/stable-dreamfusion/nerf

import torch
from os import path
import random
from utils import *
from utils import Trainer

def train_step_without_textureless(self, data, save_guidance_path:Path=None):
    """
    Args:
    save_guidance_path: an image that combines the NeRF render, the added latent noise,
    the denoised result and optionally the fully-denoised image.
    """

    # perform RGBD loss instead of SDS if is image-conditioned
    do_rgbd_loss = self.opt.images is not None and \
        (self.global_step % self.opt.known_view_interval == 0)

    # override random camera with fixed known camera
    if do_rgbd_loss:
        data = self.default_view_data

    # experiment iterations ratio
    # i.e. what proportion of this experiment have we completed (in terms of iterations) so far?
    exp_iter_ratio = (self.global_step - self.opt.exp_start_iter) / (self.opt.exp_end_iter - self.opt.exp_start_iter)

    # progressively relaxing view range
    if self.opt.progressive_view:
        r = min(1.0, self.opt.progressive_view_init_ratio + 2.0*exp_iter_ratio)
        self.opt.phi_range = [self.opt.default_azimuth * (1 - r) + self.opt.full_phi_range[0] * r,
                              self.opt.default_azimuth * (1 - r) + self.opt.full_phi_range[1] * r]
        self.opt.theta_range = [self.opt.default_polar * (1 - r) + self.opt.full_theta_range[0] * r,
                                self.opt.default_polar * (1 - r) + self.opt.full_theta_range[1] * r]
        self.opt.radius_range = [self.opt.default_radius * (1 - r) + self.opt.full_radius_range[0] * r,
                                self.opt.default_radius * (1 - r) + self.opt.full_radius_range[1] * r]
        self.opt.fovy_range = [self.opt.default_fovy * (1 - r) + self.opt.full_fovy_range[0] * r,
                                self.opt.default_fovy * (1 - r) + self.opt.full_fovy_range[1] * r]

    # progressively increase max_level
    if self.opt.progressive_level:
        self.model.max_level = min(1.0, 0.25 + 2.0*exp_iter_ratio)

    rays_o = data['rays_o'] # [B, N, 3]
    rays_d = data['rays_d'] # [B, N, 3]
    mvp = data['mvp'] # [B, 4, 4]

    B, N = rays_o.shape[:2]
    H, W = data['H'], data['W']

    # When ref_data has B images > opt.batch_size
    if B > self.opt.batch_size:
        # choose batch_size images out of those B images
        choice = torch.randperm(B)[:self.opt.batch_size]
        B = self.opt.batch_size
        rays_o = rays_o[choice]
        rays_d = rays_d[choice]
        mvp = mvp[choice]

    if do_rgbd_loss:
        ambient_ratio = 1.0
        shading = 'lambertian' # use lambertian instead of albedo to get normal
        as_latent = False
        binarize = False
        bg_color = torch.rand((B * N, 3), device=rays_o.device)

        # add camera noise to avoid grid-like artifact
        if self.opt.known_view_noise_scale > 0:
            noise_scale = self.opt.known_view_noise_scale #* (1 - self.global_step / self.opt.iters)
            rays_o = rays_o + torch.randn(3, device=self.device) * noise_scale
            rays_d = rays_d + torch.randn(3, device=self.device) * noise_scale

    elif exp_iter_ratio <= self.opt.latent_iter_ratio:
        ambient_ratio = 1.0
        shading = 'normal'
        as_latent = True
        binarize = False
        bg_color = None

    else:
        if exp_iter_ratio <= self.opt.albedo_iter_ratio:
            ambient_ratio = 1.0
            shading = 'albedo'
        else:
            # random shading
            ambient_ratio = self.opt.min_ambient_ratio + (1.0-self.opt.min_ambient_ratio) * random.random()
            rand = random.random()
            if rand >= (1.0 - self.opt.textureless_ratio):
                shading = 'textureless'
                ##1st modification: removing textureless renders.##########################################################################################################
                shading = 'lambertian'
            else:
                shading = 'lambertian'

        as_latent = False

        # random weights binarization (like mobile-nerf) [NOT WORKING NOW]
        # binarize_thresh = min(0.5, -0.5 + self.global_step / self.opt.iters)
        # binarize = random.random() < binarize_thresh
        binarize = False

        # random background
        rand = random.random()
        if self.opt.bg_radius > 0 and rand > 0.5:
            bg_color = None # use bg_net
        else:
            bg_color = torch.rand(3).to(self.device) # single color random bg

    outputs = self.model.render(rays_o, rays_d, mvp, H, W, staged=False, perturb=True, bg_color=bg_color, ambient_ratio=ambient_ratio, shading=shading, binarize=binarize)
    pred_depth = outputs['depth'].reshape(B, 1, H, W)
    pred_mask = outputs['weights_sum'].reshape(B, 1, H, W)
    if 'normal_image' in outputs:
        pred_normal = outputs['normal_image'].reshape(B, H, W, 3)

    if as_latent:
        # abuse normal & mask as latent code for faster geometry initialization (ref: fantasia3D)
        pred_rgb = torch.cat([outputs['image'], outputs['weights_sum'].unsqueeze(-1)], dim=-1).reshape(B, H, W, 4).permute(0, 3, 1, 2).contiguous() # [B, 4, H, W]
    else:
        pred_rgb = outputs['image'].reshape(B, H, W, 3).permute(0, 3, 1, 2).contiguous() # [B, 3, H, W]

    # known view loss
    if do_rgbd_loss:
        gt_mask = self.mask # [B, H, W]
        gt_rgb = self.rgb   # [B, 3, H, W]
        gt_normal = self.normal # [B, H, W, 3]
        gt_depth = self.depth   # [B, H, W]

        if len(gt_rgb) > self.opt.batch_size:
            gt_mask = gt_mask[choice]
            gt_rgb = gt_rgb[choice]
            gt_normal = gt_normal[choice]
            gt_depth = gt_depth[choice]

        # color loss
        gt_rgb = gt_rgb * gt_mask[:, None].float() + bg_color.reshape(B, H, W, 3).permute(0,3,1,2).contiguous() * (1 - gt_mask[:, None].float())
        loss = self.opt.lambda_rgb * F.mse_loss(pred_rgb, gt_rgb)

        # mask loss
        loss = loss + self.opt.lambda_mask * F.mse_loss(pred_mask[:, 0], gt_mask.float())

        # normal loss
        if self.opt.lambda_normal > 0 and 'normal_image' in outputs:
            valid_gt_normal = 1 - 2 * gt_normal[gt_mask] # [B, 3]
            valid_pred_normal = 2 * pred_normal[gt_mask] - 1 # [B, 3]

            lambda_normal = self.opt.lambda_normal * min(1, self.global_step / self.opt.iters)
            loss = loss + lambda_normal * (1 - F.cosine_similarity(valid_pred_normal, valid_gt_normal).mean())

        # relative depth loss
        if self.opt.lambda_depth > 0:
            valid_gt_depth = gt_depth[gt_mask] # [B,]
            valid_pred_depth = pred_depth[:, 0][gt_mask] # [B,]
            lambda_depth = self.opt.lambda_depth * min(1, self.global_step / self.opt.iters)
            loss = loss + lambda_depth * (1 - self.pearson(valid_pred_depth, valid_gt_depth))

            # # scale-invariant
            # with torch.no_grad():
            #     A = torch.cat([valid_gt_depth, torch.ones_like(valid_gt_depth)], dim=-1) # [B, 2]
            #     X = torch.linalg.lstsq(A, valid_pred_depth).solution # [2, 1]
            #     valid_gt_depth = A @ X # [B, 1]
            # lambda_depth = self.opt.lambda_depth #* min(1, self.global_step / self.opt.iters)
            # loss = loss + lambda_depth * F.mse_loss(valid_pred_depth, valid_gt_depth)

    # novel view loss
    else:

        loss = 0

        if 'SD' in self.guidance:
            # interpolate text_z
            azimuth = data['azimuth'] # [-180, 180]

            # ENHANCE: remove loop to handle batch size > 1
            text_z = [self.embeddings['SD']['uncond']] * azimuth.shape[0]
            if self.opt.perpneg:

                text_z_comp, weights = adjust_text_embeddings(self.embeddings['SD'], azimuth, self.opt)
                text_z.append(text_z_comp)

            else:
                for b in range(azimuth.shape[0]):
                    if azimuth[b] >= -90 and azimuth[b] < 90:
                        if azimuth[b] >= 0:
                            r = 1 - azimuth[b] / 90
                        else:
                            r = 1 + azimuth[b] / 90
                        start_z = self.embeddings['SD']['front']
                        end_z = self.embeddings['SD']['side']
                    else:
                        if azimuth[b] >= 0:
                            r = 1 - (azimuth[b] - 90) / 90
                        else:
                            r = 1 + (azimuth[b] + 90) / 90
                        start_z = self.embeddings['SD']['side']
                        end_z = self.embeddings['SD']['back']
                    text_z.append(r * start_z + (1 - r) * end_z)

            text_z = torch.cat(text_z, dim=0)
            if self.opt.perpneg:
                loss = loss + self.guidance['SD'].train_step_perpneg(text_z, weights, pred_rgb, as_latent=as_latent, guidance_scale=self.opt.guidance_scale, grad_scale=self.opt.lambda_guidance,
                                                save_guidance_path=save_guidance_path)
            else:
                loss = loss + self.guidance['SD'].train_step(text_z, pred_rgb, as_latent=as_latent, guidance_scale=self.opt.guidance_scale, grad_scale=self.opt.lambda_guidance,
                                                                save_guidance_path=save_guidance_path)

        if 'IF' in self.guidance:
            # interpolate text_z
            azimuth = data['azimuth'] # [-180, 180]

            # ENHANCE: remove loop to handle batch size > 1
            text_z = [self.embeddings['IF']['uncond']] * azimuth.shape[0]
            if self.opt.perpneg:
                text_z_comp, weights = adjust_text_embeddings(self.embeddings['IF'], azimuth, self.opt)
                text_z.append(text_z_comp)
            else:
                for b in range(azimuth.shape[0]):
                    if azimuth[b] >= -90 and azimuth[b] < 90:
                        if azimuth[b] >= 0:
                            r = 1 - azimuth[b] / 90
                        else:
                            r = 1 + azimuth[b] / 90
                        start_z = self.embeddings['IF']['front']
                        end_z = self.embeddings['IF']['side']
                    else:
                        if azimuth[b] >= 0:
                            r = 1 - (azimuth[b] - 90) / 90
                        else:
                            r = 1 + (azimuth[b] + 90) / 90
                        start_z = self.embeddings['IF']['side']
                        end_z = self.embeddings['IF']['back']
                    text_z.append(r * start_z + (1 - r) * end_z)

            text_z = torch.cat(text_z, dim=0)

            if self.opt.perpneg:
                loss = loss + self.guidance['IF'].train_step_perpneg(text_z, weights, pred_rgb, guidance_scale=self.opt.guidance_scale, grad_scale=self.opt.lambda_guidance)
            else:
                loss = loss + self.guidance['IF'].train_step(text_z, pred_rgb, guidance_scale=self.opt.guidance_scale, grad_scale=self.opt.lambda_guidance)

        if 'zero123' in self.guidance:

            polar = data['polar']
            azimuth = data['azimuth']
            radius = data['radius']

            loss = loss + self.guidance['zero123'].train_step(self.embeddings['zero123']['default'], pred_rgb, polar, azimuth, radius, guidance_scale=self.opt.guidance_scale,
                                                              as_latent=as_latent, grad_scale=self.opt.lambda_guidance, save_guidance_path=save_guidance_path)

        if 'clip' in self.guidance:

            # empirical, far view should apply smaller CLIP loss
            lambda_guidance = 10 * (1 - abs(azimuth) / 180) * self.opt.lambda_guidance

            loss = loss + self.guidance['clip'].train_step(self.embeddings['clip'], pred_rgb, grad_scale=lambda_guidance)

    # regularizations
    if not self.opt.dmtet:

        if self.opt.lambda_opacity > 0:
            loss_opacity = (outputs['weights_sum'] ** 2).mean()
            loss = loss + self.opt.lambda_opacity * loss_opacity

        if self.opt.lambda_entropy > 0:
            alphas = outputs['weights'].clamp(1e-5, 1 - 1e-5)
            # alphas = alphas ** 2 # skewed entropy, favors 0 over 1
            loss_entropy = (- alphas * torch.log2(alphas) - (1 - alphas) * torch.log2(1 - alphas)).mean()
            lambda_entropy = self.opt.lambda_entropy * min(1, 2 * self.global_step / self.opt.iters)
            loss = loss + lambda_entropy * loss_entropy

        if self.opt.lambda_2d_normal_smooth > 0 and 'normal_image' in outputs:
            # pred_vals = outputs['normal_image'].reshape(B, H, W, 3).permute(0, 3, 1, 2).contiguous()
            # smoothed_vals = TF.gaussian_blur(pred_vals.detach(), kernel_size=9)
            # loss_smooth = F.mse_loss(pred_vals, smoothed_vals)
            # total-variation
            loss_smooth = (pred_normal[:, 1:, :, :] - pred_normal[:, :-1, :, :]).square().mean() + \
                          (pred_normal[:, :, 1:, :] - pred_normal[:, :, :-1, :]).square().mean()
            loss = loss + self.opt.lambda_2d_normal_smooth * loss_smooth

        if self.opt.lambda_orient > 0 and 'loss_orient' in outputs:
            loss_orient = outputs['loss_orient']
            loss = loss + self.opt.lambda_orient * loss_orient

        if self.opt.lambda_3d_normal_smooth > 0 and 'loss_normal_perturb' in outputs:
            loss_normal_perturb = outputs['loss_normal_perturb']
            loss = loss + self.opt.lambda_3d_normal_smooth * loss_normal_perturb

    else:

        if self.opt.lambda_mesh_normal > 0:
            loss = loss + self.opt.lambda_mesh_normal * outputs['normal_loss']

        if self.opt.lambda_mesh_laplacian > 0:
            loss = loss + self.opt.lambda_mesh_laplacian * outputs['lap_loss']

    return pred_rgb, pred_depth, loss


# Monkey patching
Trainer.train_step = train_step_without_textureless

/content/notebooks/stable-dreamfusion/nerf


  @torch.cuda.amp.autocast(enabled=False)


In [None]:
%cd "/content/notebooks/stable-dreamfusion"
!python main.py -O --text "A bulldog is wearing a black pirate hat." --workspace ablation_trial
!python main.py --workspace ablation_trial -O --test

/content/notebooks/stable-dreamfusion
  @torch.cuda.amp.autocast(enabled=False)
  @custom_fwd(cast_inputs=torch.float)
  @custom_bwd
  @custom_fwd(cast_inputs=torch.float32)
  @custom_fwd(cast_inputs=torch.float32)
  @custom_fwd(cast_inputs=torch.float32)
  @custom_fwd(cast_inputs=torch.float32)
  @custom_fwd(cast_inputs=torch.float32)
  @custom_bwd
  @custom_fwd(cast_inputs=torch.float32)
  @custom_fwd(cast_inputs=torch.float32) # need to cast sigmas & rgbs to float
  @torch.cuda.amp.autocast(enabled=False)
  @torch.cuda.amp.autocast(enabled=False)
  @torch.cuda.amp.autocast(enabled=False)
Namespace(file=None, text='A bulldog is wearing a black pirate hat.', negative='', O=True, O2=False, test=False, six_views=False, eval_interval=1, test_interval=100, workspace='ablation_trial', seed=None, image=None, image_config=None, known_view_interval=4, IF=False, guidance=['SD'], guidance_scale=100, save_mesh=False, mcubes_resolution=256, decimate_target=50000.0, dmtet=False, tet_grid_size=128,

In [None]:
## 2. Remove illuminated rendering
%cd /content/stable-dreamfusion/nerf

import torch
from os import path
import random
from utils import *
from utils import Trainer

# Adjust Trainer.train_step [function]

def train_step_without_textureless_illumination(self, data, save_guidance_path:Path=None):
    """
    Args:
    save_guidance_path: an image that combines the NeRF render, the added latent noise,
    the denoised result and optionally the fully-denoised image.
    """

    # perform RGBD loss instead of SDS if is image-conditioned
    do_rgbd_loss = self.opt.images is not None and \
        (self.global_step % self.opt.known_view_interval == 0)

    # override random camera with fixed known camera
    if do_rgbd_loss:
        data = self.default_view_data

    # experiment iterations ratio
    # i.e. what proportion of this experiment have we completed (in terms of iterations) so far?
    exp_iter_ratio = (self.global_step - self.opt.exp_start_iter) / (self.opt.exp_end_iter - self.opt.exp_start_iter)

    # progressively relaxing view range
    if self.opt.progressive_view:
        r = min(1.0, self.opt.progressive_view_init_ratio + 2.0*exp_iter_ratio)
        self.opt.phi_range = [self.opt.default_azimuth * (1 - r) + self.opt.full_phi_range[0] * r,
                              self.opt.default_azimuth * (1 - r) + self.opt.full_phi_range[1] * r]
        self.opt.theta_range = [self.opt.default_polar * (1 - r) + self.opt.full_theta_range[0] * r,
                                self.opt.default_polar * (1 - r) + self.opt.full_theta_range[1] * r]
        self.opt.radius_range = [self.opt.default_radius * (1 - r) + self.opt.full_radius_range[0] * r,
                                self.opt.default_radius * (1 - r) + self.opt.full_radius_range[1] * r]
        self.opt.fovy_range = [self.opt.default_fovy * (1 - r) + self.opt.full_fovy_range[0] * r,
                                self.opt.default_fovy * (1 - r) + self.opt.full_fovy_range[1] * r]

    # progressively increase max_level
    if self.opt.progressive_level:
        self.model.max_level = min(1.0, 0.25 + 2.0*exp_iter_ratio)

    rays_o = data['rays_o'] # [B, N, 3]
    rays_d = data['rays_d'] # [B, N, 3]
    mvp = data['mvp'] # [B, 4, 4]

    B, N = rays_o.shape[:2]
    H, W = data['H'], data['W']

    # When ref_data has B images > opt.batch_size
    if B > self.opt.batch_size:
        # choose batch_size images out of those B images
        choice = torch.randperm(B)[:self.opt.batch_size]
        B = self.opt.batch_size
        rays_o = rays_o[choice]
        rays_d = rays_d[choice]
        mvp = mvp[choice]

    if do_rgbd_loss:
        ambient_ratio = 1.0
        #An edition for debugging, can be disable if it's unnecessary####################################################################################################
        print("\033[96m[INFO] Enter the do_rgb_loss branch!\033[96m")
        shading = 'lambertian' # use lambertian instead of albedo to get normal
        as_latent = False
        binarize = False
        bg_color = torch.rand((B * N, 3), device=rays_o.device)

        # add camera noise to avoid grid-like artifact
        if self.opt.known_view_noise_scale > 0:
            noise_scale = self.opt.known_view_noise_scale #* (1 - self.global_step / self.opt.iters)
            rays_o = rays_o + torch.randn(3, device=self.device) * noise_scale
            rays_d = rays_d + torch.randn(3, device=self.device) * noise_scale

    elif exp_iter_ratio <= self.opt.latent_iter_ratio:
        ambient_ratio = 1.0
        shading = 'normal'

        #2nd modification: remove illuminated rendering######################################################################################
        shading = 'albedo'
        # as_latent = True
        # binarize = False
        # bg_color = None

    else:
        if exp_iter_ratio <= self.opt.albedo_iter_ratio:
            ambient_ratio = 1.0
            shading = 'albedo'
        else:
            # random shading
            ambient_ratio = self.opt.min_ambient_ratio + (1.0-self.opt.min_ambient_ratio) * random.random()
            rand = random.random()
            if rand >= (1.0 - self.opt.textureless_ratio):
                shading = 'textureless'

                #2nd modification: remove illuminated rendering######################################################################################
                shading = 'albedo'
            else:
                shading = 'lambertian'
                #2nd modification: remove illuminated rendering######################################################################################
                shading = 'albedo'

        as_latent = False

        # random weights binarization (like mobile-nerf) [NOT WORKING NOW]
        # binarize_thresh = min(0.5, -0.5 + self.global_step / self.opt.iters)
        # binarize = random.random() < binarize_thresh
        binarize = False

        # random background
        rand = random.random()
        if self.opt.bg_radius > 0 and rand > 0.5:
            bg_color = None # use bg_net
        else:
            bg_color = torch.rand(3).to(self.device) # single color random bg

    outputs = self.model.render(rays_o, rays_d, mvp, H, W, staged=False, perturb=True, bg_color=bg_color, ambient_ratio=ambient_ratio, shading=shading, binarize=binarize)
    pred_depth = outputs['depth'].reshape(B, 1, H, W)
    pred_mask = outputs['weights_sum'].reshape(B, 1, H, W)
    if 'normal_image' in outputs:
        pred_normal = outputs['normal_image'].reshape(B, H, W, 3)

    if as_latent:
        # abuse normal & mask as latent code for faster geometry initialization (ref: fantasia3D)
        pred_rgb = torch.cat([outputs['image'], outputs['weights_sum'].unsqueeze(-1)], dim=-1).reshape(B, H, W, 4).permute(0, 3, 1, 2).contiguous() # [B, 4, H, W]
    else:
        pred_rgb = outputs['image'].reshape(B, H, W, 3).permute(0, 3, 1, 2).contiguous() # [B, 3, H, W]

    # known view loss
    if do_rgbd_loss:
        gt_mask = self.mask # [B, H, W]
        gt_rgb = self.rgb   # [B, 3, H, W]
        gt_normal = self.normal # [B, H, W, 3]
        gt_depth = self.depth   # [B, H, W]

        if len(gt_rgb) > self.opt.batch_size:
            gt_mask = gt_mask[choice]
            gt_rgb = gt_rgb[choice]
            gt_normal = gt_normal[choice]
            gt_depth = gt_depth[choice]

        # color loss
        gt_rgb = gt_rgb * gt_mask[:, None].float() + bg_color.reshape(B, H, W, 3).permute(0,3,1,2).contiguous() * (1 - gt_mask[:, None].float())
        loss = self.opt.lambda_rgb * F.mse_loss(pred_rgb, gt_rgb)

        # mask loss
        loss = loss + self.opt.lambda_mask * F.mse_loss(pred_mask[:, 0], gt_mask.float())

        # normal loss
        if self.opt.lambda_normal > 0 and 'normal_image' in outputs:
            valid_gt_normal = 1 - 2 * gt_normal[gt_mask] # [B, 3]
            valid_pred_normal = 2 * pred_normal[gt_mask] - 1 # [B, 3]

            lambda_normal = self.opt.lambda_normal * min(1, self.global_step / self.opt.iters)
            loss = loss + lambda_normal * (1 - F.cosine_similarity(valid_pred_normal, valid_gt_normal).mean())

        # relative depth loss
        if self.opt.lambda_depth > 0:
            valid_gt_depth = gt_depth[gt_mask] # [B,]
            valid_pred_depth = pred_depth[:, 0][gt_mask] # [B,]
            lambda_depth = self.opt.lambda_depth * min(1, self.global_step / self.opt.iters)
            loss = loss + lambda_depth * (1 - self.pearson(valid_pred_depth, valid_gt_depth))

            # # scale-invariant
            # with torch.no_grad():
            #     A = torch.cat([valid_gt_depth, torch.ones_like(valid_gt_depth)], dim=-1) # [B, 2]
            #     X = torch.linalg.lstsq(A, valid_pred_depth).solution # [2, 1]
            #     valid_gt_depth = A @ X # [B, 1]
            # lambda_depth = self.opt.lambda_depth #* min(1, self.global_step / self.opt.iters)
            # loss = loss + lambda_depth * F.mse_loss(valid_pred_depth, valid_gt_depth)

    # novel view loss
    else:

        loss = 0

        if 'SD' in self.guidance:
            # interpolate text_z
            azimuth = data['azimuth'] # [-180, 180]

            # ENHANCE: remove loop to handle batch size > 1
            text_z = [self.embeddings['SD']['uncond']] * azimuth.shape[0]
            if self.opt.perpneg:

                text_z_comp, weights = adjust_text_embeddings(self.embeddings['SD'], azimuth, self.opt)
                text_z.append(text_z_comp)

            else:
                for b in range(azimuth.shape[0]):
                    if azimuth[b] >= -90 and azimuth[b] < 90:
                        if azimuth[b] >= 0:
                            r = 1 - azimuth[b] / 90
                        else:
                            r = 1 + azimuth[b] / 90
                        start_z = self.embeddings['SD']['front']
                        end_z = self.embeddings['SD']['side']
                    else:
                        if azimuth[b] >= 0:
                            r = 1 - (azimuth[b] - 90) / 90
                        else:
                            r = 1 + (azimuth[b] + 90) / 90
                        start_z = self.embeddings['SD']['side']
                        end_z = self.embeddings['SD']['back']
                    text_z.append(r * start_z + (1 - r) * end_z)

            text_z = torch.cat(text_z, dim=0)
            if self.opt.perpneg:
                loss = loss + self.guidance['SD'].train_step_perpneg(text_z, weights, pred_rgb, as_latent=as_latent, guidance_scale=self.opt.guidance_scale, grad_scale=self.opt.lambda_guidance,
                                                save_guidance_path=save_guidance_path)
            else:
                loss = loss + self.guidance['SD'].train_step(text_z, pred_rgb, as_latent=as_latent, guidance_scale=self.opt.guidance_scale, grad_scale=self.opt.lambda_guidance,
                                                                save_guidance_path=save_guidance_path)

        if 'IF' in self.guidance:
            # interpolate text_z
            azimuth = data['azimuth'] # [-180, 180]

            # ENHANCE: remove loop to handle batch size > 1
            text_z = [self.embeddings['IF']['uncond']] * azimuth.shape[0]
            if self.opt.perpneg:
                text_z_comp, weights = adjust_text_embeddings(self.embeddings['IF'], azimuth, self.opt)
                text_z.append(text_z_comp)
            else:
                for b in range(azimuth.shape[0]):
                    if azimuth[b] >= -90 and azimuth[b] < 90:
                        if azimuth[b] >= 0:
                            r = 1 - azimuth[b] / 90
                        else:
                            r = 1 + azimuth[b] / 90
                        start_z = self.embeddings['IF']['front']
                        end_z = self.embeddings['IF']['side']
                    else:
                        if azimuth[b] >= 0:
                            r = 1 - (azimuth[b] - 90) / 90
                        else:
                            r = 1 + (azimuth[b] + 90) / 90
                        start_z = self.embeddings['IF']['side']
                        end_z = self.embeddings['IF']['back']
                    text_z.append(r * start_z + (1 - r) * end_z)

            text_z = torch.cat(text_z, dim=0)

            if self.opt.perpneg:
                loss = loss + self.guidance['IF'].train_step_perpneg(text_z, weights, pred_rgb, guidance_scale=self.opt.guidance_scale, grad_scale=self.opt.lambda_guidance)
            else:
                loss = loss + self.guidance['IF'].train_step(text_z, pred_rgb, guidance_scale=self.opt.guidance_scale, grad_scale=self.opt.lambda_guidance)

        if 'zero123' in self.guidance:

            polar = data['polar']
            azimuth = data['azimuth']
            radius = data['radius']

            loss = loss + self.guidance['zero123'].train_step(self.embeddings['zero123']['default'], pred_rgb, polar, azimuth, radius, guidance_scale=self.opt.guidance_scale,
                                                              as_latent=as_latent, grad_scale=self.opt.lambda_guidance, save_guidance_path=save_guidance_path)

        if 'clip' in self.guidance:

            # empirical, far view should apply smaller CLIP loss
            lambda_guidance = 10 * (1 - abs(azimuth) / 180) * self.opt.lambda_guidance

            loss = loss + self.guidance['clip'].train_step(self.embeddings['clip'], pred_rgb, grad_scale=lambda_guidance)

    # regularizations
    if not self.opt.dmtet:

        if self.opt.lambda_opacity > 0:
            loss_opacity = (outputs['weights_sum'] ** 2).mean()
            loss = loss + self.opt.lambda_opacity * loss_opacity

        if self.opt.lambda_entropy > 0:
            alphas = outputs['weights'].clamp(1e-5, 1 - 1e-5)
            # alphas = alphas ** 2 # skewed entropy, favors 0 over 1
            loss_entropy = (- alphas * torch.log2(alphas) - (1 - alphas) * torch.log2(1 - alphas)).mean()
            lambda_entropy = self.opt.lambda_entropy * min(1, 2 * self.global_step / self.opt.iters)
            loss = loss + lambda_entropy * loss_entropy

        if self.opt.lambda_2d_normal_smooth > 0 and 'normal_image' in outputs:
            # pred_vals = outputs['normal_image'].reshape(B, H, W, 3).permute(0, 3, 1, 2).contiguous()
            # smoothed_vals = TF.gaussian_blur(pred_vals.detach(), kernel_size=9)
            # loss_smooth = F.mse_loss(pred_vals, smoothed_vals)
            # total-variation
            loss_smooth = (pred_normal[:, 1:, :, :] - pred_normal[:, :-1, :, :]).square().mean() + \
                          (pred_normal[:, :, 1:, :] - pred_normal[:, :, :-1, :]).square().mean()
            loss = loss + self.opt.lambda_2d_normal_smooth * loss_smooth

        if self.opt.lambda_orient > 0 and 'loss_orient' in outputs:
            loss_orient = outputs['loss_orient']
            loss = loss + self.opt.lambda_orient * loss_orient

        if self.opt.lambda_3d_normal_smooth > 0 and 'loss_normal_perturb' in outputs:
            loss_normal_perturb = outputs['loss_normal_perturb']
            loss = loss + self.opt.lambda_3d_normal_smooth * loss_normal_perturb

    else:

        if self.opt.lambda_mesh_normal > 0:
            loss = loss + self.opt.lambda_mesh_normal * outputs['normal_loss']

        if self.opt.lambda_mesh_laplacian > 0:
            loss = loss + self.opt.lambda_mesh_laplacian * outputs['lap_loss']

    return pred_rgb, pred_depth, loss


# Monkey patching
Trainer.train_step = train_step_without_textureless_illumination

/content/notebooks/stable-dreamfusion/nerf


In [None]:
%cd "/content/notebooks/stable-dreamfusion"
!python main.py -O --text "A bulldog is wearing a black pirate hat." --workspace ablation_trial2
!python main.py --workspace ablation_trial2 -O --test

/content/notebooks/stable-dreamfusion
  @torch.cuda.amp.autocast(enabled=False)
  @custom_fwd(cast_inputs=torch.float)
  @custom_bwd
  @custom_fwd(cast_inputs=torch.float32)
  @custom_fwd(cast_inputs=torch.float32)
  @custom_fwd(cast_inputs=torch.float32)
  @custom_fwd(cast_inputs=torch.float32)
  @custom_fwd(cast_inputs=torch.float32)
  @custom_bwd
  @custom_fwd(cast_inputs=torch.float32)
  @custom_fwd(cast_inputs=torch.float32) # need to cast sigmas & rgbs to float
  @torch.cuda.amp.autocast(enabled=False)
  @torch.cuda.amp.autocast(enabled=False)
  @torch.cuda.amp.autocast(enabled=False)
Namespace(file=None, text='A bulldog is wearing a black pirate hat.', negative='', O=True, O2=False, test=False, six_views=False, eval_interval=1, test_interval=100, workspace='ablation_trial2', seed=None, image=None, image_config=None, known_view_interval=4, IF=False, guidance=['SD'], guidance_scale=100, save_mesh=False, mcubes_resolution=256, decimate_target=50000.0, dmtet=False, tet_grid_size=128

In [None]:
# 3. Disable view-dependent text prompts
%cd /content/notebooks/stable-dreamfusion/nerf

import torch
from os import path
import random
from utils import *
from utils import adjust_text_embeddings

def foo_adjust_text_embeddings(embeddings, azimuth, opt):
    text_z_list = []
    weights_list = []
    K = 0
    for b in range(azimuth.shape[0]):
        # 3rd modification: Disable view-dependent text prompts by fixing the view input to adjust_text_embeddings()#######################################################################################
        # text_z_, weights_ = get_pos_neg_text_embeddings(embeddings, azimuth[b], opt)
        text_z_, weights_ = get_pos_neg_text_embeddings(embeddings, 0, opt)
        K = max(K, weights_.shape[0])
        text_z_list.append(text_z_)
        weights_list.append(weights_)

    # Interleave text_embeddings from different dirs to form a batch
    text_embeddings = []
    for i in range(K):
        for text_z in text_z_list:
            # if uneven length, pad with the first embedding
            text_embeddings.append(text_z[i] if i < len(text_z) else text_z[0])
    text_embeddings = torch.stack(text_embeddings, dim=0) # [B * K, 77, 768]

    # Interleave weights from different dirs to form a batch
    weights = []
    for i in range(K):
        for weights_ in weights_list:
            weights.append(weights_[i] if i < len(weights_) else torch.zeros_like(weights_[0]))
    weights = torch.stack(weights, dim=0) # [B * K]
    return text_embeddings, weights

# Monkey patching
adjust_text_embeddings = foo_adjust_text_embeddings

/content/notebooks/stable-dreamfusion/nerf


In [None]:
%cd "/content/notebooks/stable-dreamfusion"
!python main.py -O --text "A bulldog is wearing a black pirate hat." --workspace ablation_trial3
!python main.py --workspace ablation_trial3 -O --test

/content/notebooks/stable-dreamfusion
  @torch.cuda.amp.autocast(enabled=False)
  @custom_fwd(cast_inputs=torch.float)
  @custom_bwd
  @custom_fwd(cast_inputs=torch.float32)
  @custom_fwd(cast_inputs=torch.float32)
  @custom_fwd(cast_inputs=torch.float32)
  @custom_fwd(cast_inputs=torch.float32)
  @custom_fwd(cast_inputs=torch.float32)
  @custom_bwd
  @custom_fwd(cast_inputs=torch.float32)
  @custom_fwd(cast_inputs=torch.float32) # need to cast sigmas & rgbs to float
  @torch.cuda.amp.autocast(enabled=False)
  @torch.cuda.amp.autocast(enabled=False)
  @torch.cuda.amp.autocast(enabled=False)
Namespace(file=None, text='A bulldog is wearing a black pirate hat.', negative='', O=True, O2=False, test=False, six_views=False, eval_interval=1, test_interval=100, workspace='ablation_trial2', seed=None, image=None, image_config=None, known_view_interval=4, IF=False, guidance=['SD'], guidance_scale=100, save_mesh=False, mcubes_resolution=256, decimate_target=50000.0, dmtet=False, tet_grid_size=128