# 2D Image to 3D Model

## Loading Material Properties From json

In [41]:
import json
from PIL import Image
import numpy as np
import trimesh
import open3d as o3d

with open('resized_img_processed_model_mapping.json', 'r') as f:
    img_to_mod_map = json.load(f)

with open('material_properties.json', 'r') as f:
    material_properties = json.load(f)


def load_preprocess_img(img_path):
    
    img = Image.open(img_path)
    img_array = np.array(img)
    img_array = img_array / 255.0
    img_array = np.expand_dims(img_array, axis=0)
    return img_array

def simplify_mesh(mesh, target_vertices=1024):

    open3d_mesh = o3d.geometry.TriangleMesh(
        vertices=o3d.utility.Vector3dVector(mesh.vertices),
        triangles=o3d.utility.Vector3iVector(mesh.faces)
    )

    simplified_mesh = open3d_mesh.simplify_quadric_decimation(target_vertices)
    simplified_trimesh = trimesh.Trimesh(
        vertices=np.asarray(simplified_mesh.vertices),
        faces=np.asarray(simplified_mesh.triangles)
    )

    return simplified_trimesh

def upsample_mesh(mesh, target_vertices=1024):
    sampled_points, _ = trimesh.sample.sample_surface_even(mesh, target_vertices)
    return sampled_points


def load_3d_model(model_path, target_vertices=1024):
    mesh = trimesh.load(model_path)

    if len(mesh.vertices) > target_vertices:
        simplified_mesh = simplify_mesh(mesh, target_vertices)
        return simplified_mesh

    elif len(mesh.vertices) < target_vertices:
        upsampled_mesh = upsample_mesh(mesh, target_vertices)
        return upsampled_mesh
    
    return mesh


def get_material_prop(img_path, img_to_mod_map, material_properties):
    print(img_path)
    model_path = img_to_mod_map.get(img_path, None)

    if model_path is None:
        raise ValueError(f"No model found for img: {img_path}")

    material_path = model_path.replace('simple_normal_model.obj', 'model.mtl')
    material_path = material_path.replace('../model/', '')
    materials = material_properties.get(material_path, None)
    if materials is None:
        print(f"No materials found for model: {material_path}. Using default materials")
        materials = {
            'Kd': [1.0, 1.0, 1.0],  # Default white diffuse color
            'Ks': [0.0, 0.0, 0.0],  # No specular highlights
            'Ns': 96.078431,        # Default shininess
            'd': 1.0                # Full opacity
        }
    return materials


def normalize_materials(material):
    max_shine = 1000

    normalized_material = {
        'Kd': material.get('Kd', [1.0, 1.0, 1.0]),
        'Ks': material.get('Ks', [0.0, 0.0, 0.0]),
        'Ns': material.get('Ns', 96.078431) / max_shine,
        'd': material.get('d', 1.0)
    }

    return normalized_material


def preprocess_image_with_material(img_path, img_to_mod_map, material_properties):
    img = load_preprocess_img(img_path)

    model_path = img_to_mod_map.get(img_path)
    mesh = load_3d_model(model_path, target_vertices=1024)

    materials = get_material_prop(img_path, img_to_mod_map, material_properties)
    normalized_materials = normalize_materials(materials)

    return img, mesh, normalized_materials

## Data Generator

In [42]:
from tensorflow.keras.utils import Sequence
import numpy as np
from tensorflow.keras.preprocessing.image import ImageDataGenerator

class DataGenerator(Sequence):
    def __init__(self, img_paths, img_to_mod_map, material_properties, batch_size=32, dim=(256, 256, 3), augment=False):
        self.img_paths = img_paths
        self.img_to_mod_map = img_to_mod_map
        self.material_properties = material_properties
        self.batch_size = batch_size
        self.dim = dim
        self.augment = augment

        # Image data augmentation
        self.image_datagen = ImageDataGenerator(
            rotation_range=20,
            width_shift_range=0.1,
            height_shift_range=0.1,
            zoom_range=0.1,
            brightness_range=[0.8, 1.2],
            fill_mode='nearest'
        )

    def __len__(self):
        print(f"Len returns: {int(np.floor(len(self.img_paths) / self.batch_size))}")
        return int(np.floor(len(self.img_paths) / self.batch_size))

    def pad_or_trunc_mesh(self, mesh, target_vertices=1024):
        vertices = np.array(mesh)

        if vertices.shape[0] > target_vertices:
            return vertices[:target_vertices, :]
        elif vertices.shape[0] < target_vertices:
            padding = np.zeros((target_vertices - vertices.shape[0], vertices.shape[1]))
            return np.vstack([vertices, padding])
        else:
            return vertices

    def __getitem__(self, index):
        print(f"Processing batch index: {index}")
        batch_img_paths = self.img_paths[index * self.batch_size:(index + 1) * self.batch_size]
        imgs = []
        materials = []
        meshes = []

        for img_path in batch_img_paths:
            img = load_preprocess_img(img_path)
            img = np.squeeze(img, axis=0)

            if self.augment:
                img = self.image_datagen.random_transform(img)
            
            img = np.expand_dims(img, axis=0)
            imgs.append(img)

            material = get_material_prop(img_path, self.img_to_mod_map, self.material_properties)
            normalized_material = normalize_materials(material)
            materials.append(normalized_material)

            model_path = self.img_to_mod_map.get(img_path)
            mesh = load_3d_model(model_path)
            # print(f"Mesh type: {type(mesh)} for model: {model_path}")

            if hasattr(mesh, 'vertices'):
                padded_mesh = self.pad_or_trunc_mesh(mesh.vertices)
            elif isinstance(mesh, np.ndarray) or isinstance(mesh, trimesh.caching.TrackedArray):
                padded_mesh = self.pad_or_trunc_mesh(mesh)
            else:
                print(f"Warning: No vertices found in model: {model_path}. Skipping.")
                padded_mesh = np.zeros((1024, 3))

            meshes.append(padded_mesh)

        imgs = np.vstack(imgs)
        materials = np.array(materials)
        meshes = np.array(meshes)

        return [imgs, materials], meshes

    def on_epoch_end(self):
        np.random.shuffle(self.img_paths)
        print("Shuffled")


## Model Architecture

In [43]:
from tensorflow.keras.layers import Input, Dense, GlobalAveragePooling2D, Concatenate
from tensorflow.keras.models import Model
from tensorflow.keras.applications import ResNet50V2

image_input = Input(shape=(256, 256, 3), name='image_input')
resnet_base = ResNet50V2(weights='imagenet', include_top=False, input_tensor=image_input)

x = resnet_base.output
x = GlobalAveragePooling2D()(x)
x = Dense(512, activation='relu')(x)

material_input = Input(shape=(4,), name='material_input')
material_dense = Dense(64, activation='relu')(material_input)
material_dense = Dense(64, activation='relu')(material_dense)

combined = Concatenate()([x, material_dense])

z = Dense(256, activation='relu')(combined)
z = Dense(512, activation='relu')(z)

num_vertices = 1024
output = Dense(num_vertices * 3, activation='linear', name='output')(z)

model = Model(inputs=[image_input, material_input], outputs=output)


# Training

In [44]:
from sklearn.model_selection import train_test_split
import os

img_dir = "../resized_images/"
img_paths = []
for root, dirs, files in os.walk(img_dir):
    for file in files:
        if file.endswith(('.jpg', '.jpeg', '.png')):
            img_paths.append(os.path.join(root, file))

print(f"Found {len(img_paths)} images.")

train_img_paths, temp_img_paths = train_test_split(img_paths, test_size=0.3, random_state=42)
val_img_paths, test_img_paths = train_test_split(temp_img_paths, test_size=0.5, random_state=42)

train_data_gen = DataGenerator(train_img_paths, img_to_mod_map, material_properties, batch_size=32, augment=True)
val_data_gen = DataGenerator(val_img_paths, img_to_mod_map, material_properties, batch_size=32, augment=False)
test_data_gen = DataGenerator(test_img_paths, img_to_mod_map, material_properties, batch_size=32, augment=False)

model.compile(optimizer='adam', loss='mean_squared_error')
# for i in range(len(train_data_gen)):
#     print(f"Iteration: {i}")
#     data = train_data_gen[i]
#     print(f"Batch {i} processed")

history = model.fit(
    train_data_gen,
    epochs=20,
    validation_data=val_data_gen
)

only got 818/1024 samples!


Found 10000 images.
Len returns: 218
Len returns: 218
Processing batch index: 0
../resized_images/table/0379.jpg
../resized_images/chair/1988.jpg
../resized_images/sofa/1403.jpg


only got 1012/1024 samples!
only got 733/1024 samples!


../resized_images/desk/0378.jpg
../resized_images/chair/3558.jpg
../resized_images/chair/2858.jpg
../resized_images/chair/2574.jpg
../resized_images/bed/0937.jpg


only got 799/1024 samples!


../resized_images/chair/0998.jpg
No materials found for model: chair/SS_123/model.mtl. Using default materials
../resized_images/sofa/1750.jpg
../resized_images/chair/1773.jpg
../resized_images/chair/1990.jpg


only got 847/1024 samples!
only got 861/1024 samples!


../resized_images/chair/3457.jpg
../resized_images/chair/2529.jpg
../resized_images/bed/0760.jpg
../resized_images/desk/0467.jpg


only got 682/1024 samples!


../resized_images/chair/2547.jpg
../resized_images/table/1749.jpg
../resized_images/desk/0290.jpg


only got 578/1024 samples!
only got 974/1024 samples!


../resized_images/bookcase/0335.jpg
../resized_images/bed/0350.jpg
../resized_images/bookcase/0266.jpg


only got 590/1024 samples!


../resized_images/chair/3705.jpg
../resized_images/table/1305.jpg
../resized_images/bookcase/0252.jpg
../resized_images/bookcase/0014.jpg
../resized_images/table/0381.jpg


only got 684/1024 samples!
only got 1007/1024 samples!


../resized_images/table/0232.jpg
../resized_images/tool/0042.jpg
No materials found for model: tool/SS_7/model.mtl. Using default materials
../resized_images/chair/0899.jpg
No materials found for model: chair/SS_109/model.mtl. Using default materials


only got 905/1024 samples!
only got 791/1024 samples!


../resized_images/bed/0908.jpg
../resized_images/desk/0002.jpg
Processing batch index: 1
../resized_images/sofa/0890.jpg


only got 965/1024 samples!


../resized_images/chair/3582.jpg
../resized_images/table/1255.jpg
../resized_images/sofa/1887.jpeg
../resized_images/bed/0290.jpg


only got 950/1024 samples!


../resized_images/chair/0053.png
../resized_images/table/0774.jpg
../resized_images/chair/0320.jpg
No materials found for model: chair/SS_028/model.mtl. Using default materials
../resized_images/chair/2248.jpg


ValueError: No model found for img: ../resized_images/chair/2248.jpg

## Plot Data

In [None]:
import mayplotlib.pyplot as plt

plt.plot(history.history['loss'], label='train_loss')
plt.plot(history.history['val_loss'], label='val_loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()