## Load the model

In [None]:
import os
import os.path as osp
import argparse
import numpy as np
import torchvision.transforms as transforms
import torch.backends.cudnn as cudnn
import torch
import cv2
import datetime
from tqdm import tqdm
from pathlib import Path
from human_models.human_models import SMPLX
from ultralytics import YOLO
from main.base import Tester
from main.config import Config
from utils.data_utils import load_img, process_bbox, generate_patch_image
from utils.visualization_utils import render_mesh
from utils.inference_utils import non_max_suppression


In [None]:
cudnn.benchmark = True

ckpt_name = "smplest_x_h"
file_name = "01.mp4"

# init config
time_str = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
root_dir = ""
config_path = osp.join('./pretrained_models', ckpt_name, 'config_base.py')
cfg = Config.load_config(config_path)
checkpoint_path = osp.join('./pretrained_models', ckpt_name, f'{ckpt_name}.pth.tar')
img_folder = osp.join(root_dir, 'demo', 'input_frames', file_name)
output_folder = osp.join(root_dir, 'demo', 'output_frames', file_name)
os.makedirs(output_folder, exist_ok=True)
exp_name = f'inference_{file_name}_{ckpt_name}_{time_str}'

new_config = {
    "model": {
        "pretrained_model_path": checkpoint_path,
    },
    "log":{
        'exp_name':  exp_name,
        'log_dir': osp.join(root_dir, 'outputs', exp_name, 'log'),  
        }
}
cfg.update_config(new_config)

# init human models
smpl_x = SMPLX(cfg.model.human_model_path)

# init tester
demoer = Tester(cfg)
demoer._make_model()

[92m08-11 14:38:39[0m Load checkpoint from ./pretrained_models\smplest_x_h\smplest_x_h.pth.tar
[92m08-11 14:38:39[0m Creating graph...


Total #parameters: 687223152 (0.69B)


[92m08-11 14:38:45[0m [93mWRN: Attention: Strict=False is set for checkpoint loading. Please check manually.[0m


In [3]:
demoer.model

DataParallel(
  (module): Model(
    (encoder): ViT(
      (patch_embed): PatchEmbed(
        (proj): Conv2d(3, 1280, kernel_size=(16, 16), stride=(16, 16), padding=(2, 2))
      )
      (blocks): ModuleList(
        (0): Block(
          (norm1): LayerNorm((1280,), eps=1e-06, elementwise_affine=True)
          (attn): Attention_ViT(
            (qkv): Linear(in_features=1280, out_features=3840, bias=True)
            (proj): Linear(in_features=1280, out_features=1280, bias=True)
            (proj_drop): Dropout(p=0.0, inplace=False)
          )
          (drop_path): Identity()
          (norm2): LayerNorm((1280,), eps=1e-06, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=1280, out_features=5120, bias=True)
            (act): GELU(approximate='none')
            (fc2): Linear(in_features=5120, out_features=1280, bias=True)
            (drop): Dropout(p=0.0, inplace=False)
          )
        )
        (1): Block(
          (norm1): LayerNorm((1280

In [4]:
bbox_model = getattr(cfg.inference.detection, "model_path", 
                        './pretrained_models/yolov8x.pt')
detector = YOLO(bbox_model)

## Convert the Video into Frames

In [None]:
import subprocess
import shutil
import sys

FPS = 15
DEMO_DIR = Path("demo")
IMG_PATH = Path("demo/input_frames/01.mp4")
FILE_NAME = "01.mp4"

# Setup paths and ensure the output directory exists
input_file = DEMO_DIR / FILE_NAME
IMG_PATH.mkdir(parents=True, exist_ok=True)
# Check if the input file actually exists
if not input_file.is_file():
    print(f"Error: Input file not found at '{input_file}'")
    sys.exit(1)

ext = input_file.suffix[1:].lower()
video_extensions = {'mp4', 'avi', 'mov', 'mkv', 'flv', 'wmv', 'webm', 'mpeg', 'mpg'}
image_extensions = {'jpg', 'jpeg', 'png', 'bmp', 'gif', 'tiff', 'tif', 'webp', 'svg'}

try:
    if ext in video_extensions:
        print(f"Processing video: {FILE_NAME}")
        # This command is the Python equivalent of the ffmpeg shell command
        command = [
            'ffmpeg',
            '-i', str(input_file),
            '-f', 'image2',
            '-vf', f'fps={FPS}/1',
            '-qscale:v', '2',  # Good quality, 0 is lossless but creates huge files
            str(IMG_PATH / '%06d.jpg')
        ]
        subprocess.run(command, check=True, capture_output=True, text=True)
    elif ext in image_extensions:
        print(f"Processing image: {FILE_NAME}")
        # Copy the image to the destination, renaming it in the process
        dest_file = IMG_PATH / f"000001.{ext}"
        shutil.copy(str(input_file), str(dest_file))
    else:
        print(f"Unknown file type: '{ext}'")
        sys.exit(1)
except subprocess.CalledProcessError as e:
    # This will catch errors from ffmpeg (e.g., corrupted file)
    print("An error occurred during ffmpeg processing:")
    print(f"Stderr: {e.stderr}")
    sys.exit(1)
except Exception as e:
    # Catch other potential errors
    print(f"An unexpected error occurred: {e}")
    sys.exit(1)

# Count the resulting files
end_count = len(os.listdir(IMG_PATH))
print(f"\nSuccess! Total files in output directory: {end_count}")

## Process example images

In [None]:
num_example = 10
for frame in tqdm(range(1, num_example+1)):    
    # prepare input image
    img_path =osp.join("demo/input_frames/01.mp4", f'{int(frame):06d}.jpg')

    transform = transforms.ToTensor()
    original_img = load_img(img_path)
    vis_img = original_img.copy()
    original_img_height, original_img_width = original_img.shape[:2]
    
    # detection, xyxy
    yolo_bbox = detector.predict(original_img, 
                            device='cuda', 
                            classes=00, 
                            conf=cfg.inference.detection.conf, 
                            save=cfg.inference.detection.save, 
                            verbose=cfg.inference.detection.verbose
                                )[0].boxes.xyxy.detach().cpu().numpy()

    num_bbox = 1
    # loop all detected bboxes
    for bbox_id in range(num_bbox):
        yolo_bbox_xywh = np.zeros((4))
        yolo_bbox_xywh[0] = yolo_bbox[bbox_id][0]
        yolo_bbox_xywh[1] = yolo_bbox[bbox_id][1]
        yolo_bbox_xywh[2] = abs(yolo_bbox[bbox_id][2] - yolo_bbox[bbox_id][0])
        yolo_bbox_xywh[3] = abs(yolo_bbox[bbox_id][3] - yolo_bbox[bbox_id][1])
        # xywh
        bbox = process_bbox(bbox=yolo_bbox_xywh, 
                            img_width=original_img_width, 
                            img_height=original_img_height, 
                            input_img_shape=cfg.model.input_img_shape, 
                            ratio=getattr(cfg.data, "bbox_ratio", 1.25))                
        img, _, _ = generate_patch_image(cvimg=original_img, 
                                            bbox=bbox, 
                                            scale=1.0, 
                                            rot=0.0, 
                                            do_flip=False, 
                                            out_shape=cfg.model.input_img_shape)
        img = transform(img.astype(np.float32))/255
        img = img[None,:,:,:]
        inputs = {'img': img}
        targets = {}
        meta_info = {}

        # mesh recovery
        with torch.no_grad():
            out = demoer.model(inputs, targets, meta_info, 'test')
        mesh = out['smplx_mesh_cam'].detach().cpu().numpy()[0]
        # render mesh
        focal = [cfg.model.focal[0] / cfg.model.input_body_shape[1] * bbox[2], 
                    cfg.model.focal[1] / cfg.model.input_body_shape[0] * bbox[3]]
        princpt = [cfg.model.princpt[0] / cfg.model.input_body_shape[1] * bbox[2] + bbox[0], 
                    cfg.model.princpt[1] / cfg.model.input_body_shape[0] * bbox[3] + bbox[1]]
        # draw the bbox on img
        vis_img = cv2.rectangle(vis_img, (int(yolo_bbox[bbox_id][0]), int(yolo_bbox[bbox_id][1])), 
                                (int(yolo_bbox[bbox_id][2]), int(yolo_bbox[bbox_id][3])), (0, 255, 0), 1)
        # draw mesh
        vis_img = render_mesh(vis_img, mesh, smpl_x.face, {'focal': focal, 'princpt': princpt}, mesh_as_vertices=False)
    # save rendered image
    frame_name = os.path.basename(img_path)
    cv2.imwrite(os.path.join(output_folder, frame_name), vis_img[:, :, ::-1])


## Apply Unmerge_ToMe

In [None]:
import tmu

tmu.patch.timm(demoer.model.module.encoder, trace_source=True)
demoer.model

DataParallel(
  (module): Model(
    (encoder): ToMeVisionTransformer(
      (patch_embed): PatchEmbed(
        (proj): Conv2d(3, 1280, kernel_size=(16, 16), stride=(16, 16), padding=(2, 2))
      )
      (blocks): ModuleList(
        (0): ToMeBlock(
          (norm1): LayerNorm((1280,), eps=1e-06, elementwise_affine=True)
          (attn): ToMeAttention(
            (qkv): Linear(in_features=1280, out_features=3840, bias=True)
            (proj): Linear(in_features=1280, out_features=1280, bias=True)
            (proj_drop): Dropout(p=0.0, inplace=False)
          )
          (drop_path): Identity()
          (norm2): LayerNorm((1280,), eps=1e-06, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=1280, out_features=5120, bias=True)
            (act): GELU(approximate='none')
            (fc2): Linear(in_features=5120, out_features=1280, bias=True)
            (drop): Dropout(p=0.0, inplace=False)
          )
        )
        (1): ToMeBlock(
        

In [None]:
demoer.model.module.encoder.r = 2
out_folder = "demo/output_frames/tmu_r" + str(demoer.model.module.encoder.r)
os.makedirs(out_folder, exist_ok=True)

num_example = 10
for frame in tqdm(range(1, num_example+1)):    
    # prepare input image
    img_path =osp.join("demo/input_frames/01.mp4", f'{int(frame):06d}.jpg')

    transform = transforms.ToTensor()
    original_img = load_img(img_path)
    vis_img = original_img.copy()
    original_img_height, original_img_width = original_img.shape[:2]
    
    # detection, xyxy
    yolo_bbox = detector.predict(original_img, 
                            device='cuda', 
                            classes=00, 
                            conf=cfg.inference.detection.conf, 
                            save=cfg.inference.detection.save, 
                            verbose=cfg.inference.detection.verbose
                                )[0].boxes.xyxy.detach().cpu().numpy()

    num_bbox = 1
    # loop all detected bboxes
    for bbox_id in range(num_bbox):
        yolo_bbox_xywh = np.zeros((4))
        yolo_bbox_xywh[0] = yolo_bbox[bbox_id][0]
        yolo_bbox_xywh[1] = yolo_bbox[bbox_id][1]
        yolo_bbox_xywh[2] = abs(yolo_bbox[bbox_id][2] - yolo_bbox[bbox_id][0])
        yolo_bbox_xywh[3] = abs(yolo_bbox[bbox_id][3] - yolo_bbox[bbox_id][1])
        # xywh
        bbox = process_bbox(bbox=yolo_bbox_xywh, 
                            img_width=original_img_width, 
                            img_height=original_img_height, 
                            input_img_shape=cfg.model.input_img_shape, 
                            ratio=getattr(cfg.data, "bbox_ratio", 1.25))                
        img, _, _ = generate_patch_image(cvimg=original_img, 
                                            bbox=bbox, 
                                            scale=1.0, 
                                            rot=0.0, 
                                            do_flip=False, 
                                            out_shape=cfg.model.input_img_shape)
        img = transform(img.astype(np.float32))/255
        img = img[None,:,:,:]
        inputs = {'img': img}
        targets = {}
        meta_info = {}
        with torch.no_grad():
            out = demoer.model(inputs, targets, meta_info, 'test')
        mesh = out['smplx_mesh_cam'].detach().cpu().numpy()[0]
        # render mesh
        focal = [cfg.model.focal[0] / cfg.model.input_body_shape[1] * bbox[2], 
                    cfg.model.focal[1] / cfg.model.input_body_shape[0] * bbox[3]]
        princpt = [cfg.model.princpt[0] / cfg.model.input_body_shape[1] * bbox[2] + bbox[0], 
                    cfg.model.princpt[1] / cfg.model.input_body_shape[0] * bbox[3] + bbox[1]]
        # draw the bbox on img
        vis_img = cv2.rectangle(vis_img, (int(yolo_bbox[bbox_id][0]), int(yolo_bbox[bbox_id][1])), 
                                (int(yolo_bbox[bbox_id][2]), int(yolo_bbox[bbox_id][3])), (0, 255, 0), 1)
        # draw mesh
        vis_img = render_mesh(vis_img, mesh, smpl_x.face, {'focal': focal, 'princpt': princpt}, mesh_as_vertices=False)

    # save rendered image
    frame_name = os.path.basename(img_path)
    cv2.imwrite(os.path.join(out_folder, frame_name), vis_img[:, :, ::-1])
