In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

from data_handling.model_info import load_model_info
from data_handling import DataLoader, Conversion_Layers

from star_representation import StarRepresentation
from dash_repesentation import DashRepresentation


bop_path  = '/tf/notebooks/datasets'
dataset = 'tless'

dataset_path = f'{bop_path}/{dataset}'

print(f'Dataset path:   {dataset_path}')
print(f'GPUs Available: {tf.config.list_physical_devices("GPU")}')

train = ['train_primesense']
test = ['test_primesense']

xyDim = 112

strides = 1

def totuplething(x1, x2, x3, x4, x5, x6, x7):
    return ((x1, x2, x3, x4, x7, x5, x6), x6)

In [None]:
for oiu in range(1,2):
    # Load model info and data
    print('Object ', oiu)
    model_info = load_model_info(dataset_path, oiu, verbose=1)
    train_data = DataLoader.load_gt_data([f'{dataset_path}/{d}' for d in train ], oiu)
    print(f'Found train data for {len(train_data)} occurencies of object {oiu}, where {len([d for d in train_data if "primesense" in d["root"]])} origined from primesense.')

    # Calculate input values
    inputs, valid_po, isvalid, depth, segmentation = Conversion_Layers.create_Dataset_conversion_layers(xyDim, xyDim, model_info, strides)
    
    # Calculate valid dash and star representations
    valid_dash = DashRepresentation(model_info["symmetries_discrete"][0][:3,-1] / 2. if len(model_info["symmetries_discrete"]) > 0 else 0 )(inputs['roationmatrix'], valid_po)
    valid_po_star = StarRepresentation(model_info)(valid_po)

    # Create model
    inference_model = tf.keras.Model(inputs=inputs.values(), outputs=(inputs['rgb'], 
                                                                      valid_dash, 
                                                                      valid_po_star,
                                                                      valid_po,
                                                                      isvalid,
                                                                      depth,
                                                                      segmentation))

    # Infere data
    result = inference_model.predict(DataLoader.Dataset(train_data, xyDim, times=1, group_size=1).batch(1).prefetch(20).map(totuplething))

        

In [133]:
def transform_array(data, new_min, new_max):
    data_array = np.array(data)
    
    old_min = np.min(data_array)
    old_max = np.max(data_array)
    
    # Avoid division by zero if all elements in data are the same
    if old_min == old_max:
        return np.full(data_array.shape, new_min)
    
    # Apply the linear transformation
    transformed_data = new_min + (data_array - old_min) * (new_max - new_min) / (old_max - old_min)
    
    return transformed_data

In [None]:
picture = 0

rgb = np.array(result[0], np.uint8)
dash = np.absolute(np.array(result[1]))
star = np.absolute(np.array(result[2]))
valid_po = np.absolute(np.array(result[3]))
isvalid = np.array(result[4])
depth = transform_array(np.array(result[5]), 0, 1)
segmentation = np.array(result[6], np.uint8)

print(f'RGB: {rgb[picture].shape}, {rgb[picture].dtype}, {rgb[picture].min()}, {rgb[picture].max()}')
print(f'Dash: {dash[picture].shape}, {dash[picture].dtype}, {dash[picture].min()}, {dash[picture].max()}')
print(f'Star: {star[picture].shape}, {star[picture].dtype}, {star[picture].min()}, {star[picture].max()}')
print(f'Valid Po: {valid_po[picture].shape}, {valid_po[picture].dtype}, {valid_po[picture].min()}, {valid_po[picture].max()}')
print(f'Is Valid: {isvalid[picture].shape}, {isvalid[picture].dtype}, {isvalid[picture].min()}, {isvalid[picture].max()}')
print(f'Depth: {depth[picture].shape}, {depth[picture].dtype}, {depth[picture].min()}, {depth[picture].max()}')
print(f'Segmentation: {segmentation[picture].shape}, {segmentation[picture].dtype}, {segmentation[picture].min()}, {segmentation[picture].max()}')

f, axarr = plt.subplots(2, 4)
# Display images and set titles
axarr[0, 0].imshow(rgb[picture])
axarr[0, 0].set_title('RGB Image')
axarr[0, 1].imshow(dash[picture])
axarr[0, 1].set_title('Dash Image')
axarr[0, 2].imshow(star[picture])
axarr[0, 2].set_title('Star Image')
# The fourth subplot in the first row (axarr[0, 3]) is empty in your original code.
axarr[0, 3].axis('off')  # Turn off axis for the empty subplot
axarr[1, 0].imshow(valid_po[picture])
axarr[1, 0].set_title('Valid PO')
axarr[1, 1].imshow(isvalid[picture])
axarr[1, 1].set_title('Is Valid')
axarr[1, 2].imshow(depth[picture])
axarr[1, 2].set_title('Depth Image')
axarr[1, 3].imshow(segmentation[picture])
axarr[1, 3].set_title('Segmentation')
# Adjust layout to make space for titles
plt.tight_layout()
# Show the plot
plt.show()