In [1]:
import torch
from pytorch_lightning import Trainer
from torch.utils.data import DataLoader
from smotdm.data.text_motion_dataset import TextMotionDataset
from smotdm.models.smotdm import SMOTDM
from smotdm.renderer.matplotlib import SingleMotionRenderer
from smotdm.rifke import feats_to_joints
from smotdm.data.collate import length_to_mask
from smotdm.models.text_encoder import TextToEmb

In [2]:
text_motion_dataset = TextMotionDataset(
    "deps/interx/processed/dataset.h5",
)
train_dataloader = DataLoader(
    text_motion_dataset,
    batch_size=32,
    collate_fn=text_motion_dataset.collate_fn,
    shuffle=True,
    num_workers=7,
    persistent_workers=True,
)

# next(iter(train_dataloader))

In [3]:
model = SMOTDM(
    vae=True,
)

trainer = Trainer(max_epochs=100)

trainer.fit(model, train_dataloader)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [None]:
# model = SMOTDM.load_from_checkpoint(
#     "lightning_logs/version_21/checkpoints/epoch=11-step=1200.ckpt",
#     vae=True,
# )
text_embeds = TextToEmb("distilbert/distilbert-base-uncased", device=model.device)(
    [
        "Two people walk towards each other. "
        "After they meet, the first person hugs the second person around the shoulders, "
        "gently patting his/her back with his/her right hand. "
        "Meanwhile, the second person puts his/her arms around the first person's waist "
        "and pats his/her waist with his/her right hand."
    ]
)
# text_embeds

mask = length_to_mask(text_embeds["length"], device=model.device)
encoded = model.text_encoder(
    {
        "x": text_embeds["x"],
        "mask": mask,
    }
)


dists = encoded.unbind(1)
mu, logvar = dists
latent_vectors = mu
motion = model.motion_decoder(
    {
        "z": latent_vectors,
        "mask": mask,
    }
).squeeze(dim=0)

In [None]:
mot = text_motion_dataset.reverse_norm(motion.detach().cpu().numpy())

renderer = SingleMotionRenderer(
    colors=("red", "red", "red", "red", "red"),
)

renderer.render_animation_single(
    feats_to_joints(torch.tensor(mot)).detach().cpu().numpy()
)