Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

open mesh extension: now supports well-formed open meshes #2

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 50 additions & 37 deletions demo.ipynb

Large diffs are not rendered by default.

14 changes: 12 additions & 2 deletions dgts_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,18 @@ def get_random_z(self, num_randoms: int) -> T:
def get_z_by_level(self, base_mesh: MeshHandler, level: int) -> T:
num_faces = len(base_mesh) * 4 ** level
if self.opt.noise_before:
num_faces = (num_faces - len(base_mesh)) // 2 + base_mesh.vs.shape[0]
return self.get_random_z(num_faces)
if base_mesh.num_be == 0:
num_ps = (num_faces - len(base_mesh)) // 2 + base_mesh.vs.shape[0]
else:
# open mesh extension: number of points determined by number of boundary and non-boundary edges
if level == 0:
num_ps = base_mesh.vs.shape[0]
else:
# Euler's adjusted formula, v = ((num_be + num_non_be) + 3 + num_be) / 3
base_mesh.num_be = base_mesh.num_be * 2
base_mesh.num_non_be = base_mesh.num_non_be * 2 + 3 * (len(base_mesh) * 4 ** (level-1))
num_ps = int((base_mesh.num_be + base_mesh.num_non_be + 3 + base_mesh.num_be) / 3)
return self.get_random_z(num_ps)

def get_z_sequence(self, base_mesh: MeshHandler, max_level: int) -> TS:
return [self.get_z_by_level(base_mesh, level) for level in range(max_level + 1)]
Expand Down
11 changes: 9 additions & 2 deletions models/mesh_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class MeshHandler:
_upsamplers: List[Union[mesh_utils.Upsampler, N]] = []

def __init__(self, path_or_mesh: Union[str, T_Mesh], opt: Options, level: int, local_axes: Union[N, TS] = None):
self.num_be, self.num_non_be = 0, 0
self.level = level
self.opt = opt
if type(path_or_mesh) is str:
Expand Down Expand Up @@ -52,7 +53,9 @@ def pad_ds(level: int):
MeshHandler._upsamplers.append(None)

def fill_ds(self, mesh: T_Mesh, level: int):
MeshHandler._mesh_dss[level] = mesh_utils.MeshDS(mesh).to(mesh[0].device)
mesh_dss_ = mesh_utils.MeshDS(mesh)
self.num_be, self.num_non_be = mesh_dss_.num_be, mesh_dss_.num_non_be
MeshHandler._mesh_dss[level] = mesh_dss_.to(mesh[0].device)
MeshHandler._upsamplers[level] = mesh_utils.Upsampler(mesh).to(mesh[0].device)

def update_ds(self, mesh: T_Mesh, level: int):
Expand Down Expand Up @@ -124,7 +127,11 @@ def extract_local_cords() -> T:
origins, local_axes = self.get_local_axes((vs_, self.faces))
else:
origins, local_axes = self.extract_local_axes()
global_cords = vs_[self.ds.face2points] - origins
# open mesh extension: If there is no adjacent face, set global_cords to 0


mask_values = torch.repeat_interleave(self.ds.face2points, 3, dim=1).reshape((self.ds.face2points.shape[0], self.ds.face2points.shape[1], 3))
global_cords = torch.where(mask_values == -1, mask_values+1.0, vs_[self.ds.face2points] - origins)
local_cords = torch.einsum('fsd,fsad->fsa', [global_cords, local_axes])
return local_cords, vs_

Expand Down
2 changes: 1 addition & 1 deletion options.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def parse_cmdline(self):
# gt optimization options
parser.add_argument('--template-start', type=int, default=0, help='')

parser = parser.parse_args().__dict__
parser = parser.parse_known_args()[0].__dict__
args = {key: item for key, item in parser.items() if item is not None}
self.fill_args(**args)

Expand Down
108 changes: 94 additions & 14 deletions process_data/mesh_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os
import pickle
import constants as const

import math

class Upsampler:

Expand Down Expand Up @@ -60,13 +60,15 @@ def __init__(self, mesh: T_Mesh):
if self.MAX_V_DEG == 0:
self.init_v_degree(mesh)
vs, faces = to(mesh, CPU) # seems to work faster on cpu
self.gfmm = torch.zeros(3, faces.shape[0], dtype=torch.int64)
self.vertex2faces = torch.zeros(vs.shape[0], self.MAX_V_DEG, dtype=torch.int64) - 1
self.gfmm = torch.neg(torch.ones(3, faces.shape[0], dtype=torch.int64))
self.vertex2faces = torch.zeros(vs.shape[0], self.MAX_V_DEG, dtype=torch.int64) - 1
# vertex2faces: vertex mapped to values j, where j % 3 is index and j // 3 is triangle index
self.__vertex2faces_flipped = None
self.vs_degree = torch.zeros(vs.shape[0], dtype=torch.int64)
self.build_ds(faces)
self.vertex2faces_ma = (self.vertex2faces > -1).float()
self.face2points = self.get_face2points(faces)


# inplace
def to(self, device: D):
Expand All @@ -87,13 +89,28 @@ def vertex2faces_flipped(self) -> T:

def update_vs_degree(self, face, face_id, zero_one_two):
self.vertex2faces[face, self.vs_degree[face]] = face_id * 3 + zero_one_two
self.vs_degree[face] += 1
self.vs_degree[face] += 1 # [degree_v1, degree_v2, ...]

def build_ds(self, faces):

def insert_edge():
'''
2
edge2faces = edge_key 0 | face_id of adj face 0, face_id of adj face 1 |
edge_key 1 | ... | # edges
... | ... |

edge_0 edge_n
= edge to adj faces of edge

edge2key = {(va, vb): edge_key_0, .... (vx, vy): edge_key_n, ...}
= edge to edge_id dict

# edges
edge2key_cache = [edge key 1st encountered, edge key 2nd encountered, ...] 1
= order in which edges are encountered
'''
nonlocal edges_count

if edge not in edge2key:
edge_key = edges_count
edge2key[edge] = edge_key
Expand All @@ -105,15 +122,38 @@ def insert_edge():
edge2key_cache[face_id * 3 + idx] = edge_key

def insert_face():
nb_faces = edge2faces[edge2key_cache[face_id * 3 + idx]]
nb_face = nb_faces[0] if nb_faces[0] != face_id else nb_faces[1]
self.gfmm[nb_count[face_id], face_id] = nb_face
'''
# faces
self.gfmm = 1st encountered adj face | 1st encountered adj face_key for face_id 0 .... |
2nd encountered adj face | ... .... | 3
3rd encountered adj face | ... .... |
'''
nb_faces = edge2faces[edge2key_cache[face_id * 3 + idx]] # adj faces of curr edge of key (face_id*3 + idx)
nb_face = nb_faces[0] if nb_faces[0] != face_id else nb_faces[1] # adj face of curr face along curr edge
self.gfmm[nb_count[face_id], face_id] = nb_face
nb_count[face_id] += 1

edge2key = dict()
edge2key_cache = torch.zeros(int(faces.shape[0] * 3), dtype=torch.int64)
edges_count = 0
edge2faces = torch.zeros(int(faces.shape[0] * 1.5), 2, dtype=torch.int64)

# open mesh extension: count number of boundary edges and non-boundary edges.
edge2count_dict = {}
num_non_be = 0
for face_id, face in enumerate(faces):
faces_edges = [(face[i].item(), face[(i + 1) % 3].item()) for i in range(3)]
for edge in faces_edges:
hashed = unord_hash(edge[0], edge[1])
if hashed in edge2count_dict:
del edge2count_dict[hashed]
num_non_be += 1
else:
edge2count_dict[unord_hash(edge[0], edge[1])] = 1
num_be = len(edge2count_dict)
self.num_non_be, self.num_be = num_non_be, num_be
# open mesh extension: if edge only has one adj face, set other to -1
edge2faces = torch.neg(torch.ones(self.num_be + self.num_non_be, 2, dtype=torch.int64))

nb_count = torch.zeros(self.gfmm .shape[1], dtype=torch.int64)
zero_one_two = torch.arange(3)
for face_id, face in enumerate(faces):
Expand All @@ -128,14 +168,43 @@ def insert_face():
self.vs_degree = self.vs_degree.float()

def get_face2points(self, faces) -> T:
'''
3
all_inds = face_id 0| face of adj face 0 |
| ... | 3
| face of adj face 2 |,
3
face_id 1| face of adj face 0 |
| ... | 3
| face of adj face 2 |,
....

<------------ # faces ----------->
= all adjacent faces per face

3
edge 1 edge 2 edge 3
cords_indices = face_id 0 | opp point of adj opp point of adj opp point of adj |
| face to edge 1 face to edge 2 face to edge 3 |
face_id 1 | ... ... ... | # faces
... | ...... |
= face's edge to opposite point of face that shares edge, if any (else -1)
'''
cords_indices = torch.zeros(len(self), 3, dtype=torch.int64)
# open mesh extension: If there is < 3 adjacent faces, fill empty face slot in all_inds with [-1, -1, -1]
mask = self.gfmm.t().repeat_interleave(3).reshape((faces.shape[0], faces.shape[1], 3))
all_inds = faces[self.gfmm.t()]
all_inds = torch.where(mask >= 0, all_inds, mask)

for i in range(3):
ma_a = all_inds - faces[:, i][:, None, None]
ma_b = all_inds - faces[:, (i + 1) % 3][:, None, None]
ma = (ma_a * ma_b) == 0
ma_final = (ma.sum(2) == 2)[:, :, None] * (~ma)
cords_indices[:, i] = all_inds[ma_final]
ma_a = all_inds - faces[:, i][:, None, None] # all_inds - vector of ith vertex per face
ma_b = all_inds - faces[:, (i + 1) % 3][:, None, None] # all_inds - vector of (i+1)th vertex per face
ma = (ma_a * ma_b) == 0 # True if either ma_a or ma_b is 0 (vertex in adjacent face matches with ith vertex of face)
ma_final = (ma.sum(2) == 2)[:, :, None] * (~ma) # flip boolean if True > 1 in that row, else change everything to False
# if True in row representing an adj face, that face shares edge i,i+1 with curr face

# open mesh extension: If there is no adj face, set to -1
cords_indices[:, i] = torch.where(ma_final, all_inds, -1).max(2)[0].max(1)[0]
return cords_indices

@staticmethod
Expand Down Expand Up @@ -191,6 +260,17 @@ def to(self, device: D):
self.edges_ind = self.edges_ind.to(device)
return self

def unord_hash(a, b):
# returns a unique hash value for each unordered (a, b) pair
if a == 0 or b == 0:
return (a + b) * -1
elif a < b:
return a * (b - 1) + math.trunc(math.pow(b - a - 2, 2)/ 4)
elif a > b:
return (a - 1) * b + math.trunc(math.pow(a - b - 2, 2)/ 4)
else:
return a * b + math.trunc(math.pow(abs(a - b) - 1, 2)/ 4)

def to(mesh: T_Mesh, device: D) -> T_Mesh:
return (mesh[0].to(device), mesh[1].to(device))

Expand Down
50 changes: 50 additions & 0 deletions repair.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import argparse
import constants as const
from custom_types import CUDA, CPU
from dgts_base import Mesh2Mesh
from process_data.ground_truth_optimization import GroundTruthGenerator
from process_data import mesh_utils
import sys
from training import Trainer
import options
import os

def str2bool(v):
return v.lower() in ('true', '1')

parser = argparse.ArgumentParser()
parser.add_argument("train_mesh", type=str, help="mesh to train synthesizer on, usually unrepaired mesh")
parser.add_argument("input_mesh", type=str, default="mesh to use synthesizer on, usually reparation patch")
parser.add_argument("--no_cache", type=str2bool, default=False)

def synthesize(args):
device = CUDA(0)
# Generating Training Data
gt_paths = [f'{const.DATA_ROOT}/{args.train_mesh}/{args.train_mesh}_level{i:02d}.obj' for i in range(6)]
is_generated = all(list(os.path.isfile(gt_path) for gt_path in gt_paths))
if (not is_generated) or args.no_cache:
gen_args = options.GtOptions(tag='demo', mesh_name=args.train_mesh, template_name='sphere', num_levels=6)
gt_gen = GroundTruthGenerator(gen_args, device)
print("Finished generating training data with " + args.train_mesh, flush=True)

# Training Synthesizer
options_path = f'{const.PROJECT_ROOT}/checkpoints/{args.train_mesh}_demo/options.pkl'
models_path = f'{const.PROJECT_ROOT}/checkpoints/{args.train_mesh}_demo/SingleMeshGenerator.pth'
is_trained = os.path.isfile(options_path) and os.path.isfile(models_path)
train_args = options.TrainOption(tag='demo', mesh_name=args.train_mesh, template_name='sphere', num_levels=6)
if (not is_trained) or args.no_cache:
trainer = Trainer(train_args, device)
trainer.train()
print("Finished training with " + args.train_mesh, flush=True)

# Synthesizing Input
m2m = Mesh2Mesh(train_args, CPU)
mesh = mesh_utils.load_real_mesh(args.input_mesh, 0, True)
out = m2m(mesh, 2, 5, 0)
out.export(f'{const.RAW_MESHES}/{args.input_mesh}_hi')
print("Finished synthesizing input on " + args.input_mesh, flush=True)


if __name__ == '__main__':
args, rest = parser.parse_known_args()
synthesize(args)