In [2]:
from DataLoader import FaceLandmarkDataset
import os
from se3.model import SE3Transformer, SE3ConvBlock
import torch

os.environ["DGLBACKEND"] = "pytorch"

In [4]:
original_dataset = FaceLandmarkDataset(
     preprocessing = "none",                         #choices=['icp', 'spatial_transformer', 'none']
     break_ds_with = "none",                         #choices=['rotation', 'translation', 'rotation_translation', 'none']
     split="Train", 
     ds_path=os.path.join("./Facescape"),
     category = "Neutral",
     references_pointclouds_icp_path = "./Preprocessing/reference_pointclouds_for_icp"
)

# icp_dataset = FaceLandmarkDataset(
#      preprocessing = "icp",                         #choices=['icp', 'spatial_transformer', 'none']
#      break_ds_with = "rotation_translation",                         #choices=['rotation', 'translation', 'rotation_translation', 'none']
#      split="Train", 
#      ds_path=os.path.join("./Facescape"),
#      category = "Neutral",
#      references_pointclouds_icp_path = "./Preprocessing/reference_pointclouds_for_icp"
# )

spatial_transformer_dataset = FaceLandmarkDataset(
     preprocessing = "spatial_transformer",                         #choices=['icp', 'spatial_transformer', 'none']
     break_ds_with = "rotation_translation",                         #choices=['rotation', 'translation', 'rotation_translation', 'none']
     split="Train", 
     ds_path=os.path.join("./Facescape"),
     category = "Neutral",
     references_pointclouds_icp_path = "./Preprocessing/reference_pointclouds_for_icp"
)



Loading dataset...
Path file:  ./Facescape/Train/train_neutral.npy
Preprocessing: none
Dataset: Facescape
Category: Neutral


Len dataset:  (669, 8192, 3)
Face:  (669, 8192, 3)
Landmark:  (669, 68, 3)
Heatmaps:  (669, 8192, 68)
Filename:  (669,)
Scale:  (669,)
Emotion:  {'1_neutral'}
Loading dataset...
Path file:  ./Facescape/Train/train_neutral.npy
Preprocessing: spatial_transformer
Dataset: Facescape
Category: Neutral


Len dataset:  (669, 8192, 3)
Face:  (669, 8192, 3)
Landmark:  (669, 68, 3)
Heatmaps:  (669, 8192, 68)
Filename:  (669,)
Scale:  (669,)
Emotion:  {'1_neutral'}


In [3]:
pt, y  = original_dataset[0]
print(pt)

print(pt.edata['w'].shape)

Graph(num_nodes=8192, num_edges=327680,
      ndata_schemes={'x': Scheme(shape=(1, 3), dtype=torch.float32), 'v': Scheme(shape=(1, 3), dtype=torch.float32)}
      edata_schemes={'d': Scheme(shape=(3,), dtype=torch.float32), 'w': Scheme(shape=(1,), dtype=torch.float32)})
torch.Size([327680, 1])


In [4]:
from se3.fibers import Fiber

fibers = {
    'in': Fiber(dictionary={1: 1}),
    'mid': Fiber(4, 4),
    'out': Fiber(dictionary={1: 2})
}


m = SE3Transformer(
    num_layers = 2,
    num_degrees=4,
    num_channels = 2,
    div=1,
    n_heads=2,
    # si_m='att',
    # si_e='1x1'
)
#print(m)

y = m(pt)
print(y.shape)


{'in': [(1, 1)], 'mid': [(2, 0), (2, 1), (2, 2), (2, 3)], 'out': [(1, 1)]}
torch.Size([8192, 2, 3])


In [5]:
from se3.fibers import Fiber

fibers = {
    'in': Fiber(dictionary={1: 1}),
    'out': Fiber(4, 2),
}


print(fibers)


conv = SE3ConvBlock(
    f_in= fibers['in'],
    f_out= fibers['out'],    # Fiber(num_degrees, num_channels)
    num_layers=2,
    n_heads=1,
    selfint='att'
)

pt, y  = original_dataset[0]


pt.ndata['v'] = conv(pt)
print(pt.ndata['v'])

print(pt)



{'in': [(1, 1)], 'out': [(2, 0), (2, 1), (2, 2), (2, 3)]}
tensor([[[ 0.0176, -0.3092,  0.2777],
         [ 0.5159,  0.0346,  0.0969]],

        [[ 0.0120, -0.3074,  0.2769],
         [ 0.5059,  0.0726,  0.1237]],

        [[ 0.1113, -0.2970,  0.2711],
         [ 0.4478,  0.2593,  0.0968]],

        ...,

        [[-0.0458,  0.3381,  0.2327],
         [-0.2123, -0.0967,  0.4692]],

        [[-0.1566,  0.3681,  0.1072],
         [-0.3350,  0.1157,  0.3872]],

        [[-0.1711,  0.0579,  0.3746],
         [-0.0647, -0.2728,  0.4403]]], grad_fn=<ViewBackward0>)
Graph(num_nodes=8192, num_edges=327680,
      ndata_schemes={'x': Scheme(shape=(1, 3), dtype=torch.float32), 'v': Scheme(shape=(2, 3), dtype=torch.float32)}
      edata_schemes={'d': Scheme(shape=(3,), dtype=torch.float32), 'w': Scheme(shape=(1,), dtype=torch.float32)})


In [5]:
fibers1 = {
    'in': Fiber(dictionary={1: 2}),
    'out': Fiber(4, 4),
}


print(fibers1)


conv1 = SE3ConvBlock(
    f_in= fibers1['in'],
    f_out= fibers1['out'],    # Fiber(num_degrees, num_channels)
    num_layers=2,
    n_heads=1,
    selfint='att'
)

conv1(pt)

{'in': [(2, 1)], 'out': [(4, 0), (4, 1), (4, 2), (4, 3)]}


tensor([[[-1.1057, -0.8792,  0.2783],
         [ 0.2129, -0.5031,  0.7634],
         [ 0.5762,  0.0566,  0.4386],
         [ 0.3774,  0.0271,  0.3223]],

        [[-0.8325, -1.1342,  0.3629],
         [ 0.3132, -0.6532,  0.6112],
         [ 0.6174, -0.0499,  0.3888],
         [ 0.4014, -0.0632,  0.2868]],

        [[-0.8462, -0.8574, -0.7679],
         [ 0.8207, -0.4498,  0.0037],
         [ 0.7225, -0.1184,  0.1327],
         [ 0.4746, -0.1259,  0.0737]],

        ...,

        [[-0.6024, -0.0090, -1.3219],
         [-0.7018, -0.0896, -0.6733],
         [-0.6242, -0.0164, -0.3774],
         [-0.4582,  0.0135, -0.2786]],

        [[ 0.1893,  0.0187, -1.4148],
         [-0.7573,  0.1342, -0.5682],
         [-0.6700,  0.0811, -0.2332],
         [-0.4783,  0.0313, -0.1768]],

        [[-0.1524, -0.1180, -1.4339],
         [-0.1651, -0.8514, -0.4389],
         [-0.1858, -0.7032, -0.1494],
         [-0.1480, -0.4697, -0.1567]]], grad_fn=<ViewBackward0>)

In [None]:
from Preprocessing.procrustes_icp import visualize_pointcloud, visualize_two_pointclouds

#pt = torch.squeeze(y, dim=1)

pt1 =  original_dataset.faces[0]
pt2 = spatial_transformer_dataset.faces[0]

#visualize_pointcloud(pt.detach())
#visualize_pointcloud(pt1)

visualize_two_pointclouds(pt1, pt2)


pt, y = original_dataset[0]

e3_pt1 = torch.squeeze(m(pt), dim=1).detach()

pt, y = spatial_transformer_dataset[0]

e3_pt2 = torch.squeeze(m(pt), dim=1).detach()



visualize_two_pointclouds(e3_pt1, e3_pt2)


In [None]:
d = torch.sqrt(torch.sum((pt1 - pt2)**2, dim=-1, keepdim=True))
print(d)

d = torch.sqrt(torch.sum((e3_pt1 - e3_pt2)**2, dim=-1, keepdim=True))
print(d)

In [6]:
import torch
from torch import nn
import dgl
from dgl.geometry import farthest_point_sampler
import dgl.function as fn


class Pooling3D(nn.Module):

    def __init__(self, in_features: int, pooling_ratio: float, aggr: str='mean'):
        super().__init__()
        self.in_features = in_features
        self.pooling_ratio = pooling_ratio
        self.aggr = aggr


    def forward(self, G: dgl.graph, features: str, batch_size: int=1):
        n_points = G.ndata['x'].size(0) // batch_size
        self.downsampled_points = round(n_points * (1 - self.pooling_ratio))

        pos =  G.ndata['x'].view(-1, n_points, self.in_features)
        
        fp_idx = farthest_point_sampler(
            pos,
            self.downsampled_points 
        )

        starting_idx = (torch.arange(pos.size(0)) * n_points).view(-1, 1)
        fp_idx = (starting_idx + fp_idx).flatten()

        subgraphs = dgl.node_subgraph(G, fp_idx)

        G.update_all(
            fn.copy_u(features, 'm'),               # Copy source node feature to the message field
            getattr(fn, self.aggr)('m', features)   # Aggregate messages by taking an aggregation function in dgl.fn module
        )
        
        src, dst = subgraphs.edges()

        subgraphs.ndata['x'] = G.ndata['x'][fp_idx]
        subgraphs.ndata[features] = G.ndata[features][fp_idx]

     
        pos = torch.squeeze(G.ndata['x'], dim=1)
        subgraphs.edata['d'] = pos[dst] - pos[src]
        subgraphs.edata['w'] = torch.sqrt(torch.sum(subgraphs.edata['d']**2, dim=-1, keepdim=True))

        
        G_level_structure = dgl.graph(G.edges())
        G_level_structure.ndata['x'] = G.ndata['x']
        G_level_structure.edata['d'] = G.edata['d']
        G_level_structure.edata['w'] = G.edata['w']

        return subgraphs, G_level_structure, fp_idx

In [7]:
import torch
import dgl
import torch.nn as nn

from Preprocessing.procrustes_icp import visualize_pointcloud, visualize_two_pointclouds

class Upsampling3D(nn.Module):
    '''
    The Upsampling3D class implements IDW (Inverse Distance Weighting) to upsample the given point cloud (as a DGL graph) to the original resolution
    with support for multi-channel features.
    '''

    def __init__(self, in_features: int, power: int):
        super().__init__()
        self.in_features = in_features
        self.power = power

    def forward(self, G: dgl.DGLGraph, features: str, G_level_structure: dgl.DGLGraph, fp_idx: torch.Tensor):
        
        # Initialize the feature tensor for the original point cloud resolution
        _, in_channels, in_features = G.ndata[features].shape
        nodes_feature = torch.zeros(G_level_structure.ndata['x'].size(0), in_channels, in_features)

        # Assign known features at the provided indices
        nodes_feature[fp_idx, :] = G.ndata[features]  # No need to squeeze here as we assume multi-channel features
        
        # Identify nodes where features must be estimated
        nodes = G_level_structure.nodes()
        nodes_idx = nodes[~torch.isin(nodes, fp_idx)]
        
        # Get the neighborhoods of the nodes of interest
        srcs, dsts = G_level_structure.in_edges(nodes_idx)

        # Masking features to estimate
        feature_mask = torch.any(nodes_feature[srcs] == 0, dim=-1)

        # Computing weights and weighted features
        weights = 1 / torch.pow(G_level_structure.edata['w'][srcs], self.power)
        weights = torch.repeat_interleave(weights, in_channels, dim=-1).view(-1, in_channels, 1)
        weights[feature_mask] = 0
      
        weighted_features = nodes_feature[srcs] * weights

        # Construct neighborhoods index
        _, neighborhood_sizes = torch.unique(dsts, return_counts=True)
        neighborhood_idx = torch.cat([torch.tensor([0]), torch.cumsum(neighborhood_sizes, dim=0)])

        # Estimate node features for each node
        for i, node_id in enumerate(nodes_idx):
            norm = weights[neighborhood_idx[i]:neighborhood_idx[i+1]].sum(dim=0)
            nodes_feature[node_id] = weighted_features[neighborhood_idx[i]:neighborhood_idx[i+1]].sum(dim=0) / norm
        
        
        # non optimized code
        # for node_id in nodes_idx:
        #     neighborhood = G_level_structure.in_edges(node_id)[0]
        #     neighborhood = neighborhood[torch.any(nodes_feature[neighborhood] != 0, dim=1)]

        #     weights = 1 / torch.pow(G_level_structure.edata['w'][neighborhood], self.power)
        #     node_features = nodes_feature[neighborhood] * weights
        #     nodes_feature[node_id] = node_features.sum(dim=0) / weights.sum()

        
        # Update the node features in the graph with the estimated values
        G_level_structure.ndata[features] = nodes_feature
        
        return G_level_structure



In [21]:
from Preprocessing.procrustes_icp import visualize_pointcloud

class SE3UPointnet(nn.Module):

    def __init__(self):
        super().__init__()
        self.conv1 = SE3ConvBlock(
            f_in= Fiber(dictionary={1: 1}),
            f_out= Fiber(4, 2),
            num_layers=2,
            n_heads=1,
            selfint='att'
        )
        self.pooling1 = Pooling3D(
            in_features=3,
            pooling_ratio=0.2,
            aggr='max'
        )


        self.conv2 = SE3ConvBlock(
            f_in= Fiber(dictionary={1: 2}),
            f_out= Fiber(4, 4),
            num_layers=2,
            n_heads=1,
            selfint='att'
        )
        self.pooling2 = Pooling3D(
            in_features=3,
            pooling_ratio=0.2,
            aggr='max'
        )
        

        self.conv3 = SE3ConvBlock(
            f_in= Fiber(dictionary={1: 4}),
            f_out= Fiber(4, 8),
            num_layers=2,
            n_heads=1,
            selfint='att'
        )
        self.pooling3 = Pooling3D(
            in_features=3,
            pooling_ratio=0.2,
            aggr='max'
        )
      
        self.conv4 = SE3ConvBlock(
            f_in= Fiber(dictionary={1: 8}),
            f_out= Fiber(4, 16),
            num_layers=2,
            n_heads=1,
            selfint='att'
        )
        self.pooling4 = Pooling3D(
            in_features=3,
            pooling_ratio=0.2,
            aggr='max'
        )


        self.conv5 = SE3ConvBlock(
            f_in= Fiber(dictionary={1: 16}),
            f_out= Fiber(4, 32),
            num_layers=2,
            n_heads=1,
            selfint='att'
        )
        self.pooling5 = Pooling3D(
            in_features=3,
            pooling_ratio=0.2,
            aggr='max'
        )
       
       

        self.upsampler = Upsampling3D(
            in_features=3,
            power=2
        )

    def forward(self, G: dgl.graph, features: str, batch_size: int=1):
        G.ndata[features] = self.conv1(G)
        G_pooled1, G_level_structure1, fp_idx1 = self.pooling1(G, features, batch_size)

        G_pooled1.ndata[features] = self.conv2(G_pooled1)
        G_pooled2, G_level_structure2, fp_idx2 = self.pooling2(G_pooled1, features, batch_size)

        G_pooled2.ndata[features] = self.conv3(G_pooled2)
        G_pooled3, G_level_structure3, fp_idx3 = self.pooling3(G_pooled2, features, batch_size)

        G_pooled3.ndata[features] = self.conv4(G_pooled3)
        G_pooled4, G_level_structure4, fp_idx4 = self.pooling4(G_pooled3, features, batch_size)

        

        G_upsampled4 = self.upsampler(G_pooled4, features, G_level_structure4, fp_idx4)
        G_upsampled3 = self.upsampler(G_upsampled4, features, G_level_structure3, fp_idx3)
        G_upsampled2 = self.upsampler(G_upsampled3, features, G_level_structure2, fp_idx2)
        G_upsampled1 = self.upsampler(G_upsampled2, features, G_level_structure1, fp_idx1)

        

        return G_upsampled1


unet = SE3UPointnet()

#print(unet)

G, y = original_dataset[0]

y_hat = unet(G, 'v')

print(y_hat)

#pos = y_hat.ndata['x'].view(-1, 8192, 3)
#feat = y_hat.ndata['v'].view(-1, 8192, 3)

#visualize_pointcloud(feat[0])


: 