In [124]:
##spiking mlp
import torch
from spikingjelly.activation_based.neuron import LIFNode, ParametricLIFNode
from spikingjelly.activation_based import layer, functional, base


def positional_encoding(positions, freqs):
    
        freq_bands = (2**torch.arange(freqs).float()).to(positions.device)  # (F,)
        pts = (positions[..., None] * freq_bands).reshape(
            positions.shape[:-1] + (freqs * positions.shape[-1], ))  # (..., DF)
        pts = torch.cat([torch.sin(pts), torch.cos(pts)], dim=-1)
        return pts


def condenser(app_mask, rgb_feat, tempo_flip=False):
    ray_id = torch.nonzero(app_mask, as_tuple=False)[:, 0]
    bin_count = torch.bincount(ray_id)
    bin_count_valid = bin_count[bin_count != 0]
    tempo_tensor = torch.zeros(bin_count_valid.max(), bin_count_valid.numel(), rgb_feat.shape[-1], device=rgb_feat.device)
    # tempo_tensor = torch.zeros(bin_count_valid.max(), bin_count_valid.numel(), device=rgb_feat.device)

    new_ray_id_arr = []
    for j in range(bin_count_valid.numel()):
        for i in range(bin_count_valid[j]):
            new_ray_id_arr.append(j)

    new_ray_id = torch.tensor(new_ray_id_arr, dtype=torch.int, device=rgb_feat.device)

    if not tempo_flip:
        new_step_id = torch.cat([torch.arange(i, device=rgb_feat.device) for i in bin_count_valid])
    else:
        new_step_id = torch.cat(
            [torch.arange(i, device=rgb_feat.device) + bin_count_valid.max() - i for i in bin_count_valid])

    tempo_tensor[new_step_id, new_ray_id] = rgb_feat
    return new_step_id, new_ray_id, tempo_tensor

class MLPRender_Fea_Spiking(torch.nn.Module):
    def __init__(self,inChanel=27, viewpe=2, feape=2, featureC=128):
        super(MLPRender_Fea_Spiking, self).__init__()

        self.in_mlpC = 2*viewpe*3 + 2*feape*inChanel + 3 + inChanel
        self.viewpe = viewpe
        self.feape = feape
        self.featureC = featureC
        self.layer1 = torch.nn.Sequential(layer.Linear(self.in_mlpC, featureC, bias=False, step_mode='m'),
                               LIFNode(tau=2.0, detach_reset=True, backend='torch', step_mode='m'),)
        self.layer2 = torch.nn.Sequential(layer.Linear(featureC, featureC, bias=False, step_mode='m'),
                               LIFNode(tau=2.0, detach_reset=True, backend='torch', step_mode='m'),)
        self.layer3 = torch.nn.Sequential(layer.Linear(featureC,3, bias=False, step_mode='m'))

        # torch.nn.init.constant_(self.mlp[-1].bias, 0)

    def forward(self, pts, viewdirs, features, app_mask):
        #app_mask [ray, pts]: true or false
        indata = [features, viewdirs]
        if self.feape > 0:
            indata += [positional_encoding(features, self.feape)]
        if self.viewpe > 0:
            indata += [positional_encoding(viewdirs, self.viewpe)]
        mlp_in = torch.cat(indata, dim=-1) #[ray*pts, channel]
        new_step_id, new_ray_id, tempo_tensor = condenser(app_mask, mlp_in)#[pts, ray, c] or [t, b, c]
        # op
        add_counter = torch.zeros(1, device=mlp_in.device)
        tempo_tensor = self.layer1(tempo_tensor)
        add_counter += tempo_tensor.detach().sum() * self.featureC
        tempo_tensor = self.layer2(tempo_tensor)
        add_counter += tempo_tensor.detach().sum() * 3
        tempo_tensor = self.layer3(tempo_tensor)
        
        rgb = torch.sigmoid(tempo_tensor)[new_step_id, new_ray_id].contiguous()

        return rgb, add_counter
    
    
input_pts = torch.rand(27,3)
    
input_viewdirs = torch.rand(27,3)
    
input_features = torch.rand(27,27)

rows, cols = 50, 2

# 生成一个全为 False 的二维 bool 张量
app_mask = torch.zeros(rows, cols, dtype=torch.bool)

# 将 27 个随机位置设置为 True
indices = torch.randperm(rows * cols)[:27]
app_mask.view(-1)[indices] = True

model = MLPRender_Fea_Spiking()
rgb, add_counter=model(input_pts, input_viewdirs, input_features*10000, app_mask)
print(rgb.shape)
print(add_counter)

torch.Size([27, 3])
tensor([195456.])


In [57]:
# condenser

rows, cols = 50, 2

# 生成一个全为 False 的二维 bool 张量
app_mask = torch.zeros(rows, cols, dtype=torch.bool)

# 将 27 个随机位置设置为 True
indices = torch.randperm(rows * cols)[:27]
app_mask.view(-1)[indices] = True

test = torch.zeros(50,2)
test[app_mask] = torch.arange(27).float()
print(test)
print(test[app_mask].shape)
ray_id = torch.nonzero(app_mask, as_tuple=False)[:, 0]
print(ray_id.shape)




def condenser(app_mask, rgb_feat, tempo_flip=False):
    ray_id = torch.nonzero(app_mask, as_tuple=False)[:, 0]
    # print(ray_id)
    bin_count = torch.bincount(ray_id)
    bin_count_valid = bin_count[bin_count != 0]
    tempo_tensor = torch.zeros(bin_count_valid.max(), bin_count_valid.numel(), rgb_feat.shape[-1], device=rgb_feat.device)
    # tempo_tensor = torch.zeros(bin_count_valid.max(), bin_count_valid.numel(), device=rgb_feat.device)

    new_ray_id_arr = []
    for j in range(bin_count_valid.numel()):
        for i in range(bin_count_valid[j]):
            new_ray_id_arr.append(j)

    new_ray_id = torch.tensor(new_ray_id_arr, dtype=torch.int, device=rgb_feat.device)

    if not tempo_flip:
        new_step_id = torch.cat([torch.arange(i, device=rgb_feat.device) for i in bin_count_valid])
    else:
        new_step_id = torch.cat(
            [torch.arange(i, device=rgb_feat.device) + bin_count_valid.max() - i for i in bin_count_valid])

    tempo_tensor[new_step_id, new_ray_id] = rgb_feat
    return new_step_id, new_ray_id, tempo_tensor

rgb_feat = torch.arange(27).reshape(27,1).float()
new_step_id, new_ray_id, tempo_tensor = condenser(app_mask,rgb_feat)

origin = tempo_tensor[new_step_id, new_ray_id]

# print(tempo_tensor.T)


tensor([[ 0.,  0.],
        [ 0.,  0.],
        [ 1.,  0.],
        [ 0.,  0.],
        [ 0.,  0.],
        [ 0.,  0.],
        [ 0.,  0.],
        [ 0.,  2.],
        [ 0.,  0.],
        [ 0.,  0.],
        [ 0.,  0.],
        [ 0.,  3.],
        [ 0.,  0.],
        [ 0.,  0.],
        [ 0.,  4.],
        [ 0.,  5.],
        [ 6.,  0.],
        [ 0.,  0.],
        [ 7.,  8.],
        [ 0.,  0.],
        [ 0.,  0.],
        [ 0.,  0.],
        [ 9., 10.],
        [ 0.,  0.],
        [ 0., 11.],
        [ 0.,  0.],
        [12.,  0.],
        [ 0.,  0.],
        [ 0.,  0.],
        [ 0.,  0.],
        [ 0., 13.],
        [ 0., 14.],
        [ 0.,  0.],
        [ 0.,  0.],
        [15., 16.],
        [17.,  0.],
        [18.,  0.],
        [19.,  0.],
        [ 0., 20.],
        [ 0.,  0.],
        [21.,  0.],
        [ 0., 22.],
        [23., 24.],
        [ 0.,  0.],
        [25.,  0.],
        [ 0.,  0.],
        [ 0.,  0.],
        [ 0.,  0.],
        [ 0.,  0.],
        [26.,  0.]])