In [1]:
%cd /ibex/user/slimhy/PADS/code
%reload_ext autoreload
"""
Extracting features into HDF5 files for each split.
"""
import argparse
import torch

import util.misc as misc
import models.s2vs as ae_mods


def get_args_parser():
    parser = argparse.ArgumentParser("Extracting Features", add_help=False)

    # Model parameters
    parser.add_argument(
        "--batch_size",
        default=32,
        type=int,
        help="Batch size per GPU"
        " (effective batch size is batch_size * accum_iter * # gpus",
    )
    parser.add_argument(
        "--text_model_name",
        type=str,
        help="Text model name to use",
    )
    parser.add_argument(
        "--ae",
        type=str,
        metavar="MODEL",
        help="Name of autoencoder",
    )
    parser.add_argument(
        "--ae-latent-dim",
        type=int,
        default=512*8,
        help="AE latent dimension",
    )
    parser.add_argument(
        "--ae_pth",
        required=True,
        help="Autoencoder checkpoint"
    )
    parser.add_argument(
        "--point_cloud_size",
        default=2048,
        type=int,
        help="input size"
    )
    parser.add_argument(
        "--fetch_keys",
        action="store_true",
        default=False,
    )
    parser.add_argument(
        "--use_embeds",
        action="store_true",
        default=False,
    )
    parser.add_argument(
        "--intensity_loss",
        action="store_true",
        default=False,
        help="Contrastive edit intensity loss using ground-truth labels.",
    )

    # Dataset parameters
    parser.add_argument(
        "--dataset",
        type=str,
        choices=["graphedits"],
        help="dataset name",
    )
    parser.add_argument(
        "--data_path",
        type=str,
        help="dataset path",
    )
    parser.add_argument(
        "--data_type",
        type=str,
        help="dataset type",
    )
    parser.add_argument(
        "--max_edge_level",
        default=None,
        type=int,
        help="maximum edge level to use",
    )
    parser.add_argument(
        "--device", default="cuda", help="device to use for training / testing"
    )
    parser.add_argument("--seed", default=0, type=int)
    parser.add_argument("--num_workers", default=60, type=int)
    parser.add_argument(
        "--pin_mem",
        action="store_true",
        help="Pin CPU memory in DataLoader for more efficient "
        "(sometimes) transfer to GPU.",
    )

    return parser


# Set dummy arg string to debug the parser
call_string = """--ae_pth ckpt/ae_m512.pth \
    --ae kl_d512_m512_l8 \
    --ae-latent-dim 4096 \
    --data_path /ibex/project/c2273/PADS/3DCoMPaT \
    --batch_size 32 \
    --num_workers 8 \
    --device cuda"""
    

# Parse the arguments
args = get_args_parser()
args = args.parse_args(call_string.split())

# --------------------
device = torch.device(args.device)

# Fix the seed for reproducibility
misc.set_all_seeds(args.seed)

torch.backends.cudnn.benchmark = True
# --------------------

# Instantiate autoencoder
ae = ae_mods.__dict__[args.ae]()
ae.eval()
print("Loading autoencoder %s" % args.ae_pth)
ae.load_state_dict(torch.load(args.ae_pth, map_location="cpu")["model"])
ae = ae.to(device)

# Compile using torch.compile
ae = torch.compile(ae, mode="max-autotune")

/ibex/user/slimhy/PADS/code
Set seed to 0
Loading autoencoder ckpt/ae_m512.pth


In [2]:
from datasets.latents import ShapeLatentDataset, ComposedPairedShapesLoader

class PairType():
    NO_ROT_PAIR = "rand_no_rot,rand_no_rot"
    PART_DROP = "part_drop,orig"

# Create your datasets
dataset_train = ShapeLatentDataset(args.data_path, split="train", shuffle_parts=True, filter_n_ids=1)
dataset_val = ShapeLatentDataset(args.data_path, split="test", shuffle_parts=False, filter_n_ids=1)

# Create the DataLoader using the sampler
data_loader_train = ComposedPairedShapesLoader(
    dataset_train,
    batch_size=2,
    pair_types_list=[PairType.PART_DROP, PairType.NO_ROT_PAIR],
    num_workers=0,
    shuffle=True,
    use_distributed=False
) 

In [3]:
from losses.partvae import KLRecLoss, ScaleInvariantLoss, PartDropLoss
from torch.nn import functional as F


def forward_pass(
    pvae,
    data_tuple,
    kl_rec_loss,
    scale_inv_loss,
    part_drop_loss,
    pair_types,
):
    """
    Compute a single forward pass of the model.
    """
    # Unpack the data tuple
    pair_types, (l_a, bb_a, bb_l_a, meta_a), (l_b, bb_b, bb_l_b, meta_b) = data_tuple
    device = pvae.device

    # Compute the mask from batch labels
    mask_a = (bb_l_a != -1).to(device)  # B x 24
    mask_b = (bb_l_b != -1).to(device)  # B x 24

    l_a, l_b = l_a.to(device), l_b.to(device)  # B x 24 x 512
    bb_a, bb_b = bb_a.to(device), bb_b.to(device)  # B x 24 x 4 x 3
    bb_l_a, bb_l_b = bb_l_a.to(device), bb_l_b.to(device)  # B x 24
    
    print(l_a.shape)

    # Optionally compute the KL Reg loss
    if pvae.is_vae:
        logits_a, kl_a, part_latents_a = pvae(
            latents=l_a, part_bbs=bb_a, part_labels=bb_l_a, batch_mask=mask_a
        )
        logits_b, kl_b, part_latents_b = pvae(
            latents=l_b, part_bbs=bb_b, part_labels=bb_l_b, batch_mask=mask_b
        )

        # KL Reg loss
        kl_reg = kl_rec_loss(kl_a, mask=mask_a) + kl_rec_loss(kl_b, mask=mask_b)
        kl_reg /= 2.0
    else:
        logits_a, part_latents_a = pvae(
            latents=l_a, part_bbs=bb_a, part_labels=bb_l_a, batch_mask=mask_a
        )
        logits_b, part_latents_b = pvae(
            latents=l_b, part_bbs=bb_b, part_labels=bb_l_b, batch_mask=mask_b
        )
        kl_reg = torch.tensor(0.0).to(device)

    # L2 loss
    rec_loss = F.mse_loss(logits_a, l_a) + F.mse_loss(logits_b, l_b)
    rec_loss /= 2.0

    if pair_types == PairType.NO_ROT_PAIR:
        inv_loss = scale_inv_loss(part_latents_a, part_latents_b, mask_a)
    elif pair_types == PairType.PART_DROP:
        inv_loss = part_drop_loss(
            part_latents_a, part_latents_b, bb_a, bb_b, mask_a, mask_b
        )

    return {
        "kl_reg": kl_reg,
        "rec_loss": rec_loss,
        "inv_loss": inv_loss,
    }

In [4]:
from schedulefree import AdamWScheduleFree
from models.partvae import PartAwareVAE, PartAwareAE


def get_losses():
    """
    Instantiate the losses.
    """
    return (
        KLRecLoss(),
        ScaleInvariantLoss(),
        PartDropLoss(),
    )


# Initialize the model
pvae = PartAwareAE(
    dim=512,
    latent_dim=128,
    heads=8,
    dim_head=64,
    depth=2,
).to(device)
pvae = pvae.to(device)

pvae.train(True)

metric_logger = misc.MetricLogger(delimiter="  ")
for epoch in enumerate(
    metric_logger.log_every(range(300), 10)
):
    # Reset the metric logger
    # metric_logger.reset()
    # Initialize the optimizer
    optimizer = AdamWScheduleFree(
        pvae.parameters(), lr=1e-3, weight_decay=1e-5
    )
    optimizer.zero_grad()

    # Instantiate the losses
    kl_loss, scale_inv_loss, part_drop_loss = get_losses()

    for data_step, data_tuple in enumerate(data_loader_train):
        # Computing loss
        loss = forward_pass(
            pvae=pvae,
            data_tuple=data_tuple,
            kl_rec_loss=kl_loss,
            scale_inv_loss=scale_inv_loss,
            part_drop_loss=part_drop_loss,
            pair_types=PairType.PART_DROP,
        )
        total_loss = (
            0.005 * loss["kl_reg"]
            + 0.8 * loss["rec_loss"]
            + 0. * loss["inv_loss"]
        )

        # Backward pass
        total_loss.backward()
        optimizer.zero_grad()
        torch.cuda.synchronize()

        # Log the losses
        loss_update = {
            "train_loss": float(total_loss.item()),
            "kl_reg": float(loss["kl_reg"].item()),
            "rec_loss": float(loss["rec_loss"].item()),
            "inv_loss": float(loss["inv_loss"].item()),
        }
        metric_logger.update(**loss_update)

torch.Size([1, 512, 8])
torch.Size([1, 512, 8])
  [  0/300]  eta: 0:02:47  train_loss: 0.7355 (0.7537)  kl_reg: 0.0000 (0.0000)  rec_loss: 0.9194 (0.9421)  inv_loss: 0.0008 (0.0015)  time: 0.5593  data: 0.0000  max mem: 596
torch.Size([1, 512, 8])
torch.Size([1, 512, 8])
torch.Size([1, 512, 8])
torch.Size([1, 512, 8])
torch.Size([1, 512, 8])
torch.Size([1, 512, 8])
torch.Size([1, 512, 8])
torch.Size([1, 512, 8])
torch.Size([1, 512, 8])
torch.Size([1, 512, 8])
  [ 10/300]  eta: 0:00:21  train_loss: 0.7595 (0.7525)  kl_reg: 0.0000 (0.0000)  rec_loss: 0.9493 (0.9406)  inv_loss: 0.0017 (0.0019)  time: 0.0734  data: 0.0000  max mem: 596
torch.Size([1, 512, 8])
torch.Size([1, 512, 8])
torch.Size([1, 512, 8])
torch.Size([1, 512, 8])
torch.Size([1, 512, 8])
torch.Size([1, 512, 8])
torch.Size([1, 512, 8])
torch.Size([1, 512, 8])
torch.Size([1, 512, 8])
  [ 20/300]  eta: 0:00:13  train_loss: 0.7595 (0.7516)  kl_reg: 0.0000 (0.0000)  rec_loss: 0.9493 (0.9395)  inv_loss: 0.0020 (0.0020)  time: 0.0

KeyboardInterrupt: 

In [None]:
from torch_linear_assignment import batch_linear_assignment


def debug_rec_loss(x, x_rec):
    """
    Call the loss function.
    """
    B, D, N = (
        x.shape
    )  # B: batch size, D: latent dimension (512), N: number of vectors (8)

    # Reshape x and x_rec to (B, N, D) for easier processing
    # x = set_a.transpose(1, 2)
    # x_rec = set_b.transpose(1, 2)

    # Compute the cost matrix using cdist
    cost_matrix = torch.cdist(x, x_rec)

    # Compute the linear assignment
    assignment = batch_linear_assignment(cost_matrix)

    # Compute the loss
    total_loss = 0
    for b in range(B):
        x_matched = x[b, assignment[b, 0]]
        x_rec_matched = x_rec[b, assignment[b, 1]]
        loss = F.mse_loss(x_matched, x_rec_matched)
        total_loss += loss

    return total_loss / B

# Define two sets of vectors
B = 32
set_A = torch.tensor([
    [0, 0, 0, 0],
    [1, 1, 1, 1],
    [1, 2, 2, 2],
]).repeat(B, 1, 1).type(torch.float32)

set_B = torch.tensor([
    [0, 0, 0, 0],
    [1, 1, 1, 1],
    [2, 2, 2, 2],
]).repeat(B, 1, 1).type(torch.float32)


print(set_A.shape, set_B.shape)

debug_loss(set_A, set_B)