## Setup

In [1]:
%cd ..

/Users/jason/repos/diffusion-motion-inbetweening


In [2]:
import os

import torch
from torchinfo import summary

from data_loaders.get_data import DatasetConfig, get_dataset_loader
from model.cfg_sampler import ClassifierFreeSampleModel
from utils import dist_util
from utils.model_util import create_model_and_diffusion, load_saved_model

  import scipy.ndimage.filters as filters


In [3]:
from argparse import Namespace

# Dumped from a previous run
args = Namespace(
    edit_mode="benchmark_sparse",
    transition_length=20,
    n_keyframes=5,
    editable_features="pos_rot_vel",
    text_condition="",
    imputate=False,
    replacement_distribution="conditional",
    reconstruction_guidance=False,
    reconstruction_weight=5.0,
    gradient_schedule=None,
    cutoff_point=0,
    stop_imputation_at=0,
    stop_recguidance_at=0,
    use_fixed_dataset=False,
    use_fixed_subset=False,
    no_text=False,
    motion_length=11.2,
    motion_length_cut=6.0,
    input_text="",
    action_file="",
    text_prompt="",
    action_name="",
    model_path="./save/condmdi_random_frames/model000750000.pt",
    output_dir="",
    num_samples=10,
    num_repetitions=3,
    guidance_param=2.5,
    keyframe_guidance_param=1.0,
    save_dir="save/nm4d9951",
    overwrite=False,
    batch_size=64,
    train_platform_type="NoPlatform",
    lr=0.0001,
    weight_decay=0.01,
    grad_clip=1.0,
    use_fp16=True,
    avg_model_beta=0.9999,
    adam_beta2=0.999,
    lr_anneal_steps=0,
    eval_batch_size=32,
    eval_split="test",
    eval_during_training=False,
    eval_rep_times=3,
    eval_num_samples=1000,
    log_interval=1000,
    save_interval=50000,
    num_steps=3000000,
    num_frames=224,
    resume_checkpoint="save/nm4d9951/model000850000.pt",
    apply_zero_mask=False,
    traj_extra_weight=1.0,
    time_weighted_loss=False,
    train_x0_as_eps=False,
    noise_schedule="cosine",
    diffusion_steps=1000,
    sigma_small=True,
    predict_xstart=True,
    use_ddim=False,
    clip_range=6.0,
    arch="unet",
    emb_trans_dec=False,
    layers=8,
    latent_dim=512,
    ff_size=1024,
    dim_mults=[2, 2, 2, 2],
    unet_adagn=True,
    unet_zero=True,
    out_mult=False,
    cond_mask_prob=0.1,
    keyframe_mask_prob=0.1,
    lambda_rcxyz=0.0,
    lambda_vel=0.0,
    lambda_fc=0.0,
    unconstrained=False,
    keyframe_conditioned=True,
    keyframe_selection_scheme="random_frames",
    zero_keyframe_loss=False,
    dataset="humanml",
    data_dir="",
    abs_3d=True,
    traj_only=False,
    xz_only=False,
    use_random_proj=False,
    random_proj_scale=10.0,
    augment_type="none",
    std_scale_shift=[1.0, 0.0],
    drop_redundant=False,
    cuda=True,
    device=0,
    seed=10,
)

In [4]:
def load_dataset(args, max_frames, split="test", num_workers=1):
    conf = DatasetConfig(
        name=args.dataset,
        batch_size=args.batch_size,
        num_frames=max_frames,
        split=split,
        hml_mode="train",  # in train mode, you get both text and motion.
        use_abs3d=args.abs_3d,
        traject_only=args.traj_only,
        use_random_projection=args.use_random_proj,
        random_projection_scale=args.random_proj_scale,
        augment_type="none",
        std_scale_shift=args.std_scale_shift,
        drop_redundant=args.drop_redundant,
    )
    data = get_dataset_loader(conf, num_workers=num_workers)
    return data

In [5]:
###########################################################################
# * Build Output Path
###########################################################################

max_frames = 196

###########################################################################
# * Prepare Text/Action Prompts
###########################################################################

# this block must be called BEFORE the dataset is loaded
use_test_set_prompts = False
if args.text_prompt != "":
    texts = [args.text_prompt]
    args.num_samples = 1
elif args.input_text != "":
    assert os.path.exists(args.input_text)
    with open(args.input_text, "r") as fr:
        texts = fr.readlines()
    texts = [s.replace("\n", "") for s in texts]
    args.num_samples = len(texts)
elif args.action_name:
    action_text = [args.action_name]
    args.num_samples = 1
elif args.action_file != "":
    assert os.path.exists(args.action_file)
    with open(args.action_file, "r") as fr:
        action_text = fr.readlines()
    action_text = [s.replace("\n", "") for s in action_text]
    args.num_samples = len(action_text)
elif args.no_text:
    texts = [""] * args.num_samples
    args.guidance_param = 0.0  # Force unconditioned generation # TODO: This is part of inbetween.py --> Will I need it here?
else:
    # use text from the test set
    use_test_set_prompts = True

###########################################################################
# * Load Dataset and Model
###########################################################################

print("Loading dataset...")
assert (
    args.num_samples <= args.batch_size
), f"Please either increase batch_size({args.batch_size}) or reduce num_samples({args.num_samples})"
# So why do we need this check? In order to protect GPU from a memory overload in the following line.
# If your GPU can handle batch size larger then default, you can specify it through --batch_size flag.
# If it doesn't, and you still want to sample more prompts, run this script with different seeds
# (specify through the --seed flag)
args.batch_size = (
    args.num_samples
)  # Sampling a single batch from the testset, with exactly args.num_samples
split = "fixed_subset" if args.use_fixed_subset else "test"
# returns a DataLoader with the Text2MotionDatasetV2 dataset
data = load_dataset(args, max_frames, split=split)


print("Creating model and diffusion...")
model, diffusion = create_model_and_diffusion(args, data)

###########################################################################
# * Load Model Checkpoint
###########################################################################

print(f"Loading checkpoints from [{args.model_path}]...")
load_saved_model(model, args.model_path)  # , use_avg_model=args.gen_avg_model)
if args.guidance_param != 1 and args.keyframe_guidance_param != 1:
    raise NotImplementedError("Classifier-free sampling for keyframes not implemented.")
elif args.guidance_param != 1:
    model = ClassifierFreeSampleModel(
        model
    )  # wrapping model with the classifier-free sampler
model.to(dist_util.dev())
model.eval()  # disable random masking

print("Model ready")

Loading dataset...
Reading ././dataset/humanml_opt.txt
Loading dataset t2m ...
mode = train
t2m dataset aug: none std_scale_shift: [1.0, 0.0]
t2m dataset drop redundant information: False


100%|██████████| 4384/4384 [00:03<00:00, 1448.31it/s]


Pointer Pointing at 0
Creating model and diffusion...
Using UNET with lantent dim:  512  and mults:  [2, 2, 2, 2]
dims:  [263, 1024, 1024, 1024, 1024] mults:  [2, 2, 2, 2]
[ models/temporal ] Channel dimensions: [(263, 1024), (1024, 1024), (1024, 1024), (1024, 1024)]
EMBED TEXT
Loading CLIP...




Loading checkpoints from [./save/condmdi_random_frames/model000750000.pt]...
loading avg model
Model ready


## Full Model Summary (`MDM_UNET`)

- `n_frames=196` gets right-padded to `224` to be a multiple of `16`


In [23]:
size = (2, 263, 1, 196)
timesteps = torch.ones(size[0], dtype=torch.int)
obs_x0 = torch.zeros(size, dtype=torch.float16)
obs_mask = torch.ones(size, dtype=torch.bool)
y = {
    "text": ["lorem testum"],
}

summary(
    model=model.model,
    device="mps",
    input_size=size,
    col_names=["input_size", "output_size", "num_params", "trainable"],
    col_width=20,
    row_settings=["var_names"],
    depth=2,
    #### extra model forward kwargs ####
    timesteps=timesteps,
    obs_x0=obs_x0,
    obs_mask=obs_mask,
    y=y,
)

Layer (type (var_name))                                 Input Shape          Output Shape         Param #              Trainable
MDM_UNET (MDM_UNET)                                     [2, 263, 1, 196]     [2, 263, 1, 196]     --                   Partial
├─TimestepEmbedder (embed_timestep)                     [2]                  [1, 2, 512]          --                   True
│    └─Sequential (time_embed)                          [2, 1, 512]          [2, 1, 512]          525,312              True
├─CLIP (clip_model)                                     --                   --                   88,150,785           False
│    └─Embedding (token_embedding)                      [1, 77]              [1, 77, 512]         (25,296,896)         False
│    └─Transformer (transformer)                        [77, 1, 512]         [77, 1, 512]         (37,828,608)         False
│    └─LayerNorm (ln_final)                             [1, 77, 512]         [1, 77, 512]         (1,024)              Fa

## Only the `TemporalUNet` Module
- HYPOTHESIS: motion vectors and masks are concatenated -> `2 * 263 = 526` 

In [34]:
cond = model.model.embed_timestep(timesteps).clone().detach().squeeze(0)
unet = model.model.unet
unet_size = (224, 2, 526)

summary(
    model=unet,
    device="mps",
    input_size=unet_size,
    col_names=["input_size", "output_size", "num_params", "trainable"],
    col_width=20,
    row_settings=["var_names"],
    depth=3,
    #### extra model forward kwargs ####
    cond=cond,
)

Layer (type (var_name))                            Input Shape          Output Shape         Param #              Trainable
TemporalUnet (TemporalUnet)                        [224, 2, 526]        [224, 2, 263]        --                   True
├─Sequential (time_mlp)                            [2, 512]             [2, 512]             --                   True
│    └─Linear (0)                                  [2, 512]             [2, 2048]            1,050,624            True
│    └─Mish (1)                                    [2, 2048]            [2, 2048]            --                   --
│    └─Linear (2)                                  [2, 2048]            [2, 512]             1,049,088            True
├─ModuleList (downs)                               --                   --                   --                   True
│    └─ModuleList (0)                              --                   --                   --                   True
│    │    └─ResidualTemporalBlock (0)        

In [39]:
cond.shape

torch.Size([2, 512])