## Visualization of the Training Result 
 - ptv3_small
 - with 8 * Specimens
 - with 6 classes: change dataloader to ptv3_6

In [None]:
import torch, random
import numpy as np
import os
from data_utils.augmentation import NormalizeFeatures, AdjustRGBColor
from data_utils.SpineDepthDataLoader_ptv3 import TrainDataset, TestDataset
# from data_utils.SpineDepthDataLoader_ptv3_6 import TrainDataset, TestDataset
from torchvision.transforms import Compose
from ptv3_model import PTv3Wrap
import torch.nn as nn
from tqdm import tqdm
import torch.nn.functional as F
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import plotly.graph_objects as go
from plotly.subplots import make_subplots

In [None]:
# reproducability
manual_seed = 42     
random.seed(manual_seed)    # Python's random module
np.random.seed(manual_seed)     # NumPy random module
torch.manual_seed(manual_seed)      # PyTorch CPU
torch.cuda.manual_seed_all(manual_seed)     # PyTorch GPU (for all devices)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# For DataLoader Reproducability
def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

# same sequence of random numbers 
g = torch.Generator()
g.manual_seed(manual_seed)

In [None]:
NUM_CLASSES = 2
NUM_CHANNELS = 6
NUM_POINT = 10000
BATCH_SIZE = 2
NUM_WORKERS = 12
test_specimen_idx = 3
adjust_strength = 0.3
randomize_rate = 0.5

def _device():
    return torch.device("cuda" if torch.cuda.is_available() else "cpu")

device = _device()

In [None]:
def ptv3_small(
    num_classes: int = NUM_CLASSES, 
    in_channels: int = NUM_CHANNELS, 
    num_points: int = NUM_POINT, 
    patch_size: int = 1024, 
    device: torch.device = _device()
):
    # Encoder settings
    enc_channels = (32, 64, 128, 256)
    enc_depths = (2, 2, 2, 2)
    enc_num_head = (2, 4, 8, 16)
    enc_patch_size = (patch_size,) * 4
    stride = (2, 2, 2)

    # Decoder settings
    dec_channels = (16, 64, 128)
    dec_depths = (2, 2, 2)
    dec_num_head = (4, 8, 16)
    dec_patch_size = (patch_size,) * 3

    # Initialize the PTv3Wrap model
    model = PTv3Wrap(
        num_classes=num_classes,
        in_channels=in_channels,
        num_points=num_points,
        enc_channels=enc_channels,
        enc_depths=enc_depths,
        enc_num_head=enc_num_head,
        enc_patch_size=enc_patch_size,
        dec_channels=dec_channels,
        dec_depths=dec_depths,
        dec_num_head=dec_num_head,
        dec_patch_size=dec_patch_size,
        stride=stride,
    )

    return model.to(device)

class WeightedFocalLoss(nn.Module):
    def __init__(self, gamma=5.0, reduction='mean'):
        super(WeightedFocalLoss, self).__init__()
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, seg_pred, targets, alpha):
        """
        seg_pred: Logits from the model [B, C] where C = number of classes.
        targets: Ground truth labels [B].
        alpha: Class weights (1D tensor) for weighted focal loss.
        """
        # return correspond class weight depands on the target label
        alpha = alpha.gather(0, targets.view(-1)).unsqueeze(1) 
        
        # Calculate log-softmax for stability
        log_prob = F.log_softmax(seg_pred, dim=-1) 
        prob = torch.exp(log_prob)  # [B, C]

        # Gather the log probabilities for the correct classes
        log_prob = log_prob.gather(1, targets.unsqueeze(1))  # for pt 
        prob = prob.gather(1, targets.unsqueeze(1))  # [B, 1]

        # Compute focal loss
        focal_loss = -alpha * (1 - prob) ** self.gamma * log_prob

        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss


In [None]:
log_train_dir = f'/home/travail/ptv3-vertebrae-segmentation/log/ptv3_small/class_2/num_examples_25/S_3/checkpoints/model.pth'
checkpoint = torch.load(log_train_dir)
print(checkpoint.keys())

In [None]:
# for training progress refers to 'model.pth'
train_loss = checkpoint['training_loss']
train_accuracy = checkpoint['train_accuracy']

eval_loss = checkpoint['eval_loss']
eval_accuracy = checkpoint['eval_accuracy']


epoch = checkpoint['epoch']
print(epoch)

In [None]:
# Assuming 'train_loss' and 'train_accuracy' are lists or arrays containing values per epoch
# Convert CUDA tensors to CPU before plotting (if necessary)
training_loss_epoch = [loss.detach().cpu().numpy() if isinstance(loss, torch.Tensor) else loss for loss in train_loss]
train_accuracy_epoch = [acc.detach().cpu().numpy() if isinstance(acc, torch.Tensor) else acc for acc in train_accuracy]

validation_loss_epoch = [loss.detach().cpu().numpy() if isinstance(loss, torch.Tensor) else loss for loss in eval_loss]
validation_accuracy_epoch = [acc.detach().cpu().numpy() if isinstance(acc, torch.Tensor) else acc for acc in eval_accuracy]


# Create a figure with 2 subplots, sharing the x-axis
fig, (ax1, ax2) = plt.subplots(2, figsize=(10, 7))

# Loss
ax1.set_xlabel('Epochs')
ax1.set_ylabel('Loss')
ax1.plot(range(len(training_loss_epoch)), training_loss_epoch, label="Train Loss", color="blue")
ax1.plot(range(len(validation_loss_epoch)), validation_loss_epoch, label="Eval Loss", color="orange")
ax1.tick_params(axis='y')
ax1.set_title('Training Loss and Validation Loss per Epoch')

# Accuracy
ax2.set_xlabel('Epochs')
ax2.set_ylabel('Accuracy')
ax2.plot(range(len(train_accuracy_epoch)), train_accuracy_epoch, label="Train Accuracy", color="blue")
ax2.plot(range(len(validation_accuracy_epoch)), validation_accuracy_epoch, label="Eval Accuracy", color="orange")
ax2.tick_params(axis='y')
ax2.set_title('Training Accuracy and Validation Accuracy per Epoch')

# Adding legends
ax1.legend(loc='best')
ax2.legend(loc='best')

# Adjust layout to avoid title overlap
plt.tight_layout()

# Show the plot
plt.show()

In [None]:
log_best_dir = f'/home/travail/ptv3-vertebrae-segmentation/log/ptv3_small/class_2/num_examples_25/S_3/checkpoints/best_model.pth'
checkpoint = torch.load(log_best_dir)
print(checkpoint.keys())
model = ptv3_small()
model.load_state_dict(checkpoint['model_state_dict'])
# [-] made a mistake of typing error already changed in 'train_sem_seg_ptv3.py (save best_model state)
DSC = checkpoint['Vertebrae DSC']
print(' - Vertebrae DSC: {:.3f}'.format(DSC))

## Validation 
**Notes**
- choose test_specimen_idx = i , i=[2,9]

In [None]:
# debug_dir = '/home/travail/PointTransformerV3/SpineDepth_labeled_symlink'
S3_vis_dir = '/home/travail/ptv3-vertebrae-segmentation/S3_vis'

# Go to augmentation.py to change 'adjust_strength'
transforms = Compose([
    # AdjustRGBColor(adjust_strength=adjust_strength, randomize_rate = randomize_rate),
    NormalizeFeatures(),
])

# TRAIN_DATASET = TrainDataset(root_dir=debug_dir, num_points=NUM_POINT, test_specimen_idx=test_specimen_idx, sample_ratio=0.2, transforms=transforms)
TEST_DATASET = TestDataset(root_dir=S3_vis_dir, num_points=NUM_POINT, test_specimen_idx=test_specimen_idx, transforms=transforms)

# trainDataLoader = torch.utils.data.DataLoader(TRAIN_DATASET, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True, worker_init_fn=seed_worker, generator=g)
testDataLoader = torch.utils.data.DataLoader(TEST_DATASET, batch_size=BATCH_SIZE, shuffle=False, drop_last=True, num_workers=NUM_WORKERS, pin_memory=True, worker_init_fn=seed_worker, generator=g)
# train_weights = torch.Tensor(TRAIN_DATASET.labelweights).to(device)
test_weights = torch.Tensor(TEST_DATASET.labelweights).to(device)


**For all the test dataset**

In [None]:
all_gt_labels = []
all_pred_labels = []
eval_loss_list = []
eval_accuracy_list = []

loss_fn = WeightedFocalLoss(gamma=5).to(device)
model.eval()
with torch.no_grad():
    eval_total_correct = 0
    eval_total_seen = 0
    eval_loss = 0
    # Load one batch from the test data
    for batch, (eval_points, eval_target) in tqdm(enumerate(testDataLoader), total=len(testDataLoader)):
        ## prepare input_dict:
        grid_size=0.01
        coord = eval_points[:, :, :3].float().to(device)
        feat =  eval_points[:, :, 0:6].float().to(device)
        # labels = target
        input_dict = {
        "coord": coord.view(-1, 3),  # [x, y, z]
        "feat": feat.view(-1, 6),  # [x, y, z, r, g, b]
        # "label": torch.tensor(labels, dtype=torch.long).to(device),  # Labels (binary)
        "batch": torch.repeat_interleave(torch.arange(eval_points.size(0)), eval_points.size(1)).to(device),
        }
        # print(f"batch shape: {input_dict['batch'].shape}")
        # Calculate grid coordinates
        input_dict["grid_coord"] = torch.div(
                    input_dict["coord"] - input_dict["coord"].min(0)[0],
                    grid_size, rounding_mode='trunc'
                ).int().to(device)
        # print(f"bf reshape: Grid coord shape: {input_dict['grid_coord'].shape}")
        input_dict["grid_coord"] = input_dict["grid_coord"].view(-1,3)
        # print(f"after reshape: Grid coord shape: {input_dict['grid_coord'].shape}")

        ##########%%%%%%%%%%%%##############            
        # put target to device
        eval_target = eval_target.long().to(device)
        pred, output = model(input_dict)
        seg_pred = pred.contiguous().view(-1, NUM_CLASSES)
        eval_target = eval_target.view(-1, 1)[:, 0]
        loss = loss_fn(seg_pred, eval_target, test_weights)
        ############%%%%%%%%%%%########            
        # pred_val = np.argmax(pred_val, 2)
        pred_choice = seg_pred.cpu().data.max(1)[1].numpy()
        # print(f"eval_pred_choice: {pred_choice}")
        batch_label = eval_target.view(-1, 1)[:, 0].cpu().data.numpy()
        # print(f"eval_ground_truth: {batch_label}")

        eval_correct = np.sum((pred_choice == batch_label))
        eval_total_correct += eval_correct
        eval_total_seen += (BATCH_SIZE * NUM_POINT)
        eval_loss += loss.item()

        # For confusion matrix calculation, collect ground truth and predicted labels
        all_gt_labels.append(batch_label.flatten())  # Flatten ground truth labels
        all_pred_labels.append(pred_choice.flatten())   # Flatten predicted labels

    # DEBUGGING 
    eval_loss /= len(testDataLoader)
    print(f'Eval Loss: {eval_loss:.4f}')
    eval_accuracy = eval_total_correct / float(eval_total_seen)
    print(f'Eval Accuracy: {eval_accuracy:.4f}')    
    eval_loss_list.append(eval_loss)
    eval_accuracy_list.append(eval_accuracy)
        
    # Concatenate all collected labels to compute confusion matrix
    all_gt_labels = np.concatenate(all_gt_labels)
    all_pred_labels = np.concatenate(all_pred_labels)

    CM = confusion_matrix(all_gt_labels, all_pred_labels, labels=np.arange(NUM_CLASSES))
    # Dynamic range for vertebrae classes (1 to NUM_CLASSES - 1)
    vertebrae_range = range(1, NUM_CLASSES)

    # Calculate TP, FN, FP, TN dynamically
    tp = np.sum([CM[i, i] for i in vertebrae_range])  # True positives for vertebrae
    fn = np.sum([CM[i, j] for i in vertebrae_range for j in range(NUM_CLASSES) if j != i])  # False negatives
    fp = np.sum([CM[j, i] for i in vertebrae_range for j in range(NUM_CLASSES) if j != i])  # False positives
    tn = CM[0, 0]  # True negatives (non-vertebrae correctly classified)
    acc=np.sum(np.diag(CM)/np.sum(CM))
    recall =tp/(tp+fn)
    precision=tp/(tp+fp)
    IoU = tp/ (tp+fn+fp)
    DSC = 2 * tp / (2 * tp + fn + fp)
    f1 = (2 * recall * precision) / (recall + precision)

    # Print with 3 decimal precision
    print('\nConfusion Matrix:')
    print(CM)

    print('\nTestset Accuracy (mean): {:.3f} %'.format(100 * acc))
    print('- Recall     : {:.3f}'.format(recall))
    print('- Precision  : {:.3f}'.format(precision))
    print('- F1 Score   : {:.3f}'.format(f1))
    print('- Vertebrae IoU: {:.3f}'.format(IoU))
    print('- Vertebrae DSC: {:.3f}'.format(DSC))

**Visualization**

In [None]:
# Batch size = number of frame = 6

point_clouds = eval_points.cpu().numpy()
print(f'point cloud per batch: {point_clouds.shape}')   #(batch_size, num_point, features)

## the last batch might have only 1 batch (instead of 2)
pred_label = pred_choice.reshape(eval_points.size(0), eval_points.size(1))
print(f'pred label: {pred_label.shape}')    #(batch_size, num_point)

gd_label = eval_target.cpu().numpy()
gd_label = gd_label.reshape(eval_points.size(0), eval_points.size(1))
print(f'ground truth label: {gd_label.shape}')  #(batch_size, num_point)

# Add new dimension
pred_label = pred_label[:, :, np.newaxis]  # (batch_size, num_point, 1)
gd_label = gd_label[:, :, np.newaxis]      # (batch_size, num_point, 1)

# Final shape: (6, 1024, 11 = (9+1+1))
vis_pc = np.concatenate((point_clouds, pred_label, gd_label), axis=2)
print(f'visualize point clouds: {vis_pc.shape}')

**2 classes Visualization**

In [None]:
batch_idx = vis_pc[1]  # Shape: (num_points, 11)

# Separate points based on their labels
gt_label_1 = batch_idx[batch_idx[:, 10] == 1][:, :3]  # Ground truth label = 1
gt_label_0 = batch_idx[batch_idx[:, 10] == 0][:, :3]  # Ground truth label = 0
pred_label_1 = batch_idx[batch_idx[:, 9] == 1][:, :3]  # Predicted label = 1
pred_label_0 = batch_idx[batch_idx[:, 9] == 0][:, :3] # Predicted label = 1

# Create Scatter3D plots for each group
trace_gt_1 = go.Scatter3d(
    x=gt_label_1[:, 0], y=gt_label_1[:, 1], z=-gt_label_1[:, 2],
    mode='markers',
    marker=dict(size=2, color='blue'),
    name='Ground Truth Vertebrae (Label = 1)'
)

trace_gt_0 = go.Scatter3d(
    x=gt_label_0[:, 0], y=gt_label_0[:, 1], z=-gt_label_0[:, 2],
    mode='markers',
    marker=dict(size=2, color='gray'),
    name='Ground Truth Others (Label = 0) '
)

trace_pred_1 = go.Scatter3d(
    x=pred_label_1[:, 0], y=pred_label_1[:, 1], z=-pred_label_1[:, 2],
    mode='markers',
    marker=dict(size=1.5, color='red'),
    name='Predicted Vertebrae (Label = 1)'
)

# trace_pred_0 = go.Scatter3d(
#     x=pred_label_0[:, 0], y=pred_label_0[:, 1], z=pred_label_0[:, 2],
#     mode='markers',
#     marker=dict(size=2.5, color='gray'),
#     name='Predicted Others (Label = 0)'
# )

# Combine all traces
fig = go.Figure(data=[trace_gt_1, trace_gt_0, trace_pred_1])

# Set layout for the figure
fig.update_layout(
    title="3D Point Cloud Visualization with Ground Truth and Predictions",
    scene=dict(
        xaxis_title='X',
        yaxis_title='Y',
        zaxis_title='Z'),
    height=750,
)

# Show the figure
fig.show()


In [None]:
# batch_idx = vis_pc[0]  # Shape: (num_points, 11)

# # Separate points based on their labels
# gt_label_1 = batch_idx[batch_idx[:, 10] == 1][:, :3]  # Ground truth label = 1
# gt_label_0 = batch_idx[batch_idx[:, 10] == 0][:, :3]  # Ground truth label = 0
# pred_label_1 = batch_idx[batch_idx[:, 9] == 1][:, :3]  # Predicted label = 1
# pred_label_0 = batch_idx[batch_idx[:, 9] == 0][:, :3] # Predicted label = 1

# # Create subplot layout: 1 row, 2 columns
# fig = make_subplots(
#     rows=1, cols=2,
#     specs=[[{'type': 'scatter3d'}, {'type': 'scatter3d'}]],
#     subplot_titles=("Ground Truth", "Prediction")
# )

# # Add ground truth traces to the first subplot (column 1)
# fig.add_trace(
#     go.Scatter3d(
#         x=gt_label_1[:, 0], y=gt_label_1[:, 1], z=gt_label_1[:, 2],
#         mode='markers',
#         marker=dict(size=2, color='gold'),
#         name='GT Vertebrae (Label = 1)'
#     ), row=1, col=1
# )

# fig.add_trace(
#     go.Scatter3d(
#         x=gt_label_0[:, 0], y=gt_label_0[:, 1], z=gt_label_0[:, 2],
#         mode='markers',
#         marker=dict(size=2, color='gray'),
#         name='GT Others (Label = 0)'
#     ), row=1, col=1
# )

# # Add prediction traces to the second subplot (column 2)
# fig.add_trace(
#     go.Scatter3d(
#         x=pred_label_1[:, 0], y=pred_label_1[:, 1], z=pred_label_1[:, 2],
#         mode='markers',
#         marker=dict(size=2, color='gold'),
#         name='Pred Vertebrae (Label = 1)'
#     ), row=1, col=2
# )

# fig.add_trace(
#     go.Scatter3d(
#         x=pred_label_0[:, 0], y=pred_label_0[:, 1], z=pred_label_0[:, 2],
#         mode='markers',
#         marker=dict(size=2, color='gray'),
#         name='Pred Others (Label = 0)'
#     ), row=1, col=2
# )

# # Set layout for the figure
# fig.update_layout(
#     title="3D Point Cloud: Ground Truth vs Prediction",
#     scene=dict(xaxis_title='X', yaxis_title='Y', zaxis_title='Z'),
#     height=800,
# )

# # Show the figure
# fig.show()


**6 classes Visualization**

In [None]:
# batch_idx = vis_pc[1]  # Shape: (num_points, 11)

# # Colors for ground truth (lighter) and predictions (darker)
# gt_colors = ['lightslategray', 'lightcoral', 'lightyellow', 'mistyrose', 'lightblue', 'lightgreen']
# pred_colors = ['lightslategray', 'red', 'orange', 'hotpink', 'blue', 'green']
# # Separate points based on their labels and create traces
# ground_truth_traces = []
# prediction_traces = []

# for label in range(6):  # Iterate over labels 0 to 5
#     # Filter points for the current ground truth label
#     gt_points = batch_idx[batch_idx[:, 10] == label][:, :3]
#     if gt_points.size > 0:  # Only add trace if there are points
#         ground_truth_traces.append(
#             go.Scatter3d(
#                 x=gt_points[:, 0], y=gt_points[:, 1], z=-gt_points[:, 2], # reverse z-axis
#                 mode='markers',
#                 marker=dict(size=1.5, color=gt_colors[label]),
#                 name=f'Ground Truth (Label = {label})'
#             )
#         )
    
#     # Filter points for the current predicted label
#     pred_points = batch_idx[batch_idx[:, 9] == label][:, :3]
#     if pred_points.size > 0:  # Only add trace if there are points
#         prediction_traces.append(
#             go.Scatter3d(
#                 x=pred_points[:, 0], y=pred_points[:, 1], z=-pred_points[:, 2], # reverse z-axis
#                 mode='markers',
#                 marker=dict(size=1.5, color=pred_colors[label], symbol='cross'),
#                 name=f'Predicted (Label = {label})'
#             )
#         )

# # Combine all traces
# fig = go.Figure(data=ground_truth_traces + prediction_traces)

# # Set layout for the figure
# fig.update_layout(
#     title="3D Point Cloud Visualization with Ground Truth and Predictions (6 Classes)",
#     scene=dict(
#         xaxis=dict(
#             title='X',
#             backgroundcolor="slategray"
#         ),
#         yaxis=dict(
#             title='Y',
#             backgroundcolor="slategray"
#         ),
#         zaxis=dict(
#             title='Z',
#             backgroundcolor="slategray"
#         ),
#     ),
#     height=750,
# )

# # Show the figure
# fig.show()

In [None]:
# batch_idx = vis_pc[1]  # Shape: (num_points, 11)

# # Colors for ground truth (lighter) and predictions (darker)
# # gt_colors = ['lightslategray', 'lightcoral', 'lightyellow', 'mistyrose', 'lightblue', 'lightgreen']
# gt_colors = ['gray', 'red', 'orange', 'pink', 'blue', 'green']
# pred_colors = ['gray', 'red', 'orange', 'pink', 'blue', 'green']

# # Create subplot layout: 1 row, 2 columns
# fig = make_subplots(
#     rows=1, cols=2,
#     specs=[[{'type': 'scatter3d'}, {'type': 'scatter3d'}]],
#     subplot_titles=("Ground Truth", "Prediction")
# )

# # Add ground truth traces to the first subplot (column 1)
# for label in range(6):  # Iterate over labels 0 to 5
#     gt_points = batch_idx[batch_idx[:, 10] == label][:, :3]
#     if gt_points.size > 0:  # Only add trace if there are points
#         fig.add_trace(
#             go.Scatter3d(
#                 x=gt_points[:, 0], y=gt_points[:, 1], z=-gt_points[:, 2], # reverse z-axis
#                 mode='markers',
#                 marker=dict(size=1.5, color=gt_colors[label]),
#                 name=f'GT (Label = {label})'
#             ), row=1, col=1
#         )

# # Add prediction traces to the second subplot (column 2)
# for label in range(6):  # Iterate over labels 0 to 5
#     pred_points = batch_idx[batch_idx[:, 9] == label][:, :3]
#     if pred_points.size > 0:  # Only add trace if there are points
#         fig.add_trace(
#             go.Scatter3d(
#                 x=pred_points[:, 0], y=pred_points[:, 1], z=-pred_points[:, 2], # reverse z-axis
#                 mode='markers',
#                 marker=dict(size=1.5, color=pred_colors[label]),
#                 name=f'Pred (Label = {label})'
#             ), row=1, col=2
#         )

# # Set layout for the figure
# fig.update_layout(
#     title="3D Point Cloud: Ground Truth vs Prediction (6 Classes)",
#     scene=dict(
#         xaxis_title='X',
#         yaxis_title='Y',
#         zaxis_title='Z'
#     ),
#     height=800,
#     width=1200,
# )

# # Show the figure
# fig.show()