In [76]:
import os
import numpy as np
import matplotlib.pyplot as plt

import itertools

import torch
from torch.utils.data import DataLoader

import kornia.geometry.conversions as conversions


from torchinfo import summary
from models import XSwinFusion, XNetSwinTransformer
from pose_estimation import PoseDataNPZTorch, PoseDataNPZ, PoseData, icp, PoseDataNPZSegmentationTorch
from pose_estimation.utils import crop_and_resize_multiple, enumerate_symmetries, COLOR_PALETTE

WORKDIR = f"{os.getcwd()}/.."
DATA_FOLDER = os.path.join(WORKDIR, "data_folder")
DATASET_NPZ_PATH = os.path.join(DATA_FOLDER, "dataset_npz")


In [85]:
segmentation_model_kwargs = {
    "patch_size" : [4, 4],
    "embed_dim" : 192,
    "depths" : [3, 3, 3],
    "num_heads" : [8, 16, 32],
    "window_size" : [8, 8],
    "mlp_ratio" : 2.0,
    "num_classes" : 82,
    "global_stages" : 0,
    "input_size" : (216, 384), 
    "final_downsample" : False,
    "residual_cross_attention" : True,
    "smooth_conv" : True,
}
xswin = XNetSwinTransformer(**segmentation_model_kwargs)
summary(xswin, input_size=[1, 3, 216, 384])


Layer (type:depth-idx)                                  Output Shape              Param #
XNetSwinTransformer                                     [1, 82, 216, 384]         --
├─ConvolutionTriplet: 1-1                               [1, 192, 216, 384]        --
│    └─Sequential: 2-1                                  [1, 192, 216, 384]        --
│    │    └─Conv2d: 3-1                                 [1, 192, 216, 384]        5,376
│    │    └─BatchNorm2d: 3-2                            [1, 192, 216, 384]        384
│    │    └─LeakyReLU: 3-3                              [1, 192, 216, 384]        --
│    │    └─Conv2d: 3-4                                 [1, 192, 216, 384]        331,968
│    │    └─BatchNorm2d: 3-5                            [1, 192, 216, 384]        384
│    │    └─LeakyReLU: 3-6                              [1, 192, 216, 384]        --
│    │    └─Conv2d: 3-7                                 [1, 192, 216, 384]        331,968
│    │    └─BatchNorm2d: 3-8                 

In [74]:
data = PoseDataNPZSegmentationTorch(DATASET_NPZ_PATH, resize=(144, 256))
dataloader = DataLoader(data, batch_size=64, shuffle=True)
# for i, (im, l, k) in enumerate(dataloader):
#     print(torch.min(l))

volumes = np.zeros(len(data.data.info) + 3)
for i, info in enumerate(data.data.info):
    volumes[i] = info["width"] * info["height"] * info["length"]
normalized_volumes = volumes / volumes.mean()
normalized_volumes += 0.2
normalized_volumes **= -1
normalized_volumes[-3:] = 0.1
normalized_volumes


Presumed Preloaded NPZ Dataset: /Users/armanommid/Code/CSE/CSE275/HW2/XSwinDiffusion/../data_folder/dataset_npz


array([2.27220729, 3.85512333, 0.06790825, 2.42104538, 2.07067204,
       3.93925143, 0.64182094, 1.18777512, 0.4407341 , 0.52512214,
       0.34137755, 1.86846364, 3.81615646, 3.64147644, 0.31318209,
       2.69653439, 1.63886531, 3.56737491, 3.64305267, 1.47233098,
       2.76495753, 3.64190419, 0.60392237, 1.31382749, 3.93023501,
       1.97133623, 2.24107948, 3.32767968, 1.14904461, 0.33874604,
       2.3076569 , 1.05357246, 0.52383978, 0.93881299, 0.87111473,
       2.13725172, 3.43294173, 0.92985386, 3.95793929, 0.51232972,
       2.475748  , 0.79718815, 0.68471138, 0.85126059, 0.15536348,
       1.9916658 , 0.13429935, 0.42351451, 1.27801061, 0.39196135,
       1.04016916, 1.51138543, 2.26624893, 1.88724767, 2.24454911,
       0.92508584, 1.31921162, 2.03160651, 0.45497933, 0.84568012,
       1.15797339, 0.93588777, 1.27437227, 1.66202315, 1.79084296,
       1.43061129, 3.72109191, 4.05522843, 4.00312151, 1.03946704,
       1.45463902, 0.95911586, 0.71062844, 0.96072785, 0.61686

In [2]:
SHOWCASE = False
if SHOWCASE:

    data = PoseDataNPZ(DATASET_NPZ_PATH)
    scene = data[1, 1, 4]
    rgb = scene["color"]
    depth = scene["depth"]
    label = scene["label"]
    meta = scene["meta"]

    mask = label == np.unique(label)[0]
    target_size = (432, 768)
    margin = 8
    aspect_ratio = True
    mask_fill = False

    (rgb_cr, depth_cr, label_cr, mask_cr), scale, translate = crop_and_resize_multiple(
        (rgb, depth, COLOR_PALETTE[label], mask), mask, target_size=target_size, margin=margin, 
        aspect_ratio=aspect_ratio, mask_fill=mask_fill)


    # print(depth[mask][:200])
    # print(depth_cr[mask_cr][:200])

    plt.figure(figsize=(15, 10))
    plt.subplot(1, 3, 1)
    plt.imshow(rgb_cr)
    plt.subplot(1, 3, 2)
    plt.imshow(depth_cr)
    plt.subplot(1, 3, 3)
    plt.imshow(label_cr)  # draw colorful segmentation


In [3]:
batch_size = 10
samples = 1_000
resize = (144, 256)
aspect_ratio = True
margin = 12
feature_dims = 64
quaternion = True
pretrained_resnet = False


model = XSwinFusion(feature_dims=feature_dims, resize=resize, 
                    quaternion=quaternion, pretrained=pretrained_resnet)

dataset = None
dataset = PoseDataNPZTorch(DATASET_NPZ_PATH, samples=samples, 
                           resize=resize, aspect_ratio=aspect_ratio, 
                           margin=margin, symmetry_pad=64)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# summary(model, depth=1)




Presumed Preloaded NPZ Dataset: /Users/armanommid/Code/CSE/CSE275/HW2/XSwinDiffusion/../data_folder/dataset_npz


In [4]:


for i, (s, t, c, mi, p, sym, key, obj_id) in enumerate(dataloader):
    print(key[0])

    if i == 0:
        break


tensor([ 2, 97, 19])
tensor([ 1, 25,  2])
tensor([  3, 158,   8])
tensor([ 2, 92,  1])
tensor([ 1, 49, 19])
tensor([ 1,  9, 15])
tensor([ 2, 82, 39])
tensor([ 2, 47, 13])
tensor([ 1, 44,  3])


KeyboardInterrupt: 

In [None]:
objects = dataloader.dataset.data.info
for obj in objects:
    pass

sym_pad = torch.eye(3).unsqueeze(0).repeat(64, 1, 1)
s = enumerate_symmetries(objects[5]["geometric_symmetry"])
s = torch.cat(s)
print(len(s))
sym_pad[:len(s), :, :] = s
sym_pad
