In [1]:
import os
import random
import re
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import MinkowskiEngine as ME
#from torchvision import transforms
#from PIL import Image
import numpy as np
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
import pandas as pd
from sklearn.metrics import f1_score
import shutil
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
from MinkowskiEngine.utils import sparse_quantize
import traceback
from pytorch3d.ops import knn_points
from torch.optim.lr_scheduler import ReduceLROnPlateau

  from .autonotebook import tqdm as notebook_tqdm


# Dataset

In [2]:
scannet_dir = './scannet_train_detection_data'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
scannet_files = os.listdir(scannet_dir)
print(len(scannet_files))

6052


In [3]:
# Set random seed for reproducibility
def create_split():
    random.seed(42)
    
    # Regex to match ScanNet-style scene IDs (e.g., scene0000_00)
    scene_id_pattern = re.compile(r'(scene\d{4}_\d{2})')
    
    # Gather all scene IDs
    scene_ids = set()
    for fname in scannet_files:
        if fname.endswith(".npy"):
            match = scene_id_pattern.match(fname)
            if match:
                scene_ids.add(match.group(1))
    
    # Convert to list and shuffle
    scene_ids = list(scene_ids)
    random.shuffle(scene_ids)
    
    # Split (80% train, 10% val, 10% test)
    n = len(scene_ids)
    train_end = int(0.8 * n)
    val_end = int(0.9 * n)
    
    train_ids = scene_ids[:train_end]
    val_ids = scene_ids[train_end:val_end]
    test_ids = scene_ids[val_end:]
    
    # Save each split to its own text file
    with open("train.txt", "w") as f:
        f.write("\n".join(train_ids))
    
    with open("val.txt", "w") as f:
        f.write("\n".join(val_ids))
    
    with open("test.txt", "w") as f:
        f.write("\n".join(test_ids))
    
    print(f"Train: {len(train_ids)}, Val: {len(val_ids)}, Test: {len(test_ids)}")

In [None]:
class ScanNetDetectionDataset(Dataset):
    def __init__(self, data_dir, split_list, voxel_size=0.2, transform=None):
        """
        Args:
            data_dir (str): Path to the directory containing the preprocessed .npy files.
            split_list (list): List of scene IDs to load (e.g., ['scene0000_00', 'scene0001_00']).
            transform (callable, optional): Optional transform to be applied on a sample.
            
        Returns:
            sparse_vert (SparseTensor): SparseTensor that has xyz and rgb features of given scene, used for TR3D
            coords_dict (Tensor): Only xyz coordinates of given scene, used for VoteNet
            bbox (Tensor): Ground truth bounding boxes for the given scene
        """
        self.data_dir = data_dir
        self.split_list = split_list
        self.transform = transform
        self.voxel_size = voxel_size
        self.class_map = {
        3 : 0,
        4 : 1,
        5 : 2,
        6 : 3,
        7 : 4,
        8 : 5,
        9 : 6,
        10 : 7,
        11 : 8,
        12 : 9,
        14 : 10,
        16 : 11,
        24 : 12,
        28 : 13,
        33 : 14,
        34 : 15,
        36 : 16,
        39 : 17
    }

    def __len__(self):
        return len(self.split_list)

    def __getitem__(self, idx):
        scene_id = self.split_list[idx]

        # Load files
        vert = np.load(os.path.join(self.data_dir, f"{scene_id}_vert.npy"))            # (N, 6)
            
        vert = normalize_points_color(vert)
        vert = sample_points(vert, 0.33)
        coords = vert[:,0:3]                                                            # (N, 3)
        features = vert[:,3:]

        coords = torch.from_numpy(coords)
        
        bbox = np.load(os.path.join(self.data_dir, f"{scene_id}_bbox.npy"))  # (K, 7)
        if self.transform:
            coords, bbox = self.transform(coords, bbox)

        
        if bbox.shape[0] > 0:
            class_ids = bbox[:, 6]
            
            map_func = np.vectorize(lambda x: self.class_map.get(x, x))  # fallback to x if not in map
            new_class_ids = map_func(class_ids)
            bbox[:, 6] = new_class_ids

        bbox = torch.from_numpy(bbox).float()

        # Convert to tensors
        coords_with_dummy = torch.hstack([coords, torch.zeros((coords.shape[0], 1))])
        batched_coords = torch.unsqueeze(coords_with_dummy, axis=0)
        votenet_coords = batched_coords.float()        
        # TR3D Preprocessing
        feats = torch.from_numpy(features)

        coords = torch.floor(coords / self.voxel_size).int()

        return coords, feats, votenet_coords, bbox
    
def get_data(file_name, voxel_size, transform=None):
    file = open(file_name, 'r')
    ids = [l[0:-1] for l in file if os.path.exists(f"/nfs/home/adatay20/votenet/scannet/scannet_train_detection_data/{l[0:-1]}_vert.npy")]
    data_dir = "/nfs/home/adatay20/votenet/scannet/scannet_train_detection_data"
    data = ScanNetDetectionDataset(scannet_dir, ids, voxel_size, transform)
    return data

In [None]:
def collate_fn(batch):
    coords_list = []
    feats_list = []
    xyz_list = []
    bboxes_list = []

    for i, (coords, feats, xyz, bboxes) in enumerate(batch):
        coords_list.append(coords)
        feats_list.append(feats)
        
        N = xyz.shape[1]
        if N < 50000:
            repeat_idx = np.random.choice(N, 50000 - N)
            repeated_points = xyz[:, repeat_idx, :]
            padded_points = torch.cat([xyz, repeated_points], axis=1)
            xyz_list.append(padded_points)
        else:
            xyz_list.append(xyz)     # just append the dicts — they remain per-sample
        bboxes_list.append(bboxes)    # variable-size tensors, stay in list
    batched_xyz = torch.cat(xyz_list, dim=0)
    xyz_dict = {'point_clouds': batched_xyz}

    return feats_list, coords_list, xyz_dict, bboxes_list

# Helper Functions

In [6]:
def sample_points(points, percentage):
    num_samples = int(np.random.uniform(percentage, 1.) * points.shape[0])
    point_range = range(len(points))
    choices = np.random.choice(point_range, num_samples, replace=False)
    return points[choices]

In [None]:
def to_batch_dim(ft, ct):
    batches = ct[:,0]
    batch_size = batches.max().item() + 1
    #nr_features = ft.shape[0] // batch_size
    nr_features = torch.bincount(batches)
    max_nr_features = torch.max(nr_features)
    return_coords = torch.zeros([batch_size, max_nr_features , 3], device=ct.device)
    return_features = torch.zeros([batch_size, 128, max_nr_features], device=ft.device)
    for b in range(batch_size):
        mask = batches == b
        selected_feats = ft[mask]
        selected_coords = ct[mask][:, 1:]
        
        return_features[b, :, :selected_feats.shape[0]] = selected_feats.T
        return_coords[b, :selected_coords.shape[0], :] = selected_coords
    return return_features, return_coords, nr_features

def to_batch_output(ct, bb, cl, po):
    batches = ct[:,0]
    batch_size = batches.max().item() + 1
    
    out_bb, out_cl, out_po = [], [], []
    for b in range(batch_size):
        mask = batches == b
        bb_batch = bb[mask]
        cl_batch = cl[mask]
        po_batch = po[mask]
        
        out_bb.append(bb_batch)
        out_cl.append(cl_batch)
        out_po.append(po_batch)
    return [out_bb], [out_cl], [out_po]
    

In [None]:
def to_sparse_dim(ct, bt):
    B, C, N, _ = bt.shape
    device = bt.device
    batch_ids = ct[:, 0].long()  # [B*N]

    sorted_ids, sorted_idx = torch.sort(batch_ids)
    counts = torch.bincount(batch_ids, minlength=B)

    point_indices_sorted = torch.cat([torch.arange(c, device=device) for c in counts.tolist()])
    point_indices = torch.empty_like(point_indices_sorted)
    point_indices[sorted_idx] = point_indices_sorted
    out = bt[batch_ids, :, point_indices, 0]
    return out

In [9]:
def one_hot_encode(label_list):
    class_map = {
        3 : 0,
        4 : 1,
        5 : 2,
        6 : 3,
        7 : 4,
        8 : 5,
        9 : 6,
        10 : 7,
        11 : 8,
        12 : 9,
        14 : 10,
        16 : 11,
        24 : 12,
        28 : 13,
        33 : 14,
        34 : 15,
        36 : 16,
        39 : 17
    }
    return_tensor = torch.zeros([label_list.shape[0], 18], device = device)
    for i, l in enumerate(label_list):
        return_tensor[i, class_map[int(l)]] = 1
    #print(return_tensor)
    return return_tensor

In [10]:
def normalize(tensor, new_min=0.0, new_max=1.0):
    t_min, t_max = tensor.min(), tensor.max()
    if t_min == t_max:
        return torch.full_like(tensor, new_min)
    return new_min + (tensor - t_min) * (new_max - new_min) / (t_max - t_min)

def normalize_points_color(points, color_mean=None):
    """
    Normalize RGB color channels in point cloud.
    
    Args:
        points (np.ndarray): (N, 6) array, with columns [x, y, z, r, g, b]
        color_mean (list or None): Mean values for r, g, b channels. 
                                   If None, compute from data.
    
    Returns:
        np.ndarray: Normalized points with shape (N, 6)
    """
    assert points.shape[1] >= 6, "Expected points with at least 6 dimensions (x, y, z, r, g, b)"
    
    colors = points[:, 3:6].astype(np.float32)
    
    if color_mean is None:
        color_mean = colors.mean(axis=0)
    
    # Normalize RGB values: (value - mean) / 255.0
    normalized_colors = (colors - color_mean) / 255.0
    
    # Replace original colors with normalized ones
    points[:, 3:6] = normalized_colors
    return points

def to_real_coords(voxels, voxel_size):
    #norm_voxels = []
    
    """for p, v in zip(points, voxels):
        p_min = p.min()
        p_max = p.max()
        #print("voxels:", v.min(), v.max())
        #print("points:", p_min, p_max)
        
        #norm_voxels.append(normalize(v, p_min, p_max))
    """
        
    return (voxels.float()) * voxel_size

In [11]:
def match_feat(vote_coords, vote_feat, tr3d_coords, tr3d_feat, lengths):
    #print("Vote coords", vote_coords.shape)
    #print("TR3D coords", tr3d_coords.shape)
    _, ids, _ = knn_points(tr3d_coords[:,:,0:3], vote_coords[:,:,0:3],K=1, lengths1=lengths)
    index = ids.squeeze(-1).unsqueeze(1)  # remove last dim, add dim for channels
    index = index.expand(-1, 256, -1)
    output = torch.gather(vote_feat, dim=2, index=index).to(device)
    #print("MATCHED:", output.shape)
    return output
    

# RGBD Backbone

In [12]:
from TR3D.mink_resnet_TR3D import TR3DMinkResNet
from TR3D.tr3d_neck import TR3DNeck
import torch.nn.functional as F

class RGBDBackbone(nn.Module):
    def __init__(self, weights_path="TR3D.pth"):
        super().__init__()
        self.in_channels = 3
        self.depth = 34
        self.norm = 'batch'
        self.num_planes=(64, 128, 128, 128)
        self.backbone = TR3DMinkResNet(in_channels=self.in_channels, depth=self.depth, 
                                    norm=self.norm, num_planes=self.num_planes, pool=False)
        
        self.neck_in_channels=(64, 128, 128, 128) 
        self.neck_out_channels=128
        self.neck = TR3DNeck(in_channels=self.neck_in_channels, out_channels=self.neck_out_channels)
        
        # initialize weights of backbone and neck
        if os.path.exists(weights_path):
            state_dict = torch.load(weights_path, map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu'), weights_only=False)
            backbone_state = {k.replace('backbone.', ''): v for k, v in state_dict['state_dict'].items() if k.startswith('backbone.')}
            neck_state = {k.replace('neck.', ''): v for k, v in state_dict['state_dict'].items() if k.startswith('neck.')}
            self.neck.load_state_dict(neck_state, strict=True)
            self.backbone.load_state_dict(backbone_state, strict=True)
            print(f"Loaded pretrained weights from {weights_path}")
        else:
            raise FileNotFoundError(f"Pretrained weights not found at {weights_path}")

    def forward(self,input_rgb):
        features = self.backbone(input_rgb)
        modified_features = self.neck(features)
        return modified_features



# LiDAR Backbone

In [None]:
from votenet.models.backbone_module import Pointnet2Backbone
from votenet.models.voting_module import VotingModule
from votenet.models.proposal_module import ProposalModule
from votenet.models.dump_helper import dump_results
from votenet.models.loss_helper import get_loss


class LiDARBackbone(nn.Module):
    r"""
        A deep neural network for 3D object detection with end-to-end optimizable hough voting.

        Parameters
        ----------
        num_class: int
            Number of semantics classes to predict over -- size of softmax classifier
        num_heading_bin: int
        num_size_cluster: int
        input_feature_dim: (default: 0)
            Input dim in the feature descriptor for each point.  If the point cloud is Nx9, this
            value should be 6 as in an Nx9 point cloud, 3 of the channels are xyz, and 6 are feature descriptors
        num_proposal: int (default: 128)
            Number of proposals/detections generated from the network. Each proposal is a 3D OBB with a semantic class.
        vote_factor: (default: 1)
            Number of votes generated from each seed point.
    """

    def __init__(self, num_class, num_heading_bin, num_size_cluster, mean_size_arr,
        input_feature_dim=1, num_proposal=128, vote_factor=1, sampling='vote_fps', backbone_path='votenet_backbone.pth',neck_path='votenet_neck.pth'):
        super().__init__()

        self.num_class = num_class
        self.num_heading_bin = num_heading_bin
        self.num_size_cluster = num_size_cluster
        self.mean_size_arr = mean_size_arr
        assert(mean_size_arr.shape[0] == self.num_size_cluster)
        self.input_feature_dim = input_feature_dim
        self.num_proposal = num_proposal
        self.vote_factor = vote_factor
        self.sampling=sampling

        # Backbone point feature learning
        self.backbone_net = Pointnet2Backbone(input_feature_dim=self.input_feature_dim)

        # Hough voting (Neck)
        self.vgen = VotingModule(self.vote_factor, 256)
            
        # initialize weights of backbone and neck
        if os.path.exists(neck_path) and os.path.exists(backbone_path):
            neck_state = torch.load(neck_path, map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu'), weights_only=False)
            backbone_state = torch.load(backbone_path, map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu'), weights_only=False)
            self.backbone_net.load_state_dict(backbone_state, strict=True)
            self.vgen.load_state_dict(neck_state, strict=True)
            print(f"Loaded pretrained weights from {backbone_path} and {neck_path}")
        else:
            raise FileNotFoundError(f"Pretrained weights not found at {backbone_path} or {neck_path}")

    def forward(self, inputs):
        """ Forward pass of the network

        Args:
            inputs: dict
                {point_clouds}

                point_clouds: Variable(torch.cuda.FloatTensor)
                    (B, N, 3 + input_channels) tensor
                    Point cloud to run predicts on
                    Each point in the point-cloud MUST
                    be formated as (x, y, z, features...)
        Returns:
            end_points: dict
        """
        
        end_points = {}
        batch_size = inputs['point_clouds'].shape[0]

        end_points = self.backbone_net(inputs['point_clouds'], end_points)
                
        # --------- HOUGH VOTING ---------
        xyz = end_points['fp2_xyz']
        features = end_points['fp2_features']
        end_points['seed_inds'] = end_points['fp2_inds']
        end_points['seed_xyz'] = xyz
        end_points['seed_features'] = features
        
        xyz, features = self.vgen(xyz, features)
        features_norm = torch.norm(features, p=2, dim=1)
        features = features.div(features_norm.unsqueeze(1))
        end_points['vote_xyz'] = xyz
        end_points['vote_features'] = features

        return end_points['vote_features'], end_points['vote_xyz']

# Fusion Network

# Detection Head

In [14]:
class TR3DHead(nn.Module):
    def __init__(self, in_channels, num_reg_outs, num_classes, voxel_size = 0.02, pts_center_threshold=10, weights_path="TR3D.pth"):
        super().__init__()
        self.voxel_size = voxel_size
        self.pts_center_threshold = pts_center_threshold

        self.conv_reg = ME.MinkowskiConvolution(
            in_channels, num_reg_outs, kernel_size=1, bias=True, dimension=3
        )
        self.conv_cls = ME.MinkowskiConvolution(
            in_channels, num_classes, kernel_size=1, bias=True, dimension=3
        )

        # Init
        nn.init.normal_(self.conv_reg.kernel, std=0.01)
        nn.init.normal_(self.conv_cls.kernel, std=0.01)
        nn.init.constant_(self.conv_cls.bias, -torch.log(torch.tensor((1 - 0.01) / 0.01)))
        
    def forward_single(self, x):
        reg_out = self.conv_reg(x).features
        cls_out = self.conv_cls(x).features

        reg_distance = torch.exp(reg_out[:, 3:6])
        reg_angle = reg_out[:, 6:] if reg_out.shape[1] > 6 else None
        bbox_pred = torch.cat([reg_out[:, :3], reg_distance, reg_angle], dim=1) if reg_angle is not None else torch.cat([reg_out[:, :3], reg_distance], dim=1)

        return bbox_pred, cls_out, x.coordinates[:, 1:] * self.voxel_size

    def forward(self, sparse_tensor_list):
        all_bbox_preds, all_cls_preds, all_points = [], [], []
        for x in sparse_tensor_list:
            bbox_pred, cls_pred, points = self.forward_single(x)
            all_bbox_preds.append(bbox_pred)
            all_cls_preds.append(cls_pred)
            all_points.append(points)
        return all_bbox_preds, all_cls_preds, all_points

# Loss Function / Accuracy

In [None]:
from TR3D.IoU_TR3D_loss import TR3DAxisAlignedIoULoss
from TR3D.Focal_loss import FocalLoss
from typing import List, Optional, Tuple
from torch import Tensor
from TR3D.iou_3d import calculate_map
from TR3D.IoU_TR3D_loss import AxisAlignedBboxOverlaps3D

from collections import defaultdict

@torch.no_grad()
def calculate_map(pred_boxes, gt_boxes, pred_cls, gt_labels, iou_threshold=0.25, num_classes=18):
    """
    Compute mean Average Precision (mAP) for 3D object detection.
    
    Parameters:
        pred_boxes: List of predicted boxes [x, y, z, dx, dy, dz, class_id, confidence]
        gt_boxes:   List of ground truth boxes [x, y, z, dx, dy, dz, class_id]
        iou_threshold: IoU threshold for matching
        num_classes: number of classes
        
    Returns:
        mAP: Mean Average Precision across classes
        ap_per_class: Dictionary of AP per class
    """
    gt_by_class = defaultdict(list)
    pred_by_class = defaultdict(list)

    for gt, l in zip(gt_boxes, gt_labels):
        gt_by_class[int(l)].append(gt[:6])
    for pred, l in zip(pred_boxes, pred_cls):
        pred_by_class[int(torch.argmax(l))].append(pred[:6])  # (box, confidence)

    ap_per_class = {}
    eps = 1e-6

    for cls in range(num_classes):
        gts = gt_by_class[cls]
        preds = pred_by_class[cls]

        if len(gts) == 0 and len(preds) == 0:
            ap_per_class[cls] = None  # no predictions or targets — perfect
            continue
        elif len(gts) == 0:
            ap_per_class[cls] = 0.0
            continue

        # Sort predictions by confidence
        preds = sorted(preds, key=lambda x: -x[1])
        tp = np.zeros(len(preds))
        fp = np.zeros(len(preds))
        matched_gt = set()
        for i, pred_box in enumerate(preds):
            best_iou = 0
            best_gt_idx = -1
            for j, gt_box in enumerate(gts):
                if j in matched_gt:
                    continue
                if len(pred_box.shape) != 3:
                    pred_box = pred_box.unsqueeze(0).unsqueeze(0)
                if len(gt_box.shape) != 3:
                    gt_box = gt_box.unsqueeze(0).unsqueeze(0)
                iou = AxisAlignedBboxOverlaps3D()(LossFunction._bbox_to_loss(pred_box), LossFunction._bbox_to_loss(gt_box))
                if iou > best_iou:
                    best_iou = iou
                    best_gt_idx = j

            if best_iou >= iou_threshold:
                tp[i] = 1
                matched_gt.add(best_gt_idx)
            else:
                fp[i] = 1

        tp_cumsum = np.cumsum(tp)
        fp_cumsum = np.cumsum(fp)
        recalls = tp_cumsum / (len(gts) + eps)
        precisions = tp_cumsum / (tp_cumsum + fp_cumsum + eps)

        ap = 0.0
        for t in np.linspace(0, 1, 101):
            if np.sum(recalls >= t) == 0:
                p = 0
            else:
                p = np.max(precisions[recalls >= t])
            ap += p / 101
        ap_per_class[cls] = ap

    values = [v for v in ap_per_class.values() if v is not None]
    mAP = np.mean(values)
    return mAP, ap_per_class

class LossFunction(nn.Module):
    def __init__(self, weights = None):
        super().__init__()
        self.bbox_loss = TR3DAxisAlignedIoULoss(mode='diou', reduction='none')
        self.cls_loss = FocalLoss(gamma=2.0, alpha=0.25, reduction = "none")
        self.label2level = [0, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0, 1, 0]
        self.pts_center_threshold = 6
        self.weights = weights
    
    def _loss_by_feat_single(self, bbox_preds: List[Tensor],
                         cls_preds: List[Tensor], points: List[Tensor],
                         gt_bboxes: Tensor, gt_labels: Tensor) -> Tuple[Tensor, ...]:
        """Loss function of single sample.

        Args:
            bbox_preds (list[Tensor]): Bbox predictions for all levels.
            cls_preds (list[Tensor]): Classification predictions for all
                levels.
            points (list[Tensor]): Final location coordinates for all levels.
            gt_bboxes (:obj:`BaseInstance3DBoxes`): Ground truth boxes.
            gt_labels (Tensor): Ground truth labels.
            input_meta (dict): Scene meta info.

        Returns:
            tuple[Tensor, ...]: Bbox and classification loss
                values and a boolean mask of assigned points.
        """
        
        num_classes = cls_preds[0].shape[1]
        bbox_targets, cls_targets = self.get_targets(points, gt_bboxes,
                                                     gt_labels, num_classes)
        bbox_preds_cat = torch.cat(bbox_preds)
        cls_preds_cat = torch.cat(cls_preds)
        points = torch.cat(points)

        
        # cls loss
        num_classes = cls_preds_cat.size(1)
        cls_loss = self.cls_loss(cls_preds_cat, cls_targets, self.weights[cls_targets])

        
        # bbox loss
        pos_mask = cls_targets < num_classes
        pos_bbox_preds = bbox_preds_cat[pos_mask]
        
        
        if pos_mask.sum() > 0:
            pos_points = points[pos_mask]
            pos_bbox_preds = bbox_preds_cat[pos_mask]
            pos_bbox_targets = bbox_targets[pos_mask]
            pos_pred_cls = cls_preds_cat[pos_mask]
            pos_target_cls = cls_targets[pos_mask]
            
            size_loss = (pos_bbox_preds[..., 3:] - 5.0).clamp(min=0) ** 2
            
            transformed_pred_bbox = self._bbox_to_loss(self._bbox_pred_to_bbox(pos_points, pos_bbox_preds))
            transformed_bbox_targets = self._bbox_to_loss(pos_bbox_targets)
            
            bbox_loss = self.bbox_loss(transformed_pred_bbox, transformed_bbox_targets)
            cls_thresholds = torch.max(torch.sigmoid(pos_pred_cls), dim=1, keepdim=True)[0][:,0]
            cls_ret = cls_thresholds.max()
            mAP, AP = calculate_map(self._bbox_pred_to_bbox(pos_points, pos_bbox_preds)[cls_thresholds > 0.05], gt_bboxes,
                                  pos_pred_cls[cls_thresholds > 0.05], gt_labels)
            
        else:
            cls_ret = 0
            mAP = 0
            AP = {i:0 for i in range(18)}
            size_loss = torch.zeros_like(pos_bbox_preds[...,3:])
            bbox_loss = pos_bbox_preds
        return bbox_loss, cls_loss, pos_mask, size_loss, mAP, AP, cls_ret

    def loss_by_feat(self,
                 bbox_preds: List[List[Tensor]],
                 cls_preds: List[List[Tensor]],
                 points: List[List[Tensor]],
                 batch_gt_bboxes_3d: List[Tensor],
                 batch_gt_instances_ignore = None) -> dict:
        """Loss function about feature.

        Args:
            bbox_preds (list[list[Tensor]]): Bbox predictions for all scenes.
                The first list contains predictions from different
                levels. The second list contains predictions in a mini-batch.
            cls_preds (list[list[Tensor]]): Classification predictions for all
                scenes. The first list contains predictions from different
                levels. The second list contains predictions in a mini-batch.
            points (list[list[Tensor]]): Final location coordinates for all
                scenes. The first list contains predictions from different
                levels. The second list contains predictions in a mini-batch.
            batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
                gt_instance_3d.  It usually includes ``bboxes_3d``、`
                `labels_3d``、``depths``、``centers_2d`` and attributes.
            batch_input_metas (list[dict]): Meta information of each image,
                e.g., image size, scaling factor, etc.

        Returns:
            dict: Bbox, and classification losses.
        """
        bbox_losses, cls_losses, pos_masks, size_losses, mAPs, APs, ths = [], [], [], [], [], [], []
        for i in range(len(batch_gt_bboxes_3d)):
            bbox_loss, cls_loss, pos_mask, size_loss, mAP, AP, cls_th = self._loss_by_feat_single(
                bbox_preds=[x[i] for x in bbox_preds],
                cls_preds=[x[i] for x in cls_preds],
                points=[x[i][:,0:3] for x in points],
                gt_bboxes=batch_gt_bboxes_3d[i][:, 0:6],
                gt_labels=batch_gt_bboxes_3d[i][:,6].long())
            if len(bbox_loss) > 0:
                bbox_losses.append(bbox_loss)
            cls_losses.append(cls_loss)
            pos_masks.append(pos_mask)
            size_losses.append(size_loss)
            mAPs.append(mAP)
            APs.append(AP)
            ths.append(cls_th)
        #print(max(ths))
        return dict(
            bbox_loss=torch.mean(torch.cat(bbox_losses)),
            cls_loss=torch.sum(torch.cat(cls_losses)) / 
            torch.sum(torch.cat(pos_masks)),
            size_loss=torch.mean(torch.cat(size_losses))), np.mean(mAPs), APs



    @staticmethod
    def _bbox_to_loss(bbox):
        """Transform box to the axis-aligned or rotated iou loss format.

        Args:
            bbox (Tensor): 3D box of shape (N, 6) or (N, 7).

        Returns:
            Tensor: Transformed 3D box of shape (N, 6) or (N, 7).
        """
        # rotated iou loss accepts (x, y, z, w, h, l, heading)
        if bbox.shape[-1] != 6:
            return bbox

        # axis-aligned case: x, y, z, w, h, l -> x1, y1, z1, x2, y2, z2
        return torch.stack(
            (bbox[..., 0] - bbox[..., 3] / 2, bbox[..., 1] - bbox[..., 4] / 2, bbox[..., 2] - bbox[..., 5] / 2, 
             bbox[..., 0] + bbox[..., 3] / 2, bbox[..., 1] + bbox[..., 4] / 2, bbox[..., 2] + bbox[..., 5] / 2),
            dim=-1)

    @staticmethod
    def _bbox_pred_to_bbox(points, bbox_pred):
        """Transform predicted bbox parameters to bbox.

        Args:
            points (Tensor): Final locations of shape (N, 3)
            bbox_pred (Tensor): Predicted bbox parameters of shape (N, 6)
                or (N, 8).
        Returns:
            Tensor: Transformed 3D box of shape (N, 6) or (N, 7).
        """
        if bbox_pred.shape[0] == 0:
            return bbox_pred

        x_center = points[:, 0] + bbox_pred[:, 0]
        y_center = points[:, 1] + bbox_pred[:, 1]
        z_center = points[:, 2] + bbox_pred[:, 2]
        base_bbox = torch.stack([
            x_center, y_center, z_center, bbox_pred[:, 3], bbox_pred[:, 4],
            bbox_pred[:, 5]
        ], -1)

        # axis-aligned case
        if bbox_pred.shape[1] == 6:
            return base_bbox

        # rotated case: ..., sin(2a)ln(q), cos(2a)ln(q)
        scale = bbox_pred[:, 3] + bbox_pred[:, 4]
        q = torch.exp(
            torch.sqrt(
                torch.pow(bbox_pred[:, 6], 2) + torch.pow(bbox_pred[:, 7], 2)))
        alpha = 0.5 * torch.atan2(bbox_pred[:, 6], bbox_pred[:, 7])
        return torch.stack(
            (x_center, y_center, z_center, scale / (1 + q), scale /
             (1 + q) * q, bbox_pred[:, 5] + bbox_pred[:, 4], alpha),
            dim=-1)

    @torch.no_grad()
    def get_targets(self, points: Tensor, gt_bboxes: Tensor,
                    gt_labels: Tensor, num_classes: int) -> Tuple[Tensor, ...]:
        """Compute targets for final locations for a single scene.

        Args:
            points (list[Tensor]): Final locations for all levels.
            gt_bboxes (BaseInstance3DBoxes): Ground truth boxes.
            gt_labels (Tensor): Ground truth labels.
            num_classes (int): Number of classes.

        Returns:
            tuple[Tensor, ...]: Bbox and classification targets for all
                locations.
        """
        float_max = points[0].new_tensor(1e8)
        levels = torch.cat([
            points[i].new_tensor(i, dtype=torch.long).expand(len(points[i]))
            for i in range(len(points))
        ])
        points = torch.cat(points)
        n_points = len(points)
        n_boxes = len(gt_bboxes)

        if len(gt_labels) == 0:
            return points.new_tensor([]), \
                gt_labels.new_full((n_points,), num_classes)
        
        zero_col = torch.zeros([n_boxes,1], device=gt_bboxes.device)
        boxes = torch.cat([gt_bboxes, zero_col], dim=1)
        #print("box", boxes.shape)
        boxes = boxes.expand(n_points, n_boxes, 7)
        points = points.unsqueeze(1).expand(n_points, n_boxes, 3)

        # condition 1: fix level for label
        label2level = gt_labels.new_tensor(self.label2level)
        label_levels = label2level[gt_labels].unsqueeze(0).expand(
            n_points, n_boxes)
        point_levels = torch.unsqueeze(levels, 1).expand(n_points, n_boxes)
        level_condition = label_levels == point_levels

        # condition 2: keep topk location per box by center distance
        center = boxes[..., :3]
        center_distances = torch.sum(torch.pow(center - points, 2), dim=-1)
        center_distances = torch.where(level_condition, center_distances,
                                       float_max)
        topk_distances = torch.topk(
            center_distances,
            min(self.pts_center_threshold + 1, len(center_distances)),
            largest=False,
            dim=0).values[-1]
        topk_condition = center_distances < topk_distances.unsqueeze(0)

        # condition 3: min center distance to box per point
        center_distances = torch.where(topk_condition, center_distances,
                                       float_max)
        min_values, min_ids = center_distances.min(dim=1)
        min_inds = torch.where(min_values < float_max, min_ids, -1)

        bbox_targets = boxes[0][min_inds]
        
        bbox_targets = bbox_targets[:, :-1]
        cls_targets = torch.where(min_inds >= 0, gt_labels[min_inds],
                                  num_classes)
        return bbox_targets, cls_targets

    def _single_scene_multiclass_nms(self, bboxes: Tensor, scores: Tensor,
                                     input_meta: dict) -> Tuple[Tensor, ...]:
        """Multi-class nms for a single scene.

        Args:
            bboxes (Tensor): Predicted boxes of shape (N_boxes, 6) or
                (N_boxes, 7).
            scores (Tensor): Predicted scores of shape (N_boxes, N_classes).
            input_meta (dict): Scene meta data.

        Returns:
            tuple[Tensor, ...]: Predicted bboxes, scores and labels.
        """
        num_classes = scores.shape[1]
        with_yaw = bboxes.shape[1] == 7
        nms_bboxes, nms_scores, nms_labels = [], [], []
        for i in range(num_classes):
            ids = scores[:, i] > self.test_cfg.score_thr
            if not ids.any():
                continue

            class_scores = scores[ids, i]
            class_bboxes = bboxes[ids]
            if with_yaw:
                nms_function = nms3d
            else:
                class_bboxes = torch.cat(
                    (class_bboxes, torch.zeros_like(class_bboxes[:, :1])),
                    dim=1)
                nms_function = nms3d_normal

            nms_ids = nms_function(class_bboxes, class_scores,
                                   self.test_cfg.iou_thr)
            nms_bboxes.append(class_bboxes[nms_ids])
            nms_scores.append(class_scores[nms_ids])
            nms_labels.append(
                bboxes.new_full(
                    class_scores[nms_ids].shape, i, dtype=torch.long))

        if len(nms_bboxes):
            nms_bboxes = torch.cat(nms_bboxes, dim=0)
            nms_scores = torch.cat(nms_scores, dim=0)
            nms_labels = torch.cat(nms_labels, dim=0)
        else:
            nms_bboxes = bboxes.new_zeros((0, bboxes.shape[1]))
            nms_scores = bboxes.new_zeros((0, ))
            nms_labels = bboxes.new_zeros((0, ))

        if not with_yaw:
            nms_bboxes = nms_bboxes[:, :6]

        return nms_bboxes, nms_scores, nms_labels
    
    def forward(self,
                 bbox_preds: List[List[Tensor]],
                 cls_preds: List[List[Tensor]],
                 points: List[List[Tensor]],
                 batch_gt_instances_3d: List[Tensor],
                 batch_gt_instances_ignore = None):
        return self.loss_by_feat(bbox_preds, cls_preds, points, batch_gt_instances_3d)
    


# Full Network

In [None]:
class CombinedModel(nn.Module):
    def __init__(self, rgb_backbone, lidar_backbone, voxel_size, n_classes=18, weights_path="TR3D.pth"):
        super().__init__()
        self.voxel_size = voxel_size
        self.rgb_backbone = rgb_backbone
        self.lidar_backbone = lidar_backbone
        
        self.detection_head = TR3DHead(128 + 256, 6, 18)
        
        state_dict = torch.load(weights_path, map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu'), weights_only=False)
        head_state = {k.replace('head.', ''): v for k, v in state_dict['state_dict'].items() if k.startswith('head.')}
        head_state.pop("conv_reg.kernel", None)
        head_state.pop("conv_cls.kernel", None)
        self.detection_head.load_state_dict(head_state, strict=False)
        print(f"Loaded pretrained weights from {weights_path}")


    def forward(self, rgb_input, lidar_input):
        rgb_tensors = self.rgb_backbone(rgb_input)  # Dropout applied inside the RGBBackbone
        
        rgb_coords = [t.coordinates for t in rgb_tensors]
        rgb_coords = torch.cat(rgb_coords)
        
        rgb_feats = [t.features for t in rgb_tensors]
        rgb_feats = torch.cat(rgb_feats)
        
        lidar_features, lidar_coords = self.lidar_backbone(lidar_input)  # Dropout applied inside LiDARBackbone
        
        rgb_features, _, _ = to_batch_dim(rgb_feats, rgb_coords)
        
        pad_amount = rgb_features.size(2) - lidar_features.size(2)
        
        padded_lidar = F.pad(lidar_features, (0, pad_amount))
        fused_feats = torch.cat([rgb_features, padded_lidar], dim=1)
        fused_feats = fused_feats.unsqueeze(3)
        
        feats = to_sparse_dim(rgb_coords, fused_feats)
        
        
             
        sparse_fused = ME.SparseTensor(features=feats, coordinates=rgb_coords, device=device)
        
        bb, cl, po = self.detection_head([sparse_fused])
        bb, cl, po = to_batch_output(sparse_fused.coordinates, 
                                     torch.cat(bb), torch.cat(cl), torch.cat(po))
        return bb, cl, po

# Test Run

In [None]:
def test_run():
    mean_sizes=np.array([[0.76966727, 0.8116021, 0.92573744],
                            [1.876858, 1.8425595, 1.1931566],
                            [0.61328, 0.6148609, 0.7182701],
                            [1.3955007, 1.5121545, 0.83443564],
                            [0.97949594, 1.0675149, 0.6329687],
                            [0.531663, 0.5955577, 1.7500148],
                            [0.9624706, 0.72462326, 1.1481868],
                            [0.83221924, 1.0490936, 1.6875663],
                            [0.21132214, 0.4206159, 0.5372846],
                            [1.4440073, 1.8970833, 0.26985747],
                            [1.0294262, 1.4040797, 0.87554324],
                            [1.3766412, 0.65521795, 1.6813129],
                            [0.6650819, 0.71111923, 1.298853],
                            [0.41999173, 0.37906948, 1.7513971],
                            [0.59359556, 0.5912492, 0.73919016],
                            [0.50867593, 0.50656086, 0.30136237],
                            [1.1511526, 1.0546296, 0.49706793],
                            [0.47535285, 0.49249494, 0.5802117]])

    voxel_size = 0.2

    votenet = LiDARBackbone(num_class=18, num_heading_bin=1, num_size_cluster=18, mean_size_arr = mean_sizes).to(device)

    TR3D = RGBDBackbone().to(device)

    test_fusion = CombinedModel(TR3D, votenet, voxel_size).to(device)

    #bb, cl, po = test_fusion(test_dataset[0][0], test_dataset[0][1])
    #print(bb[0][0], cl[0][0])

    #--------------------------------

    data = get_data('scannetv1_train.txt', TR3D, voxel_size)

    batch_size = 16

    train_loader = DataLoader(
        data,
        batch_size=batch_size,
        shuffle=True,
        num_workers=0,
        collate_fn=collate_fn
    )

    #batch = next(iter(train_loader))
    count = 0
    count1 = 0

    test_loss = LossFunction()
    for feats, coords, votenet, bboxes_list in train_loader:
        count1 += 1
        #try:
        bboxes = [bbox.to(device) for bbox in bboxes_list]
        votenet['point_clouds'] = votenet['point_clouds'].to(device)

        #print("Feats", feats[0].shape)
        coords, feats = ME.utils.sparse_collate(coords, feats)
        #print("after sparse collate", coords.shape, feats.shape)
        sparse_tensor = ME.SparseTensor(features=feats.to(device), coordinates=coords.to(device),  # coordinates must be defined in a integer grid. If the scale
        quantization_mode=ME.SparseTensorQuantizationMode.UNWEIGHTED_AVERAGE)
        bb, cl, po = test_fusion(sparse_tensor, votenet)
        loss = test_loss(bb,cl,po,bboxes)
        count += 1
    print(count, count1)
    
#test_run()

In [18]:
def count_instances():
    train_data = get_data('scannet_train.txt', 0.2)
    train_loader = DataLoader(
        train_data,
        batch_size=1,
        shuffle=True,
        num_workers=0,  # change to >0 once it works stably
        collate_fn=collate_fn
    )
    counts = [0] * 18
    for feats, coords, lidar_inputs, bboxes_list in train_loader:
        for bbox in bboxes_list[0]:
            counts[int(bbox[6])] += 1

    return counts

def get_class_weights(method='log_inv'):
    """
    Returns normalized class weights as a torch.Tensor.
    
    Args:
        class_counts (list or array-like): A list of the number of occurrences for each class.
        method (str): Normalization method. Options are:
                      'inverse' - inverse frequency (1 / count)
                      'sqrt_inv' - inverse sqrt frequency (1 / sqrt(count))
    
    Returns:
        torch.Tensor: A tensor of normalized weights summing to 1.
    """
    class_counts = count_instances()
    counts = torch.tensor(class_counts, dtype=torch.float32, device=device)
    if method == 'inverse':
        weights = 1.0 / counts  # Avoid division by zero
    elif method == 'sqrt_inv':
        weights = 1.0 / torch.sqrt(counts)
    elif method == "log_inv":
        weights = torch.log(counts.sum() / counts)
    else:
        raise ValueError("Unsupported method. Use 'inverse' or 'sqrt_inv'.")
    # Normalize to sum to 1
    #weights = weights / weights.sum()
    print(weights)
    zero_to_add = torch.tensor([2.0], device=device)  # must be a tensor with same dtype
    return torch.cat((weights, zero_to_add))
#print(count_instances())

# Training

In [None]:
def train_model(n_classes=18, num_epochs=3, batch_size=16, 
    learning_rate=0.001, debug=False, output_folder="output_fused_D", class_names=None,
    voxel_size = 0.2):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    mean_sizes=np.array([[0.76966727, 0.8116021, 0.92573744],
                        [1.876858, 1.8425595, 1.1931566],
                        [0.61328, 0.6148609, 0.7182701],
                        [1.3955007, 1.5121545, 0.83443564],
                        [0.97949594, 1.0675149, 0.6329687],
                        [0.531663, 0.5955577, 1.7500148],
                        [0.9624706, 0.72462326, 1.1481868],
                        [0.83221924, 1.0490936, 1.6875663],
                        [0.21132214, 0.4206159, 0.5372846],
                        [1.4440073, 1.8970833, 0.26985747],
                        [1.0294262, 1.4040797, 0.87554324],
                        [1.3766412, 0.65521795, 1.6813129],
                        [0.6650819, 0.71111923, 1.298853],
                        [0.41999173, 0.37906948, 1.7513971],
                        [0.59359556, 0.5912492, 0.73919016],
                        [0.50867593, 0.50656086, 0.30136237],
                        [1.1511526, 1.0546296, 0.49706793],
                        [0.47535285, 0.49249494, 0.5802117]])



    # Initialize backbones
    rgb_backbone = RGBDBackbone()  # Correctly instantiate the RGB backbone
    lidar_backbone = LiDARBackbone(num_class=18, num_heading_bin=1, num_size_cluster=18, mean_size_arr = mean_sizes)  # Correctly instantiate the LiDAR backbone

    # Multi-GPU setup
    os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
    model = CombinedModel(rgb_backbone, lidar_backbone, voxel_size)
    model = torch.nn.DataParallel(model, device_ids=[0, 1])  # Wrap model for multi-GPU usage
    model = model.to(device)

    def lidar_transform(lidar_data, gt_boxes):
        if torch.rand(1).item() > 0.5:  # Apply random vertical flip with 50% chance
            
            lidar_data[:, 0] = -lidar_data[:, 0]
            gt_boxes[:,0] = -gt_boxes[:,0]
        if torch.rand(1).item() > 0.5:  # Apply random horizontal flip with 50% chance
            lidar_data[:, 1] = -lidar_data[:, 1]
            gt_boxes[:,1] = -gt_boxes[:,1]
        return lidar_data, gt_boxes

    print(f"Using GPUs: {torch.cuda.device_count()}")

    train_data = get_data('train.txt', voxel_size, lidar_transform)
    val_data = get_data('val.txt', voxel_size)
    test_data = get_data('test.txt', voxel_size)
    

    train_loader = DataLoader(
        train_data,
        batch_size=batch_size,
        shuffle=True,
        num_workers=8,  # change to >0 once it works stably
        collate_fn=collate_fn
    )
    #train_dataset.lidar_transform = lidar_transform  # Attach lidar_transform to train_dataset

    val_loader = DataLoader(
        val_data,
        batch_size=batch_size,
        shuffle=True,
        num_workers=8,  # change to >0 once it works stably
        collate_fn=collate_fn
    )

    test_loader = DataLoader(
        test_data,
        batch_size=batch_size,
        shuffle=True,
        num_workers=8,  # change to >0 once it works stably
        collate_fn=collate_fn
    )

    model = CombinedModel(rgb_backbone, lidar_backbone, voxel_size).to(device)
    
    for param in model.rgb_backbone.parameters():
        param.requires_grad = False
    model.rgb_backbone.eval()

    # Freeze LiDAR backbone
    for param in model.lidar_backbone.parameters():
        param.requires_grad = False
    model.lidar_backbone.eval()

    weights = get_class_weights()
    criterion = LossFunction(weights)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=3, threshold=0.001)
    
    os.makedirs(output_folder, exist_ok=True)
    metrics_path = os.path.join(output_folder, "metrics.csv")

    log_dir = os.path.join(output_folder, "tensorboard_logs")
    
    writer = SummaryWriter(log_dir)

    dummy_rgb = torch.randn(batch_size*370, 4).to(device)
    dummy_lidar = torch.randn(1, 50000, 3).to(device)

    #with torch.no_grad():
        #writer.add_graph(model, (dummy_rgb, dummy_lidar))

    columns = [
        'epoch', 'train_loss', 'train_accuracy', 'train_f1',
        'val_loss', 'val_accuracy', 'val_f1',
        'test_loss', 'test_accuracy', 'test_f1'
    ]
    metrics_df = pd.DataFrame(columns=columns)
    metrics_df.to_csv(metrics_path, index=False)
    torch.cuda.empty_cache()
    
    mean_APs = {i:0 for i in range(18)}
    best_val_accuracy = 0
    counter = 0
    for epoch in range(num_epochs):
        epoch_metrics = {'epoch': epoch + 1}

        with tqdm(total=len(train_loader) + len(val_loader), desc=f"Epoch {epoch + 1}/{num_epochs}", leave=False) as pbar:
            for phase in ['train', 'val']:
                model.train() if phase == 'train' else model.eval()
                data_loader = train_loader if phase == 'train' else val_loader

                running_loss, all_preds, all_labels = 0.0, [], []

                with torch.set_grad_enabled(phase == 'train'):
                    for feats, coords, lidar_inputs, bboxes_list in data_loader:
                        labels = [bbox.to(device) for bbox in bboxes_list]
                        lidar_inputs['point_clouds'] = lidar_inputs['point_clouds'].to(device)
                        coords, feats = ME.utils.sparse_collate(coords, feats)
                        in_field = ME.TensorField(
                            features = feats,
                            coordinates = coords,
                            quantization_mode=ME.SparseTensorQuantizationMode.UNWEIGHTED_AVERAGE,
                            minkowski_algorithm=ME.MinkowskiAlgorithm.SPEED_OPTIMIZED,
                            device=device
                        )
                        rgb_inputs = in_field.sparse()
                        """
                        rgb_inputs = ME.SparseTensor(features=feats.to(device), coordinates=coords.to(device),  # coordinates must be defined in a integer grid. If the scale
                        quantization_mode=ME.SparseTensorQuantizationMode.UNWEIGHTED_AVERAGE)"""
                        
                        pred_bb, pred_cl, points = model(rgb_inputs, lidar_inputs)
                        
                        loss, mAP, AP = criterion(pred_bb, pred_cl, points, labels)
                        for ap_per_class in AP:
                            for i in range(18):
                                if ap_per_class[i] is not None:
                                    mean_APs[i] += ap_per_class[i]
                            counter += 1
                        loss = loss['cls_loss'] + loss['bbox_loss']
                        if phase == 'train':
                            optimizer.zero_grad(); loss.backward(); optimizer.step()

                        running_loss += loss.item() * batch_size
                        all_preds.append(mAP)
                        pbar.update(1)

                epoch_loss = running_loss / len(data_loader.dataset)
                epoch_accuracy = np.mean(all_preds)
                #epoch_f1 = f1_score(all_labels, all_preds, average="weighted")

                epoch_metrics[f'{phase}_loss'] = epoch_loss
                epoch_metrics[f'{phase}_accuracy'] = epoch_accuracy
                #epoch_metrics[f'{phase}_f1'] = epoch_f1

                if phase == 'val':
                    scheduler.step(epoch_loss)
                    if epoch_accuracy > best_val_accuracy:
                        best_val_accuracy = epoch_accuracy
                        torch.save(model.state_dict(), os.path.join(output_folder, "best_model.pth"))

        writer.add_scalar('Loss/train', epoch_metrics['train_loss'], epoch)
        writer.add_scalar('Loss/validation', epoch_metrics['val_loss'], epoch)
        writer.add_scalar('Accuracy/train', epoch_metrics['train_accuracy'], epoch)
        writer.add_scalar('Accuracy/validation', epoch_metrics['val_accuracy'], epoch)
        #writer.add_scalar('F1/train', epoch_metrics['train_f1'], epoch)
        #writer.add_scalar('F1/validation', epoch_metrics['val_f1'], epoch)

        metrics_df = pd.concat([metrics_df, pd.DataFrame([epoch_metrics])], ignore_index=True)
        metrics_df.to_csv(metrics_path, index=False)
    for i in range(18):
        mean_APs[i] /= counter
    print("\nEvaluating best model on test set...")
    model.load_state_dict(torch.load(os.path.join(output_folder, "best_model.pth")))
    model.eval()

    running_loss, all_preds, all_labels = 0.0, [], []
    with torch.no_grad():
        for feats, coords, lidar_inputs, bboxes_list in data_loader:
            labels = [bbox.to(device) for bbox in bboxes_list]
            lidar_inputs['point_clouds'] = lidar_inputs['point_clouds'].to(device)
            coords, feats = ME.utils.sparse_collate(coords, feats)
            rgb_inputs = ME.SparseTensor(features=feats.to(device), coordinates=coords.to(device))
            pred_bb, pred_cl, points = model(rgb_inputs, lidar_inputs)
            loss, mAP, AP = criterion(pred_bb, pred_cl, points, labels)
            loss = loss['bbox_loss'] + loss['cls_loss']
            
            running_loss += loss.item() * batch_size
            all_preds.append(mAP)

    test_loss = running_loss / len(test_loader.dataset)
    test_accuracy = np.mean(all_preds)
    print(pred_bb[0][0].shape)
    print(mean_APs)
    #test_f1 = f1_score(all_labels, all_preds, average="weighted")

    test_metrics_row = {col: None for col in columns}
    test_metrics_row.update({
        'epoch': 'best_model',
        'test_loss': test_loss,
        'test_accuracy': test_accuracy,
        #'test_f1': test_f1
    })
    
    metrics_df = pd.concat([metrics_df, pd.DataFrame([test_metrics_row])], ignore_index=True)
    metrics_df.to_csv(metrics_path, index=False)

    # Save confusion matrix as CSV
    """conf_matrix = confusion_matrix(all_labels, all_preds, normalize='true')
    conf_matrix_csv_path = os.path.join(output_folder, 'normalized_confusion_matrix.csv')
    pd.DataFrame(conf_matrix, index=class_names, coluns=class_names).to_csv(conf_matrix_csv_path)
    print(f"Confusion matrix CSV saved at: {conf_matrix_csv_path}")"""

    writer.close()

In [20]:
train_model(learning_rate=0.001, num_epochs=20, batch_size=10, voxel_size = 0.02)

<class 'MinkowskiNormalization.MinkowskiBatchNorm'>
Loaded pretrained weights from TR3D.pth
Loaded pretrained weights from votenet_backbone.pth and votenet_neck.pth
Loaded pretrained weights from TR3D.pth
Using GPUs: 2
Loaded pretrained weights from TR3D.pth
tensor([2.4158, 3.9594, 1.2447, 3.6952, 2.5133, 2.0973, 2.8201, 3.9728, 3.1347,
        4.3157, 3.3919, 4.0510, 4.4236, 4.9219, 4.3465, 3.7117, 4.9288, 2.0817],
       device='cuda:0')


                                                                                                                        


Evaluating best model on test set...




torch.Size([8947, 6])
{0: 0.3068562548544587, 1: 0.0, 2: 0.35974233907946307, 3: 0.0, 4: 0.0, 5: 0.5271795823265206, 6: 0.0, 7: 0.0, 8: 0.17642712067486124, 9: 0.0, 10: 0.0, 11: 0.0, 12: 0.12726695585544634, 13: 0.07554424973368508, 14: 0.1464666646135584, 15: 0.22845055664403818, 16: 0.0, 17: 0.4986162498151042}
