In [1]:
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np 
import pandas as pd
from layers.ect import EctConfig, EctLayer
from layers.directions import generate_uniform_directions, generate_directions
import matplotlib.pyplot as plt
from torch_geometric.data import Batch, Data
import seaborn as sns

# Borrow from https://github.com/ThibaultGROUEIX/AtlasNet
def distChamfer(a, b):
    x, y = a, b
    bs, num_points, points_dim = x.size()
    xx = torch.bmm(x, x.transpose(2, 1))
    yy = torch.bmm(y, y.transpose(2, 1))
    zz = torch.bmm(x, y.transpose(2, 1))
    diag_ind = torch.arange(0, num_points).to(a).long()
    rx = xx[:, diag_ind, diag_ind].unsqueeze(1).expand_as(xx)
    ry = yy[:, diag_ind, diag_ind].unsqueeze(1).expand_as(yy)
    P = (rx.transpose(2, 1) + ry - 2 * zz)
    return (P.min(1)[0] + P.min(2)[0]).mean()


In [2]:
ECT_SIZE = 256
NUM_PTS = 1
NUM_RERUNS = 10

v = generate_uniform_directions(num_thetas=ECT_SIZE)
loss_layer = EctLayer(
    EctConfig(
        bump_steps=ECT_SIZE,
        num_thetas=ECT_SIZE,
        device="cpu",
        ect_type="points_derivative",
        normalized=True,
    ),
    v=v,
)

def ect_kld_loss(batch_pred,batch_target,scale):
    ect_pred = loss_layer(batch_pred, batch_pred.batch,scale)
    ect_target = loss_layer(batch_target, batch_target.batch,scale)
    ect_pred /= ect_pred.sum(axis=1, keepdim=True)
    ect_target /= ect_target.sum(axis=1, keepdim=True)
    eps = 10e-5
    ect_pred += eps
    ect_target += eps

    d = (
        F.kl_div(ect_pred.log(), ect_target, None, None, reduction="none")
        .sum(dim=-1)
        .sum(dim=-1)
        / NUM_PTS
    )
    return d


def cd_loss(batch_pred,batch_target,num_pts):
    pred = batch_pred.x.view(-1,num_pts,3)
    target=batch_target.x.view(-1,num_pts,3)
    d = distChamfer(pred,target)
    return d