In [6]:
from transformers import BertTokenizerFast, BertModel


def tokenize_sentences(sentence, device):
    tokenizer = BertTokenizerFast.from_pretrained("setu4993/LEALLA-small")
    tokenizer_model = BertModel.from_pretrained("setu4993/LEALLA-small").to(device)
    tokenizer_model = tokenizer_model.eval()
    english_inputs = tokenizer([sentence], return_tensors="pt", padding=True, max_length=512, truncation=True).to(device)
    with torch.no_grad():
        english_outputs = tokenizer_model(**english_inputs).pooler_output

    return english_outputs.cpu().numpy()[0]


In [65]:
import os

import torch
from models.vae_gaussian import *
from models.vae_flow import *

ckpt = './pretrained/conditioned/GEN_2023_11_12__12_10_24/ckpt_0.000000_46500.pt'
device = 'mps'
sample_num_points = 3000
batch_size = 2

print('Loading model')
ckpt = torch.load(ckpt, map_location=device)

if ckpt['args'].model == 'gaussian':
    model = GaussianVAE(ckpt['args']).to(device)
elif ckpt['args'].model == 'flow':
    model = FlowVAE(ckpt['args']).to(device)

print('Loading state dict')
model.load_state_dict(ckpt['state_dict'])

def generate_conditioned(text):
    print('Generating 3D model')
    with torch.no_grad():
        encoded_text = tokenize_sentences(text, device)
        encoded_text = torch.tensor(np.resize(encoded_text, (batch_size, encoded_text.shape[0]))).to(device)
    
        z = torch.randn([batch_size, ckpt['args'].latent_dim]).to(device)
        x = model.sample(z, encoded_text, sample_num_points, flexibility=ckpt['args'].flexibility)
    
        res = x.detach().cpu()[0]
    
    return res

Loading model
Loading state dict


In [68]:
import plotly.graph_objects as go
from matplotlib import pyplot as plt

coordinates = generate_conditioned("Table")

# Create a 3D scatter plot
fig = go.Figure()

# Scatter plot of the points
fig.add_trace(go.Scatter3d(
    x=coordinates[:, 0],
    y=coordinates[:, 1],
    z=coordinates[:, 2],
    mode='markers',
    marker=dict(
        size=5,
        color='blue',  # You can customize the color here
    )
))

# Set layout options
fig.update_layout(scene=dict(
    xaxis_title='X-axis',
    yaxis_title='Y-axis',
    zaxis_title='Z-axis'),
    title='Interactive 3D Point Cloud Visualization')

# Show the plot
fig.show()

Generating 3D model
