In [80]:
import numpy as np
import nibabel as nib
import pyvista as pv
from pyvista import examples
import matplotlib.pyplot as plt
import os

#This script creates 3D visualizations of subcortical structures with degree values as color mapping. It loads a subcortical atlas, applies degree values to each region, and generates multi-view 3D renders using PyVista.

# 1. Set parameters
atlas_file = 'G:/Code/Atlas/Tian2020MSA/3T/Subcortex-Only/Tian_Subcortex_S2_3T_1mm.nii.gz'  # Replace with your atlas file path
degree_file = 'G:/PHD_1/NeuroT/Code/PlotSubCor/degree.txt'    # Replace with your degree file path
output_path = 'G:/PHD_1/NeuroT/Code/PlotSubCor/'             # Replace with your output directory
# Visualization parameters
min_val = 1      # Minimum value for color mapping
max_val = 16     # Maximum value for color mapping
colormap = 'viridis'  # Color scheme

# 2. Read data
print("Reading data...")
# Load atlas
atlas = nib.load(atlas_file)
atlas_data = atlas.get_fdata()

# Load degree values
degrees = np.loadtxt(degree_file)
print(f"Atlas shape: {atlas_data.shape}")
print(f"Unique labels in atlas: {np.unique(atlas_data)}")
print(f"Number of degree values: {len(degrees)}")
print(f"Degree range: [{degrees.min()}, {degrees.max()}]")

# 3. Create value mapping
print("Creating value mapping...")
# Print original data for verification
print("\nDegree values:", degrees)

# Create mapping: label 1 maps to degrees[0], label 2 maps to degrees[1], etc.
value_map = {i: degrees[i-1] for i in range(1, 33)}  # Start mapping from label 1
value_map[0] = 0  # Set background value to 0

# Print mapping relationships for verification
print("\nValue mapping:")
for label, value in sorted(value_map.items()):
    print(f"Label {label:2d} -> Degree {value:4.1f}")

# 4. Replace atlas labels with degree values
print("\nMapping values...")
volume_data = np.zeros_like(atlas_data)
for label, value in value_map.items():
    mask = (atlas_data == label)
    count = np.sum(mask)
    volume_data[mask] = value
    print(f"Label {label:2d} -> Value {value:4.1f}: {count:8d} voxels")

# Verify results
print("\nVerification:")
print("Unique values in volume_data:", np.unique(volume_data))

# 5. Convert to PyVista format
print("Converting to PyVista format...")
grid = pv.ImageData()
grid.dimensions = atlas_data.shape
grid.spacing = atlas.header.get_zooms()
grid.point_data["values"] = volume_data.flatten(order="F")  # Use point_data instead of point_arrays

# 6. Extract isosurfaces
print("Creating surface mesh...")
# Method 3: Extract the surface for each value, then merge
surfaces = []
unique_values = np.unique(volume_data)
unique_values = unique_values[unique_values > 0]  # Exclude background value

for value in unique_values:
    temp_surface = grid.threshold([value-0.1, value+0.1])
    if len(temp_surface.points) > 0:  # Ensure there are points
        surfaces.append(temp_surface)

# Merge all surfaces
surface = surfaces[0].merge(surfaces[1:])

# 7. Generate images from different viewpoints
print("Generating views...")
views = ['xy', 'yz', 'xz']
for view in views:
    print(f"Processing {view} view...")
    plotter = pv.Plotter(off_screen=True, window_size=[1000, 1000])
    
    plotter.add_mesh(surface, 
                    scalars="values",
                    clim=[1.0, 16.0],  # Explicitly set value range
                    cmap='rainbow',     # Use rainbow color spectrum
                    show_scalar_bar=True,
                    scalar_bar_args={'title': 'Degree Values',
                                   'n_labels': 8,
                                   'fmt': '%.1f',
                                   'position_x': 0.8,
                                   'position_y': 0.3})
    
    plotter.set_background('white')
    
    if view == 'xy':
        plotter.view_xy()
    elif view == 'yz':
        plotter.view_yz()
    else:  # xz
        plotter.view_xz()
    
    plotter.camera.zoom(1.5)
    
    # Add title and ticks to the color bar
    plotter.add_scalar_bar(title='Degree Values',
                          n_labels=8,
                          position_x=0.8,
                          position_y=0.3,
                          width=0.1)
    
    os.makedirs(output_path, exist_ok=True)
    output_file = os.path.join(output_path, f"sub_temp_{view}.png")
    plotter.screenshot(output_file)
    plotter.close()
    
    print(f"Saved {view} view to {output_file}")

print("Done! All views have been generated.")