In [6]:
!yes | pip install git+https://github.com/MarcusLoppe/classifier-free-guidance-pytorch.git
!yes | pip install -q git+https://github.com/MarcusLoppe/meshgpt-pytorch.git
!yes | pip install trimesh

Collecting git+https://github.com/MarcusLoppe/classifier-free-guidance-pytorch.git
  Cloning https://github.com/MarcusLoppe/classifier-free-guidance-pytorch.git to /tmp/pip-req-build-urmhpaon
  Running command git clone --filter=blob:none --quiet https://github.com/MarcusLoppe/classifier-free-guidance-pytorch.git /tmp/pip-req-build-urmhpaon
  Resolved https://github.com/MarcusLoppe/classifier-free-guidance-pytorch.git to commit 5c189f32dbc20cd5882be4ef2a132e2aabcb8df5
  Preparing metadata (setup.py) ... [?25l[?25hdone
  Preparing metadata (setup.py) ... [?25l[?25hdone


In [7]:
import torch
import trimesh
import numpy as np
import os
import csv
from collections import OrderedDict

from meshgpt_pytorch import (
    MeshTransformerTrainer,
    MeshAutoencoderTrainer,
    MeshAutoencoder,
    MeshTransformer
)
from meshgpt_pytorch.data import (
    derive_face_edges_from_faces
)

def get_mesh(file_path):
    mesh = trimesh.load(file_path, force='mesh')

    vertices = mesh.vertices.tolist()
    faces = mesh.faces.tolist()
    # Center
    centered_vertices = vertices - np.mean(vertices, axis=0)

    # Calculate the padding and scale factor numerically
    padding_fraction = 1 / 128                    # 0.0078125
    total_padding = padding_fraction * 2          # 0.015625
    scale_factor = 1 - total_padding              # 0.984375

    max_abs = np.max(np.abs(centered_vertices))
    vertices = centered_vertices / (max_abs / scale_factor)

    # Sort by Z, Y,X where Z is vertical
    def sort_vertices(vertex):
        return vertex[2], vertex[1], vertex[0]

    seen = OrderedDict()
    for point in vertices:
      key = tuple(point)
      if key not in seen:
        seen[key] = point

    unique_vertices =  list(seen.values())
    sorted_vertices = sorted(unique_vertices, key=sort_vertices)

    vertices_as_tuples = [tuple(v) for v in vertices]
    sorted_vertices_as_tuples = [tuple(v) for v in sorted_vertices]

    vertex_map = {old_index: new_index for old_index, vertex_tuple in enumerate(vertices_as_tuples) for new_index, sorted_vertex_tuple in enumerate(sorted_vertices_as_tuples) if vertex_tuple == sorted_vertex_tuple}
    reindexed_faces = [[vertex_map[face[0]], vertex_map[face[1]], vertex_map[face[2]]] for face in faces]
    sorted_faces = [sorted(sub_arr) for sub_arr in reindexed_faces]

    return np.array(sorted_vertices), np.array(sorted_faces)


def augment_mesh(vertices, scale_factor):
    jitter_factor=0.01
    possible_values = np.arange(-jitter_factor, jitter_factor , 0.0005)
    offsets = np.random.choice(possible_values, size=vertices.shape)
    vertices = vertices + offsets

    vertices = vertices * scale_factor
    return vertices


def load_filename(directory, variations):
    obj_datas = []
    chosen_models_count = {}
    possible_values = np.arange(0.75, 1.0 , 0.005)
    scale_factors = np.random.choice(possible_values, size=variations)

    for filename in os.listdir(directory):
        if filename.endswith((".glb")):
            file_path = os.path.join(directory, filename)
            vertices, faces = get_mesh(file_path)

            faces = torch.tensor(faces.tolist(), dtype=torch.long).to("cuda")
            face_edges =  derive_face_edges_from_faces(faces)
            texts, ext = os.path.splitext(filename)

            for scale_factor in scale_factors:
                aug_vertices = augment_mesh(vertices.copy(), scale_factor)
                obj_data = {"vertices": torch.tensor(aug_vertices.tolist(), dtype=torch.float).to("cuda"), "faces":  faces, "face_edges" : face_edges, "texts": texts }
                obj_datas.append(obj_data)

    print(f"[create_mesh_dataset] Returning {len(obj_data)} meshes")
    return obj_datas

In [8]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [23]:
from pathlib import Path
import gc
import torch
import os
from meshgpt_pytorch import MeshDataset

torch.cuda.empty_cache()
gc.collect()


project_name = "demo_mesh"

working_dir = f'/content/drive/MyDrive/{project_name}'

working_dir = Path(working_dir)
working_dir.mkdir(exist_ok = True, parents = True)
dataset_path = working_dir / (project_name + ".npz")


if not os.path.isfile(dataset_path):
    data = load_filename("/content/drive/MyDrive/demo_mesh",10)
    dataset = MeshDataset(data)
    dataset.generate_face_edges()
    print(set(item["texts"] for item in dataset.data)  )
    dataset.save(dataset_path)

dataset = MeshDataset.load(dataset_path)
print(dataset.data[0].keys())


[MeshDataset] Loaded 40 entrys
[MeshDataset] Created from 40 entrys
dict_keys(['vertices', 'faces', 'face_edges', 'texts'])


### Inspect

In [25]:
from pathlib import Path

folder = f"{working_dir}/renders"
obj_file_path = Path(folder)
obj_file_path.mkdir(exist_ok = True, parents = True)

all_vertices = []
all_faces = []
vertex_offset = 0
translation_distance = 0.5

for r, item in enumerate(data):
    vertices_copy =  np.copy(item['vertices'].cpu())
    vertices_copy += translation_distance * (r / 0.2 - 1)

    for vert in vertices_copy:
        vertex = vert
        all_vertices.append(f"v {float(vertex[0])}  {float(vertex[1])}  {float(vertex[2])}\n")
    for face in item['faces']:
        all_faces.append(f"f {face[0]+1+ vertex_offset} {face[1]+ 1+vertex_offset} {face[2]+ 1+vertex_offset}\n")
    vertex_offset = len(all_vertices)

obj_file_content = "".join(all_vertices) + "".join(all_faces)

obj_file_path = f'{folder}/3d_models_inspect.obj'

with open(obj_file_path, "w") as file:
    file.write(obj_file_content)


### Train!

In [28]:
autoencoder = MeshAutoencoder().to("cuda")

**Have at least 400-2000 items in the dataset, use this to multiply the dataset**  

In [29]:
dataset.data = [dict(d) for d in dataset.data] * 10
print(len(dataset.data))

400


*Load previous saved model if you had to restart session*

In [27]:
autoencoder_trainer = MeshAutoencoderTrainer(model =autoencoder ,warmup_steps = 10, dataset = dataset, num_train_steps=100, batch_size=8,  grad_accum_every=1, learning_rate = 1e-4)
autoencoder_trainer.load(f'{working_dir}/mesh-encoder_{project_name}.pt')
autencoder = autoencoder_trainer.model
for param in autoencoder.parameters():
    param.requires_grad = True

AssertionError: 

**Train to about 0.3 loss if you are using a small dataset**

In [None]:
autoencoder_trainer = MeshAutoencoderTrainer(model =autoencoder ,warmup_steps = 10, dataset = dataset, num_train_steps=100,
                                             batch_size=8,
                                             grad_accum_every=2,
                                             learning_rate = 4e-3)
loss = autoencoder_trainer.train(280,stop_at_loss = 0.28, diplay_graph= True)
autoencoder_trainer.save(f'{working_dir}/mesh-encoder_{project_name}.pt')

Epoch 1/280: 100%|██████████| 50/50 [00:03<00:00, 14.11it/s, commit_loss=-.319, loss=0.505, recon_loss=0.537]


Epoch 1 average loss: 0.8432329672574997 recon loss: 0.8585: commit_loss -0.1522


Epoch 2/280: 100%|██████████| 50/50 [00:03<00:00, 14.01it/s, commit_loss=-.174, loss=1.17, recon_loss=1.18]


Epoch 2 average loss: 0.8369616591930389 recon loss: 0.8540: commit_loss -0.1706


Epoch 3/280: 100%|██████████| 50/50 [00:04<00:00, 10.04it/s, commit_loss=-.194, loss=0.605, recon_loss=0.625]


Epoch 3 average loss: 0.8208733612298965 recon loss: 0.8370: commit_loss -0.1611


Epoch 4/280: 100%|██████████| 50/50 [00:04<00:00, 11.66it/s, commit_loss=-.286, loss=0.734, recon_loss=0.763]


Epoch 4 average loss: 0.8308381533622742 recon loss: 0.8489: commit_loss -0.1801          avg loss speed: 0.0028511758645375362 epochs left: 186.18


Epoch 5/280: 100%|██████████| 50/50 [00:03<00:00, 13.51it/s, commit_loss=-.0912, loss=0.914, recon_loss=0.923]


Epoch 5 average loss: 0.8474435269832611 recon loss: 0.8664: commit_loss -0.1894          avg loss speed: -0.017885802388191152


Epoch 6/280: 100%|██████████| 50/50 [00:03<00:00, 12.92it/s, commit_loss=-.141, loss=1.08, recon_loss=1.09]


Epoch 6 average loss: 0.8499523031711579 recon loss: 0.8653: commit_loss -0.1537          avg loss speed: -0.016900622646013885


Epoch 7/280: 100%|██████████| 50/50 [00:05<00:00,  9.94it/s, commit_loss=-.0718, loss=0.958, recon_loss=0.966]


Epoch 7 average loss: 0.8294972753524781 recon loss: 0.8465: commit_loss -0.1702          avg loss speed: 0.013247385819752933 epochs left: 39.97


Epoch 8/280: 100%|██████████| 50/50 [00:04<00:00, 12.40it/s, commit_loss=-.158, loss=0.868, recon_loss=0.884]


Epoch 8 average loss: 0.8366289049386978 recon loss: 0.8536: commit_loss -0.1693          avg loss speed: 0.005668796896934447 epochs left: 94.66


Epoch 9/280: 100%|██████████| 50/50 [00:03<00:00, 14.17it/s, commit_loss=-.21, loss=1.36, recon_loss=1.38]


Epoch 9 average loss: 0.8474504554271698 recon loss: 0.8649: commit_loss -0.1743          avg loss speed: -0.00875762760639176


Epoch 10/280: 100%|██████████| 50/50 [00:04<00:00, 11.18it/s, commit_loss=-.273, loss=1.02, recon_loss=1.04]


Epoch 10 average loss: 0.8503154933452606 recon loss: 0.8678: commit_loss -0.1749          avg loss speed: -0.012456614772478614


Epoch 11/280: 100%|██████████| 50/50 [00:03<00:00, 13.55it/s, commit_loss=-.188, loss=1.19, recon_loss=1.21]


Epoch 11 average loss: 0.818656479716301 recon loss: 0.8369: commit_loss -0.1824          avg loss speed: 0.026141804854075135 epochs left: 19.84


Epoch 12/280: 100%|██████████| 50/50 [00:03<00:00, 14.00it/s, commit_loss=-.199, loss=0.94, recon_loss=0.96]


Epoch 12 average loss: 0.7962182116508484 recon loss: 0.8140: commit_loss -0.1776          avg loss speed: 0.04258926451206202 epochs left: 11.65


Epoch 13/280: 100%|██████████| 50/50 [00:03<00:00, 13.88it/s, commit_loss=-.216, loss=0.907, recon_loss=0.929]


Epoch 13 average loss: 0.8037458086013793 recon loss: 0.8226: commit_loss -0.1881          avg loss speed: 0.017984252969423964 epochs left: 28.01


Epoch 14/280: 100%|██████████| 50/50 [00:05<00:00,  9.11it/s, commit_loss=-.264, loss=0.669, recon_loss=0.695]


Epoch 14 average loss: 0.82264264523983 recon loss: 0.8413: commit_loss -0.1868          avg loss speed: -0.016435811916987175


Epoch 15/280: 100%|██████████| 50/50 [00:04<00:00, 10.30it/s, commit_loss=-.34, loss=0.498, recon_loss=0.532]


Epoch 15 average loss: 0.8214196532964706 recon loss: 0.8391: commit_loss -0.1767          avg loss speed: -0.013884098132451284


Epoch 16/280: 100%|██████████| 50/50 [00:04<00:00, 11.73it/s, commit_loss=-.164, loss=0.492, recon_loss=0.509]


Epoch 16 average loss: 0.7915117239952087 recon loss: 0.8106: commit_loss -0.1911          avg loss speed: 0.024424311717351377 epochs left: 20.12


Epoch 17/280: 100%|██████████| 50/50 [00:07<00:00,  6.71it/s, commit_loss=-.186, loss=0.834, recon_loss=0.853]


Epoch 17 average loss: 0.82513925075531 recon loss: 0.8450: commit_loss -0.1982          avg loss speed: -0.013281243244807062


Epoch 18/280: 100%|██████████| 50/50 [00:04<00:00, 11.55it/s, commit_loss=-.161, loss=0.671, recon_loss=0.687]


Epoch 18 average loss: 0.8113855820894241 recon loss: 0.8306: commit_loss -0.1921          avg loss speed: 0.0013046272595722552 epochs left: 391.98


Epoch 19/280: 100%|██████████| 50/50 [00:03<00:00, 13.32it/s, commit_loss=-.369, loss=0.571, recon_loss=0.608]


Epoch 19 average loss: 0.8258328437805176 recon loss: 0.8448: commit_loss -0.1893          avg loss speed: -0.016487324833869832


Epoch 20/280: 100%|██████████| 50/50 [00:04<00:00, 11.12it/s, commit_loss=-.209, loss=0.788, recon_loss=0.809]


Epoch 20 average loss: 0.8238621890544892 recon loss: 0.8418: commit_loss -0.1796          avg loss speed: -0.003076296846071913


Epoch 21/280: 100%|██████████| 50/50 [00:03<00:00, 14.02it/s, commit_loss=-.194, loss=0.848, recon_loss=0.868]


Epoch 21 average loss: 0.8278667843341827 recon loss: 0.8466: commit_loss -0.1874          avg loss speed: -0.007506579359372378


Epoch 22/280: 100%|██████████| 50/50 [00:03<00:00, 14.37it/s, commit_loss=-.23, loss=0.908, recon_loss=0.931]


Epoch 22 average loss: 0.9118725955486298 recon loss: 0.9308: commit_loss -0.1895          avg loss speed: -0.08601865649223328


Epoch 23/280: 100%|██████████| 50/50 [00:03<00:00, 12.59it/s, commit_loss=-.139, loss=0.81, recon_loss=0.824]


Epoch 23 average loss: 0.8489308542013169 recon loss: 0.8639: commit_loss -0.1495          avg loss speed: 0.0056030021111169726 epochs left: 97.97


Epoch 24/280: 100%|██████████| 50/50 [00:04<00:00, 12.05it/s, commit_loss=-.14, loss=0.81, recon_loss=0.824]


Epoch 24 average loss: 0.820583564043045 recon loss: 0.8386: commit_loss -0.1805          avg loss speed: 0.042306513984998184 epochs left: 12.31


Epoch 25/280: 100%|██████████| 50/50 [00:03<00:00, 14.44it/s, commit_loss=-.103, loss=0.871, recon_loss=0.881]


Epoch 25 average loss: 0.7937628495693206 recon loss: 0.8117: commit_loss -0.1797          avg loss speed: 0.06669948836167672 epochs left: 7.40


Epoch 26/280: 100%|██████████| 50/50 [00:03<00:00, 14.03it/s, commit_loss=-.12, loss=0.802, recon_loss=0.814]


Epoch 26 average loss: 0.8060168963670731 recon loss: 0.8235: commit_loss -0.1745          avg loss speed: 0.015075526237487757 epochs left: 33.57


Epoch 27/280: 100%|██████████| 50/50 [00:04<00:00, 12.22it/s, commit_loss=-.161, loss=0.966, recon_loss=0.982]


Epoch 27 average loss: 0.7903568947315216 recon loss: 0.8086: commit_loss -0.1826          avg loss speed: 0.016430875261624567 epochs left: 29.84


Epoch 28/280: 100%|██████████| 50/50 [00:04<00:00, 12.49it/s, commit_loss=-.333, loss=0.41, recon_loss=0.443]


Epoch 28 average loss: 0.7615941172838211 recon loss: 0.7793: commit_loss -0.1768          avg loss speed: 0.03511809627215068 epochs left: 13.14


Epoch 29/280: 100%|██████████| 50/50 [00:03<00:00, 14.16it/s, commit_loss=-.251, loss=0.716, recon_loss=0.741]


Epoch 29 average loss: 0.677488683462143 recon loss: 0.6975: commit_loss -0.2005          avg loss speed: 0.10850061933199562 epochs left: 3.48


Epoch 30/280: 100%|██████████| 50/50 [00:03<00:00, 14.31it/s, commit_loss=-.13, loss=0.589, recon_loss=0.602]


Epoch 30 average loss: 0.6526379525661469 recon loss: 0.6702: commit_loss -0.1757          avg loss speed: 0.09050861259301501 epochs left: 3.90


Epoch 31/280: 100%|██████████| 50/50 [00:04<00:00, 11.67it/s, commit_loss=-.23, loss=0.956, recon_loss=0.979]


Epoch 31 average loss: 0.5932538884878159 recon loss: 0.6113: commit_loss -0.1809          avg loss speed: 0.10398636261622107 epochs left: 2.82


Epoch 32/280: 100%|██████████| 50/50 [00:03<00:00, 12.91it/s, commit_loss=-.135, loss=0.572, recon_loss=0.585]


Epoch 32 average loss: 0.5874183249473571 recon loss: 0.6061: commit_loss -0.1863          avg loss speed: 0.053708516558011565 epochs left: 5.35


Epoch 33/280: 100%|██████████| 50/50 [00:03<00:00, 12.52it/s, commit_loss=-.0946, loss=0.597, recon_loss=0.606]


Epoch 33 average loss: 0.5816872483491897 recon loss: 0.6007: commit_loss -0.1900          avg loss speed: 0.02941614031791684 epochs left: 9.58


Epoch 34/280: 100%|██████████| 50/50 [00:03<00:00, 14.06it/s, commit_loss=-.207, loss=0.604, recon_loss=0.625]


Epoch 34 average loss: 0.5519315600395203 recon loss: 0.5715: commit_loss -0.1954          avg loss speed: 0.035521593888600655 epochs left: 7.09


Epoch 35/280: 100%|██████████| 50/50 [00:04<00:00, 10.96it/s, commit_loss=-.159, loss=0.656, recon_loss=0.672]


Epoch 35 average loss: 0.5112252014875412 recon loss: 0.5296: commit_loss -0.1842          avg loss speed: 0.06245384295781442 epochs left: 3.38


Epoch 36/280: 100%|██████████| 50/50 [00:03<00:00, 14.10it/s, commit_loss=-.238, loss=0.553, recon_loss=0.577]


Epoch 36 average loss: 0.48497660756111144 recon loss: 0.5038: commit_loss -0.1884          avg loss speed: 0.06330472906430568 epochs left: 2.92


Epoch 37/280: 100%|██████████| 50/50 [00:03<00:00, 13.97it/s, commit_loss=-.209, loss=0.404, recon_loss=0.425]


Epoch 37 average loss: 0.47614225149154665 recon loss: 0.4955: commit_loss -0.1935          avg loss speed: 0.03990220487117768 epochs left: 4.41


Epoch 38/280: 100%|██████████| 50/50 [00:03<00:00, 12.94it/s, commit_loss=-.148, loss=0.568, recon_loss=0.583]


Epoch 38 average loss: 0.446226099729538 recon loss: 0.4653: commit_loss -0.1905          avg loss speed: 0.04455525378386177 epochs left: 3.28


Epoch 39/280: 100%|██████████| 50/50 [00:04<00:00, 11.44it/s, commit_loss=-.247, loss=0.365, recon_loss=0.39]


Epoch 39 average loss: 0.4129255497455597 recon loss: 0.4316: commit_loss -0.1866          avg loss speed: 0.05618943651517233 epochs left: 2.01


Epoch 40/280: 100%|██████████| 50/50 [00:03<00:00, 14.19it/s, commit_loss=-.36, loss=0.303, recon_loss=0.339]


Epoch 40 average loss: 0.4071366423368454 recon loss: 0.4271: commit_loss -0.1998          avg loss speed: 0.037961324652036055 epochs left: 2.82


Epoch 41/280: 100%|██████████| 50/50 [00:03<00:00, 13.85it/s, commit_loss=-.122, loss=0.339, recon_loss=0.352]


Epoch 41 average loss: 0.38592670261859896 recon loss: 0.4061: commit_loss -0.2020          avg loss speed: 0.03616939465204877 epochs left: 2.38


Epoch 42/280: 100%|██████████| 50/50 [00:04<00:00, 12.25it/s, commit_loss=-.198, loss=0.421, recon_loss=0.44]


Epoch 42 average loss: 0.36275006413459776 recon loss: 0.3824: commit_loss -0.1963          avg loss speed: 0.03924623409907024 epochs left: 1.60


Epoch 43/280: 100%|██████████| 50/50 [00:04<00:00, 11.98it/s, commit_loss=-.274, loss=0.307, recon_loss=0.335]


Epoch 43 average loss: 0.3377884659171104 recon loss: 0.3586: commit_loss -0.2084          avg loss speed: 0.04748267044623694 epochs left: 0.80


Epoch 44/280: 100%|██████████| 50/50 [00:03<00:00, 14.32it/s, commit_loss=-.181, loss=0.416, recon_loss=0.434]


Epoch 44 average loss: 0.3593804454803467 recon loss: 0.3813: commit_loss -0.2196          avg loss speed: 0.002774632076422312 epochs left: 21.40


Epoch 45/280: 100%|██████████| 50/50 [00:03<00:00, 13.94it/s, commit_loss=-.0459, loss=0.442, recon_loss=0.446]


Epoch 45 average loss: 0.5032330232858658 recon loss: 0.5228: commit_loss -0.1955          avg loss speed: -0.14992669810851422


Epoch 46/280: 100%|██████████| 50/50 [00:04<00:00, 11.56it/s, commit_loss=-.319, loss=0.331, recon_loss=0.363]


Epoch 46 average loss: 0.38060089349746706 recon loss: 0.3990: commit_loss -0.1845          avg loss speed: 0.019533084730307293 epochs left: 4.13


Epoch 47/280: 100%|██████████| 50/50 [00:03<00:00, 13.04it/s, commit_loss=-.199, loss=0.331, recon_loss=0.351]


Epoch 47 average loss: 0.3384891849756241 recon loss: 0.3576: commit_loss -0.1914          avg loss speed: 0.07591560244560241 epochs left: 0.51


Epoch 48/280: 100%|██████████| 50/50 [00:03<00:00, 13.83it/s, commit_loss=-.211, loss=0.31, recon_loss=0.331]


Epoch 48 average loss: 0.3199825304746628 recon loss: 0.3384: commit_loss -0.1838          avg loss speed: 0.08745850344498957 epochs left: 0.23


Epoch 49/280: 100%|██████████| 50/50 [00:03<00:00, 14.19it/s, commit_loss=-.0946, loss=0.306, recon_loss=0.316]


Epoch 49 average loss: 0.307504768371582 recon loss: 0.3282: commit_loss -0.2074          avg loss speed: 0.03885276794433595 epochs left: 0.19


Epoch 50/280: 100%|██████████| 50/50 [00:04<00:00, 11.02it/s, commit_loss=-.124, loss=0.307, recon_loss=0.319]


Epoch 50 average loss: 0.30487499475479124 recon loss: 0.3261: commit_loss -0.2125          avg loss speed: 0.01711716651916506 epochs left: 0.28


Epoch 51/280: 100%|██████████| 50/50 [00:03<00:00, 13.73it/s, commit_loss=-.221, loss=0.298, recon_loss=0.32]


Epoch 51 average loss: 0.2999150338768959 recon loss: 0.3219: commit_loss -0.2199          avg loss speed: 0.010872397323449468 epochs left: 0.01


Epoch 52/280:  44%|████▍     | 22/50 [00:02<00:02, 11.42it/s, commit_loss=-.13, loss=0.298, recon_loss=0.311] 

In [None]:
import gc
torch.cuda.empty_cache()
gc.collect()

max_length =  max(len(d["faces"]) for d in dataset if "faces" in d)
max_seq = max_length * 6
print("Highest face count:" , max_length)
print("Max token sequence:" , max_seq)

transformer = MeshTransformer(
    autoencoder,
    dim = 512,
    coarse_pre_gateloop_depth = 6, # Better performance using more gateloop layers
    fine_pre_gateloop_depth= 4,
    #attn_depth = 24, # GPT-2 medium have 24 layer depth, change if needed
    max_seq_len = max_seq,
    condition_on_text = True,
    gateloop_use_heinsen = False,
    text_condition_model_types = "bge", ## Change or remove this line if you are using:  https://github.com/MarcusLoppe/classifier-free-guidance-pytorch
    text_condition_cond_drop_prob = 0.0
)

total_params = sum(p.numel() for p in transformer.decoder.parameters())
total_params = f"{total_params / 1000000:.1f}M"
print(f"Decoder total parameters: {total_params}")
total_params = sum(p.numel() for p in transformer.parameters())
total_params = f"{total_params / 1000000:.1f}M"
print(f"Total parameters: {total_params}")

## **Required!**, embed the text and run generate_codes to save 4-96 GB VRAM (dependant on dataset) ##

**If you don't;** <br>
During each during each training step the autoencoder will generate the codes and the text encoder will embed the text.
<br>
After these fields are generate: **they will be deleted and next time it generates the code again:**<br>

This is due to the dataloaders nature, it writes this information to a temporary COPY of the dataset


In [None]:
labels = set(item["texts"] for item in dataset.data)
print(labels)
dataset.embed_texts(transformer)
dataset.generate_codes(autoencoder)
print(dataset.data[0].keys())

*Load previous saved model if you had to restart session*

In [None]:
trainer = MeshTransformerTrainer(model = transformer,warmup_steps = 10,grad_accum_every=1,num_train_steps=100, dataset = dataset, learning_rate = 1e-1, batch_size=2)
trainer.load(f'{working_dir}/mesh-transformer_{project_name}.pt')
transformer = trainer.model

**Train to about 0.0001 loss (or less) if you are using a small dataset**

In [None]:
trainer = MeshTransformerTrainer(model = transformer,warmup_steps = 10,grad_accum_every=4,num_train_steps=100, dataset = dataset,
                                 learning_rate = 1e-3, batch_size=2)
loss = trainer.train(100, stop_at_loss = 0.009)

trainer = MeshTransformerTrainer(model = transformer,warmup_steps = 10,grad_accum_every=4,num_train_steps=100, dataset = dataset,
                                 learning_rate = 5e-4, batch_size=2)
loss = trainer.train(200, stop_at_loss = 0.00001)

In [None]:

trainer.save(f'{working_dir}/mesh-transformer_{project_name}.pt')

## Generate and view mesh

In [None]:
def combind_mesh(path, mesh):
    all_vertices = []
    all_faces = []
    vertex_offset = 0
    translation_distance = 0.5

    for r, faces_coordinates in enumerate(mesh):
        numpy_data = faces_coordinates[0].cpu().numpy().reshape(-1, 3)

        for vertex in numpy_data:
            all_vertices.append(f"v {vertex[0]} {vertex[1]} {vertex[2]}\n")

        for i in range(1, len(numpy_data), 3):
            all_faces.append(f"f {i + vertex_offset} {i + 1 + vertex_offset} {i + 2 + vertex_offset}\n")

        vertex_offset += len(numpy_data)

    obj_file_content = "".join(all_vertices) + "".join(all_faces)

    with open(path , "w") as file:
        file.write(obj_file_content)

def combind_mesh_with_rows(path, meshes):
    all_vertices = []
    all_faces = []
    vertex_offset = 0
    translation_distance = 0.5

    for row, mesh in enumerate(meshes):
        for r, faces_coordinates in enumerate(mesh):
            numpy_data = faces_coordinates[0].cpu().numpy().reshape(-1, 3)
            numpy_data[:, 0] += translation_distance * (r / 0.2 - 1)
            numpy_data[:, 2] += translation_distance * (row / 0.2 - 1)

            for vertex in numpy_data:
                all_vertices.append(f"v {vertex[0]} {vertex[1]} {vertex[2]}\n")

            for i in range(1, len(numpy_data), 3):
                all_faces.append(f"f {i + vertex_offset} {i + 1 + vertex_offset} {i + 2 + vertex_offset}\n")

            vertex_offset += len(numpy_data)

        obj_file_content = "".join(all_vertices) + "".join(all_faces)

    with open(path , "w") as file:
        file.write(obj_file_content)


def write_mesh_output(path, coords):
    numpy_data = faces_coordinates[0].cpu().numpy().reshape(-1, 3)
    obj_file_content = ""

    for vertex in numpy_data:
        obj_file_content += f"v {vertex[0]} {vertex[1]} {vertex[2]}\n"

    for i in range(1, len(numpy_data), 3):
        obj_file_content += f"f {i} {i + 1} {i + 2}\n"

    with open(path, "w") as file:
        file.write(obj_file_content)


**Using only text**

In [None]:

from pathlib import Path

folder = working_dir / 'renders'
obj_file_path = Path(folder)
obj_file_path.mkdir(exist_ok = True, parents = True)

text_coords = []
for text in labels:
    print(f"Generating {text}")
    faces_coordinates = transformer.generate(texts = [text],  temperature = 0.0)
    text_coords.append(faces_coordinates)

    write_mesh_output(f'{folder}/3d_output_{text}.obj', faces_coordinates)


combind_mesh(f'{folder}/3d_models_all.obj', text_coords)

**Text + prompt of tokens**

Grab fresh copy of dataset

In [None]:
dataset = MeshDataset.load(dataset_path)
dataset.generate_codes(autoencoder)

**Prompt with 10% of codes/tokens**

In [None]:
from pathlib import Path
token_length_procent = 0.10
codes = []
texts = []
for label in labels:
    for item in dataset.data:
        if item['texts'] == label:
            num_tokens = int(item["codes"].shape[0] * token_length_procent)

            texts.append(item['texts'])
            codes.append(item["codes"].flatten()[:num_tokens].unsqueeze(0))
            break

folder = working_dir / f'renders/text+codes'
obj_file_path = Path(folder)
obj_file_path.mkdir(exist_ok = True, parents = True)

coords = []



for text, prompt in zip(texts, codes):
    print(f"Generating {text} with {prompt.shape[1]} tokens")
    faces_coordinates = transformer.generate(texts = [text],  prompt = prompt, temperature = 0)
    coords.append(faces_coordinates)

    obj_file_path = f'{folder}/{text}_{prompt.shape[1]}_tokens.obj'
    write_mesh_output(obj_file_path, faces_coordinates)

    print(obj_file_path)


combind_mesh(f'{folder}/text+prompt_all.obj', coords)

if text_coords is not None:
    combind_mesh_with_rows(f'{folder}/both_verisons.obj', [text_coords , coords])

**Prompt with 0% to 80% of tokens**

In [None]:
from pathlib import Path

folder = working_dir / f'renders/text+codes_rows'
obj_file_path = Path(folder)
obj_file_path.mkdir(exist_ok = True, parents = True)

mesh_rows = []
for token_length_procent in np.arange(0, 0.8, 0.1):
    codes = []
    texts = []
    for label in labels:
        for item in dataset.data:
            if item['texts'] == label:
                num_tokens = int(item["codes"].shape[0] * token_length_procent)

                texts.append(item['texts'])
                codes.append(item["codes"].flatten()[:num_tokens].unsqueeze(0))
                break

    coords = []
    for text, prompt in zip(texts, codes):

        print(f"Generating {text} with {prompt.shape[1]} tokens")
        faces_coordinates = transformer.generate(texts = [text],  prompt = prompt, temperature = 0)
        coords.append(faces_coordinates)

        obj_file_path = f'{folder}/{text}_{prompt.shape[1]}_tokens.obj'
        write_mesh_output(obj_file_path, coords)
        print(obj_file_path)


    mesh_rows.append(coords)
    combind_mesh(f'{folder}/text+prompt_all_{token_length_procent}.obj', coords)

combind_mesh_with_rows(f'{folder}/all.obj', mesh_rows)


**Just some testing for text embedding similarity**

In [None]:
import numpy as np
texts = list(labels)
vectors = [transformer.conditioner.text_models[0].embed_text([text], return_text_encodings = False).cpu().flatten() for text in texts]

max_label_length = max(len(text) for text in texts)

# Print the table header
print(f"{'Text':<{max_label_length}} |", end=" ")
for text in texts:
    print(f"{text:<{max_label_length}} |", end=" ")
print()

# Print the similarity matrix as a table with fixed-length columns
for i in range(len(texts)):
    print(f"{texts[i]:<{max_label_length}} |", end=" ")
    for j in range(len(texts)):
        # Encode the texts and calculate cosine similarity manually
        vector_i = vectors[i]
        vector_j = vectors[j]

        dot_product = torch.sum(vector_i * vector_j)
        norm_vector1 = torch.norm(vector_i)
        norm_vector2 = torch.norm(vector_j)
        similarity_score = dot_product / (norm_vector1 * norm_vector2)

        # Print with fixed-length columns
        print(f"{similarity_score.item():<{max_label_length}.4f} |", end=" ")
    print()