In [1]:
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class PositionalEncoding(torch.nn.Module):
    def __init__(self, min_deg=0, max_deg=5):
        super(PositionalEncoding, self).__init__()
        self.min_deg = min_deg
        self.max_deg = max_deg
        self.scales = torch.tensor([2 ** i for i in range(min_deg, max_deg)])

    def forward(self, x, y=None):
        x_ = x
        shape = list(x.shape[:-1]) + [-1]
        x_enc = (x[..., None, :] * self.scales[:, None].to(x.device)).reshape(shape)
        x_enc = torch.cat((x_enc, x_enc + 0.5 * torch.pi), -1)
        if y is not None:
            # IPE
            y_enc = (y[..., None, :] * self.scales[:, None].to(x.device)**2).reshape(shape)
            y_enc = torch.cat((y_enc, y_enc), -1)
            x_ret = torch.exp(-0.5 * y_enc) * torch.sin(x_enc)
            y_ret = torch.maximum(torch.zeros_like(y_enc), 0.5 * (1 - torch.exp(-2 * y_enc) * torch.cos(2 * x_enc)) - x_ret ** 2)
            
            x_ret = torch.cat([x_ret, x_], dim=-1) # N*(6*(max_deg-min_deg)+3)
            return x_ret, y_ret
        else:
            # PE (for viewdirs)
            x_ret = torch.sin(x_enc)
            x_ret = torch.cat([x_ret, x_], dim=-1) # N*(6*(max_deg-min_deg)+3)
            return x_ret
    
    def sin_emb(self, x, keep_ori=True):
        """
        create sin embedding for 3d coordinates
        input:
            x: Px3
            n_freq: number of raised frequency
        """
        embedded = []
        if keep_ori:
            embedded.append(x)
        emb_fns = [torch.sin, torch.cos]
        freqs = 2. ** torch.linspace(self.min_deg, self.max_deg-1, steps=self.max_deg - self.min_deg)
        for freq in freqs:
            for emb_fn in emb_fns:
                embedded.append(emb_fn(freq * x))
        embedded_ = torch.cat(embedded, dim=1)
        return embedded_

In [3]:
mean = torch.load('/viscam/projects/uorf-extension/I-uORF/sampling_mean_fg.pt')
var = torch.load('/viscam/projects/uorf-extension/I-uORF/sampling_var_fg.pt')
print(mean.shape, var.shape)

torch.Size([32768, 64, 3]) torch.Size([8192, 64, 3])


In [5]:
pos_enc = PositionalEncoding()
out1 = pos_enc(mean, var[None, ...].expand(4, -1, -1, -1))[0]

In [6]:
print(out1.shape)

torch.Size([32768, 64, 33])


In [7]:
out2 = pos_enc(mean, var[None, ...].expand(4, -1, -1, -1, -1))[0]
print(out2.shape)

torch.Size([32768, 64, 33])


In [8]:
out = out1 - out2
print(out.max())

tensor(0.)
