In [1]:
import torch
from pytorch_lightning import Trainer
from torch.utils.data import DataLoader
from smort.data.text_motion_dataset import TextMotionDataset
from smort.models.smort import SMORT
from smort.renderer.matplotlib import SingleMotionRenderer
from smort.rifke import feats_to_joints
from smort.data.collate import length_to_mask
from smort.models.text_encoder import TextToEmb

In [2]:
from smort.data.data_module import InterXDataModule


text_motion_dataset = TextMotionDataset(
    "deps/interx/processed/dataset_2k.h5",
)
train_dataloader = DataLoader(
    text_motion_dataset,
    batch_size=1,
    collate_fn=text_motion_dataset.collate_fn,
    shuffle=True,
    # num_workers=7,
    # persistent_workers=True,
)

data_module = InterXDataModule("deps/interx/processed/dataset_2k.h5", batch_size=1, num_workers=1)

# next(iter(train_dataloader))

In [3]:
mean, std = text_motion_dataset.get_mean_std()
assert type(mean) == torch.Tensor and type(std) == torch.Tensor
model = SMORT(mean, std)

trainer = Trainer(accelerator="cpu", max_epochs=1, fast_dev_run=True)

trainer.fit(model, data_module)

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/Users/cogniveon/src/uos/smort/.venv/lib/python3.11/site-packages/pytorch_lightning/trainer/setup.py:177: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.
Running in `fast_dev_run` mode: will run the requested loop using 1 batch(es). Logging and checkpointing is suppressed.

  | Name                   | Type                    | Params | Mode 
---------------------------------------------------------------------------
0 | reactor_encoder        | ACTORStyleEncoderWithCA | 5.2 M  | train
1 | text_encoder           | ACTORStyleEncoder       | 4.9 M  | train
2 | actor_encoder          | ACTORStyleEncoderWithCA | 5.2 M  | train
3 | motion_decoder         | ACTORStyleDecoder       | 6.4 M  | train
4 | reconstruction_loss_fn | SmoothL1Loss            | 0      | train
5 | latent_loss_fn         | SmoothL1Loss            | 0      | train
6 | joint

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

  output = torch._nested_tensor_from_mask(output, src_key_padding_mask.logical_not(), mask_check=False)


> [0;32m/Users/cogniveon/src/uos/smort/smort/models/smort.py[0m(266)[0;36mvalidation_step[0;34m()[0m
[0;32m    264 [0;31m            [0mrandom_idx[0m [0;34m=[0m [0mrandom[0m[0;34m.[0m[0mrandint[0m[0;34m([0m[0;36m0[0m[0;34m,[0m [0mbs[0m [0;34m-[0m [0;36m1[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    265 [0;31m            [0;32mimport[0m [0mpdb[0m[0;34m;[0m [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 266 [0;31m            [0mself[0m[0;34m.[0m[0mrender_motion[0m[0;34m([0m[0mjoints[0m[0;34m[[0m[0mrandom_idx[0m[0;34m][0m[0;34m,[0m [0;34m"viz.mp4"[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    267 [0;31m            [0mself[0m[0;34m.[0m[0mrender_motion[0m[0;34m([0m[0mgt_joints[0m[0;34m[[0m[0mrandom_idx[0m[0;34m][0m[0;34m,[0m [0;34m"gt.mp4"[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    268 [0;31m[0;34m[0m[0m
[0m
0
torch.S

`Trainer.fit` stopped: `max_steps=1` reached.


In [4]:
# model = SMORT.load_from_checkpoint(
#     "lightning_logs/version_21/checkpoints/epoch=11-step=1200.ckpt",
# )
text_embeds = TextToEmb(
    "distilbert/distilbert-base-uncased",
    device=model.device,
)(
    [
        "Two people walk towards each other. "
        "After they meet, the first person hugs the second person around the "
        "shoulders, gently patting his/her back with his/her right hand. "
        "Meanwhile, the second person puts his/her arms around the first "
        "person's waist and pats his/her waist with his/her right hand."
    ]
)
# text_embeds

mask = length_to_mask(text_embeds["length"], device=model.device)
encoded = model.text_encoder(
    {
        "x": text_embeds["x"],
        "mask": mask,
    }
)

dists = encoded.unbind(1)
mu, logvar = dists
latent_vectors = mu
motion = text_motion_dataset.reverse_norm(
    model.motion_decoder(
        {
            "z": latent_vectors,
            "mask": mask,
        }
    ).squeeze(dim=0)
)

renderer = SingleMotionRenderer(
    colors=("red", "red", "red", "red", "red"),
)

renderer.render_animation_single(
    feats_to_joints(torch.from_numpy(motion)).detach().cpu().numpy()
)