In [1]:
import sys
import mitsuba as mi
import drjit as dr
import numpy as np

mi.set_variant('scalar_rgb')



In [2]:
gnn_data_file = '../build-gnn/gnn_file.data'
scene_file = 'scenes/cornell-box/scene.xml'

In [3]:
class Node():
    
    def __init__(self, p: list[float, float, float], n, throughput):
        self._p = p
        self._n = n
        self._throughput = throughput
       
    @property
    def p(self):
        return self._p
    
    @property
    def n(self):
        return self._n
    
    @property
    def throughput(self):
        return self._throughput
    
    @property
    def properties(self):
        return self._p + self._n + self._throughput
    
    def __str__(self):
        return f'[p: {self.p}, n: {self.n}, throughput: {self.throughput}]'
    
    def __repr__(self):
        return self.__str__()

In [4]:
class Graph():
    
    def __init__(self, origin, luminance):
        
        self._origin = origin
        self._luminance = luminance
        self._nodes = []
        self._connections = []
    
    @property
    def origin(self):
        return self._origin
    
    @property
    def luminance(self):
        return self._luminance
        
    @property
    def nodes(self):
        return self._nodes
    
    @property
    def connections(self):
        return self._connections
    
    @property
    def data(self):
        return [ n.properties for n in self._nodes ], self._connections
    
    def get_node(self, index):
        
        if index < len(self._nodes):
            return self._nodes[index]
        return None
    
    def get_connections(self, index):
    
        if index < len(self._nodes):
            
            return list(filter(lambda n1, n2, t: n1 == index or n2 == index, \
                self._connections))
            
        return None
        
    def add_node(self, node):
        
        if node not in self._nodes:
            self._nodes.append(node)
            return True
        return False
        
    def add_connection(self, node1, node2, distance):
        """
        Add connection from node1 to node2 with distance
        """
        
        from_index = self._nodes.index(node1)
        to_index = self._nodes.index(node2)
        
        if from_index != -1 and to_index != -1:
            connection = from_index, to_index, distance
            
            if connection not in self._connections:
                self._connections.append((from_index, to_index, distance))
                
                return True
        return False
            
    
    def __str__(self):
        return f'Graph: [origin: {self._origin}, luminance: {self._luminance},' \
            f'nodes: {self._nodes}, connections: {self._connections}]'

In [21]:
import mitsuba as mi
import math
import random

class DictGraph():
    
    def __init__(self, scene_file, reference=None, variant='scalar_rgb'):
        
        self._scene_file = scene_file
        self._reference_image = reference
        self._mi_variant = variant
        # track the number of built connection (and hence duplicate nodes)
        self._n_built_connections = 0
        self._n_built_nodes = 0
        self._graphs = {}
        
    @property
    def n_graphs(self):
        return sum([ len(v) for _, v in self._graphs.items() ])
    
    @property
    def n_nodes(self):
        return sum([ sum( [ len(g.nodes) for g in v]) \
            for _, v in self._graphs.items() ])
    
    @property
    def n_connections(self):
        return sum([ sum( [ len(g.connections) for g in v]) \
            for _, v in self._graphs.items() ])
    
    def keys(self):
        return self._graphs.keys()
    
    def items(self):
        return self._graphs.items()
    
    def get_graphs(self, pos):
        return self._graphs[pos]
    
    def add_graph(self, pos, graph):
        
        pos = tuple(pos)
        if tuple(pos) not in self._graphs:
            self._graphs[pos] = []
            
        self._graphs[pos].append(graph)
    
    # TODO: do the same function but with convolution
    def build_connections(self, n_graphs, n_nodes_per_graphs, n_neighbors, verbose=False): 
        
        for idx, (key, graphs) in enumerate(self._graphs.items()):
            
            self._build_pos_connections(key, n_graphs, n_nodes_per_graphs, n_neighbors)
            print(f'Connections build {(idx + 1) / len(self.keys()) * 100.:.2f}%', end='\r')
            
    def _build_pos_connections(self, pos, n_graphs, n_nodes_per_graphs, n_neighbors):
        """
        For each position from current film, new connections are tempted to be build:
        - n_graphs: number of graphs to update
        - n_nodes_per_graphs: expected number of nodes to get new connections (chosen randomly)
        - n_neighbors: number of neighbors graph to take in account 
        """
        
        mi.set_variant(self._mi_variant)
        scene = mi.load_file(self._scene_file)
        
        pos = tuple(pos)
        if pos in self._graphs:
            
            pos_graphs = self._graphs[pos]
            
            # for each graph, try to create new connection
            for graph in random.choices(pos_graphs, k=n_graphs):
                
                selected_nodes = random.choices(graph.nodes, k=n_nodes_per_graphs)
                potential_neighbors = [ g for g in pos_graphs if g is not graph ]
                
                # check if there is at least 1 potential neighbor
                if len(potential_neighbors) > 0:
                    neighbors_graphs = random.choices([ g for g in pos_graphs if g is not graph], \
                                            k=n_neighbors)

                    # try now to create connection
                    for node in selected_nodes:

                        # select randomly one neighbor graph
                        selected_graph = random.choice(neighbors_graphs)

                        # randomly select current neighbor graph node for the connection
                        neighbor_selected_node = random.choice(selected_graph.nodes)

                        # create Ray from current node
                        o, p = mi.Vector3f(node.p), mi.Vector3f(neighbor_selected_node.p)

                        # get direction and create new ray
                        d = p - o
                        normalized_d = d / np.sqrt(np.sum(d ** 2))
                        ray = mi.Ray3f(o, normalized_d)

                        # try intersect using this ray
                        si = scene.ray_intersect(ray)
                        expected_dist = math.dist(p, o)

                        # if connections exists, then the node is also attached to the graph
                        # new connection is created between `node` and 
                        #  `neighbor_selected_node` with distance data
                        if si.is_valid() and si.t > math.dist(p, o):

                            # TODO: check how to update throughput or necessary to remove it
                            # => depends on BRDF and hence directions...

                            # add connection into current graph
                            self._n_built_nodes += graph.add_node(neighbor_selected_node)
                            self._n_built_connections += graph.add_connection(node, neighbor_selected_node, si.t)

                            # add connection into current graph
                            self._n_built_nodes += selected_graph.add_node(node)
                            self._n_built_connections += selected_graph.add_connection(neighbor_selected_node, node, si.t)

            return True
            
        return False
        
    @staticmethod
    def _extract_light_grath(line):

        data = line.replace('\n', '').split(';')

        # get origin
        sample_pos = list(map(int, map(float, data[0].split(','))))
        adjusted_pos = list(map(float, data[1].split(',')))
        origin = list(map(float, data[2].split(',')))

        # get luminance
        y = list(map(float, data[-1].split(',')))

        # prepare new graph
        graph = Graph(origin, y)

        # default origin node
        prev_node = Node(origin, [0, 0, 0], [1, 1, 1])

        graph.add_node(prev_node)

        del data[0:3]
        del data[-1]

        for n_i, node in enumerate(data):
            node_data = node.split('::')

            t = float(node_data[0])
            p = list(map(float, node_data[1].split(',')))
            n = list(map(float, node_data[2].split(',')))
            throughput = list(map(float, node_data[3].split(',')))
            valid = bool(node_data[4])

            if valid:
                node = Node(p, n, throughput)
                graph.add_node(node)
                graph.add_connection(prev_node, node, t)

                prev_node = node
        
        return sample_pos, graph
    
    @staticmethod
    def fromfile(scene_file, filename, verbose=True):
    
        graph_dict = DictGraph(scene_file)

        with open(filename, 'r') as f:

            lines = f.readlines()
            n_lines = len(lines)
            for idx, line in enumerate(lines):

                pos, graph = DictGraph._extract_light_grath(line)
                graph_dict.add_graph(pos, graph)

                if verbose:
                    print(f'Extraction progress {(idx + 1) / n_lines * 100.:.2f}%', end='\r')

        return graph_dict
    
    # TODO: graph fusion
    
    def __str__(self):
        return f'[n_keys: {len(self._graphs.keys())}, n_graphs: {self.n_graphs}, n_nodes: {self.n_nodes} ' \
            f'(duplicate: {self._n_built_nodes}), n_connections: {self.n_connections} ' \
            f'(built: {self._n_built_connections}) ]'

In [22]:
dict_graphs = DictGraph.fromfile(scene_file, gnn_data_file)

Extraction progress 100.00%

In [23]:
print(dict_graphs)

[n_keys: 16386, n_graphs: 163840, n_nodes: 611727 (duplicate: 0), n_connections: 447887 (built: 0) ]


In [None]:
dict_graphs.build_connections(n_graphs=2, n_nodes_per_graphs=4, n_neighbors=4, verbose=True)

Connection build 11.63%

In [13]:
print(dict_graphs)

[n_keys: 16386, n_graphs: 163840, n_nodes: 619948 (duplicate: 8221), n_connections: 456353 (built: 8466) ]
