File Structure:

Original Image: ./assets/images/sagittal_mouse.nii

Organs: ./data
- Contains .nii files for organ voxels, and .obj files for meshes

Annotations: ./annotations

Import Libraries

In [1]:
from os.path import join
from os import listdir
import os

import json
import numpy as np
from skimage.measure import marching_cubes
import cv2
import nibabel as nib
import meshio

import plotly.graph_objects as go

Define global variables

In [2]:
WORKDIR = os.getcwd()
ANNOTATION_FOLDER = join(WORKDIR, "..", "annotations")
MOUSE_NIFTI_FILE = join(WORKDIR, "..", "assets", "images", "sagittal_mouse.nii")
MESHES_AND_VOXELS_FOLDER = join(WORKDIR, "..", "data")

Create mesh processor functions and library

In [3]:
class Mesh:
    def __init__(self, vertices:np.ndarray, faces:np.ndarray) -> None:
        """
        Parameters
        vertices: input vertices from marching_cubes
        faces: input faces from marching_cubes
        """
        self.vertices = vertices
        self.faces = faces
    
    def saveMesh(self, file_path:str) -> bool:
        """
        Parameters
        file_path: path to save the mesh to
        Returns
        true on success, false on failure
        """
        try:
            mesh = meshio.Mesh(self.vertices, {"triangle": self.faces})
            meshio.write(file_path, mesh, file_format="obj")

            return True
        except:
            return False
    
    def getVertices(self) -> np.ndarray:
        return self.vertices
    
    def getFaces(self) -> np.ndarray:
        return self.faces

def generate_mesh_from_voxels(voxels:np.ndarray, *, threshold:int=None, step_size:int=1, file_path:str=None) -> Mesh:
    """
    Parameters:
    voxels: 3D array of voxels (from nifti, or generated)
    threshold: threshold level to include from voxels in volume mesh. if not set, threshold is 95% of the maximum value
    step_size: step size for creating vertices
    file_path: file path to save .obj file to. if not set, no file is saved
    
    Returns:
    Mesh: mesh object of vertices and faces
    """
    if (threshold == None):
        maxval = np.max(voxels)
        threshold = int(maxval * 0.95)
    vertices, faces, _, _ = marching_cubes(voxels, level=threshold, step_size=step_size)
    mesh = Mesh(vertices, faces)
    if (file_path != None):
        mesh.saveMesh(file_path)

    return mesh

def read_mesh_from_file(file_path:str) -> Mesh:
    """
    Parameters:
    file_path: .obj file path to read mesh from

    Returns:
    Mesh: mesh object of vertices and faces
    """
    mesh = meshio.read(file_path)
    vertices = mesh.points
    faces = mesh.cells_dict["triangle"]

    return Mesh(vertices, faces)

def visualize_meshes(meshes:list) -> go.Figure:
    """
    Parameters:
    meshes: list of all Meshes you want to render

    Returns:
    figure: Plotly GO figure
    """
    data = []
    for mesh in meshes:
        vertices = mesh.getVertices()
        faces = mesh.getFaces()
        go_mesh = go.Mesh3d(
            x=vertices[:, 0], 
            y=vertices[:, 1], 
            z=vertices[:, 2], 
            i=faces[:, 0],
            j=faces[:, 1],
            k=faces[:, 2]
        )
        data.append(go_mesh)
    figure = go.Figure(data=data)
    figure.show()
    return figure

Define Atlas and Organ objects

In [4]:
class Atlas:
    def __init__(self, annotationDirectory, calibrationAnnotation, save=False) -> None:
        # TODO: Get directory length
        self.calibration: float = None
        self.xOffset: int = None
        self.yOffset: int = None
        self.zOffset: int = None

        self.atlasDims: tuple = None
        self.affine = None

        self.organs = {}

        self.annotationDirectory = annotationDirectory

        ct = nib.load("./assets/images/sample/CT_TS_HEUHR_In111_free_M1039_0h_220721-selfcal.nii")
        ct = nib.load(join("assets", "images", "sagittal_mouse.nii"))

        annFiles = listdir(annotationDirectory)
        numSlices = 0
        for file in annFiles:
            annList = json.load(open(join(self.annotationDirectory, file)))
            numSlices += len(annList)

        self.calibrateDepth(calibrationAnnotation, numSlices)
        self.calibrateImgSize(ct)
        self.constructAtlasFromList(annFiles, numSlices)
        self.constructImgVoxels(save=save)

    def constructAtlasFromList(self, fileList, numSlices):
        if (self.calibration == None):
            print("Requires calibration")
            return
        # get number of slices
        for file in fileList:
            path = join(self.annotationDirectory, file)
            if (path == self.calibrationFile):
                print("Calibration file. Skipping...")
                continue
            print(f"Reading file {file}")
            annotationList = json.load(open(path))
            for annotation in annotationList:
                fname = annotation["documents"][0]["name"]
                index = int(fname.split('.')[0].replace("rat",""))
                try:
                    for entity in annotation["annotation"]["annotationGroups"][0]["annotationEntities"]:
                        name = entity["name"]
                        try:
                            organ = self.organs[name]
                        except(KeyError):
                            organ = Organ(name, 
                                            numSlices, 
                                            self.calibration, 
                                            self.atlasDims,
                                            self.affine)
                            self.organs[name] = organ
                        organ.appendOrganSlice(index, entity)
                except:
                    print(annotation)

    def calibrateDepth(self, calibrationAnnotation, numSlices):
        f = open(calibrationAnnotation)
        annotation = json.load(f)[0]
        body = annotation["annotation"]["annotationGroups"][0]["annotationEntities"][0]
        domain = []
        for point in body["annotationBlocks"][0]["annotations"][0]["segments"][0]:
            domain.append(point[1])
        minZ = min(domain)
        maxZ = max(domain)
        diff = maxZ - minZ
        self.calibration = 1.0 * float(diff) / float(numSlices)
        self.calibrationFile = calibrationAnnotation
        return self.calibration
    
    def calibrateImgSize(self, inputNifti):
        """
        Parameters:
            inputNifti: nibabel.nifti1.Nifti1Image
        """
        # TODO: use nifi image size to calibrate canvas size for voxelclouds
        shape = inputNifti.get_fdata().shape
        self.atlasDims = shape
        self.affine = inputNifti.affine
        print(self.affine, type(self.affine))
        pass

    def constructImgVoxels(self, save=False):
        for organName in self.organs:
            organ = None
            try:
                organ = self.organs[organName]
            except:
                continue
            if (organ != None):
                print(f"Constructing voxel map for {organName}")
                organ.constructVoxelMap(save)
                print("Done")

class Organ:
    def __init__(self, name, numSlices, depth, dims, affine) -> None:
        self.name = name
        self.numSlices = numSlices
        self.slices = [[]] * numSlices
        self.scale = 1.0
        self.depth = depth * self.scale
        self.affine = affine
        # offset: [offset_x, offset_y, offset_z]
        self.offset = {
            "x": 250, # 124, # OFFSET_X - 36 * 6, # 424.2
            "y": 173, # 45, # OFFSET_Y - 80 * 6, # 33
            "z": 385 # 50 # 65
        }
        imgSliceDims = (numSlices, dims[1], dims[2])
        self.imageSlices: np.ndarray = np.zeros(imgSliceDims, dtype=np.uint8)
        self.voxelCloud: np.ndarray = np.zeros(dims, dtype=np.uint8)
    
    def appendOrganSlice(self, index, entity):
        for polygon in entity["annotationBlocks"][0]["annotations"]:
            polyPts = np.array(polygon["segments"][0].copy()).astype(int)
            for i, pt in enumerate(polyPts):
                polyPts[i] = [(self.scale * pt[1]) - self.offset['y'], (self.scale * pt[0]) - self.offset['x']]
            self.slices[index].append(polyPts)
            cv2.fillPoly(self.imageSlices[index], pts=[polyPts], color=(255, 255, 255))

    def constructVoxelMap(self, save=False):
        for z in range(self.voxelCloud.shape[0]):
            # i: voxel layer
            i = z - self.offset['z']
            index = float(i) / self.depth
            ind0 = int(np.floor(index))
            if (ind0 < 0):
                continue
            if (ind0 + 1 >= len(self.imageSlices)):
                break
            img0 = self.imageSlices[ind0]
            img1 = self.imageSlices[ind0 + 1]
            img0 = cv2.GaussianBlur(img0, (27, 27), cv2.BORDER_DEFAULT)
            img1 = cv2.GaussianBlur(img1, (27, 27), cv2.BORDER_DEFAULT)
            alpha = (float(i % int(np.round(self.depth)) ) ) / (self.depth)
            additiveImage = np.add(img0 *  (1.0 - alpha), img1 * alpha)
            self.voxelCloud[z][np.where(additiveImage > 196)] = 255
        
        self.customCalibration()
        self.generateMesh(save)

        if (save):
            img = nib.Nifti1Image(self.voxelCloud, self.affine)
            nib.save(img, f"./data/{self.name}.nii")
            print(f"Saved {self.name} image at ./data/{self.name}.nii")
            del self.imageSlices
            del self.slices
            del self.voxelCloud
    
    def customCalibration(self):
        self.voxelCloud = np.swapaxes(self.voxelCloud, 0, 1);
        self.voxelCloud = self.voxelCloud[::-1,::-1,::]

        # smooth between slices
        for i, img in enumerate(self.voxelCloud):
            # smoothImg = 
            self.voxelCloud[i] = cv2.GaussianBlur(img, (27, 27), cv2.BORDER_DEFAULT)
        
    def generateMesh(self, save=False):
        threshold = 50
        step_size = 3
        print("Getting mesh")
        vertices, faces, _, _ = marching_cubes(self.voxelCloud, level=threshold, step_size=step_size)
        self.vertices = vertices
        self.faces = faces
        if (save):
            self.saveMesh()

    def getMesh(self):
        return self.vertices, self.faces

    def saveMesh(self):
        mesh = meshio.Mesh(self.vertices, {"triangle": self.faces})
        meshio.write(f"./data/{self.name}.obj", mesh, file_format="obj")

Generate atlas if it doesn't exist yet

In [5]:
data = listdir(MESHES_AND_VOXELS_FOLDER)
if (len(data) == 0):
    atlas = Atlas(join("assets", "annotations"), join("assets", "annotations", "calibrate.json"), save=True)

Load and rescale sagittal mouse NIFTI image

In [9]:
mouse_data= nib.load(MOUSE_NIFTI_FILE).get_fdata()

print("Image Loaded")

original_shape = np.array(mouse_data.shape)
new_shape = (original_shape / 4).astype(int)

from scipy.ndimage import zoom

print("Rescaled image size: ", new_shape)

rescaled_data = zoom(mouse_data, 0.25, order=1)

rescaled_image = nib.Nifti1Image(rescaled_data, np.eye(4))

print("Saving Image")

nib.save(rescaled_image, "../assets/images/scaled_mouse.nii")

del mouse_data
del rescaled_data
