In [1]:
%cd /home/slimhy/Documents/PADS/code
"""
Extracting features into HDF5 files for each split.
"""
import argparse
import torch

import util.misc as misc
import models.s2vs as ae_mods


def get_args_parser():
    parser = argparse.ArgumentParser("Extracting Features", add_help=False)

    # Model parameters
    parser.add_argument(
        "--batch_size",
        default=32,
        type=int,
        help="Batch size per GPU"
        " (effective batch size is batch_size * accum_iter * # gpus",
    )
    parser.add_argument(
        "--text_model_name",
        type=str,
        help="Text model name to use",
    )
    parser.add_argument(
        "--ae",
        type=str,
        metavar="MODEL",
        help="Name of autoencoder",
    )
    parser.add_argument(
        "--ae-latent-dim",
        type=int,
        default=512*8,
        help="AE latent dimension",
    )
    parser.add_argument(
        "--ae_pth",
        required=True,
        help="Autoencoder checkpoint"
    )
    parser.add_argument(
        "--point_cloud_size",
        default=2048,
        type=int,
        help="input size"
    )
    parser.add_argument(
        "--fetch_keys",
        action="store_true",
        default=False,
    )
    parser.add_argument(
        "--use_embeds",
        action="store_true",
        default=False,
    )
    parser.add_argument(
        "--intensity_loss",
        action="store_true",
        default=False,
        help="Contrastive edit intensity loss using ground-truth labels.",
    )

    # Dataset parameters
    parser.add_argument(
        "--dataset",
        type=str,
        choices=["graphedits"],
        help="dataset name",
    )
    parser.add_argument(
        "--data_path",
        type=str,
        help="dataset path",
    )
    parser.add_argument(
        "--data_type",
        type=str,
        help="dataset type",
    )
    parser.add_argument(
        "--max_edge_level",
        default=None,
        type=int,
        help="maximum edge level to use",
    )
    parser.add_argument(
        "--device", default="cuda", help="device to use for training / testing"
    )
    parser.add_argument("--seed", default=0, type=int)
    parser.add_argument("--num_workers", default=60, type=int)
    parser.add_argument(
        "--pin_mem",
        action="store_true",
        help="Pin CPU memory in DataLoader for more efficient "
        "(sometimes) transfer to GPU.",
    )

    return parser


# Set dummy arg string to debug the parser
call_string = """--ae_pth ckpt/ae_m512.pth \
    --ae kl_d512_m512_l8 \
    --ae-latent-dim 4096 \
    --batch_size 32 \
    --num_workers 8 \
    --device cuda"""
    

# Parse the arguments
args = get_args_parser()
args = args.parse_args(call_string.split())

# --------------------
device = torch.device(args.device)

# Fix the seed for reproducibility
misc.set_all_seeds(args.seed)

torch.backends.cudnn.benchmark = True
# --------------------

# Instantiate autoencoder
ae = ae_mods.__dict__[args.ae]()
ae.eval()
print("Loading autoencoder %s" % args.ae_pth)
ae.load_state_dict(torch.load(args.ae_pth, map_location="cpu")["model"])
ae = ae.to(device)

# Compile using torch.compile
ae = torch.compile(ae, mode="max-autotune")

/home/slimhy/Documents/PADS/code
Set seed to 0
Loading autoencoder ckpt/ae_m512.pth


In [2]:
from datasets.latents import ShapeLatentDataset, ComposedPairedShapesLoader

class PairType():
    ROT_PAIR = "rand_no_rot,rand_no_rot"
    PART_DROP = "part_drop,orig"

latents_dir = "/home/slimhy/Documents/datasets/PADS/3DCoMPaT"

# Create your dataset
dataset = ShapeLatentDataset(latents_dir, split="train", shuffle_parts=False)

# Create the DataLoader using the sampler
dataloader = ComposedPairedShapesLoader(
    dataset,
    batch_size=32,
    pair_types_list=[PairType.ROT_PAIR, PairType.PART_DROP],
    num_workers=0,
    shuffle=False,
) 

# Use the dataloader in your training loop
k_break = 1
k = 0
for pair_types, (latent_A, bb_coords_A, bb_labels_A, meta_A), (
    latent_B,
    bb_coords_B,
    bb_labels_B,
    meta_B, 
) in dataloader:
    k += 1
    if k == k_break:
        break

In [3]:
from torch import nn
from datasets.metadata import COMPAT_FINE_PARTS
from models.modules import (
    Attention,
    DiagonalGaussianDistribution,
    FeedForward,
    GEGLU,
    PointEmbed,
    PreNorm,
)
from util.misc import cache_fn


# ================== MODULE
class PartAwareAE(nn.Module):
    def __init__(
            self,
            dim=512,
            latent_dim=128,
            max_parts=24,
            heads=8,
            dim_head=64,
            depth=2,
            weight_tie_layers=False,
        ):
        super().__init__()

        self.dim = dim
        self.latent_dim = latent_dim
        self.max_parts = max_parts
        self.heads = heads
        self.dim_head = dim_head
        self.depth = depth
        self.weight_tie_layers = weight_tie_layers

        cache_args = {"_cache": self.weight_tie_layers}
    
        # Point Embedding
        self.point_embed = PointEmbed(dim=dim // 2)
        # Label Embedding
        self.part_label_embed = nn.Embedding(COMPAT_FINE_PARTS, dim // 2)

        # Input/Output Cross-Attention Blocks
        self.in_block = PreNorm(
            dim, Attention(dim, dim, heads=1, dim_head=dim), context_dim=dim
        )
        self.out_proj = nn.Linear(24, 8)
        
        # Stacked Attention Layers
        def get_latent_attn():
            return PreNorm(
                dim, Attention(dim, heads=heads, dim_head=dim_head, drop_path_rate=0.1)
            )
        def get_latent_ff():
            return PreNorm(dim, FeedForward(dim, drop_path_rate=0.1))
        get_latent_attn, get_latent_ff = map(cache_fn, (get_latent_attn, get_latent_ff))

        self.encoder_layers = nn.ModuleList([])
        for i in range(depth):
            self.encoder_layers.append(
                nn.ModuleList(
                    [get_latent_attn(**cache_args), get_latent_ff(**cache_args)]
                )
            )
        self.decoder_layers = nn.ModuleList([])
        for i in range(depth):
            self.decoder_layers.append(
                nn.ModuleList(
                    [get_latent_attn(**cache_args), get_latent_ff(**cache_args)]
                )
            )

        # Compress/Expand latents
        self.compress_latents = nn.Sequential(
            nn.Linear(dim, dim // 2),
            GEGLU(),
            nn.Linear(latent_dim, latent_dim),
        )
        self.expand_latents = nn.Sequential(
            nn.Linear(latent_dim, dim),
            GEGLU(),
            nn.Linear(dim // 2, dim),
        )

    def encode(self, latents, part_bbs, part_labels):
        # Compute the mask from batch labels (part labels equal to -1 are masked)
        batch_mask = part_labels != -1
        batch_mask = batch_mask.to(latents.device)

        # Embed bounding boxes
        bb_centroids = torch.mean(part_bbs, dim=-2) # B x 24 x 3
        bb_embeds = self.point_embed(bb_centroids)  # B x 24 x 256

        # Embed part labels (take mask into account)
        part_labels = part_labels * batch_mask
        part_labels_embed = self.part_label_embed(part_labels) # B x 24 x 256

        # Repeat latents to match the number of parts
        latents_in = latents.transpose(1,2).repeat(1, 3, 1)    # B x 512 x 8 -> B x 24 x 512
        part_embeds = torch.cat((bb_embeds, part_labels_embed), dim=-1)

        x = self.in_block(part_embeds, context=latents_in, mask=batch_mask)

        # Stacked encoder layers
        for attn, ff in self.encoder_layers:
            x = attn(x) + x 
            x = ff(x) + x
            
        # Compress to desired reduced dimension
        part_latents = self.compress_latents(x)

        return part_latents
    
    def decode(self, part_latents):
        # Expand latents to full dimension
        x = self.expand_latents(part_latents)

        # Stacked decoder layers
        for attn, ff in self.decoder_layers:
            x = attn(x) + x
            x = ff(x) + x

        # Apply output block
        latents = self.out_proj(x.transpose(1,2))

        return latents

    def forward(self, latents, part_bbs, part_labels):
        encoded = self.encode(latents, part_bbs, part_labels)
        latents = self.decode(encoded)
        return latents
    

class PartAwareVAE(PartAwareAE):
    def __init__(
            self,
            *args,
            **kwargs,
        ):
        super().__init__(*args, **kwargs)
        self.mean_fc = nn.Linear(self.latent_dim, self.latent_dim)
        self.logvar_fc = nn.Linear(self.latent_dim, self.latent_dim)

    def encode(self, latents, part_bbs, part_labels):
        x = super().encode(latents, part_bbs, part_labels)
        
        mean = self.mean_fc(x)
        logvar = self.logvar_fc(x)

        posterior = DiagonalGaussianDistribution(mean, logvar)
        x = posterior.sample()
        kl = posterior.kl()

        return kl, x
    
    def decode(self, part_latents):
        return super().decode(part_latents)

    def forward(self, latents, part_bbs, part_labels):
        kl, part_latents = self.encode(latents, part_bbs, part_labels)
        logits = self.decode(part_latents).squeeze(-1)

        return logits, kl

# ==================

In [4]:
latents = latent_A.to(device)        # B x 512 x 8
part_bbs = bb_coords_A.to(device)    # B x 24 x 8 x 3
part_labels = bb_labels_A.to(device) # B x 24

dummy = PartAwareVAE().to(device)
dummy = dummy.to(device)

In [5]:
# Encode the latents
kl, sampled_part_latents = dummy.encode(latents, part_bbs, part_labels)
latents = dummy.decode(sampled_part_latents)

In [10]:
sampled_part_latents.shape

torch.Size([16, 24, 128])

In [6]:
bb_decoder = nn.Linear(512, 24)

In [8]:
misc.count_params(dummy)

'18,319,432'

In [9]:
misc.count_params(ae)

'106,128,913'