In [64]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [65]:
from dataclasses import dataclass, field

from typing import Dict, List, Tuple, Optional, Callable, Any

import math
import numpy as np
from PIL import Image

from pysolotools.consumers import Solo
from pysolotools.converters.solo2coco import SOLO2COCOConverter
from pysolotools.core.models import KeypointAnnotationDefinition, RGBCameraCapture
from pysolotools.core.models import BoundingBox2DLabel, BoundingBox2DAnnotation
from pysolotools.core.models import BoundingBox3DLabel, BoundingBox3DAnnotation
from pysolotools.core.models import Frame, Capture
from scipy.spatial.transform import Rotation as R

from torchvision.datasets import ImageFolder
from torchvision.models import swin_v2_t, Swin_V2_T_Weights
from torchvision.models import swin_v2_b, Swin_V2_B_Weights
# from torch.utils.data import ConcatDataset, DataLoader
from collections import OrderedDict
from torch.utils.data import DataLoader

import torch
from torch import nn, Tensor

from torch.nn import functional as F

from torchvision.ops import FeaturePyramidNetwork, MLP, sigmoid_focal_loss
import lightning.pytorch as pl
from lightning.pytorch.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import LearningRateMonitor

import matplotlib.pyplot as plt
from scipy.optimize import linear_sum_assignment
from scipy.io import savemat

import reader
import utils
import network
import transformer
import cvmpca
from torchvision.ops import FeaturePyramidNetwork
from my_trainer import SetCriterion
torch.set_float32_matmul_precision('medium')
# %matplotlib ipympl

In [66]:
train_folder = 'D:/Unity/dataset/solo'
# training_dir = './data/train'
# testing_dir = './data/test'

In [67]:
train_loader = reader.UnityDataset.from_unity_to_loader(root=train_folder, batch_size=4)

for batch in train_loader:
    image_dicts, object_list = batch

    # for key, image in image_dicts.items():
    #     print(key, image.shape)
    # print('-'*30)
    # for targets in object_list:
    print(object_list.position.shape)

    break

torch.Size([4, 10, 3])


In [68]:
captures = train_loader.dataset.captures
cameras = {capture.id: utils.Camera.from_unity(capture) for capture in captures}
for key, item in image_dicts.items():
    print(key, item.shape)

camera_0 torch.Size([4, 3, 256, 256])
camera_2 torch.Size([4, 3, 256, 256])
camera torch.Size([4, 3, 256, 256])
camera_1 torch.Size([4, 3, 256, 256])


In [69]:
swin = transformer.Swin(is_trainable=True)
fpn = FeaturePyramidNetwork(swin.embed_dims, swin.embed_dim)
embed_dim = swin.embed_dim
embed_dims = swin.embed_dims

In [70]:
keys = list(image_dicts.keys())  # camera ids
B = image_dicts[keys[0]].size(0)

visual_features = {
    camera_key: fpn({
        f'feat{i}': x.permute(0, 3, 1, 2)
        for i, x in enumerate(swin(images))
    })
    for camera_key, images in image_dicts.items()
}

In [71]:
for camera_key, features in visual_features.items():
    print(camera_key)
    for key, feature in features.items():
        print(key, feature.shape)
        # break

camera_0
feat0 torch.Size([4, 96, 64, 64])
feat1 torch.Size([4, 96, 32, 32])
feat2 torch.Size([4, 96, 16, 16])
feat3 torch.Size([4, 96, 8, 8])
camera_2
feat0 torch.Size([4, 96, 64, 64])
feat1 torch.Size([4, 96, 32, 32])
feat2 torch.Size([4, 96, 16, 16])
feat3 torch.Size([4, 96, 8, 8])
camera
feat0 torch.Size([4, 96, 64, 64])
feat1 torch.Size([4, 96, 32, 32])
feat2 torch.Size([4, 96, 16, 16])
feat3 torch.Size([4, 96, 8, 8])
camera_1
feat0 torch.Size([4, 96, 64, 64])
feat1 torch.Size([4, 96, 32, 32])
feat2 torch.Size([4, 96, 16, 16])
feat3 torch.Size([4, 96, 8, 8])


In [72]:

def _voxelize(space, voxel_size):
    X = torch.arange(space[0][0], space[0][1] + voxel_size[0]/8, voxel_size[0])
    Y = torch.arange(space[1][0], space[1][1] + voxel_size[1]/8, voxel_size[1])
    Z = torch.arange(space[2][0], space[2][1] + voxel_size[2]/8, voxel_size[2])
    # print(X.shape, Y.shape, Z.shape)

    grid_x, grid_y, grid_z = torch.meshgrid(X, Y, Z, indexing='ij')
    return torch.stack([grid_x.flatten(), grid_y.flatten(), grid_z.flatten()], dim=1)

In [73]:
n_class, k = 2, 6
layers = nn.ModuleList([
    cvmpca.VoxelMHA(
        embed_dim=embed_dim, num_heads=12, attention_dropout=0.1, dropout=0.1,
        cameras=cameras
    ) for _ in range(len(swin.embed_dims))
])
cls_embedding = nn.Embedding(n_class, embed_dim)

space = torch.tensor([[-11, 11], [0, 3], [-7, 7]])
voxel_size = [2.7, 2.7, 2.7]

def _original(indices, n_bins):
    X, Y, Z = n_bins

    x, y, z = (indices//Z)//Y, (indices//Z)%Y, indices%Z
    print(x, y, z)
    # return x, y, z
    return torch.stack([x, y, z], dim=-1)

query = cls_embedding.weight.unsqueeze(0).unsqueeze(-2).expand(B, -1, -1, -1)
spaces = space.unsqueeze(0).unsqueeze(0).unsqueeze(0).expand(B, n_class, -1, -1, -1)
for l_idx, layer in enumerate(layers):
    print('spaces', spaces.shape, query.shape)
    B, C, K, *_ = spaces.shape
    next_voxel_size = np.array(voxel_size)/(3**l_idx)
    
    voxels = torch.stack([
        torch.stack([
            torch.stack([_voxelize(s, next_voxel_size) for s in sp])
            for sp in space
        ], dim=0)
        for space in spaces
    ], dim=0)
    print("voxels", voxels.shape)

    B, C, K, N, _ = voxels.shape
    
    x, (space_idx, voxel_idx) = layer(
        voxels + next_voxel_size/2, k,
        query,
        {key: list(features.values())[~l_idx] for key, features in visual_features.items()}
    )  # voxels + next_voxel_size/2: center of voxels
    # space_ids, voxel_ids = flatten_voxel_idx // n_bins, flatten_voxel_idx % n_bins

    top_voxels = voxels.gather(
        2, space_idx.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, -1, N, 3)
    ).gather(
        3, voxel_idx.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, -1, -1, 3)
    ).squeeze(-2)
    

    spaces = torch.stack([
        torch.stack([top_voxels[i], top_voxels[i] + next_voxel_size], dim=-1)
        for i in range(top_voxels.shape[0])
    ])
    # spaces = torch.stack([top_voxels.unsqueeze(-1).unsqueeze(-1).expand(-1, C, K), top_voxels + next_voxel_size], dim=-1)
    # print(top_voxels.shape, spaces.shape)
    
    if l_idx == 2: break
    print('-'*30)
    

spaces torch.Size([4, 2, 1, 3, 2]) torch.Size([4, 2, 1, 96])
voxels torch.Size([4, 2, 1, 108, 3])
------------------------------
spaces torch.Size([4, 2, 6, 3, 2]) torch.Size([4, 2, 1, 96])
voxels torch.Size([4, 2, 6, 64, 3])
------------------------------
spaces torch.Size([4, 2, 6, 3, 2]) torch.Size([4, 2, 1, 96])
voxels torch.Size([4, 2, 6, 64, 3])


In [30]:
net = cvmpca.CVMPCA(
    train_loader.dataset.captures,
    n_classes=len(train_loader.dataset.category_lookup),  # ignore background
    spaces=[[-11, 11], [0, 3], [-7, 7]],
    voxel_size=[2.7, 2.7, 2.7],
    ratio=3,
)
# outputs = net(50, image_dicts)
print()




In [31]:
image_dicts, object_list = batch
output_list, voxel_list = net(50, image_dicts)
# net.cuda()
# output_list, voxel_list = net(50, {key: item.cuda() for key, item in image_dicts.items()})
# B, L, C, P, _ = cls_list.shape

In [32]:
net.next_voxel_size.shape

torch.Size([4, 3])

In [50]:
(pred_ids, label_ids) = net._match((output_list, voxel_list), object_list)

torch.Size([4, 10, 50, 3]) torch.Size([4, 1, 1, 3])
torch.Size([4, 10, 50, 3]) torch.Size([4, 1, 1, 3])
torch.Size([4, 10, 50, 3]) torch.Size([4, 1, 1, 3])
torch.Size([4, 10, 50, 3]) torch.Size([4, 1, 1, 3])


In [34]:
pred_cls = output_list[..., :net.n_classes].log_softmax(-1)
pred_pos = output_list[..., net.n_classes:net.n_classes+3].sigmoid()
pred_link = output_list[..., -net.n_cameras:]
pred_voxel = voxel_list

_device = pred_cls.device
B, L, C, P, _ = pred_cls.shape
G = object_list.position.shape[1]
b_ids, l_ids, c_ids, p_idx, g_idx = torch.arange(B, device=_device), torch.arange(L, device=_device), torch.arange(C, device=_device), torch.arange(P, device=_device), torch.arange(G, device=_device)

pred_cls.shape, pred_pos.shape, pred_link.shape, pred_voxel.shape

(torch.Size([4, 4, 3, 50, 3]),
 torch.Size([4, 4, 3, 50, 3]),
 torch.Size([4, 4, 3, 50, 4]),
 torch.Size([4, 4, 3, 50, 3]))

In [35]:
b_ids[..., None, None, None].shape, l_ids[None, ..., None, None].shape, c_ids[None, None, ..., None].shape

(torch.Size([4, 1, 1, 1]), torch.Size([1, 4, 1, 1]), torch.Size([1, 1, 3, 1]))

In [36]:
pred_cls.shape, c_ids[None, None, ..., None, None].shape, object_list.category.shape, pred_ids.shape

(torch.Size([4, 4, 3, 50, 3]),
 torch.Size([1, 1, 3, 1, 1]),
 torch.Size([4, 10, 3]),
 torch.Size([4, 4, 3, 10]))

In [37]:

class_label = torch.zeros_like(pred_cls)
class_label[
    b_ids[..., None, None, None],
    l_ids[None, ..., None, None],
    c_ids[None, None, ..., None],
] = 1
class_label[..., 0] = 1 # every thing is a background
# only a few have non-background labels
class_label[
    b_ids[..., None, None, None],
    l_ids[None, ..., None, None],
    c_ids[None, None, ..., None],
    label_ids
] = object_list.category.unsqueeze(1).unsqueeze(1)

cross_entropy = F.cross_entropy(pred_cls.flatten(0, -2), class_label.flatten(0, -2).argmax(-1))
cross_entropy

tensor(1.1326, grad_fn=<NllLossBackward0>)

In [38]:
pred_pos.shape, voxel_list.shape, net.next_voxel_size.shape, object_list.position.shape

(torch.Size([4, 4, 3, 50, 3]),
 torch.Size([4, 4, 3, 50, 3]),
 torch.Size([4, 3]),
 torch.Size([4, 10, 3]))

In [56]:
pred_pos[
    b_ids[..., None, None, None],
    l_ids[None, ..., None, None],
    c_ids[None, None, ..., None],
    label_ids
].shape, net.next_voxel_size.unsqueeze(1).unsqueeze(1).unsqueeze(0).shape

# mse_loss = F.mse_loss(
#     pred_pos[
#         b_ids[..., None, None, None],
#         l_ids[None, ..., None, None],
#         c_ids[None, None, ..., None],
#         label_ids
#     ] * net.next_voxel_size.unsqueeze(0).unsqueeze(1).unsqueeze(1) + voxel_list[
#         b_ids[..., None, None, None],
#         l_ids[None, ..., None, None],
#         c_ids[None, None, ..., None],
#         label_ids
#     ],
#     object_list.position.unsqueeze(1).unsqueeze(1)
# )
# mse_loss

(torch.Size([4, 4, 3, 10, 3]), torch.Size([1, 4, 1, 1, 3]))

In [17]:
pred_link.shape, object_list.los.shape

(torch.Size([4, 4, 3, 50, 4]), torch.Size([4, 10, 4]))

In [19]:
los_link_loss = F.mse_loss(
    pred_link[
        b_ids[..., None, None, None],
        l_ids[None, ..., None, None],
        c_ids[None, None, ..., None],
        label_ids
    ],
    object_list.los.unsqueeze(1).unsqueeze(1)
)
los_link_loss

  los_link_loss = F.mse_loss(


tensor(0.8090, grad_fn=<MseLossBackward0>)

In [63]:
trainer = pl.Trainer(
    max_epochs=100, precision='16-mixed',
    gradient_clip_val=35,
    log_every_n_steps=20,
    # accelerator="cpu"
    # profiler="simple",
)
net = cvmpca.CVMPCA(
    train_loader.dataset.captures,
    n_classes=len(train_loader.dataset.category_lookup), 
    spaces=torch.tensor([[-11, 11], [0, 3], [-7, 7]]),
    voxel_size=[2.7, 2.7, 2.7],
    ratio=3,
)
# loss_func = SetCriterion(
#     num_ue=50, num_sbs=2, num_class=1,
#     num_layers=3, pc_range=[[-5, 5], [0, 3], [-5, 5]]
# )

trainer.fit(net, train_dataloaders=train_loader)


Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  self.register_buffer('spaces', torch.tensor(spaces))
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type                  | Params
--------------------------------------------------
0 | swin    | Swin                  | 27.6 M
1 | fpn     | FeaturePyramidNetwork | 470 K 
2 | decoder | ModuleList            | 522 K 
3 | heads   | ModuleList            | 38.7 K
4 | query   | Embedding             | 288   
--------------------------------------------------
28.6 M    Trainable params
0         Non-trainable params
28.6 M    Total params
114.454   Total estimated model params size (MB)


Training: |          | 0/? [00:00<?, ?it/s]

In [None]:
B, L, C, P, G, D = 2, 4, 3, 5, 3, 10