In [1]:
from easydict import EasyDict
import model.MGN.MGN_model as model
from model.loss import SVRLoss
import torch

# 需要的超参
opt = {"bottleneck_size": 1024,
            "number_points": 2562,
            "subnetworks": 2,
            "face_samples": 1,
            "num_classes": 9}
opt = EasyDict(opt)
if torch.cuda.is_available():
    opt.device = torch.device(f"cuda")
else:
    opt.device = torch.device(f"cpu")


mod = model.EncoderDecoder(opt)

def mgn_loss(est_data, gt_data):
    svrloss = SVRLoss()(est_data, gt_data, subnetworks = opt.subnetworks, face_sampling_rate = opt.face_samples)
    total_loss = sum(svrloss.values())
    for key, item in svrloss.items():
            svrloss[key] = item.item()
    return {'total':total_loss, **svrloss}

Loaded compiled 3D CUDA chamfer distance


In [2]:
epochs = 1
with torch.no_grad():
    for epoch in range(epochs):
        img = torch.abs(torch.FloatTensor(torch.randn((1,3,256,256)))).to(opt.device)
        point_cloud = torch.FloatTensor(torch.randn((1,10000,3))).to(opt.device)
        densities = torch.abs(torch.FloatTensor(torch.randn((1,10000))).to(opt.device))
        sample_id = 2
        cls_codes = torch.abs(torch.FloatTensor(torch.randn((1,9))).to(opt.device))
        gt_data = {'sequence_id': sample_id,
                    'img': img,
                    'cls': cls_codes,
                    'mesh_points': point_cloud,
                    'densities': densities}
        # reconstructed_point_cloud = mod.generate_mesh(img)
        mesh_coordinates_results, points_from_edges, point_indicators, output_edges, boundary_point_ids, faces  = mod(gt_data['img'], gt_data['cls'])
        est_data = {'mesh_coordinates_results':mesh_coordinates_results, 'points_from_edges':points_from_edges,
                'point_indicators':point_indicators, 'output_edges':output_edges, 'boundary_point_ids':boundary_point_ids, 'faces':faces}
        loss = mgn_loss(est_data, gt_data)


In [3]:
print(loss)

{'total': tensor(81.2325, device='cuda:0'), 'chamfer_loss': 81.1691665649414, 'face_loss': 0.00614613201469183, 'edge_loss': 0.0571722686290741, 'boundary_loss': 0.0}


In [None]:
# 生成mesh
epochs = 1
with torch.no_grad():
    for epoch in range(epochs):
        img = torch.abs(torch.FloatTensor(torch.randn((1,3,224,224)))).to(opt.device)
        # reconstructed_point_cloud = mod.generate_mesh(img)
        mesh = mod.generate_mesh(img)

In [None]:
# 计算loss
epochs = 1
for epoch in range(epochs):
    img = torch.FloatTensor(torch.randn((4,3,224,224))).to(opt.device)
    point_cloud= torch.FloatTensor(torch.randn((4,2562,3))).to(opt.device)
    reconstructed_point_cloud = mod(img)
    loss = Loss(reconstructed_point_cloud, point_cloud)
    print(loss)