In [13]:
%cd /ibex/user/slimhy/PADS/code
%reload_ext autoreload
%set_env CUBLAS_WORKSPACE_CONFIG=:4096:8
"""
Latents dataset.
"""
import json
import os
import random

import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, BatchSampler, DataLoader
from collections import defaultdict


class PairType:
    NO_ROT_PAIR = "rand_no_rot,rand_no_rot"
    PART_DROP = "part_drop,orig"


TEST_V = None
class ShapeLatentDataset(Dataset):
    """
    Shape latent dataset.
    """

    PART_CAP = 24

    def __init__(
        self,
        data_dir,
        exclude_types=None,
        cap_parts=True,
        shuffle_parts=True,
        class_code=None,
        split=None,
        filter_n_ids=None,
        get_part_points=False,
    ):
        exclude_types = set(exclude_types) if exclude_types else set()
        self.shuffle_parts = shuffle_parts
        self.get_part_points = get_part_points

        # Load file list
        file_list = "capped_list.json" if cap_parts else "full_list.json"
        file_list = json.load(open(os.path.join(data_dir, file_list)))
        latents_dir = os.path.join(data_dir, "latents")
        bbs_dir = os.path.join(data_dir, "bounding_boxes")
        points_dir = os.path.join(data_dir, "part_points")

        # Load the split
        if split is not None:
            split = json.load(open(os.path.join(data_dir, "split_" + split + ".json")))
            split = set(split)

        final_list = []
        for f in file_list:
            if split is not None and f[:6] not in split:
                continue

            file_type = "_".join(f.split("_")[2:-1])

            # Filter by class code
            valid_cls = class_code is None or f.startswith(class_code)
            if not valid_cls:
                continue

            # Filter by file type
            if file_type not in exclude_types:
                bb_coords_f = f + "_part_bbs"
                bb_labels_f = f + "_part_labels"
                part_points_f = f[:6]
                if self.get_part_points:
                    final_list += [[k + ".npy" for k in [f, bb_coords_f, bb_labels_f, part_points_f]]]
                else:
                    final_list += [[k + ".npy" for k in [f, bb_coords_f, bb_labels_f]]]

        # Create a list of file paths
        file_list = final_list
        file_list.sort()
        self.file_list = file_list
        self.file_tuples = []
        for idx in range(len(file_list)):
            # Unpack file paths
            file_paths = [os.path.join(latents_dir, file_list[idx][0])]
            file_paths = file_paths + [
                os.path.join(bbs_dir, f) for f in file_list[idx][1:3]
            ]
            if self.get_part_points:
                file_paths.append(os.path.join(points_dir, file_list[idx][3]))
            
            latent_f, bb_coords_f, bb_labels_f = file_paths[:3]
            part_points_f = file_paths[3] if self.get_part_points else None

            # Extract model ID from the filename
            basename = os.path.basename(latent_f)
            model_id = basename.split("_")[0:2][0] + basename.split("_")[0:2][1]
            model_id = int(model_id, 16)
            self.file_tuples += [(latent_f, bb_coords_f, bb_labels_f, part_points_f, model_id)]

        if filter_n_ids is not None:
            # Only keep the samples corresponding to N=filter_n_ids distinct model IDs
            unique_model_ids = list(set(tup[4] for tup in self.file_tuples))
            random.shuffle(unique_model_ids)
            selected_model_ids = set(unique_model_ids[:filter_n_ids])

            self.file_tuples = [
                tuple for tuple in self.file_tuples if tuple[4] in selected_model_ids
            ]
            filtered_file_paths = [tuple[0] for tuple in self.file_tuples]
            self.file_list = [
                sublist
                for sublist in self.file_list
                if os.path.join(latents_dir, sublist[0]) in filtered_file_paths
            ]

        self.rng = torch.Generator()
        self.rng_counter = 0

    def __len__(self):
        return len(self.file_tuples)

    def __getitem__(self, idx):
        global TEST_V

        # Unpack file paths
        latent_f, bb_coords_f, bb_labels_f, part_points_f, model_id = self.file_tuples[idx]

        # Loading latent and bounding box data
        latent = np.load(latent_f)
        bb_coords = np.load(bb_coords_f)
        bb_labels = np.load(bb_labels_f)

        # Convert numpy array to torch tensor
        latent_tensor = torch.from_numpy(latent).float()
        bb_coords_tensor = torch.from_numpy(bb_coords).float()
        bb_labels_tensor = torch.from_numpy(bb_labels).long()

        if self.get_part_points:
            part_points = np.load(part_points_f, allow_pickle=True)
            TEST_V = part_points
            # print(part_points[0])
            # part_points_stacked = np.stack(part_points.values())
            # part_points_tensor = torch.from_numpy(part_points_stacked).float()

        # Shuffle the order of parts if self.shuffle is True
        if self.shuffle_parts:
            self.rng.manual_seed(model_id + self.rng_counter)

            num_parts = bb_coords_tensor.size(0)
            shuffle_indices = torch.randperm(num_parts, generator=self.rng)
            bb_coords_tensor = bb_coords_tensor[shuffle_indices]
            bb_labels_tensor = bb_labels_tensor[shuffle_indices]
            if self.get_part_points:
                part_points_tensor = part_points_tensor[shuffle_indices]

        # Pad bb coords and labels
        pad_size = self.PART_CAP - bb_coords_tensor.size(0)

        # Pad the tensors
        bb_coords_tensor = F.pad(bb_coords_tensor, (0, 0, 0, 0, 0, pad_size))
        bb_labels_tensor = F.pad(bb_labels_tensor, (0, pad_size), value=-1)
        if self.get_part_points:
            part_points_tensor = F.pad(part_points_tensor, (0, 0, 0, pad_size))

        # Extract metadata from filename
        meta = os.path.basename(latent_f).split(".")[0]

        if self.get_part_points:
            return (
                latent_tensor,
                bb_coords_tensor,
                bb_labels_tensor,
                part_points_tensor,
                meta,
            )
        else:
            return (
                latent_tensor,
                bb_coords_tensor,
                bb_labels_tensor,
                meta,
            )

class PairedSampler(BatchSampler):
    """
    Sampling augmented shape pairs.
    """

    def __init__(self, dataset, batch_size, pair_types, shuffle=True, drop_last=False):
        pair_types = [t.strip() for t in pair_types.split(",")]
        if len(pair_types) != 2:
            raise ValueError(
                "pair_types should contain exactly two types separated by a comma"
            )

        self.dataset = dataset
        self.batch_size = batch_size
        self.shuffle = shuffle

        # Group indices by ID and type
        id_to_indices = defaultdict(lambda: defaultdict(list))
        for idx, file_info in enumerate(dataset.file_list):
            if dataset.get_part_points:
                filename, _, _, _ = file_info  # Unpack considering part points file
            else:
                filename, _, _ = file_info  # Unpack without part points file
            
            parts = filename.split("_")
            id_part = "_".join(parts[:2])
            type_part = "_".join(parts[2:-1])
            
            id_to_indices[id_part][type_part].append(idx)
    
        # Filter out IDs that don't have both required types
        valid_ids = [
            id_part
            for id_part, type_dict in id_to_indices.items()
            if all(p_type in type_dict for p_type in pair_types)
        ]

        self.paired_indices = self._create_paired_indices(
            id_to_indices, valid_ids, pair_types
        )

    def _create_paired_indices(self, id_to_indices, valid_ids, pair_types):
        paired_indices = []

        if self.shuffle:
            random.shuffle(valid_ids)

        for id_part in valid_ids:
            type1, type2 = pair_types
            indices1 = id_to_indices[id_part][type1]
            indices2 = id_to_indices[id_part][type2]

            if type1 == type2:
                # If the same type is requested, ensure we have at least 2 files
                if len(indices1) < 2:
                    continue
                pair = random.sample(indices1, 2)
            else:
                # If different types, take one from each
                pair = [random.choice(indices1), random.choice(indices2)]

            paired_indices.extend(pair)

        return paired_indices

    def __iter__(self):
        batch = []
        for idx in self.paired_indices:
            batch.append(idx)
            if len(batch) == self.batch_size:
                yield batch
                batch = []

    def __len__(self):
        return len(self.paired_indices) * 2 // self.batch_size


class DistributedPairedSampler(BatchSampler):
    def __init__(
        self,
        dataset,
        batch_size,
        pair_types,
        num_replicas=None,
        rank=None,
        seed=0,
        shuffle=True,
        drop_last=False,
    ):
        pair_types = [t.strip() for t in pair_types.split(",")]
        if len(pair_types) != 2:
            raise ValueError(
                "pair_types should contain exactly two types separated by a comma"
            )

        self.dataset = dataset
        self.batch_size = batch_size
        self.epoch = 0
        self.num_replicas = num_replicas
        self.rank = rank
        self.seed = seed
        self.shuffle = shuffle
        self.drop_last = drop_last

        # Create paired indices
        self.paired_indices = self._create_paired_indices(dataset.file_list, pair_types)
        self.num_samples = len(self.paired_indices) // self.num_replicas

        # Create RNG
        self.rng = torch.Generator()
        self.rng.manual_seed(self.seed + self.epoch)
        self.indices = None

    def _create_paired_indices(self, file_list, pair_types):
        """
        Initialize a list of paired indices.
        """

        # Group indices by ID and type
        id_to_indices = defaultdict(lambda: defaultdict(list))
        for idx, (filename, _, _) in enumerate(file_list):
            parts = filename.split("_")
            id_part = "_".join(parts[:2])
            file_type = "_".join(parts[2:-1])
            id_to_indices[id_part][file_type].append(idx)

        # Filter out IDs that don't have both required types
        valid_ids = [
            id_part
            for id_part, type_dict in id_to_indices.items()
            if all(type in type_dict for type in pair_types)
        ]

        paired_indices = []

        for id_part in valid_ids:
            type1, type2 = pair_types
            indices1 = id_to_indices[id_part][type1]
            indices2 = id_to_indices[id_part][type2]

            if type1 == type2:
                # If the same type is requested, ensure we have at least 2 files
                if len(indices1) < 2:
                    continue
                pair = random.sample(indices1, 2)
            else:
                # If different types, take one from each
                pair = [random.choice(indices1), random.choice(indices2)]

            paired_indices.extend(pair)

        return paired_indices

    def sample_indices(self):
        """
        Sample indices for the current epoch.
        """
        # Deterministically shuffle based on epoch and seed
        if self.shuffle:
            n = len(self.paired_indices)

            # Generate a permutation for N/2 pairs
            pair_perm = torch.randperm(n // 2, generator=self.rng).tolist()

            # Use the permutation to reindex the paired list
            indices = [j for i in pair_perm for j in (2 * i, 2 * i + 1)]
        else:
            indices = list(range(len(self.paired_indices)))

        # Subsample while preserving pairs
        n_pairs = len(indices) // 2
        pair_indices = list(range(n_pairs))
        subsampled_pair_indices = pair_indices[self.rank : n_pairs : self.num_replicas]

        self.indices = [
            idx
            for pair_idx in subsampled_pair_indices
            for idx in indices[2 * pair_idx : 2 * pair_idx + 2]
        ]

    def __iter__(self):
        if self.indices is None:
            self.sample_indices()

        # Create batches
        batches = []
        batch = []
        for idx in self.indices:
            batch.append(self.paired_indices[idx])
            if len(batch) == self.batch_size:
                batches.append(batch)
                batch = []

        return iter(batches)

    def __len__(self):
        num_samples = len(self.paired_indices)
        return num_samples // self.batch_size

    def set_epoch(self, epoch):
        self.epoch = epoch
        self.rng.manual_seed(self.seed + self.epoch)
        self.sample_indices()


class PairedShapesLoader:
    """
    Paired shapes loader.
    """

    def __init__(
        self,
        dataset,
        batch_size,
        pair_types,
        num_workers,
        shuffle,
        use_distributed=False,
        num_replicas=None,
        rank=None,
        get_part_points=False,
        **kwargs,
    ):
        # Filter out keys from kwargs that are not DataLoader arguments
        valid_keys = set(DataLoader.__init__.__code__.co_varnames)
        kwargs = {k: v for k, v in kwargs.items() if k in valid_keys}
        self.kwargs = kwargs
        self.dataset = dataset
        self.batch_size = batch_size
        self.pair_types = pair_types
        self.num_workers = num_workers
        self.shuffle = shuffle
        self.use_distributed = use_distributed
        self.num_replicas = num_replicas
        self.rank = rank
        self.get_part_points = get_part_points
        self.create_dataloader()

    def create_dataloader(self):
        if self.use_distributed:
            batch_sampler = DistributedPairedSampler(
                self.dataset,
                self.batch_size,
                self.pair_types,
                shuffle=self.shuffle,
                num_replicas=self.num_replicas,
                rank=self.rank,
            )
        else:
            batch_sampler = PairedSampler(
                self.dataset,
                pair_types=self.pair_types,
                batch_size=self.batch_size,
                shuffle=self.shuffle,
            )
        self.sampler = batch_sampler
        self.dataloader = DataLoader(
            self.dataset,
            batch_sampler=batch_sampler,
            num_workers=self.num_workers,
        )
        self.iterator = iter(self.dataloader)

    def set_epoch(self, epoch):
        if self.use_distributed:
            self.sampler.set_epoch(epoch)
            self.iterator = iter(self.dataloader)
        else:
            raise ValueError("set_epoch is only supported in distributed mode")

    def split_tensor(self, tensor):
        tensor_A = tensor[::2]
        tensor_B = tensor[1::2]
        return tensor_A, tensor_B

    def __iter__(self):
        return self

    def __next__(self):
        try:
            tuple_data = next(self.iterator)
            tuple_A, tuple_B = zip(*(self.split_tensor(t) for t in tuple_data))
            return tuple_A, tuple_B

        except StopIteration:
            self.iterator = iter(self.dataloader)
            raise StopIteration

    def __len__(self):
        return len(self.dataloader)


class ComposedPairedShapesLoader:
    """
    Composed loader that alternates between batches of multiple shape pair types.
    """

    def __init__(
        self,
        dataset,
        batch_size,
        pair_types_list,
        num_workers,
        shuffle=False,
        use_distributed=False,
        num_replicas=None,
        rank=None,
        reset_every=100,
        get_part_points=False,
        **kwargs,
    ):
        self.dataset = dataset
        self.batch_size = batch_size
        self.pair_types_list = pair_types_list
        self.num_workers = num_workers
        self.shuffle = shuffle
        self.use_distributed = use_distributed
        self.num_replicas = num_replicas
        self.rank = rank
        self.kwargs = kwargs
        self.reset_every = reset_every
        self.get_part_points = get_part_points
        self.loaders = None

    def create_loaders(self):
        self.loaders = [
            (
                pair_types,
                PairedShapesLoader(
                    self.dataset,
                    self.batch_size,
                    pair_types,
                    self.num_workers,
                    shuffle=self.shuffle,
                    use_distributed=self.use_distributed,
                    num_replicas=self.num_replicas,
                    rank=self.rank,
                    get_part_points=self.get_part_points,
                    **self.kwargs,
                ),
            )
            for pair_types in self.pair_types_list
        ]
        self.num_loaders = len(self.loaders)

    def set_epoch(self, epoch, force_reset=False):
        if epoch % self.reset_every == 0 or force_reset:
            self.create_loaders()
        for _, loader in self.loaders:
            loader.set_epoch(epoch)

    def get_tuple(self, device=None, return_single=True):
        """
        Get a data tuple for debugging.
        """
        for pair_types, tuple_a, tuple_b in self:
            if device is not None:
                tuple_a = tuple(t.to(device) if isinstance(t, torch.Tensor) else t for t in tuple_a)
                tuple_b = tuple(t.to(device) if isinstance(t, torch.Tensor) else t for t in tuple_b)
            
            if return_single:
                return tuple_a
            else:
                return tuple_a, tuple_b

        raise ValueError("No data available")

    def __iter__(self):
        if self.loaders is None:
            self.create_loaders()
        while True:
            for pair_types, loader in [
                (pair_types, loader) for pair_types, loader in self.loaders
            ]:
                try:
                    yield pair_types, *next(loader)
                except StopIteration:
                    self.dataset.rng_counter += 1
                    return

    def __len__(self):
        if self.loaders is None:
            self.create_loaders()
        return max(len(loader) for _, loader in self.loaders)

/ibex/user/slimhy/PADS/code
env: CUBLAS_WORKSPACE_CONFIG=:4096:8


In [14]:
"""
Test demo for MLP mapper.
"""

import argparse
import numpy as np

import torch
import torch.backends.cudnn as cudnn

import models.s2vs as autoencoders
import util.misc as misc


def get_args_parser():
    parser = argparse.ArgumentParser("Latent Diffusion", add_help=False)

    # Model parameters
    parser.add_argument(
        "--model",
        default="kl_d512_m512_l8_edm",
        type=str,
        metavar="MODEL",
        help="Name of model to train",
    )
    parser.add_argument(
        "--batch_size",
        default=32,
        type=int,
        help="Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus",
    )
    parser.add_argument(
        "--text_model_name",
        type=str,
        help="Text model name to use",
    )
    parser.add_argument(
        "--ae",
        type=str,
        metavar="MODEL",
        help="Name of autoencoder",
    )
    parser.add_argument(
        "--ae-latent-dim",
        type=int,
        default=512*8,
        help="AE latent dimension",
    )
    parser.add_argument(
        "--ae_pth",
        required=True,
        help="Autoencoder checkpoint"
    )
    parser.add_argument(
        "--point_cloud_size",
        default=2048,
        type=int,
        help="input size"
    )
    parser.add_argument(
        "--fetch_keys",
        action="store_true",
        default=False,
    )
    parser.add_argument(
        "--use_embeds",
        action="store_true",
        default=False,
    )
    parser.add_argument(
        "--resume",
        default="",
        help="Resume from checkpoint"
    )
    parser.add_argument(
        "--resume_weights",
        action="store_true",
        default=False,
        help="Only resume weights, not optimizer state",
    )
    parser.add_argument(
        "--resume_full_weights",
        action="store_true",
        default=False,
        help="Resume the full model weights with the EDM wrapper",
    )

    # Dataset parameters
    parser.add_argument(
        "--device", default="cuda", help="device to use for training / testing"
    )
    parser.add_argument("--seed", default=0, type=int)
    parser.add_argument("--num_workers", default=60, type=int)
    parser.add_argument(
        "--pin_mem",
        action="store_true",
        help="Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.",
    )

    return parser


# Set dummy arg string to debug the parser
call_string = """--ae_pth ckpt/ae_m512.pth \
    --ae kl_d512_m512_l8 \
    --ae-latent-dim 4096 \
    --num_workers 8 \
    --batch_size 2 \
    --device cuda \
    --fetch_keys \
    --use_embeds \
    --seed 0"""

# Parse the arguments
args = get_args_parser()
args = args.parse_args(call_string.split())
args.data_path = "/ibex/project/c2273/PADS/3DCoMPaT/"

# --------------------
device = torch.device(args.device)

# Fix the seed for reproducibility
seed = args.seed + misc.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)

cudnn.benchmark = True
args.fetch_keys = True
# --------------------

# Instantiate autoencoder
ae = autoencoders.__dict__[args.ae]()
ae.eval()
print("Loading autoencoder %s" % args.ae_pth)
ae.load_state_dict(torch.load(args.ae_pth, map_location="cpu")["model"])
ae = ae.to(device)

Loading autoencoder ckpt/ae_m512.pth


In [15]:
from util.misc import visualize_bounding_boxes
from datasets.metadata import class_to_hex
import util.s2vs as s2vs


def decode_latents(model, latents, grid_density=256, batch_size=128**3):
    # Decode the latents
    with torch.no_grad():
        mesh = s2vs.decode_latents(
            ae=model,
            latent=latents[0].unsqueeze(0),
            grid_density=grid_density,
            batch_size=batch_size,
        )
    return mesh


# Create your datasets
dataset_train = ShapeLatentDataset(
    args.data_path,
    class_code=class_to_hex("chair"),
    split="train",
    get_part_points=True,
    shuffle_parts=True,
    filter_n_ids=8
)

loader = ComposedPairedShapesLoader(
    dataset_train,
    batch_size=args.batch_size,
    pair_types_list=[PairType.NO_ROT_PAIR],
    num_workers=4,
    seed=0,
    shuffle=True,
    drop_last=True,
)

(l_a, bb_a, bb_l_a, part_pts_a, meta_a), (l_b, bb_b, bb_l_b, part_pts_b, meta_b) = loader.get_tuple(device=device, return_single=False)



UnboundLocalError: Caught UnboundLocalError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/slimhy/conda/envs/3D2VS_flexicubes/lib/python3.10/site-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/slimhy/conda/envs/3D2VS_flexicubes/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 51, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/slimhy/conda/envs/3D2VS_flexicubes/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 51, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/tmp/ipykernel_2142383/4266975062.py", line 157, in __getitem__
    part_points_tensor = part_points_tensor[shuffle_indices]
UnboundLocalError: local variable 'part_points_tensor' referenced before assignment


In [16]:
TEST_V

In [4]:
def show_boxes(bb_coords, latents):
    bounding_boxes = np.array(bb_coords.cpu()).squeeze()
    mesh = decode_latents(ae, latents.float()).trimesh_mesh
    
    return visualize_bounding_boxes(mesh, bounding_boxes, box_type='lines', colormap='hsv').show()

In [None]:
show_boxes(bb_a, l_a)

In [None]:
show_boxes(bb_b, l_b)