In [1]:
import torch
import random
from torch.utils.data import DataLoader
from smort.data.text_motion_dataset import TextMotionDataset
from smort.data.data_module import InterXDataModule

from smort.models.smort import SMORT
from smort.rifke import feats_to_joints
from smort.data.collate import length_to_mask
from smort.models.text_encoder import TextToEmb

In [2]:
text_motion_dataset = TextMotionDataset(
    "deps/interx/processed/dataset.h5",
)
train_dataloader = DataLoader(
    text_motion_dataset,
    batch_size=1,
    collate_fn=text_motion_dataset.collate_fn,
    shuffle=True,
    # num_workers=7,
    # persistent_workers=True,
)

data_module = InterXDataModule(
    "deps/interx/processed/dataset.h5",
    batch_size=1,
    num_workers=1,
    use_tiny=True,
    return_scene=True,
)
data_module.setup("train")

# next(iter(train_dataloader))

In [3]:
# mean, std = text_motion_dataset.get_mean_std()
# assert type(mean) == torch.Tensor and type(std) == torch.Tensor
# model = SMORT(mean, std)

# trainer = Trainer(
#     accelerator="cpu", max_epochs=10, fast_dev_run=False, num_sanity_val_steps=0
# )

# trainer.fit(model, data_module)

In [4]:
# import wandb

# run = wandb.init()
# artifact = run.use_artifact("rohit-k-kesavan/smort/model-df4n45vn:v0", type="model")
# artifact_dir = artifact.download()

In [22]:
mean, std = text_motion_dataset.get_mean_std()
scene_idx = random.randint(0, 2000)
print(f"Scene: {scene_idx}")
# success: 227, 127
# success: 430, 1726, 253, 642, 424, 1676, 982
# failure: 277, 1155, 292, 613
sample = data_module.dataset.collate_fn([data_module.get_scene(scene_idx)])
model = SMORT.load_from_checkpoint(
    "artifacts/model-df4n45vn:v0/model.ckpt",
    data_mean=mean,
    data_std=std,
)

encoded = model.motion_encoder(sample["actor_x_dict"])

dists = encoded.unbind(1)
mu, logvar = dists
latent_vectors = mu
motion = text_motion_dataset.reverse_norm(
    model.motion_decoder(
        {
            "z": latent_vectors,
            "mask": sample["reactor_x_dict"]["mask"],
        },
        sample["actor_x_dict"],
    ).squeeze(dim=0)
)

from smort.renderer.matplotlib import SceneRenderer

SceneRenderer().render_animation(
    [
        feats_to_joints(torch.from_numpy(motion)),
        feats_to_joints(
            torch.from_numpy(
                text_motion_dataset.reverse_norm(sample["actor_x_dict"]["x"][0])
            )
        ),
    ]
)
SceneRenderer().render_animation(
    [
        feats_to_joints(
            torch.from_numpy(
                text_motion_dataset.reverse_norm(sample["reactor_x_dict"]["x"][0])
            )
        ),
        feats_to_joints(
            torch.from_numpy(
                text_motion_dataset.reverse_norm(sample["actor_x_dict"]["x"][0])
            )
        ),
    ]
)

Scene: 165
