In [None]:
import glob
import ntpath
import pickle
from typing import Dict, Union

import einops
import torch
import yaml
from fvcore.common.config import CfgNode as CN
from tqdm import tqdm

from data.ava_dataset import MultiCaptureDataset as AvaMultiCaptureDataset
from data.ava_dataset import SingleCaptureDataset as AvaSingleCaptureDataset
from data.ava_dataset import none_collate_fn
from data.utils import MugsyCapture
from utils import get_autoencoder, load_checkpoint, render_img, tocuda, train_csv_loader


In [None]:
def fetch_id_embedding(uid):
    with open(f"id_embeddings/{uid}.pickle", "rb") as f:
        id_embedding = pickle.load(f)
    return id_embedding


def id_cond_to_device(id_cond, device=torch.device("cuda")):
    # put id_cond on the gpu

    id_cond2 = {}
    id_cond2["z_tex"] = id_cond["z_tex"].detach().to(device)
    id_cond2["z_geo"] = id_cond["z_geo"].detach().to(device)
    id_cond2["b_tex"] = [None, None, None, None, None, None, None, None]
    id_cond2["b_geo"] = [None, None, None, None, None, None, None, None]
    for i in range(8):
        id_cond2["b_tex"][i] = id_cond["b_tex"][i].detach().to(device)
        id_cond2["b_geo"][i] = id_cond["b_geo"][i].detach().to(device)

    return id_cond2


def generate_image(ae, id_cond, cudadriver):
    id_cond = id_cond_to_device(id_cond)

    output = ae(
        camrot=cudadriver["camrot"],
        campos=cudadriver["campos"],
        focal=cudadriver["focal"],
        princpt=cudadriver["princpt"],
        modelmatrix=cudadriver["modelmatrix"],
        avgtex=cudadriver["avgtex"],
        verts=cudadriver["verts"],
        neut_avgtex=cudadriver["neut_avgtex"],
        neut_verts=cudadriver["neut_verts"],
        target_neut_avgtex=None,
        target_neut_verts=None,
        id_cond=id_cond,
        pixelcoords=cudadriver["pixelcoords"],
    )

    rgb = output["irgbrec"].detach().cpu().numpy()
    rgb = einops.rearrange(rgb, "1 c h w -> h w c")

    return rgb


def render(ae, id_cond, cudadriver, out_path: str = "test.png"):
    rgb = generate_image(ae, id_cond, cudadriver)
    render_img([[rgb]], out_path)


In [5]:
checkpoint = "aeparams_1440000.pt"  # the pretrained model
config = "configs/config.yaml"
opts = []


with open(config, "r") as file:
    config = CN(yaml.load(file, Loader=yaml.UnsafeLoader))

config.merge_from_list(opts)

train_params = config.train

# Train dataset mean/std texture and vertex for normalization
train_captures, train_dirs = train_csv_loader(
    train_params.dataset_dir, train_params.data_csv, train_params.nids
)
dataset = AvaMultiCaptureDataset(
    train_captures, train_dirs, downsample=train_params.downsample
)

batchsize = 1
numworkers = 1

dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=batchsize,
    shuffle=False,
    drop_last=True,
    num_workers=numworkers,
    collate_fn=none_collate_fn,
)

# Get Autoencoder
assetpath = "assets"
ae = get_autoencoder(dataset, assetpath=assetpath)
# Load from checkpoint
ae = load_checkpoint(ae, checkpoint).cuda()
# Set to Evaluation mode
ae.eval()

id_model = ae.id_encoder
texmean = dataset.texmean
vertmean = dataset.vertmean
texstd = dataset.texstd
vertstd = dataset.vertstd

# Delete dataset because it is no longer used
del dataset


user_ids = []

for ui in glob.glob("E://codec_dataset/*"):
    user_ids.append(ntpath.basename(ui))

Loading single id captures: 100%|██████████| 256/256 [02:59<00:00,  1.42it/s]


@@@ Get autoencoder ABLATION CONFIG FILE : length of data set : 256
dataset vertmean: (7306, 3)
id_encoder params: 5062060
encoder params: 5_551_232
decoder params: 35_918_504
colorcal params: 3_252
bgmodel params: 454_739
total params: 46_991_899


  checkpoint = th.load(filename)


In [8]:
n_people = 256
all_uids = user_ids[0:n_people]

with torch.no_grad():
    for uid in tqdm(all_uids):
        # Driver capture dataloader
        driver_capture = MugsyCapture(
            mcd=uid.split("--")[0], mct=uid.split("--")[1], sid=uid.split("--")[2]
        )
        driver_dir = f"{train_params.dataset_dir}/{uid}/decoder"
        driver_dataset = AvaSingleCaptureDataset(
            driver_capture, driver_dir, downsample=train_params.downsample
        )

        # Grab driven normalization stats
        for dataset in [driver_dataset]:
            dataset.texmean = texmean
            dataset.texstd = texstd
            dataset.vertmean = vertmean
            dataset.vertstd = vertstd

        # if possible, we want a front-facing camera view
        if (
            "401031" in driver_dataset.cameras
            or "401880" in driver_dataset.cameras
            or "401878" in driver_dataset.cameras
        ):
            driver_dataset.cameras = ["401031", "401880", "401878"]

        driver_loader = torch.utils.data.DataLoader(
            driver_dataset,
            batch_size=batchsize,
            shuffle=True,
            drop_last=False,
            num_workers=numworkers,
            collate_fn=none_collate_fn,
        )

        for driver in driver_loader:
            # Skip if any of the frames is empty
            if driver is None:
                continue

            cudadriver: Dict[str, Union[torch.Tensor, int, str]] = tocuda(driver)

            running_avg_scale = False
            gt_geo = None
            residuals_weight = 1.0
            output_set = set(["irgbrec", "bg"])

            id_embedding_dict = {}

            id_cond = id_model(cudadriver["neut_verts"], cudadriver["neut_avgtex"])
            id_embedding_dict["uid"] = uid
            id_embedding_dict["id_cond"] = id_cond_to_device(
                id_cond, torch.device("cpu")
            )
            id_embedding_dict["cudadriver"] = cudadriver

            # serialize each embedding
            with open(f"id_embeddings/{uid}.pickle", "wb") as f:
                pickle.dump(id_embedding_dict, f)
            break

100%|██████████| 256/256 [32:10<00:00,  7.54s/it]


In [17]:
for uid in tqdm(all_uids):
    id_embedding = fetch_id_embedding(uid)
    render(
        ae,
        id_embedding["id_cond"],
        id_embedding["cudadriver"],
        f"out/{id_embedding['uid']}.png",
    )

100%|██████████| 256/256 [01:05<00:00,  3.93it/s]
