In [None]:
%pip install git+https://github.com/MarcusLoppe/meshgpt-pytorch.git
%pip install matplotlib
from pathlib import Path 
import gc    
import torch
import os
import torch  
from meshgpt_pytorch import (
    MeshTransformerTrainer,
    MeshAutoencoderTrainer,
    MeshAutoencoder,
    MeshTransformer,MeshDataset
)
from meshgpt_pytorch.data import ( 
    derive_face_edges_from_faces
)   



In [None]:

pkg = torch.load("./16k_autoencoder_229M_0.338.pt", map_location=torch.device("cuda")) 
autoencoder = MeshAutoencoder( 
        decoder_dims_through_depth =  (128,) * 6 + (192,) * 12 + (256,) * 24 + (384,) * 6,    
        dim_codebook = 192,  
        dim_area_embed = 16,
        dim_coor_embed = 16, 
        dim_normal_embed = 16,
        dim_angle_embed = 8,
    
    attn_decoder_depth  = 4,
    attn_encoder_depth = 2
    ).to("cuda")
autoencoder.load_state_dict(pkg['model'])

transformer = MeshTransformer(
    autoencoder,
    dim = 768,
    coarse_pre_gateloop_depth =2,  
    fine_pre_gateloop_depth= 2, 
    attn_depth = 12,  
    attn_heads = 12, 
    fine_cross_attend_text = True,
    text_cond_with_film = False,
    cross_attn_num_mem_kv = 4,
    num_sos_tokens = 1, 
    dropout  = 0.0,
    max_seq_len = 1500, 
    fine_attn_depth = 2,
    condition_on_text = True, 
    gateloop_use_heinsen = False,
    text_condition_model_types = "bge", 
    text_condition_cond_drop_prob = 0.0, 
).to("cuda")

In [None]:
%cd C:/Users/ernest.lee/OneDrive/Desktop/text-to-mesh

In [10]:
import torch.optim as optim
from torch.optim import lr_scheduler

datasets = [
    "labels_885_10x5_21720_mod.npz",
    # "objverse_250f_45.9M_3086_labels_53730_10_min_x1_aug.npz",
    # "objverse_250f_229.7M_3086_labels_268650_10_min_x5_aug.npz"
]

dataset = None  # Initialize dataset variable

for ds in datasets:
    if dataset is None:
        dataset = MeshDataset.load(ds)
    else:
        temp_dataset = MeshDataset.load(ds)
        dataset.data.extend(temp_dataset.data)

dataset.sort_dataset_keys()
batch_size = 16
grad_accum_every = 4

rate = 0.5

trainer = MeshTransformerTrainer(model=transformer, warmup_steps=10, grad_accum_every=grad_accum_every,
                                 num_train_steps=100, dataset=dataset, batch_size=batch_size, learning_rate=rate, checkpoint_every_epoch=1)
EARLY_STOP_LOSS = 0.00005
for epoch in range(35):
    print(f"Learning rate is {rate}")
    optimizer = optim.SGD(trainer.model.parameters(), lr=rate)
    scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
    loss = trainer.train(1, stop_at_loss = EARLY_STOP_LOSS)
    if loss <= EARLY_STOP_LOSS:
        break
    scheduler.step(loss) 
    for param_group in optimizer.param_groups:
        rate = param_group['lr']
    trainer = MeshTransformerTrainer(model=transformer, warmup_steps=10, grad_accum_every=grad_accum_every,
                                    num_train_steps=100, dataset=dataset, learning_rate=rate, batch_size=batch_size, checkpoint_every_epoch=1)



In [None]:
pkg = dict( model = transformer.state_dict(), ) 
torch.save(pkg, str("./MeshGPT-transformer_trained.pt"))