In [1]:
import torch
from pytorch_lightning import Trainer
from torch.utils.data import DataLoader
from smort.data.text_motion_dataset import TextMotionDataset
from smort.models.smort import SMORT
from smort.renderer.matplotlib import SingleMotionRenderer
from smort.rifke import feats_to_joints
from smort.data.collate import length_to_mask
from smort.models.text_encoder import TextToEmb

In [2]:
from smort.data.data_module import InterXDataModule


text_motion_dataset = TextMotionDataset(
    "deps/interx/processed/dataset_2k.h5",
)
train_dataloader = DataLoader(
    text_motion_dataset,
    batch_size=1,
    collate_fn=text_motion_dataset.collate_fn,
    shuffle=True,
    # num_workers=7,
    # persistent_workers=True,
)

data_module = InterXDataModule(
    "deps/interx/processed/dataset_2k.h5",
    batch_size=1,
    num_workers=1,
    use_tiny=True,
    return_scene=True,
)

# next(iter(train_dataloader))

In [3]:
mean, std = text_motion_dataset.get_mean_std()
assert type(mean) == torch.Tensor and type(std) == torch.Tensor
model = SMORT(mean, std)

trainer = Trainer(
    accelerator="cpu", max_epochs=10, fast_dev_run=False, num_sanity_val_steps=0
)

trainer.fit(model, data_module)

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/Users/cogniveon/src/uos/smort/.venv/lib/python3.11/site-packages/pytorch_lightning/trainer/setup.py:177: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.

  | Name                   | Type              | Params | Mode 
---------------------------------------------------------------------
0 | scene_encoder          | ACTORStyleEncoder | 12.8 M | train
1 | text_encoder           | ACTORStyleEncoder | 13.0 M | train
2 | motion_decoder         | ACTORStyleDecoder | 19.0 M | train
3 | reconstruction_loss_fn | SmoothL1Loss      | 0      | train
4 | latent_loss_fn         | SmoothL1Loss      | 0      | train
5 | joint_loss_fn          | JointLoss         | 0      | train
---------------------------------------------------------------------
44.8 M    Trainable params
0         Non-trainable params
44.8 M    Total params
179.246   Total estimated 

Training: |          | 0/? [00:00<?, ?it/s]

> [0;32m/Users/cogniveon/src/uos/smort/smort/models/modules.py[0m(89)[0;36mforward[0;34m()[0m
[0;32m     87 [0;31m        [0mtoken_mask[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mones[0m[0;34m([0m[0;34m([0m[0mbs[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mnbtokens[0m[0;34m)[0m[0;34m,[0m [0mdtype[0m[0;34m=[0m[0mbool[0m[0;34m,[0m [0mdevice[0m[0;34m=[0m[0mdevice[0m[0;34m)[0m  [0;31m# type: ignore[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     88 [0;31m        [0;32mimport[0m [0mpdb[0m[0;34m;[0m [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 89 [0;31m        [0maug_mask[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mcat[0m[0;34m([0m[0;34m([0m[0mtoken_mask[0m[0;34m,[0m [0mmask[0m[0;34m)[0m[0;34m,[0m [0;36m1[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     90 [0;31m[0;34m[0m[0m
[0m[0;32m     91 [0;31m        [0;31m# add positional encoding[0m[0;34m[0m

In [None]:
# # model = SMORT.load_from_checkpoint(
# #     "lightning_logs/version_21/checkpoints/epoch=11-step=1200.ckpt",
# # )
# text_embeds = TextToEmb(
#     "distilbert/distilbert-base-uncased",
#     device=model.device,
# )(
#     [
#         "Two people walk towards each other. "
#         "After they meet, the first person hugs the second person around the "
#         "shoulders, gently patting his/her back with his/her right hand. "
#         "Meanwhile, the second person puts his/her arms around the first "
#         "person's waist and pats his/her waist with his/her right hand."
#     ]
# )
# # text_embeds

# mask = length_to_mask(text_embeds["length"], device=model.device)
# encoded = model.text_encoder(
#     {
#         "x": text_embeds["x"],
#         "mask": mask,
#     }
# )

# dists = encoded.unbind(1)
# mu, logvar = dists
# latent_vectors = mu
# motion = text_motion_dataset.reverse_norm(
#     model.motion_decoder(
#         {
#             "z": latent_vectors,
#             "mask": mask,
#         }
#     ).squeeze(dim=0)
# )

# renderer = SingleMotionRenderer(
#     colors=("red", "red", "red", "red", "red"),
# )

# renderer.render_animation_single(
#     feats_to_joints(torch.from_numpy(motion)).detach().cpu().numpy()
# )