In [1]:
import torch
from pytorch_lightning import Trainer
from torch.utils.data import DataLoader
from smotdm.data.text_motion_dataset import TextMotionDataset
from smotdm.models.smotdm import SMOTDM
from smotdm.renderer.matplotlib import SingleMotionRenderer
from smotdm.rifke import feats_to_joints
from smotdm.data.collate import length_to_mask
from smotdm.models.text_encoder import TextToEmb

In [2]:
text_motion_dataset = TextMotionDataset(
    "deps/interx/processed/dataset.h5",
)
train_dataloader = DataLoader(
    text_motion_dataset,
    batch_size=1,
    collate_fn=text_motion_dataset.collate_fn,
    shuffle=True,
    num_workers=7,
    persistent_workers=True,
)

# next(iter(train_dataloader))

{'reactor_x_dict': {'x': tensor([[[ 1.8316e+01, -2.0167e-03,  3.4866e-02,  ..., -4.4550e-03,
            -2.1530e+00, -5.1591e-01],
           [ 1.8319e+01, -1.9892e-03,  3.4858e-02,  ..., -1.3289e+00,
             8.7408e-01,  4.9365e+00],
           [ 1.8320e+01, -1.9674e-03,  3.4849e-02,  ..., -2.5023e+00,
            -3.8664e+00, -1.7030e+00],
           ...,
           [ 2.3589e+01, -6.0322e-03,  2.2960e-02,  ...,  4.3576e+00,
             3.0433e+00, -7.5178e-01],
           [ 2.3590e+01, -6.0837e-03,  2.2961e-02,  ...,  5.3177e+00,
             2.6497e+00, -3.0327e+00],
           [ 2.3592e+01, -6.1257e-03,  2.2961e-02,  ...,  6.1875e+00,
             2.2193e+00, -5.1982e+00]]]),
  'length': [600],
  'mask': tensor([[True, True, True, True, True, True, True, True, True, True, True, True,
           True, True, True, True, True, True, True, True, True, True, True, True,
           True, True, True, True, True, True, True, True, True, True, True, True,
           True, True, True,

In [3]:
model = SMOTDM(text_motion_dataset)

trainer = Trainer(max_epochs=1)

trainer.fit(model, train_dataloader)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/Users/cogniveon/src/uos/smotdm/.venv/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:75: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
/Users/cogniveon/src/uos/smotdm/.venv/lib/python3.11/site-packages/pytorch_lightning/trainer/configuration_validator.py:70: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.

  | Name                   | Type              | Params | Mode 
----------------------------------------------------------------

Training: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=1` reached.


In [5]:
# model = SMOTDM.load_from_checkpoint(
#     "lightning_logs/version_21/checkpoints/epoch=11-step=1200.ckpt",
# )
text_embeds = TextToEmb(
    "distilbert/distilbert-base-uncased",
    device=model.device,
)(
    [
        "Two people walk towards each other. "
        "After they meet, the first person hugs the second person around the "
        "shoulders, gently patting his/her back with his/her right hand. "
        "Meanwhile, the second person puts his/her arms around the first "
        "person's waist and pats his/her waist with his/her right hand."
    ]
)
# text_embeds

mask = length_to_mask(text_embeds["length"], device=model.device)
encoded = model.text_encoder(
    {
        "x": text_embeds["x"],
        "mask": mask,
    }
)

dists = encoded.unbind(1)
mu, logvar = dists
latent_vectors = mu
motion = text_motion_dataset.reverse_norm(
    model.motion_decoder(
        {
            "z": latent_vectors,
            "mask": mask,
        }
    ).squeeze(dim=0)
)

renderer = SingleMotionRenderer(
    colors=("red", "red", "red", "red", "red"),
)

renderer.render_animation_single(
    feats_to_joints(torch.from_numpy(motion)).detach().cpu().numpy()
)