In [1]:
import os

os.chdir("/root/dev/triplora/")
os.getcwd()

'/root/dev/triplora'

In [2]:
import torch

torch.cuda.set_device(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
!gpustat

[1m[37maa4173e0a5f2              [m  Mon May 20 22:32:13 2024  [1m[30m535.129.03[m
[36m[0][m [34mNVIDIA GeForce RTX 4090[m |[31m 27°C[m, [32m  0 %[m | [36m[1m[33m    5[m / [33m24564[m MB |
[36m[1][m [34mNVIDIA GeForce RTX 4090[m |[31m 32°C[m, [32m  0 %[m | [36m[1m[33m    5[m / [33m24564[m MB |
[36m[2][m [34mNVIDIA GeForce RTX 4090[m |[31m 32°C[m, [32m  0 %[m | [36m[1m[33m    5[m / [33m24564[m MB |
[36m[3][m [34mNVIDIA GeForce RTX 4090[m |[31m 31°C[m, [32m  0 %[m | [36m[1m[33m    5[m / [33m24564[m MB |
[36m[4][m [34mNVIDIA GeForce RTX 4090[m |[31m 32°C[m, [32m  0 %[m | [36m[1m[33m    5[m / [33m24564[m MB |
[36m[5][m [34mNVIDIA GeForce RTX 4090[m |[31m 32°C[m, [32m  0 %[m | [36m[1m[33m    5[m / [33m24564[m MB |


# DataLoader

In [33]:
import numpy as np

In [34]:
grid_res = 128
mesh_scale = 1.1
fitted_tet_path = "/root/dataset_sj/DMTet/res_128/chair/tets_pre/dmt_dict_00000.pt"
init_tet_path = "/root/dev/DMTet_Models/MeshDiffusion/nvdiffrec/data_/tets/128_tets_cropped.npz"

In [35]:
tet_init = np.load(init_tet_path)
init_verts = torch.tensor(tet_init["vertices"], dtype=torch.float32, device=device) * mesh_scale
tet_indices = torch.tensor(tet_init["indices"], dtype=torch.long)

In [36]:
def get_deformed(verts, deform, grid_res=128, deform_scale=0.45, no_grad=False):
    """ deform_scale is 0.45 for resolution 128 and 2.0 for resolution 64 """
    deform = deform.detach() if no_grad else deform
    return verts + 2 / (grid_res * 2) * deform * deform_scale

## DatasetMesh from MeshDiffusion

In [37]:
!ls /root/dev/DMTet_Models/MeshDiffusion/data/03001627/1006be65e7bc937e9141f9b58470d646

model.mtl  model.obj  models


In [38]:
mesh_path = "/root/dev/DMTet_Models/MeshDiffusion/data/03001627/1006be65e7bc937e9141f9b58470d646/model.obj"

In [40]:
from utils.render.mesh import load_mesh
from utils.render import texture

mtl_default = {
    'name' : '_default_mat',
    'bsdf': 'diffuse',
    'uniform': True,
    'kd'   : texture.Texture2D(torch.tensor([0.75, 0.3, 0.6], dtype=torch.float32, device='cuda'), trainable=False),
    'ks'   : texture.Texture2D(torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32, device='cuda'), trainable=False)
}

ref_mesh = load_mesh(mesh_path, mtl_override=None, mtl_default=mtl_default, use_default=True, no_additional=True)
print(vars(ref_mesh).keys())

dict_keys(['v_pos', 'v_nrm', 'v_tex', 'v_tng', 't_pos_idx', 't_nrm_idx', 't_tex_idx', 't_tng_idx', 'material', 'f_nrm'])


In [41]:
import kaolin

sampled_pts, face_indices = kaolin.ops.mesh.sample_points(ref_mesh.v_pos.unsqueeze(0), ref_mesh.t_pos_idx, 50000)

In [42]:
sampled_pts.shape

torch.Size([1, 50000, 3])

In [43]:
# import trimesh

# scene = trimesh.Scene()
# sampled_pc = trimesh.PointCloud(sampled_pts[0].detach().cpu().numpy())
# sampled_pc.colors = [255, 0, 0, 0]
# scene.add_geometry(sampled_pc)
# scene.show()

In [44]:
print("max: ", sampled_pts.max())
print("min: ", sampled_pts.min())
print("mean: ", sampled_pts.mean())
print("mean: ", sampled_pts.std())

max:  tensor(0.3420, device='cuda:0')
min:  tensor(-0.3423, device='cuda:0')
mean:  tensor(0.0139, device='cuda:0')
mean:  tensor(0.1493, device='cuda:0')


In [45]:
def normalize_point_clouds(pcs: torch.Tensor, return_shift_scale=False, padding=0.0):
    # refactored version for batched processing
    pcs = pcs.clone()
    pc_max = pcs[..., :3].amax(dim=1, keepdim=True)  # (B, 1, 3)
    pc_min = pcs[..., :3].amin(dim=1, keepdim=True)
    shift = (pc_min + pc_max) / 2
    scale = 2 / (pc_max - pc_min).amax(dim=-1, keepdim=True) * (1 - padding)
    pcs[..., :3] = (pcs[..., :3] - shift) * scale

    if return_shift_scale:
        return pcs, shift, scale
    else:
        return pcs

In [46]:
torch.mean(sampled_pts, dim=0).shape

torch.Size([50000, 3])

In [47]:
sampled_pts_norm, shift, scale = normalize_point_clouds(sampled_pts, return_shift_scale=True, padding=0.1)
print("max: ", sampled_pts_norm.max())
print("min: ", sampled_pts_norm.min())
print("mean: ", sampled_pts_norm.mean())
print("std: ", sampled_pts_norm.std())

max:  tensor(0.9000, device='cuda:0')
min:  tensor(-0.9000, device='cuda:0')
mean:  tensor(0.0369, device='cuda:0')
std:  tensor(0.3927, device='cuda:0')


In [48]:
shift.shape

torch.Size([1, 1, 3])

In [49]:
def denormalize_point_clouds(pcs: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor):
    pcs = pcs.clone()
    pcs[..., :3] = pcs[..., :3] / scale + shift
    return pcs

In [50]:
denormed_pts = denormalize_point_clouds(sampled_pts_norm, shift, scale)
print("max: ", denormed_pts.max())
print("min: ", denormed_pts.min())
print("mean: ", denormed_pts.mean())
print("std: ", denormed_pts.std())

max:  tensor(0.3420, device='cuda:0')
min:  tensor(-0.3423, device='cuda:0')
mean:  tensor(0.0139, device='cuda:0')
std:  tensor(0.1493, device='cuda:0')


In [51]:
!ls /root/dataset_sj/DMTet/data/irrmaps

README.txt  aerodynamics_workshop_2k.hdr  bsdf_256_256.bin


In [53]:
import time
import numpy as np

import nvdiffrast.torch as dr
import kaolin

from utils.render import light, util, render, mesh


glctx = dr.RasterizeCudaContext()
train_res = [1000, 1000]
fovy = np.deg2rad(45)
cam_near_far = [0.1, 1000.0]
cam_radius = 2.0
env_scale = 1.0
envlight = light.load_env("/root/dataset_sj/DMTet/data/irrmaps/aerodynamics_workshop_2k.hdr", env_scale)
layers = 1
flat_shading = False

def _random_scene(train_res, fovy, cam_near_far, cam_radius):
    # ==============================================================================================
    #  Setup projection matrix
    # ==============================================================================================
    proj_mtx = util.perspective(fovy, train_res[1] / train_res[0], cam_near_far[0], cam_near_far[1])

    # ==============================================================================================
    #  Random camera & light position
    # ==============================================================================================

    # Random rotation/translation matrix for optimization.
    mv     = util.translate(0, 0, -cam_radius) @ util.random_rotation_translation(0.2)
    mvp    = proj_mtx @ mv
    campos = torch.linalg.inv(mv)[:3, 3]

    return mv[None, ...].cuda(), mvp[None, ...].cuda(), campos[None, ...].cuda(), train_res # Add batch dimension


def get_item(glctx, ref_mesh, envlight, train_res, fovy, cam_near_far, cam_radius, layers=1, flat_shading=False, random_light=True, camera_light=False):
    mv, mvp, campos, train_res = _random_scene(train_res, fovy, cam_near_far, cam_radius)
    if random_light: # True
        rnd_rot = util.random_rotation()
        camera_mv = rnd_rot.unsqueeze(0).clone()
    elif camera_light: # False
        camera_mv = mv.clone()
    else:
        camera_mv = None
    
    with torch.no_grad():
        render_out = render.render_mesh(glctx, ref_mesh, mvp, campos, envlight, train_res, spp=1, 
                            num_layers=layers, msaa=True, background=None, xfm_lgt=camera_mv, flat_shading=flat_shading)
        img = render_out['shaded']
        img_second = render_out['shaded_second']
        normal = render_out['normal']
        depth = render_out['depth']
        geo_normal = render_out['geo_normal']
        pos = render_out['pos']
    
        sample_points = torch.tensor(kaolin.ops.mesh.sample_points(ref_mesh.v_pos.unsqueeze(0), ref_mesh.t_pos_idx, 50000)[0][0])
        vertex_points = ref_mesh.v_pos
    
    return_dict = {
        'mv' : mv,
        'mvp' : mvp,
        'campos' : campos,
        'resolution' : train_res,
        'spp' : 1, # from FLAGS
        'img' : img,
        'img_second' : img_second,
        'spts': sample_points,
        'vpts': vertex_points,
        'faces': ref_mesh.t_pos_idx,
        'depth': depth,
        'normal': normal,
        'geo_normal': geo_normal,
        'geo_viewdir': render_out['geo_viewdir'],
        'pos': pos,
        'envlight_transform': camera_mv,
        'mask': render_out['mask'],
        'mask_cont': render_out['mask_cont'],
        'rast_triangle_id': render_out['rast_triangle_id']
    }

start = time.time()
data = get_item(glctx, ref_mesh, envlight, train_res, fovy, cam_near_far, cam_radius)
end = time.time()
print(f"Time: {end - start} (s)")

Using /root/.cache/torch_extensions/py39_cu118 as PyTorch extensions root...
No modifications detected for re-loaded extension module renderutils_plugin, skipping build step...
Loading extension module renderutils_plugin...


Time: 0.1502225399017334 (s)


  sample_points = torch.tensor(kaolin.ops.mesh.sample_points(ref_mesh.v_pos.unsqueeze(0), ref_mesh.t_pos_idx, 50000)[0][0])


In [55]:
import csv
import json

from torch.utils.data import Dataset

from utils.render import mesh, texture, render, light, util

In [72]:
class BaseDataset(Dataset):
    """Basic dataset interface"""
    def __init__(self): 
        super().__init__()

    def __len__(self):
        raise NotImplementedError

    def __getitem__(self):
        raise NotImplementedError

    def collate(self, batch):
        iter_res, iter_spp = batch[0]['resolution'], batch[0]['spp']
        res_dict = {
                'mv' : torch.cat(list([item['mv'] for item in batch]), dim=0),
                'mvp' : torch.cat(list([item['mvp'] for item in batch]), dim=0),
                'campos' : torch.cat(list([item['campos'] for item in batch]), dim=0),
                'resolution' : iter_res,
                'spp' : iter_spp,
                'img' : torch.cat(list([item['img'] for item in batch]), dim=0)
            }

        if 'sampled_pts' in batch[0]:
            res_dict['sampled_pts'] = batch[0]['sampled_pts']
        if 'vertex_pts' in batch[0]:
            res_dict['vertex_pts'] = batch[0]['vertex_pts']
        if 'faces' in batch[0]:
            res_dict['faces'] = batch[0]['faces']
        if 'rast_triangle_id' in batch[0]:
            res_dict['rast_triangle_id'] = batch[0]['rast_triangle_id']
        
        if 'depth' in batch[0]:
            res_dict['depth'] = torch.cat(list([item['depth'] for item in batch]), dim=0)
        if 'normal' in batch[0]:
            res_dict['normal'] = torch.cat(list([item['normal'] for item in batch]), dim=0)
        if 'geo_normal' in batch[0]:
            res_dict['geo_normal'] = torch.cat(list([item['geo_normal'] for item in batch]), dim=0)
        if 'geo_viewdir' in batch[0]:
            res_dict['geo_viewdir'] = torch.cat(list([item['geo_viewdir'] for item in batch]), dim=0)
        if 'pos' in batch[0]:
            res_dict['pos'] = torch.cat(list([item['pos'] for item in batch]), dim=0)
        if 'mask' in batch[0]:
            res_dict['mask'] = torch.cat(list([item['mask'] for item in batch]), dim=0)
        if 'mask_cont' in batch[0]:
            res_dict['mask_cont'] = torch.cat(list([item['mask_cont'] for item in batch]), dim=0)
        if 'envlight_transform' in batch[0]:
            if batch[0]['envlight_transform'] is not None:
                res_dict['envlight_transform'] = torch.cat(list([item['envlight_transform'] for item in batch]), dim=0)
            else:
                res_dict['envlight_transform'] = None

        try:
            res_dict['depth_second'] = torch.cat(list([item['depth_second'] for item in batch]), dim=0)
        except:
            pass
        try:
            res_dict['normal_second'] = torch.cat(list([item['normal_second'] for item in batch]), dim=0)
        except:
            pass
        try:
            res_dict['img_second'] = torch.cat(list([item['img_second'] for item in batch]), dim=0)
        except:
            pass


        return res_dict


class ShapeNetDataset(BaseDataset):
    def __init__(
        self,
        glctx,
        data_root,
        mesh_data_root,
        num_pts,
        rendering_kwargs,
        categories=["chair", "table"],
        add_camera_cond=False,
        split="train", # ["train", "val", "test"]
        split_type="text_cond" # ["text_cond", "uncond"]
    ):
        assert split in ["train", "val", "test"] and split_type in ["text_cond", "uncond"]
        
        if categories == "all":
            categories = ["airplane", "car", "chair", "table"] if split_type == "uncond" else ["chair", "table"]
        
        self.data_root = data_root
        self.mesh_data_root = mesh_data_root
        self.rendering_kwargs = rendering_kwargs
        self.num_pts = num_pts
        
        split_root = os.path.join(data_root, "splits")
        self.split = split
        self.split_type = split_type
        
        self.glctx = glctx
        self.fovy = np.deg2rad(rendering_kwargs.fovy)
        self.add_camera_cond = add_camera_cond
        self.envlight = light.load_env(os.path.join(data_root, "irrmaps", "aerodynamics_workshop_2k.hdr"), scale=rendering_kwargs.env_scale)
        if rendering_kwargs.mtl_type == "default":
            self.mtl = {
                'name' : '_default_mat',
                'bsdf': 'diffuse',
                'uniform': True,
                'kd'   : texture.Texture2D(torch.tensor([0.75, 0.3, 0.6], dtype=torch.float32, device='cuda'), trainable=False),
                'ks'   : texture.Texture2D(torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32, device='cuda'), trainable=False)
            }
        else:
            raise NotImplementedError

        # Get text info
        self.text_data = None
        if split_type == "text_cond":
            text_path = os.path.join(data_root, "text", "captions.tablechair.csv")
            self.text_data = {}
            with open(text_path, "r") as file:
                text_data = csv.DictReader(file)
                for col in text_data:
                    self.text_data[col["modelId"]] = col["description"]

        # Get meshes info according to the split
        split_model_ids = []
        self.mesh_paths, self.camera_paths, self.img_paths = [], [], []
        self.cat_ids, self.model_ids = [], []
        for cat, cat_id in self.get_category_dict().items():
            if cat in categories:
                with open(os.path.join(data_root, f"shapenet_v1_{cat_id}.json"), "r") as file:
                    mesh_paths_per_cat = json.load(file)
                
                if split_type == "text_cond":
                    with open(os.path.join(split_root, split_type, f"{cat_id}_{split}.lst"), "r") as file:
                        split_model_ids = [line.strip() for line in file]
                elif split_type == "uncond":
                    with open(os.path.join(split_root, split_type, f"shapenet_{cat}", f"{split}.txt"), "r") as file:
                        split_model_ids = [line.strip() for line in file]

                for path in mesh_paths_per_cat:
                    model_id = path.split("/")[-2]
                    if split_type == "text_cond":
                        if model_id in split_model_ids and model_id in self.text_data:
                            self.mesh_paths.append(os.path.join(mesh_data_root, path))
                            # self.camera_paths.append(os.path.join(data_root, "rendered_shapenet_v1", "camera", cat_id, model_id)) # TODO: check
                            # self.img_paths.append(os.path.join(data_root, "rendered_shapenet_v1", "img", cat_id, model_id)) # TODO: check
                            self.cat_ids.append(cat_id)
                            self.model_ids.append(model_id)
                            
                    elif split_type == "uncond":
                        if model_id in split_model_ids:
                            self.mesh_paths.append(os.path.join(mesh_data_root, path))
                            # self.camera_paths.append(os.path.join(data_root, "rendered_shapenet_v1", "camera", cat_id, model_id)) # TODO: check
                            # self.img_paths.append(os.path.join(data_root, "rendered_shapenet_v1", "img", cat_id, model_id)) # TODO: check
                            self.cat_ids.append(cat_id)
                            self.model_ids.append(model_id)


    def get_category_dict(self):
        # should be updated if additional categories used
        return {
            "chair": "03001627", # 0
            "table": "04379243", # 1
            "airplane": "02691156", # 3
            "car": "02958343" # 4
        }


    def normalize_point_clouds(self, pcs: torch.Tensor, return_shift_scale=True, padding=0.0):
        # refactored version for batched processing
        pcs = pcs.clone()
        pc_max = pcs[..., :3].amax(dim=1, keepdim=True)  # (B, 1, 3)
        pc_min = pcs[..., :3].amin(dim=1, keepdim=True)
        shift = (pc_min + pc_max) / 2
        scale = 2 / (pc_max - pc_min).amax(dim=-1, keepdim=True) * (1 - padding)
        pcs[..., :3] = (pcs[..., :3] - shift) * scale

        if return_shift_scale:
            return pcs, shift, scale
        else:
            return pcs
    
    
    def denormalize_point_clouds(pcs: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor):
        pcs = pcs.clone()
        pcs[..., :3] = pcs[..., :3] / scale + shift
        return pcs
    
    
    def _rotate_scene(self, idx):
        proj_mtx = util.perspective(self.fovy, self.rendering_kwargs.render_res[1] / self.rendering_kwargs.render_res[0],
                                    self.rendering_kwargs.cam_near_far[0], self.rendering_kwargs.cam_near_far[1])

        # Smooth rotation for display.
        ang    = (idx / 50) * np.pi * 2
        mv     = util.translate(0, 0, -self.rendering_kwargs.cam_radius) @ (util.rotate_x(-0.4) @ util.rotate_y(ang))
        mvp    = proj_mtx @ mv
        campos = torch.linalg.inv(mv)[:3, 3]

        return mv[None, ...].cuda(), mvp[None, ...].cuda(), campos[None, ...].cuda()

    def _random_scene(self):
        # Setup projection matrix
        proj_mtx = util.perspective(self.fovy, self.rendering_kwargs.render_res[1] / self.rendering_kwargs.render_res[0],
                                    self.rendering_kwargs.cam_near_far[0], self.rendering_kwargs.cam_near_far[1])

        # Random camera & light position
        # Random rotation/translation matrix for optimization.
        mv     = util.translate(0, 0, -self.rendering_kwargs.cam_radius) @ util.random_rotation_translation(0.2)
        mvp    = proj_mtx @ mv
        campos = torch.linalg.inv(mv)[:3, 3]

        return mv[None, ...].cuda(), mvp[None, ...].cuda(), campos[None, ...].cuda() # Add batch dimension


    def __len__(self):
        return len(self.mesh_paths)
    
    
    def __getitem__(self, idx):
        self.ref_mesh = mesh.load_mesh(self.mesh_paths[idx], None, self.mtl,
                                       use_default=self.rendering_kwargs.normal_only, no_additional=True)
        self.ref_mesh = mesh.center_by_reference(self.ref_mesh, mesh.aabb_clean(self.ref_mesh), 1.)
        self.ref_mesh = mesh.auto_normals(self.ref_mesh)
        
        resolution = self.rendering_kwargs.render_res
        spp = self.rendering_kwargs.spp
        if self.split == "val" or "test":
            mv, mvp, campos = self._rotate_scene(idx)
            camera_mv = None
        else:
            mv, mvp, campos = self._random_scene()
            if self.rendering_kwargs.light_type == "random":
                rnd_rot = util.random_rotation()
                camera_mv = rnd_rot.unsqueeze(0).clone()
            elif self.rendering_kwargs.light_type == "camera":
                camera_mv = mv.clone()
            else:
                raise NotImplementedError() # camera_mv = None
        
        with torch.no_grad():
            render_out = render.render_mesh(self.glctx, self.ref_mesh, mvp, campos, self.envlight, resolution, spp,
                                            num_layers=self.rendering_kwargs.layers, msaa=True, background=None,
                                            xfm_lgt=camera_mv, flat_shading=self.rendering_kwargs.flat_shading)
            
            sampled_pts = kaolin.ops.mesh.sample_points(self.ref_mesh.v_pos.unsqueeze(0), self.ref_mesh.t_pos_idx, self.num_pts)[0][0]
        
        # TODO: remove unused variables
        return_dict = {
            "mv": mv,
            "mvp": mvp,
            "campos": campos,
            "resolution": resolution,
            "spp": spp,
            "envlight_transform": camera_mv,
            "img": render_out["shaded"],
            "img_second": render_out["shaded_second"],
            "depth": render_out["depth"],
            "normal": render_out["normal"],
            "geo_normal": render_out["geo_normal"],
            "geo_viewdir": render_out["geo_viewdir"],
            "pos": render_out["pos"],
            "mask": render_out["mask"],
            "mask_cont": render_out["mask_cont"],
            "rast_triangle_id": render_out["rast_triangle_id"],
            "sampled_pts": sampled_pts,
            "vertex_pts": self.ref_mesh.v_pos,
            "faces": self.ref_mesh.t_pos_idx,
            "cat_id": self.cat_ids[idx],
            "model_id": self.model_ids[idx]
        }
        
        if render_out["depth_second"] is not None:
            return_dict["depth_second"] = render_out["depth_second"]
        
        if render_out["normal_second"] is not None:
            return_dict["normal_second"] = render_out["normal_second"]
        
        return return_dict

In [73]:
!ls /root/dev/DMTet_Models/MeshDiffusion/data

02691156  02958343  03001627  04090263	04379243


In [74]:
!ls /root/dev/DMTet_Models/Meshdiffusion/data/03001627/

ls: cannot access '/root/dev/DMTet_Models/Meshdiffusion/data/03001627/': No such file or directory


In [75]:
!ls /root/dev/triplora/datasets/data

irrmaps			   shapenet_v1_03790512.json  tets
shapenet_v1_02691156.json  shapenet_v1_04090263.json  text
shapenet_v1_02958343.json  shapenet_v1_04379243.json
shapenet_v1_03001627.json  splits


In [85]:
from easydict import EasyDict

data_root = "/root/dev/triplora/datasets/data"
mesh_data_root = "/root/dev/DMTet_Models/MeshDiffusion/data"
num_pts = 20480
rendering_kwargs = EasyDict({
    "env_scale": 1.0,
    "fovy": 45,
    "cam_radius": 2.0,
    "mtl_type": "default",
    "spp": 1,
    "render_res": [1000, 1000],
    "cam_near_far": [0.1, 1000.0],
    "normal_only": True,
    "flat_shading": False,
    "light_type": "random",
    "layers": 1
})

ds_train = ShapeNetDataset(glctx, data_root, mesh_data_root, num_pts, rendering_kwargs,
                           categories=["chair"], split="train", split_type="uncond")
ds_val = ShapeNetDataset(glctx, data_root, mesh_data_root, num_pts, rendering_kwargs,
                        categories=["chair"], split="val", split_type="uncond")
ds_test = ShapeNetDataset(glctx, data_root, mesh_data_root, num_pts, rendering_kwargs,
                          categories=["chair"], split="val", split_type="uncond")

ds_train = ShapeNetDataset(glctx, data_root, mesh_data_root, num_pts, rendering_kwargs,
                           categories=["chair"], split="train", split_type="text_cond")
ds_test = ShapeNetDataset(glctx, data_root, mesh_data_root, num_pts, rendering_kwargs,
                          categories=["chair"], split="test", split_type="text_cond")

print(len(ds_train) + len(ds_test))

6577


In [77]:
from typing import Sequence, Union

import torch.distributed as dist
from torch.utils.data import DataLoader, DistributedSampler

def build_dataloaders(
    batch_size: int,
    num_workers: int,
    ds: Dataset,
    ddp=None,
    shuffle=False,
    pin_memory=None,
    persistent_workers=None,
    **kwargs,
) -> Union[DataLoader, Sequence[DataLoader]]:
    if pin_memory is None:
        pin_memory = torch.cuda.is_available()
    if persistent_workers is None:
        persistent_workers = num_workers > 0
    dl_kwargs = dict(
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
        persistent_workers=persistent_workers,
    )
    dl_kwargs.update(kwargs)

    if ddp is None:
        ddp = dist.is_initialized() and dist.get_world_size() > 1

    if ddp:
        sampler = DistributedSampler(ds, shuffle=shuffle)
        dl = DataLoader(ds, sampler=sampler, **dl_kwargs)
    else:
        dl = DataLoader(ds, shuffle=shuffle, **dl_kwargs)

    return dl

In [78]:
dl_train = DataLoader(ds_train, batch_size=8, num_workers=0, collate_fn=ds_train.collate, shuffle=True)

In [79]:
data = next(iter(dl_train))

In [80]:
data.keys()

dict_keys(['mv', 'mvp', 'campos', 'resolution', 'spp', 'img', 'sampled_pts', 'vertex_pts', 'faces', 'rast_triangle_id', 'depth', 'normal', 'geo_normal', 'geo_viewdir', 'pos', 'mask', 'mask_cont', 'envlight_transform', 'depth_second', 'normal_second', 'img_second'])

In [81]:
print(data["img"].shape)
print(data["img_second"].shape)
print(data["depth"].shape)
print(data["depth_second"].shape)
print(data["sampled_pts"].shape)
print(data["vertex_pts"].shape)
print(data["faces"].shape)

torch.Size([8, 1000, 1000, 4])
torch.Size([8, 1000, 1000, 4])
torch.Size([8, 1000, 1000, 2])
torch.Size([8, 1000, 1000, 2])
torch.Size([20480, 3])
torch.Size([324, 3])
torch.Size([1258, 3])


## Data loading of MeshDiffusion and GET3D

In [112]:
# import trimesh

# scene = trimesh.Scene()
# scene.add_geometry(trimesh.load_mesh(mesh_path))
# scene.show()

In [None]:
from MeshDiffusion.nvdiffrec.lib.render import render

with torch.no_grad():
    render_out = render.render_mesh(glctx, ref_mesh, )

# DMTet related classes

In [37]:
# dmtet_utils.py!!
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
#
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto.  Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.

import torch


def get_center_boundary_index(verts):
    length_ = torch.sum(verts ** 2, dim=-1)
    center_idx = torch.argmin(length_)
    boundary_neg = verts == verts.max()
    boundary_pos = verts == verts.min()
    boundary = torch.bitwise_or(boundary_pos, boundary_neg)
    boundary = torch.sum(boundary.float(), dim=-1)
    boundary_idx = torch.nonzero(boundary)
    return center_idx, boundary_idx.squeeze(dim=-1)

In [14]:
!ls /root/dev/DMTet_Models/MeshDiffusion/nvdiffrec/data_/tets

128_tets_cropped.npz  README.md     generate_tets.py
64_tets_cropped.npz   crop_tets.py


In [15]:
!ls /root/dev/DMTet_Models/GET3D/tets

100_compress.npz  70_compress.npz  90_compress.npz
64_compress.npz   80_compress.npz


In [16]:
import numpy as np

In [21]:
tets_meshdiffusion = np.load("./MeshDiffusion/nvdiffrec/data_/tets/64_tets_cropped.npz")
tets_get3d = np.load("./GET3D/tets/64_compress.npz")

print(tets_meshdiffusion["vertices"].shape)
print(tets_meshdiffusion["indices"].shape)
print()
print(tets_get3d["vertices"].shape)
print(tets_get3d["tets"].shape)

(30512, 3)
(159330, 4)

(36562, 3)
(192492, 4)


In [35]:
import torch
from torch import nn

class DMTet:
    def __init__(self, device="cuda"):
        self.triangle_table = torch.tensor([
                [-1, -1, -1, -1, -1, -1],
                [ 1,  0,  2, -1, -1, -1],
                [ 4,  0,  3, -1, -1, -1],
                [ 1,  4,  2,  1,  3,  4],
                [ 3,  1,  5, -1, -1, -1],
                [ 2,  3,  0,  2,  5,  3],
                [ 1,  4,  0,  1,  5,  4],
                [ 4,  2,  5, -1, -1, -1],
                [ 4,  5,  2, -1, -1, -1],
                [ 4,  1,  0,  4,  5,  1],
                [ 3,  2,  0,  3,  5,  2],
                [ 1,  3,  5, -1, -1, -1],
                [ 4,  1,  2,  4,  3,  1],
                [ 3,  0,  4, -1, -1, -1],
                [ 2,  0,  1, -1, -1, -1],
                [-1, -1, -1, -1, -1, -1]
                ], dtype=torch.long, device=device)

        self.num_triangles_table = torch.tensor([0,1,1,2,1,2,2,1,1,2,2,1,2,1,1,0], dtype=torch.long, device=device)
        self.base_tet_edges = torch.tensor([0,1,0,2,0,3,1,2,1,3,2,3], dtype=torch.long, device=device)
        self.v_id = torch.pow(2, torch.arange(4, dtype=torch.long, device=device))

        self.tet_table = torch.tensor(
            [[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
             [0, 4, 5, 6, -1, -1, -1, -1, -1, -1, -1, -1],
             [1, 4, 7, 8, -1, -1, -1, -1, -1, -1, -1, -1],
             [1, 0, 8, 7, 0, 5, 8, 7, 0, 5, 6, 8],
             [2, 5, 7, 9, -1, -1, -1, -1, -1, -1, -1, -1],
             [2, 0, 9, 7, 0, 4, 9, 7, 0, 4, 6, 9],
             [2, 1, 9, 5, 1, 4, 9, 5, 1, 4, 8, 9],
             [6, 0, 1, 2, 6, 1, 2, 8, 6, 8, 2, 9],
             [3, 6, 8, 9, -1, -1, -1, -1, -1, -1, -1, -1],
             [3, 0, 9, 8, 0, 4, 9, 8, 0, 4, 5, 9],
             [3, 1, 9, 6, 1, 4, 9, 6, 1, 4, 7, 9],
             [5, 0, 1, 3, 5, 1, 3, 7, 5, 7, 3, 9],
             [3, 2, 8, 6, 2, 5, 8, 6, 2, 5, 7, 8],
             [4, 0, 2, 3, 4, 2, 3, 7, 4, 7, 3, 8],
             [4, 1, 2, 3, 4, 2, 3, 5, 4, 5, 3, 6],
             [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1]], dtype=torch.long, device=device)
        self.num_tets_table = torch.tensor([0, 1, 1, 3, 1, 3, 3, 3, 1, 3, 3, 3, 3, 3, 3, 0], dtype=torch.long, device=device)

    ###############################################################################
    # Utility functions # NOTE → needed?
    ###############################################################################

    def sort_edges(self, edges_ex2):
        with torch.no_grad():
            order = (edges_ex2[:,0] > edges_ex2[:,1]).long()
            order = order.unsqueeze(dim=1)

            a = torch.gather(input=edges_ex2, index=order, dim=1)      
            b = torch.gather(input=edges_ex2, index=1-order, dim=1)  

        return torch.stack([a, b],-1)

    def map_uv(self, faces, face_gidx, max_idx):
        N = int(np.ceil(np.sqrt((max_idx+1)//2)))
        tex_y, tex_x = torch.meshgrid(
            torch.linspace(0, 1 - (1 / N), N, dtype=torch.float32, device=self.device),
            torch.linspace(0, 1 - (1 / N), N, dtype=torch.float32, device=self.device),
            indexing='ij'
        )

        pad = 0.9 / N

        uvs = torch.stack([
            tex_x      , tex_y,
            tex_x + pad, tex_y,
            tex_x + pad, tex_y + pad,
            tex_x      , tex_y + pad
        ], dim=-1).view(-1, 2)

        def _idx(tet_idx, N):
            x = tet_idx % N
            y = torch.div(tet_idx, N, rounding_mode='trunc')
            return y * N + x

        tet_idx = _idx(torch.div(face_gidx, 2, rounding_mode='trunc'), N)
        tri_idx = face_gidx % 2

        uv_idx = torch.stack((
            tet_idx * 4, tet_idx * 4 + tri_idx + 1, tet_idx * 4 + tri_idx + 2
        ), dim = -1). view(-1, 3)

        return uvs, uv_idx

    ###############################################################################
    # Marching tets implementation
    ###############################################################################

    def __call__(self, pos_nx3, sdf_n, tet_fx4, return_tet_mesh=False, ori_v=None):
        with torch.no_grad():
            occ_n = sdf_n > 0
            occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1, 4)
            occ_sum = torch.sum(occ_fx4, -1)
            valid_tets = (occ_sum > 0) & (occ_sum < 4)
            occ_sum = occ_sum[valid_tets]

            # find all vertices
            all_edges = tet_fx4[valid_tets][:, self.base_tet_edges].reshape(-1, 2)
            all_edges = self.sort_edges(all_edges)
            unique_edges, idx_map = torch.unique(all_edges, dim=0, return_inverse=True)  
            
            unique_edges = unique_edges.long()
            mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1
            mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=self.device) * -1
            mapping[mask_edges] = torch.arange(mask_edges.sum(), dtype=torch.long, device=self.device)
            idx_map = mapping[idx_map] # map edges to verts

            interp_v = unique_edges[mask_edges]
        edges_to_interp = pos_nx3[interp_v.reshape(-1)].reshape(-1, 2, 3)
        edges_to_interp_sdf = sdf_n[interp_v.reshape(-1)].reshape(-1, 2, 1)
        edges_to_interp_sdf[:, -1] *= -1

        denominator = edges_to_interp_sdf.sum(1, keepdim=True)

        edges_to_interp_sdf = torch.flip(edges_to_interp_sdf, [1]) / denominator
        verts = (edges_to_interp * edges_to_interp_sdf).sum(1)

        idx_map = idx_map.reshape(-1, 6)

        tetindex = (occ_fx4[valid_tets] * self.v_id.unsqueeze(0)).sum(-1)
        num_triangles = self.num_triangles_table[tetindex]

        # Generate triangle indices
        faces = torch.cat((
            torch.gather(input=idx_map[num_triangles==1], dim=1, index=self.triangle_table[tetindex[num_triangles==1]][:, :3]).reshape(-1, 3),
            torch.gather(input=idx_map[num_triangles==2], dim=1, index=self.triangle_table[tetindex[num_triangles==2]][:, :6]).reshape(-1, 3),
        ), dim=0)

        # # Get global face index (static, does not depend on topology)
        # num_tets = tet_fx4.shape[0]
        # tet_gidx = torch.arange(num_tets, dtype=torch.long, device=self.device)[valid_tets]
        # face_gidx = torch.cat((
        #     tet_gidx[num_triangles == 1]*2,
        #     torch.stack((tet_gidx[num_triangles == 2]*2, tet_gidx[num_triangles == 2]*2 + 1), dim=-1).view(-1)
        # ), dim=0)

        # uvs, uv_idx = self.map_uv(faces, face_gidx, num_tets*2)

        # face_to_valid_tet = torch.cat((
        #     tet_gidx[num_triangles == 1],
        #     torch.stack((tet_gidx[num_triangles == 2], tet_gidx[num_triangles == 2]), dim=-1).view(-1)
        # ), dim=0)

        # valid_vert_idx = tet_fx4[tet_gidx[num_triangles > 0]].long().unique()
        
        if not return_tet_mesh:
            return verts, faces # , uvs, uv_idx, face_to_valid_tet.long(), valid_vert_idx

        occupied_verts = ori_v[occ_n]
        mapping = torch.ones(pos_nx3.shape[0], dtype=torch.long, device=self.device) * -1
        mapping[occ_n] = torch.arange(occupied_verts.shape[0], device=self.device)
        tet_fx4 = mapping[tet_fx4.reshape(-1)].reshape(-1, 4)

        idx_map = torch.cat([tet_fx4[valid_tets] + verts.shape[0], idx_map], -1) # t x 10
        tet_verts = torch.cat([verts, occupied_verts], 0)
        num_tets = self.num_tets_table[tetindex]

        tets = torch.cat((
            torch.gather(input=idx_map[num_tets==1], dim=1, index=self.tet_table[tetindex[num_tets==1]][:, :4]).reshape(-1, 4),
            torch.gather(input=idx_map[num_tets==3], dim=1, index=self.tet_table[tetindex[num_tets==3]][:, :12]).reshape(-1, 4)
        ), dim=0)

        # Add fully occupied tets
        fully_occupied = occ_fx4.sum(-1) == 4
        tet_fully_occupied = tet_fx4[fully_occupied] + verts.shape[0]
        tets = torch.cat([tets, tet_fully_occupied])

        return verts, faces, tet_verts, tets


class DMTetGeometry(nn.Module):
    def __init__(
        self,
        grid_res,
        scale,
        data_root = "./",
        renderer = None,
        render_type = "neural_render",
        device = "cuda",
        **kwargs
    ):
        super().__init__()

        self.device = device
        self.grid_res = grid_res
        self.dmtet = DMTet()

        init_tets = np.load(os.path.join(data_root, f"tets/{grid_res}_tets_cropped.npz"))
        self.verts = torch.tensor(init_tets["vertices"], dtype=torch.float32, device=device)
        length = self.verts.max(dim=0)[0] - self.verts.min(dim=0)[0]
        
        if isinstance(scale, list):
            # TODO: check (GET3D: scale[1])
            self.verts[:, 0] = self.verts[:, 0] * scale[0]
            self.verts[:, 1] = self.verts[:, 1] * scale[1]
            self.verts[:, 2] = self.verts[:, 2] * scale[2]
        else:
            self.verts = self.verts * scale
        self.indices = torch.tensor(init_tets["indices"], dtype=torch.long, device=device)

        # Generate edges
        edges = torch.tensor([0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3], dtype=torch.long, device=device)
        all_edges = self.indices[:, edges].reshape(-1, 2)
        all_edges_sorted = torch.sort(all_edges, dim=1)[0]
        self.all_edges = torch.unique(all_edges_sorted, dim=0)
        
        # Random init
        sdf = torch.rand_like(self.verts[:, 0]).clamp(-1., 1.) - 0.1
        self.sdf = torch.nn.Parameter(sdf.clone().detach(), requires_grad=True)
        self.register_parameter("sdf", self.sdf)

        # Parameters used for fix boundary sdf
        self.center_indices, self.boundary_indices = get_center_boundary_index(self.verts)
        self.renderer = renderer
        self.render_type = render_type


    def getAABB(self):
        return torch.min(self.verts, dim=0).values, torch.max(self.verts, dim=0).values


    def get_mesh(self, v_deformed, sdf, with_uv=False, indices=None):
        if indices is None:
            indices = self.indices

        verts, faces = self.dmtet(v_deformed, sdf, indices)
        faces = torch.cat([
            faces[:, 0:1], faces[:, 2:3], faces[:, 1:2]
        ], dim=-1)
        return verts, faces

    def get_tet_mesh(self, v_deformed, sdf, with_uv=False, indices=None):
        if indices is None:
            indices = self.indices

        verts, faces, tet_verts, tets = self.dmtet(v_deformed, sdf, indices, return_tet_mesh=True, ori_v=v_deformed)
        faces = torch.cat([
            faces[:, 0:1], faces[:, 2:3], faces[:, 1:2]
        ], dim=-1)
        return verts, faces, tet_verts, tets


    def render_mesh(self, mesh_verts, mesh_faces, camera_mv, resolution=256, hierarchical_mask=False):
        """ Add description """
        out_dict = dict()

        if self.render_type == "neural_render":
            tex_pos, mask, hard_mask, rast, v_pos_clip, mask_pyramid, depth = self.renderer.render_mesh(
                mesh_verts.unsqueeze(0),
                mesh_faces.int(),
                camera_mv,
                mesh_verts.unsqueeze(0),
                resolution=resolution,
                device=self.device,
                hierarchical_mask=hierarchical_mask
            )

            out_dict["tex_pos"] = tex_pos
            out_dict["mask"] = mask
            out_dict["hard_mask"] = hard_mask
            out_dict["rast"] = rast
            out_dict["v_pos_clip"] = v_pos_clip
            out_dict["mask_pyramid"] = mask_pyramid
            out_dict["depth"] = depth
        else:
            raise NotImplementedError
        
        return out_dict

    def render(self, v_deformed=None, sdf=None, camera_mv=None, resolution=256):
        """ Add description 
        Here I assume a batch of meshes (can be different mesh and geometry), for the other shapes, the batch is 1. """
        verts_list, faces_list, all_render_output, batch_size = [], [], [], v_deformed.shape[0]
        for i in range(batch_size):
            verts, faces = self.get_mesh(v_deformed[i], sdf[i]) # (N,3), (F,3)
            verts_list.append(verts)
            faces_list.append(faces)
            render_output = self.render_mesh(verts, faces, camera_mv[i], resolution)
            all_render_output.append(render_output)

        # Concat all render outputs
        out_keys = all_render_output[0].keys()
        out_dict = dict()
        for k in out_keys:
            value = [v[k] for v in all_render_output]
            out_dict[k] = value
            # We can do concatenation outside of the render
        return out_dict

# TODO: add useful functions from meshdiffusion
# NOTE: Consider deform_scale!!!!!!

# Applying LoRA

In [12]:
# !export HF_HOME=~/.cache/huggingface 

In [4]:
import peft
import transformers
import diffusers

In [5]:
def model_params(model):
    model_size = 0
    for param in model.parameters():
        model_size += param.data.nelement()
    return model_size

In [6]:
from diffusers import UNet2DConditionModel

# unet = UNet2DConditionModel.from_pretrained(
#     "stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet"
# )
unet = UNet2DConditionModel.from_pretrained(
    "stabilityai/stable-diffusion-2-base",
    subfolder="unet"
    # cache_dir="/root/dataset_sj/.cache/huggingface/models/cuda/stable-diffusion-2-base/unet"
)

In [7]:
model_params(unet) / 10e6

86.5910724

In [8]:
from diffusers.models import AutoencoderKL

vae = AutoencoderKL.from_pretrained(
    "stabilityai/stable-diffusion-2-base",
    subfolder="vae"
).to(device)

In [9]:
model_params(vae) / 10e6

8.3653863

In [10]:
vars(vae.encoder).keys()

dict_keys(['training', '_parameters', '_buffers', '_non_persistent_buffers_set', '_backward_pre_hooks', '_backward_hooks', '_is_full_backward_hook', '_forward_hooks', '_forward_hooks_with_kwargs', '_forward_pre_hooks', '_forward_pre_hooks_with_kwargs', '_state_dict_hooks', '_state_dict_pre_hooks', '_load_state_dict_pre_hooks', '_load_state_dict_post_hooks', '_modules', 'layers_per_block', 'gradient_checkpointing'])

In [11]:
vae.encoder(torch.rand((8, 3, 256, 256)).to(device)).shape

torch.Size([8, 8, 32, 32])

In [57]:
vae.decoder(torch.rand((8, 4, 64, 64))).shape

torch.Size([8, 3, 512, 512])

## AutoencoderKL

In [None]:
from models.TripLoRA.triplane_ae import AutoencoderTriplane



## UnetTriplaneModel

In [25]:
from models.TripLoRA.triplane.triplane_unet import UNetTriplaneModel

model = UNetTriplaneModel(
    triplane_res=256,
    in_channels=32,
    out_channels=2*4,
    down_block_types=("DownBlockTriplane", "ResnetDownsampleBlockTriplane"),
    up_block_types=("UpBlockTriplane", "ResnetUpsampleBlockTriplane"),
    block_out_channels=(32, 64),
    layers_per_block=1
)

In [26]:
model_params(model) / 10e6

0.5117784

In [27]:
model_state_dict = model.state_dict()
expected_keys = list(model_state_dict.keys())
print(len(expected_keys))

252


In [28]:
pretrained_state_dict = unet.state_dict()
loaded_keys = list(pretrained_state_dict.keys())
print(len(loaded_keys))

686


In [29]:
expected_keys[:20]

['conv_in.conv_xy.weight',
 'conv_in.conv_xy.bias',
 'conv_in.conv_xz.weight',
 'conv_in.conv_xz.bias',
 'conv_in.conv_yz.weight',
 'conv_in.conv_yz.bias',
 'down_blocks.0.resnets.0.norm1.norm_xy.weight',
 'down_blocks.0.resnets.0.norm1.norm_xy.bias',
 'down_blocks.0.resnets.0.norm1.norm_xz.weight',
 'down_blocks.0.resnets.0.norm1.norm_xz.bias',
 'down_blocks.0.resnets.0.norm1.norm_yz.weight',
 'down_blocks.0.resnets.0.norm1.norm_yz.bias',
 'down_blocks.0.resnets.0.conv1.conv_xy.weight',
 'down_blocks.0.resnets.0.conv1.conv_xy.bias',
 'down_blocks.0.resnets.0.conv1.conv_xz.weight',
 'down_blocks.0.resnets.0.conv1.conv_xz.bias',
 'down_blocks.0.resnets.0.conv1.conv_yz.weight',
 'down_blocks.0.resnets.0.conv1.conv_yz.bias',
 'down_blocks.0.resnets.0.norm2.norm_xy.weight',
 'down_blocks.0.resnets.0.norm2.norm_xy.bias']

In [30]:
loaded_keys[57:80]

['down_blocks.0.attentions.1.proj_out.bias',
 'down_blocks.0.resnets.0.norm1.weight',
 'down_blocks.0.resnets.0.norm1.bias',
 'down_blocks.0.resnets.0.conv1.weight',
 'down_blocks.0.resnets.0.conv1.bias',
 'down_blocks.0.resnets.0.time_emb_proj.weight',
 'down_blocks.0.resnets.0.time_emb_proj.bias',
 'down_blocks.0.resnets.0.norm2.weight',
 'down_blocks.0.resnets.0.norm2.bias',
 'down_blocks.0.resnets.0.conv2.weight',
 'down_blocks.0.resnets.0.conv2.bias',
 'down_blocks.0.resnets.1.norm1.weight',
 'down_blocks.0.resnets.1.norm1.bias',
 'down_blocks.0.resnets.1.conv1.weight',
 'down_blocks.0.resnets.1.conv1.bias',
 'down_blocks.0.resnets.1.time_emb_proj.weight',
 'down_blocks.0.resnets.1.time_emb_proj.bias',
 'down_blocks.0.resnets.1.norm2.weight',
 'down_blocks.0.resnets.1.norm2.bias',
 'down_blocks.0.resnets.1.conv2.weight',
 'down_blocks.0.resnets.1.conv2.bias',
 'down_blocks.0.downsamplers.0.conv.weight',
 'down_blocks.0.downsamplers.0.conv.bias']

In [31]:
cnt = 0
new_state_dict = {}
for key in model_state_dict.keys():
    # Transform the key to match the pre-trained model format
    # ex. down_blocks.0.resnets.0.norm1.norm_xy.weight -> down_blocks.0.resnets.0.norm1.weight
    mapped_key = key.replace(".norm_xy.", ".").replace(".norm_xz.", ".").replace(".norm_yz.", ".")
    mapped_key = mapped_key.replace(".conv_xy.", ".").replace(".conv_xz.", ".").replace(".conv_yz.", ".")
    
    # Check if this transformed key exists in the pre-trained model's keys
    if mapped_key in pretrained_state_dict:
        # Map the pre-trained weights to the new model key
        new_state_dict[key] = pretrained_state_dict[mapped_key]
        print("loaded pretrained weights: ", key)
        cnt += 1
    else:
        new_state_dict[key] = model_state_dict[key]
        print("not loaded: ", key)

loaded pretrained weights:  conv_in.conv_xy.weight
loaded pretrained weights:  conv_in.conv_xy.bias
loaded pretrained weights:  conv_in.conv_xz.weight
loaded pretrained weights:  conv_in.conv_xz.bias
loaded pretrained weights:  conv_in.conv_yz.weight
loaded pretrained weights:  conv_in.conv_yz.bias
loaded pretrained weights:  down_blocks.0.resnets.0.norm1.norm_xy.weight
loaded pretrained weights:  down_blocks.0.resnets.0.norm1.norm_xy.bias
loaded pretrained weights:  down_blocks.0.resnets.0.norm1.norm_xz.weight
loaded pretrained weights:  down_blocks.0.resnets.0.norm1.norm_xz.bias
loaded pretrained weights:  down_blocks.0.resnets.0.norm1.norm_yz.weight
loaded pretrained weights:  down_blocks.0.resnets.0.norm1.norm_yz.bias
loaded pretrained weights:  down_blocks.0.resnets.0.conv1.conv_xy.weight
loaded pretrained weights:  down_blocks.0.resnets.0.conv1.conv_xy.bias
loaded pretrained weights:  down_blocks.0.resnets.0.conv1.conv_xz.weight
loaded pretrained weights:  down_blocks.0.resnets.0

In [32]:
cnt

252

- all params is loaded from the pretrained model

In [33]:
trainable_params = 0
all_params = 0
for _, param in model.named_parameters():
    num_params = param.numel()
    all_params += num_params
    if param.requires_grad:
        trainable_params += num_params
print(f"trainable params: {trainable_params:,d} || all params: {all_params:,d} || trainable%: {100 * trainable_params / all_params}")
model.train();

trainable params: 5,117,784 || all params: 5,117,784 || trainable%: 100.0


In [34]:
from peft import LoraConfig, get_peft_model

model = model.to(device)
model.requires_grad_(False)

config = LoraConfig(
    r=4,
    lora_alpha=4,
    lora_dropout=0.1,
    target_modules=["conv_xy", "conv_yz", "conv_xz"],
    init_lora_weights="gaussian",
    bias="none"
) # scale = alpha / r

model.add_adapter(config)
# model_lora = get_peft_model(model, config)
lora_layers = filter(lambda p: p.requires_grad, model.parameters())

In [35]:
trainable_params = 0
all_params = 0
for _, param in model.named_parameters():
    num_params = param.numel()
    all_params += num_params
    if param.requires_grad:
        trainable_params += num_params
print(f"trainable params: {trainable_params:,d} || all params: {all_params:,d} || trainable%: {100 * trainable_params / all_params}")
model.train();

trainable params: 402,912 || all params: 5,520,696 || trainable%: 7.298210225667199


In [36]:
model.disable_adapters()
trainable_params = 0
all_params = 0
for _, param in model.named_parameters():
    num_params = param.numel()
    all_params += num_params
    if param.requires_grad:
        trainable_params += num_params
print(f"trainable params: {trainable_params:,d} || all params: {all_params:,d} || trainable%: {100 * trainable_params / all_params}")
model.train();

trainable params: 0 || all params: 5,520,696 || trainable%: 0.0


In [37]:
model.enable_adapters()
trainable_params = 0
all_params = 0
for _, param in model.named_parameters():
    num_params = param.numel()
    all_params += num_params
    if param.requires_grad:
        trainable_params += num_params
print(f"trainable params: {trainable_params:,d} || all params: {all_params:,d} || trainable%: {100 * trainable_params / all_params}")
model.train();

trainable params: 402,912 || all params: 5,520,696 || trainable%: 7.298210225667199
