In [1]:
%cd /ibex/user/slimhy/PADS/code
"""
Extracting features into HDF5 files for each split.
"""
import argparse
import torch

import util.misc as misc
import models.s2vs as ae_mods


def get_args_parser():
    parser = argparse.ArgumentParser("Extracting Features", add_help=False)

    # Model parameters
    parser.add_argument(
        "--batch_size",
        default=32,
        type=int,
        help="Batch size per GPU"
        " (effective batch size is batch_size * accum_iter * # gpus",
    )
    parser.add_argument(
        "--text_model_name",
        type=str,
        help="Text model name to use",
    )
    parser.add_argument(
        "--ae",
        type=str,
        metavar="MODEL",
        help="Name of autoencoder",
    )
    parser.add_argument(
        "--ae-latent-dim",
        type=int,
        default=512*8,
        help="AE latent dimension",
    )
    parser.add_argument(
        "--ae_pth",
        required=True,
        help="Autoencoder checkpoint"
    )
    parser.add_argument(
        "--point_cloud_size",
        default=2048,
        type=int,
        help="input size"
    )
    parser.add_argument(
        "--fetch_keys",
        action="store_true",
        default=False,
    )
    parser.add_argument(
        "--use_embeds",
        action="store_true",
        default=False,
    )
    parser.add_argument(
        "--intensity_loss",
        action="store_true",
        default=False,
        help="Contrastive edit intensity loss using ground-truth labels.",
    )

    # Dataset parameters
    parser.add_argument(
        "--dataset",
        type=str,
        choices=["graphedits"],
        help="dataset name",
    )
    parser.add_argument(
        "--data_path",
        type=str,
        help="dataset path",
    )
    parser.add_argument(
        "--data_type",
        type=str,
        help="dataset type",
    )
    parser.add_argument(
        "--max_edge_level",
        default=None,
        type=int,
        help="maximum edge level to use",
    )
    parser.add_argument(
        "--device", default="cuda", help="device to use for training / testing"
    )
    parser.add_argument("--seed", default=0, type=int)
    parser.add_argument("--num_workers", default=60, type=int)
    parser.add_argument(
        "--pin_mem",
        action="store_true",
        help="Pin CPU memory in DataLoader for more efficient "
        "(sometimes) transfer to GPU.",
    )

    return parser


# Set dummy arg string to debug the parser
call_string = """--ae_pth ckpt/ae_m512.pth \
    --ae kl_d512_m512_l8 \
    --ae-latent-dim 4096 \
    --data_path /ibex/project/c2273/PADS/3DCoMPaT \
    --batch_size 32 \
    --num_workers 8 \
    --device cuda"""
    

# Parse the arguments
args = get_args_parser()
args = args.parse_args(call_string.split())

# --------------------
device = torch.device(args.device)

# Fix the seed for reproducibility
misc.set_all_seeds(args.seed)

torch.backends.cudnn.benchmark = True
# --------------------

# Instantiate autoencoder
ae = ae_mods.__dict__[args.ae]()
ae.eval()
print("Loading autoencoder %s" % args.ae_pth)
ae.load_state_dict(torch.load(args.ae_pth, map_location="cpu")["model"])
ae = ae.to(device)

# Compile using torch.compile
ae = torch.compile(ae, mode="max-autotune")

/ibex/user/slimhy/PADS/code
Set seed to 0
Loading autoencoder ckpt/ae_m512.pth


In [2]:
from datasets.latents import ShapeLatentDataset, ComposedPairedShapesLoader

class PairType():
    NO_ROT_PAIR = "rand_no_rot,rand_no_rot"
    PART_DROP = "part_drop,orig"

# Create your datasets
dataset_train = ShapeLatentDataset(args.data_path, split="train", shuffle_parts=True)

N_REPLICAS = 4

all_loaders = []
for rank in range(N_REPLICAS):
    loader = ComposedPairedShapesLoader(
        dataset_train,
        batch_size=64,
        pair_types_list=['rand_no_rot,rand_no_rot', 'part_drop,orig'],
        num_workers=4,
        use_distributed=True,
        num_replicas=N_REPLICAS,
        rank=rank,
        seed=0,
        shuffle=True,
        drop_last=True,
    )
    all_loaders.append(loader)


In [3]:
from tqdm.notebook import tqdm

all_samples = []
for epoch in tqdm(range(0, 10)):
    for ddp_loader in all_loaders:
        ddp_loader.set_epoch(epoch)
        data_seen = False
        data_count = 0
        for pair_types, (l_a, bb_a, bb_l_a, meta_a), (l_b, bb_b, bb_l_b, meta_b) in ddp_loader:
            data_count += 1
            data_seen = True
        assert data_seen, "No data seen!"
        all_samples += [data_count]

  0%|          | 0/10 [00:00<?, ?it/s]



In [5]:
# Create the DataLoader using the sampler
# ddp_loader_0 = ComposedPairedShapesLoader(
#     dataset_train,
#     batch_size=64,
#     pair_types_list=['rand_no_rot,rand_no_rot', 'part_drop,orig'],
#     num_workers=4,
#     use_distributed=True,
#     num_replicas=2,
#     rank=0,
#     seed=0,
#     shuffle=True,
#     drop_last=True,
# ) 
# ddp_loader_1 = ComposedPairedShapesLoader(
#     dataset_train,
#     batch_size=64,
#     pair_types_list=['rand_no_rot,rand_no_rot', 'part_drop,orig'],
#     num_workers=4,
#     use_distributed=True,
#     num_replicas=2,
#     rank=1,
#     seed=0,
#     shuffle=True,
#     drop_last=True,
# ) 
# seq_loader = ComposedPairedShapesLoader(
#     dataset_train,
#     batch_size=64,
#     pair_types_list=['rand_no_rot,rand_no_rot', 'part_drop,orig'],
#     num_workers=4,
#     shuffle=True,
#     use_distributed=False,
#     drop_last=True,
# ) 
# for epoch in tqdm(range(0, 800)):
#     ddp_loader_0.set_epoch(epoch)
#     ddp_loader_1.set_epoch(epoch)
# 
#     # Use the dataloader in your training loop
#     seen_models_a = set()
#     seen_batches_a = 0
#     for pair_types, (l_a, bb_a, bb_l_a, meta_a), (l_b, bb_b, bb_l_b, meta_b) in ddp_loader_0:
#         seen_batches_a += 1
#         seen_models_a |= set(list(meta_a))
#         seen_models_a |= set(list(meta_b))
# 
#     # Use the dataloader in your training loop
#     seen_models_b = set()
#     seen_batches_b = 0
#     for pair_types, (l_a, bb_a, bb_l_a, meta_a), (l_b, bb_b, bb_l_b, meta_b) in ddp_loader_1:
#         seen_batches_b += 1
#         seen_models_b |= set(list(meta_a))
#         seen_models_b |= set(list(meta_b))
# 
#     # Use the dataloader in your training loop
#     seen_models_seq = set()
#     seen_batches_seq = 0
#     for pair_types, (l_a, bb_a, bb_l_a, meta_a), (l_b, bb_b, bb_l_b, meta_b) in seq_loader:
#         seen_batches_seq += 1
#         seen_models_seq |= set(list(meta_a))
#         seen_models_seq |= set(list(meta_b))
# 