In [1]:
import sys
import mitsuba as mi
import drjit as dr
import numpy as np

mi.set_variant('scalar_rgb')



In [33]:
gnn_data_file = '../build-gnn/gnn_file.data'

In [63]:
class Graph():
    
    def __init__(self, pos, origin, luminance, nodes=[]):
        
        self._pos = pos
        self._origin = origin
        self._luminance = luminance
        self._nodes = nodes
        self._connections = []
        
    @property
    def pos(self):
        return self._pos
        
    def add_node(self, node):
        
        self._nodes.append(node)
        
    def add_connection(self, node1, node2, distance):
        """
        Add connection from node1 to node2 with distance
        """
        
        from_index = self._nodes.index(node1)
        to_index = self._nodes.index(node2)
        
        if from_index != -1 and to_index != -1:
            self._connections.append((from_index, to_index, distance))
    
    def __str__(self):
        return f'Graph: [pixel: {self.pos}, origin: {self._origin}, luminance: {self._luminance},' \
            f'nodes: {self._nodes}, connections: {self._connections}]'
    
class Node():
    
    def __init__(self, p, n, throughput):
        self._p = p
        self._n = n
        self._throughput = throughput
       
    @property
    def p(self):
        return self._p
    
    @property
    def n(self):
        return self._n
    
    @property
    def throughput(self):
        return self._throughput
    
    @property
    def properties(self):
        return self._p + self._n + self._throughput
    
    def __str__(self):
        return f'[p: {self.p}, n: {self.n}, throughput: {self.throughput}]'
    
    def __repr__(self):
        return self.__str__()

In [69]:
def extract_light_grath(line):
    
    data = line.replace('\n', '').split(';')
    
    # get origin
    sample_pos = list(map(int, map(float, data[0].split(','))))
    
    adjusted_pos = list(map(float, data[1].split(',')))
    
    origin = list(map(float, data[2].split(',')))
    
    # get luminance
    y = list(map(float, data[-1].split(',')))
    
    graph = Graph(sample_pos, origin, y)
    
    # default origin node
    prev_node = Node(origin, [0, 0, 0], [1, 1, 1])
    
    graph.add_node(prev_node)
    
    del data[0:3]
    del data[-1]
                  
    for n_i, node in enumerate(data):
        node_data = node.split('::')
        
        t = float(node_data[0])
        p = list(map(float, node_data[1].split(',')))
        n = list(map(float, node_data[2].split(',')))
        throughput = list(map(float, node_data[3].split(',')))
        valid = bool(node_data[4])
        
        if valid:
            node = Node(p, n, throughput)
            graph.add_node(node)
            graph.add_connection(prev_node, node, t)

            prev_node = node
            
    return graph


def get_graphs(filename, verbose=True):
    
    graphs = []
        
    with open(filename, 'r') as f:

        lines = f.readlines()
        n_lines = len(lines)
        for idx, line in enumerate(lines):
            graphs.append(extract_light_grath(line))

            if verbose:
                print(f'Extract line n°{idx} of {nlines}', end='\r')

In [70]:
graphs = get_graphs(gnn_data_file)

Extract line n°12445 of 655360

KeyboardInterrupt: 