In [6]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import imageio
import os

In [8]:
# Load one-hot encoded segmentation map
path = '/shared/scratch/0/home/v_nishchay_nilabh/oasis_data/scans/OASIS_OAS1_0001_MR1/seg4_onehot.npy'
seg = np.load(path)  # shape: (C, D, H, W)

print("Shape:", seg.shape)  # Should be (4, 256, 256, 256)

Shape: (5, 256, 256, 256)


In [55]:
# Get label map from one-hot
label_map = np.argmax(seg, axis=0)  # shape: (256, 256, 256)

# Visualization settings
axis = 2  # 0 = axial, 1 = coronal, 2 = sagittal
colors = ['white', 'red', 'green', 'blue', 'yellow']
cmap = ListedColormap(colors[:seg.shape[0]])

gif_frames = []
temp_dir = "tmp_slices"
os.makedirs(temp_dir, exist_ok=True)

for idx in range(label_map.shape[axis]):
    if axis == 0:
        slice_2d = label_map[idx, :, :]
    elif axis == 1:
        slice_2d = label_map[:, idx, :]
    else:
        slice_2d = label_map[:, :, idx]

    # Save current slice as image
    temp_img_path = os.path.join(temp_dir, f"frame_{idx:03d}.png")
    plt.figure(figsize=(4, 4))
    plt.imshow(slice_2d, cmap=cmap, interpolation='none', vmin=0, vmax=4)
    plt.axis('off')
    plt.tight_layout()
    plt.savefig(temp_img_path, bbox_inches='tight', pad_inches=0)
    plt.close()

    # Read image into list
    gif_frames.append(imageio.v2.imread(temp_img_path))

# Save as GIF
gif_path = f"./images/segmentation_axis{axis}_slices.gif"
imageio.mimsave(gif_path, gif_frames, duration=0.08)

print(f"Saved GIF: {gif_path}")

Saved GIF: ./images/segmentation_axis2_slices.gif
