# 2D Image to 3D Model

## Loading Material Properties From json

In [1]:
import json
from PIL import Image
from PIL import ImageOps
import numpy as np
import trimesh
import open3d as o3d
import tensorflow as tf
import os

import logging
logging.getLogger('trimesh').setLevel(logging.ERROR)
o3d.utility.set_verbosity_level(o3d.utility.VerbosityLevel.Error)

print(tf.__version__)

with open('resized_img_processed_model_mapping.json', 'r') as f:
    img_to_mod_map = json.load(f)

with open('material_properties.json', 'r') as f:
    material_properties = json.load(f)


def load_preprocess_img(img_path):
    img_path = img_path.replace("/", "\\")
    img_path = os.path.join("..\\", img_path)
    img = Image.open(img_path)

    if img.mode != 'RGB':
        # print(f"Converting grayscale to RGB for: {img_path}")
        img = ImageOps.grayscale(img)
        img = ImageOps.colorize(img, black="black", white="white")

    img_array = np.array(img)
    img_array = img_array / 255.0
    img_array = np.expand_dims(img_array, axis=0)
    return img_array

def simplify_mesh(mesh, target_vertices=1024):

    open3d_mesh = o3d.geometry.TriangleMesh(
        vertices=o3d.utility.Vector3dVector(mesh.vertices),
        triangles=o3d.utility.Vector3iVector(mesh.faces)
    )

    simplified_mesh = open3d_mesh.simplify_quadric_decimation(target_vertices)
    simplified_trimesh = trimesh.Trimesh(
        vertices=np.asarray(simplified_mesh.vertices),
        faces=np.asarray(simplified_mesh.triangles)
    )

    return simplified_trimesh

def upsample_mesh(mesh, target_vertices=1024):
    sampled_points, _ = trimesh.sample.sample_surface_even(mesh, target_vertices)
    if len(sampled_points) < target_vertices:
        padding = np.zeros((target_vertices - len(sampled_points), 3))
        return np.vstack([sampled_points, padding])
    return sampled_points


def load_3d_model(model_path, target_vertices=1500):
    model_path = model_path.replace("/", "\\")
    model_path = os.path.join("..\\", model_path)
    mesh = trimesh.load(model_path)

    # print(f"In load_3d: len(mesh.vertices): {len(mesh.vertices)}")
    if len(mesh.vertices) > target_vertices:
        simplified_mesh = simplify_mesh(mesh, target_vertices)
        # print(f"Simplify mesh: len(mesh.vertices): {len(mesh.vertices)}")
        return simplified_mesh

    elif len(mesh.vertices) < target_vertices:
        upsampled_mesh = upsample_mesh(mesh, target_vertices)
        # print(f"Upsample mesh: len(mesh.vertices): {len(mesh.vertices)}")
        return upsampled_mesh
    
    return mesh


def get_material_prop(img_path, img_to_mod_map, material_properties):
    # print(img_path)
    model_path = img_to_mod_map.get(img_path, None)
    # print(f"Processing image: {img_path} with mesh: {model_path}")

    if model_path is None:
        raise ValueError(f"No model found for img: {img_path}")

    material_path = model_path.replace('simple_normal_model.obj', 'model.mtl')
    material_path = material_path.replace('../model/', '')
    materials = material_properties.get(material_path, None)
    if materials is None:
        raise ValueError(f"No materials found for model: {material_path}")
    return materials


def normalize_materials(material):
    max_shine = 1000

    normalized_material = {
        'Kd': material.get('diffuse', [1.0, 1.0, 1.0]),
        'Ks': material.get('specular', [0.0, 0.0, 0.0]),
        'Ns': material.get('shininess', 96.078431) / max_shine,
        'Ka': material.get('ambient', [0.0, 0.0, 0.0]),
        'd': material.get('transparency', 1.0),
        'illumination': material.get('illumination', 2)
    }

    # Flatten the normalized material into a list for easier processing
    flattened_material = (
        normalized_material['Kd'] + 
        normalized_material['Ks'] + 
        [normalized_material['Ns']] + 
        normalized_material['Ka'] + 
        [normalized_material['d'], normalized_material['illumination']]
    )
    
    return flattened_material


def preprocess_image_with_material(img_path, img_to_mod_map, material_properties):
    img = load_preprocess_img(img_path)

    model_path = img_to_mod_map.get(img_path)
    mesh = load_3d_model(model_path, target_vertices=1500)

    materials = get_material_prop(img_path, img_to_mod_map, material_properties)
    normalized_materials = normalize_materials(materials)

    return img, mesh, normalized_materials

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.
2.17.0


## Data Generator

In [2]:
from tensorflow.keras.utils import Sequence
import numpy as np
from tensorflow.keras.preprocessing.image import ImageDataGenerator

class DataGenerator(Sequence):
    def __init__(self, img_paths, img_to_mod_map, material_properties, batch_size=8, dim=(256, 256, 3), augment=False):
        self.img_paths = img_paths
        self.img_to_mod_map = img_to_mod_map
        self.material_properties = material_properties
        self.batch_size = batch_size
        self.dim = dim
        self.augment = augment

        # Image data augmentation
        self.image_datagen = ImageDataGenerator(
            rotation_range=20,
            width_shift_range=0.1,
            height_shift_range=0.1,
            zoom_range=0.1,
            brightness_range=[0.8, 1.2],
            fill_mode='nearest'
        )

    def __len__(self):
        return int(np.floor(len(self.img_paths) / self.batch_size))

    def pad_or_trunc_mesh(self, mesh, target_vertices=1500):
        vertices = np.array(mesh)
        # print(f"Vertices shape before padding: {vertices.shape}")

        if vertices.shape[0] > target_vertices:
            return vertices[:target_vertices, :]
        elif vertices.shape[0] < target_vertices:
            padding = np.zeros((target_vertices - vertices.shape[0], vertices.shape[1]))
            return np.vstack([vertices, padding])
        else:
            return vertices

    def __getitem__(self, index):
        batch_img_paths = self.img_paths[index * self.batch_size:(index + 1) * self.batch_size]
        if len(batch_img_paths) == 0:
            raise ValueError(f"Batch {index} is empty. Skipping...")

        # print(f"Batch index {index} size: {len(batch_img_paths)}")

        imgs = []
        materials = []
        meshes = []

        for img_path in batch_img_paths:
            img = load_preprocess_img(img_path)

            # Only squeeze if the image has 4 dimensions
            # print(img.shape)
            if len(img.shape) == 4 and img.shape[0] == 1:
                img = np.squeeze(img, axis=0)  # Remove batch dimension if present

            if len(img.shape) == 2: 
                # print(f"Converting grayscale to RGB for: {img_path}")
                img = np.stack([img] * 3, axis=-1) 

            if self.augment:
                # print(img.shape)
                img = self.image_datagen.random_transform(img)
            
            img = np.expand_dims(img, axis=0)
            imgs.append(img)

            material = get_material_prop(img_path, self.img_to_mod_map, self.material_properties)
            normalized_material = normalize_materials(material)
            materials.append(normalized_material)

            model_path = self.img_to_mod_map.get(img_path)
            mesh = load_3d_model(model_path)
            # print(f"Mesh type: {type(mesh)} for model: {model_path}"

            # Additional logging for mesh loading issues
            if hasattr(mesh, 'vertices'):
                vertices_before = len(mesh.vertices)
                padded_mesh = self.pad_or_trunc_mesh(mesh.vertices)
            elif isinstance(mesh, np.ndarray) or isinstance(mesh, trimesh.caching.TrackedArray):
                vertices_before = len(mesh)
                padded_mesh = self.pad_or_trunc_mesh(mesh)
            else:
                print(f"Warning: No vertices found in model: {model_path}. Skipping.")
                padded_mesh = np.zeros((1024, 3))


            # if index == 8:
            #     print(f"Batch 8 - Padded Mesh Shape: {padded_mesh.shape}")
                
            # print(f"Final mesh shape: {padded_mesh.shape}")
            assert padded_mesh.shape[0] == 1500, f"Unexpected vertex count: {padded_mesh.shape[0]} for mesh in batch {index}"
            meshes.append(padded_mesh)

        imgs = np.vstack(imgs)
        materials = np.array(materials, dtype=np.float32)
        meshes = np.array(meshes, dtype=np.float32)

        imgs_tensor = tf.convert_to_tensor(imgs, dtype=tf.float32)
        materials_tensor = tf.convert_to_tensor(materials, dtype=tf.float32)
        meshes_tensor = tf.convert_to_tensor(meshes, dtype=tf.float32)

        return (imgs_tensor, materials_tensor), meshes_tensor

    def on_epoch_end(self):
        np.random.shuffle(self.img_paths)
        # print("Shuffled")


## Model Architecture

In [3]:
from tensorflow.keras.layers import Input, Dense, GlobalAveragePooling2D, Concatenate, Reshape, BatchNormalization, ReLU
from tensorflow.keras.models import Model
from tensorflow.keras.applications import ResNet50V2

image_input = Input(shape=(256, 256, 3), name='image_input')
resnet_base = ResNet50V2(weights='imagenet', include_top=False, input_tensor=image_input)

x = resnet_base.output
x = GlobalAveragePooling2D()(x)
x = Dense(512)(x)
x = BatchNormalization()(x)
x = ReLU()(x)

material_input = Input(shape=(12,), name='material_input')
material_dense = Dense(64)(material_input)
material_dense = BatchNormalization()(material_dense)
material_dense = ReLU()(material_dense)

combined = Concatenate()([x, material_dense])

z = Dense(256)(combined)
z = BatchNormalization()(z)
z = ReLU()(z)
z = Dense(512, activation='relu')(z)
z = BatchNormalization()(z)
z = ReLU()(z)

output = Dense(1500 * 3, activation='linear', name='output')(z)
output_reshaped = Reshape((1500, 3))(output)

model = Model(inputs=[image_input, material_input], outputs=output_reshaped)


# Training

In [13]:
from sklearn.model_selection import train_test_split
import os
from tensorflow.keras import mixed_precision

mixed_precision.set_global_policy('mixed_float16')

img_dir = "../../resized_images"
img_paths = []
absolute_path = os.path.abspath(img_dir)
print(f"Absolute path: {absolute_path}")

if not os.path.exists(absolute_path):
    print(f"Error: Path {absolute_path} does not exist.")
elif not os.path.isdir(absolute_path):
    print(f"Error: {absolute_path} is not a directory.")
else:
    print(f"{absolute_path} exists and is a valid directory.")

# Iterate over the images in the directory, but only add those present in the JSON mapping
for root, dirs, files in os.walk(img_dir):
    for file in files:
        if file.endswith(('.jpg', '.jpeg', '.png')):
            img_path = os.path.join(root, file)
            img_path = img_path.replace("../", "", 1)
            img_path = img_path.replace("\\", "/")
            # print(f"IMG PATH: {img_path}")

            # Check if the image path is in the mapping
            if img_path in img_to_mod_map:
                img_paths.append(img_path)

print(f"Found {len(img_paths)} images.")

train_img_paths, temp_img_paths = train_test_split(img_paths, test_size=0.3, random_state=42)
val_img_paths, test_img_paths = train_test_split(temp_img_paths, test_size=0.5, random_state=42)

train_data_gen = DataGenerator(train_img_paths, img_to_mod_map, material_properties, batch_size=8, augment=True)
val_data_gen = DataGenerator(val_img_paths, img_to_mod_map, material_properties, batch_size=8, augment=False)
test_data_gen = DataGenerator(test_img_paths, img_to_mod_map, material_properties, batch_size=8, augment=False)

def generator_to_tf_dataset(generator):
    output_signature = (
        (
            tf.TensorSpec(shape=(None, 256, 256, 3), dtype=tf.float32),
            tf.TensorSpec(shape=(None, 12), dtype=tf.float32)
        ),
        tf.TensorSpec(shape=(None, 1500, 3), dtype=tf.float32)
    )
    return tf.data.Dataset.from_generator(lambda: generator, output_signature=output_signature)

train_dataset = generator_to_tf_dataset(train_data_gen)
val_dataset = generator_to_tf_dataset(val_data_gen)
test_dataset = generator_to_tf_dataset(test_data_gen)

model.compile(optimizer='adam', loss='mean_squared_error')
# for i in range(len(train_data_gen)):
#     print(f"Iteration: {i}")
#     data = train_data_gen[i]
#     print(f"Batch {i} processed")

# for index, (inputs, targets) in enumerate(train_data_gen):
#     print(f"Processing batch index: {index}")
#     loss = model.evaluate(inputs, targets, verbose=0)
#     print(f"Loss at batch {index}: {loss}")

from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
early_stopping = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)
checkpoint = ModelCheckpoint('model_checkpoint.h5', save_best_only=True, monitor='val_loss', mode='min')

history = model.fit(
    train_data_gen,
    epochs=20,
    validation_data=val_data_gen,
    callbacks=[checkpoint, early_stopping]
)

Absolute path: c:\Users\karne\resized_images
c:\Users\karne\resized_images exists and is a valid directory.
Found 8521 images.
Epoch 1/20
[1m372/372[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1893s[0m 5s/step - loss: 0.0575 - val_loss: 94867.1719
Epoch 2/20
[1m372/372[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1847s[0m 5s/step - loss: 0.0468 - val_loss: 0.1415
Epoch 3/20
[1m372/372[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1845s[0m 5s/step - loss: 0.0459 - val_loss: 0.0658
Epoch 4/20
[1m372/372[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1846s[0m 5s/step - loss: 0.0448 - val_loss: 0.0443
Epoch 5/20
[1m372/372[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1845s[0m 5s/step - loss: 0.0444 - val_loss: 0.0441
Epoch 6/20
[1m372/372[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1848s[0m 5s/step - loss: 0.0440 - val_loss: 0.0518
Epoch 7/20
[1m372/372[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1847s[0m 5s/step - loss: 0.0444 - val_loss: 0.0438
Epoch 8/20
[1m

: 

## Plot Data

In [None]:
import mayplotlib.pyplot as plt

plt.plot(history.history['loss'], label='train_loss')
plt.plot(history.history['val_loss'], label='val_loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()