In [1]:
# VesselMNIST3D Dataset Analysis
# Medical Imaging Dataset Exploration

import matplotlib
matplotlib.use('TkAgg')  # or 'Qt5Agg' depending on your system

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
from matplotlib import colors
from matplotlib.widgets import Slider
import tensorflow as tf

# Configure matplotlib fonts
matplotlib.rcParams['font.family'] = 'sans-serif'
matplotlib.rcParams['font.sans-serif'] = ['Arial', 'Tahoma', 'DejaVu Sans', 'Verdana']




In [2]:
from medmnist import VesselMNIST3D


In [3]:
print("="*60)
print("LOADING VESSELMNIST3D DATASET")
print("="*60)

# Load datasets
train_dataset = VesselMNIST3D(split='train', size=28, download=True)
test_dataset = VesselMNIST3D(split='test', size=28, download=True)
val_dataset = VesselMNIST3D(split='val', size=28, download=True)

# Convert to lists
trainx = []
trainy = []
testx = []
testy = []
valx = []
valy = []

print("\nProcessing training data...")
for i in range(len(train_dataset)):
    trainx.append(train_dataset[i][0])
    trainy.append(train_dataset[i][1])

print("Processing test data...")
for i in range(len(test_dataset)):
    testx.append(test_dataset[i][0])
    testy.append(test_dataset[i][1])

print("Processing validation data...")
for i in range(len(val_dataset)):
    valx.append(val_dataset[i][0])
    valy.append(val_dataset[i][1])

# Convert to tensors
trainx_tensor = tf.convert_to_tensor(trainx, dtype=tf.float16)
trainy_tensor = tf.convert_to_tensor(trainy, dtype=tf.float16)
testx_tensor = tf.convert_to_tensor(testx, dtype=tf.float16)
testy_tensor = tf.convert_to_tensor(testy, dtype=tf.float16)
valx_tensor = tf.convert_to_tensor(valx, dtype=tf.float16)
valy_tensor = tf.convert_to_tensor(valy, dtype=tf.float16)

print("\nDataset loaded successfully!")


LOADING VESSELMNIST3D DATASET

Processing training data...
Processing test data...
Processing validation data...

Dataset loaded successfully!


In [4]:
# ============================================================================
# SECTION 1: DATASET STATISTICS
# ============================================================================

print("\n" + "="*60)
print("DATASET STATISTICS")
print("="*60)
print(f"Training samples: {len(train_dataset)}")
print(f"Test samples: {len(test_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
print(f"\nInput shape: {trainx_tensor.shape}")
print(f"Label shape: {trainy_tensor.shape}")
print(f"\nSingle image shape: {trainx[0].shape}")
print(f"Data type: {trainx_tensor.dtype}")


DATASET STATISTICS
Training samples: 1335
Test samples: 382
Validation samples: 191

Input shape: (1335, 1, 28, 28, 28)
Label shape: (1335, 1)

Single image shape: (1, 28, 28, 28)
Data type: <dtype: 'float16'>


In [8]:
# ============================================================================
# SECTION 2: LABEL DISTRIBUTION ANALYSIS
# ============================================================================

print("\n" + "="*60)
print("LABEL DISTRIBUTION")
print("="*60)

# Training set
unique_train, counts_train = np.unique(trainy, return_counts=True)
print("\nTraining set class distribution:")
for label, count in zip(unique_train, counts_train):
    percentage = (count / len(trainy)) * 100
    print(f"  Class {label}: {count} samples ({percentage:.2f}%)")  # Remove [0]

# Test set
unique_test, counts_test = np.unique(testy, return_counts=True)
print("\nTest set class distribution:")
for label, count in zip(unique_test, counts_test):
    percentage = (count / len(testy)) * 100
    print(f"  Class {label}: {count} samples ({percentage:.2f}%)")  # Remove [0]

# Validation set
unique_val, counts_val = np.unique(valy, return_counts=True)
print("\nValidation set class distribution:")
for label, count in zip(unique_val, counts_val):
    percentage = (count / len(valy)) * 100
    print(f"  Class {label}: {count} samples ({percentage:.2f}%)")  # Remove [0]


LABEL DISTRIBUTION

Training set class distribution:
  Class 0: 1185 samples (88.76%)
  Class 1: 150 samples (11.24%)

Test set class distribution:
  Class 0: 339 samples (88.74%)
  Class 1: 43 samples (11.26%)

Validation set class distribution:
  Class 0: 169 samples (88.48%)
  Class 1: 22 samples (11.52%)


In [6]:

# ============================================================================
# SECTION 3: DATA VALUE RANGE ANALYSIS
# ============================================================================

print("\n" + "="*60)
print("DATA VALUE RANGES")
print("="*60)

sample_vol = np.array(trainx[0])
print(f"Min value: {sample_vol.min()}")
print(f"Max value: {sample_vol.max()}")
print(f"Mean value: {sample_vol.mean():.4f}")
print(f"Std deviation: {sample_vol.std():.4f}")

print(f"\nData range: [{sample_vol.min()}, {sample_vol.max()}]")
if sample_vol.max() > 1.0:
    print("Note: Data may need normalization for neural network training")

# ============================================================================
# SECTION 4: MEMORY USAGE ANALYSIS
# ============================================================================

import sys

trainx_size_mb = sys.getsizeof(trainx_tensor.numpy()) / (1024**2)
testx_size_mb = sys.getsizeof(testx_tensor.numpy()) / (1024**2)
valx_size_mb = sys.getsizeof(valx_tensor.numpy()) / (1024**2)

print("\n" + "="*60)
print("MEMORY USAGE")
print("="*60)
print(f"Training data: {trainx_size_mb:.2f} MB")
print(f"Test data: {testx_size_mb:.2f} MB")
print(f"Validation data: {valx_size_mb:.2f} MB")
print(f"Total: {trainx_size_mb + testx_size_mb + valx_size_mb:.2f} MB")
print(f"\nUsing float16 saves ~50% memory compared to float32")




DATA VALUE RANGES
Min value: 0.0
Max value: 1.0
Mean value: 0.0374
Std deviation: 0.1896

Data range: [0.0, 1.0]

MEMORY USAGE
Training data: 55.90 MB
Test data: 15.99 MB
Validation data: 8.00 MB
Total: 79.89 MB

Using float16 saves ~50% memory compared to float32


In [7]:

# ============================================================================
# VISUALIZATION 1: CLASS DISTRIBUTION BAR CHART
# ============================================================================

print("\n" + "="*60)
print("GENERATING VISUALIZATIONS...")
print("="*60)

plt.figure(figsize=(10, 6))
bars = plt.bar(unique_train.flatten(), counts_train, edgecolor='black', alpha=0.7, color='steelblue')
plt.xlabel('Class Label', fontsize=12)
plt.ylabel('Number of Samples', fontsize=12)
plt.title('Class Distribution in Training Set', fontsize=14, fontweight='bold')
plt.xticks(unique_train.flatten())
plt.grid(True, alpha=0.3, axis='y')

# Add count labels on bars
for bar, count in zip(bars, counts_train):
    height = bar.get_height()
    plt.text(bar.get_x() + bar.get_width()/2., height,
             f'{count}',
             ha='center', va='bottom', fontsize=10)

plt.tight_layout()
# plt.savefig('class_distribution.png', dpi=150, bbox_inches='tight')
# print("✓ Saved: class_distribution.png")
plt.show()

# Calculate imbalance ratio
max_count = counts_train.max()
min_count = counts_train.min()
imbalance_ratio = max_count / min_count
print(f"\nClass imbalance ratio: {imbalance_ratio:.2f}:1")
if imbalance_ratio > 3:
    print("⚠️  Significant class imbalance detected!")
    print("   Consider using class weights or data augmentation during training.")

# ============================================================================
# VISUALIZATION 2: VOXEL INTENSITY HISTOGRAM
# ============================================================================

sample_vol = np.array(trainx[0]).squeeze()

plt.figure(figsize=(10, 5))
plt.hist(sample_vol.flatten(), bins=50, edgecolor='black', alpha=0.7, color='coral')
plt.xlabel('Voxel Intensity', fontsize=12)
plt.ylabel('Frequency', fontsize=12)
plt.title('Distribution of Voxel Intensities (Sample Volume)', fontsize=14, fontweight='bold')
plt.grid(True, alpha=0.3)
plt.tight_layout()
# plt.savefig('voxel_intensity_histogram.png', dpi=150, bbox_inches='tight')
# print("✓ Saved: voxel_intensity_histogram.png")
plt.show()

# ============================================================================
# VISUALIZATION 3: SAMPLES FROM EACH CLASS
# ============================================================================

# Get one example from each class
class_examples = {}
for i, label in enumerate(trainy):
    class_label = label[0]
    if class_label not in class_examples:
        class_examples[class_label] = i
    if len(class_examples) == len(np.unique(trainy)):
        break

num_classes = len(class_examples)
fig, axes = plt.subplots(1, num_classes, figsize=(4*num_classes, 4))

if num_classes == 1:
    axes = [axes]

for idx, (class_label, sample_idx) in enumerate(class_examples.items()):
    vol = np.array(trainx[sample_idx]).squeeze()
    mid_slice = vol.shape[0] // 2
    
    axes[idx].imshow(vol[mid_slice], cmap='gray')
    axes[idx].set_title(f'Class {class_label}\n(Sample {sample_idx})', fontsize=12)
    axes[idx].axis('off')

plt.suptitle('Middle Slice from Each Class', fontsize=14, fontweight='bold', y=1.02)
plt.tight_layout()
# plt.savefig('class_examples.png', dpi=150, bbox_inches='tight')
# print("✓ Saved: class_examples.png")
plt.show()

# ============================================================================
# VISUALIZATION 4: ORTHOGONAL VIEWS (AXIAL, CORONAL, SAGITTAL)
# ============================================================================

vol = np.array(trainx[1]).squeeze()

i_mid = vol.shape[0] // 2
j_mid = vol.shape[1] // 2
k_mid = vol.shape[2] // 2

fig, axes = plt.subplots(1, 3, figsize=(12, 4))

axes[0].imshow(vol[i_mid, :, :], cmap='gray')
axes[0].set_title(f'Axial View (slice {i_mid})', fontsize=12, fontweight='bold')
axes[0].axis('off')

axes[1].imshow(vol[:, j_mid, :], cmap='gray')
axes[1].set_title(f'Coronal View (slice {j_mid})', fontsize=12, fontweight='bold')
axes[1].axis('off')

axes[2].imshow(vol[:, :, k_mid], cmap='gray')
axes[2].set_title(f'Sagittal View (slice {k_mid})', fontsize=12, fontweight='bold')
axes[2].axis('off')

plt.suptitle('Orthogonal Views of 3D Volume', fontsize=14, fontweight='bold')
plt.tight_layout()
# plt.savefig('orthogonal_views.png', dpi=150, bbox_inches='tight')
# print("✓ Saved: orthogonal_views.png")
plt.show()

# ============================================================================
# VISUALIZATION 5: ALL SLICES GRID
# ============================================================================

num_slices = vol.shape[0]
rows, cols = 7, 4

fig, axes = plt.subplots(rows, cols, figsize=(10, 18))

for i, ax in enumerate(axes.flat):
    if i < num_slices:
        ax.imshow(vol[i], cmap='gray')
        ax.set_title(f"Slice {i}", fontsize=10)
        ax.axis('off')
    else:
        ax.axis('off')

plt.suptitle('All Slices of Sample Volume', fontsize=14, fontweight='bold')
plt.tight_layout()
# plt.savefig('all_slices_grid.png', dpi=150, bbox_inches='tight')
# print("✓ Saved: all_slices_grid.png")
plt.show()

# ============================================================================
# VISUALIZATION 6: MULTIPLE SAMPLES COMPARISON
# ============================================================================

num_samples = 4
fig, axes = plt.subplots(num_samples, 3, figsize=(12, 4*num_samples))

random_indices = np.random.choice(len(trainx), num_samples, replace=False)

for row, idx in enumerate(random_indices):
    vol = np.array(trainx[idx]).squeeze()
    label = trainy[idx][0]
    
    slice_indices = [vol.shape[0]//4, vol.shape[0]//2, 3*vol.shape[0]//4]
    
    for col, slice_idx in enumerate(slice_indices):
        axes[row, col].imshow(vol[slice_idx], cmap='gray')
        if row == 0:
            axes[row, col].set_title(f'Slice {slice_idx}', fontsize=11)
        if col == 0:
            axes[row, col].set_ylabel(f'Sample {idx}\nClass {label}', 
                                      fontsize=10, rotation=0, 
                                      ha='right', va='center')
        axes[row, col].axis('off')

plt.suptitle('Multiple Samples at Different Slice Positions', fontsize=14, fontweight='bold')
plt.tight_layout()
# plt.savefig('multiple_samples_comparison.png', dpi=150, bbox_inches='tight')
# print("✓ Saved: multiple_samples_comparison.png")
plt.show()

# ============================================================================
# VISUALIZATION 7: 3D VOXEL PLOT
# ============================================================================

fig = plt.figure(figsize=(10, 8))

vol = np.squeeze(trainx[1], axis=0)

ax = fig.add_subplot(111, projection='3d')

filled = vol > 0

norm = colors.Normalize(vmin=vol.min(), vmax=vol.max())
cmap = plt.cm.viridis

facecolors = cmap(norm(vol))

alpha = np.clip(vol, 0, 1)
facecolors[..., 3] = alpha
facecolors[~filled, 3] = 0.0

ax.voxels(filled, facecolors=facecolors)

ax.set_xlabel('X-axis')
ax.set_ylabel('Y-axis')
ax.set_zlabel('Z-axis')
plt.title('3D Voxel Visualization with Magnitude-Based Transparency', fontweight='bold')
plt.tight_layout()
# plt.savefig('3d_voxel_plot.png', dpi=150, bbox_inches='tight')
# print("✓ Saved: 3d_voxel_plot.png")
plt.show()

print("\n" + "="*60)
print("ANALYSIS COMPLETE!")
print("="*60)
print("\nAll visualizations have been generated and saved.")
print("Check the current directory for PNG files.")


GENERATING VISUALIZATIONS...

Class imbalance ratio: 7.90:1
⚠️  Significant class imbalance detected!
   Consider using class weights or data augmentation during training.

ANALYSIS COMPLETE!

All visualizations have been generated and saved.
Check the current directory for PNG files.
