In [8]:
import pickle
import numpy as np
import time,sys,glob,os
import pandas as pd
import scipy.ndimage as ndi
import cc3d
import cv2
import torch
import open3d as o3d

from skimage import color, morphology, measure
from skimage.transform import downscale_local_mean
from skimage.registration import phase_cross_correlation

from scipy.stats import zscore
from scipy import sparse

import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from functions.cellregister import *
from functions.iterive_non_rigid import *

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [3]:
def load_and_extract_matches(lut_path, device):
    """Load LUT and extract matched coordinates and images with GPU support"""
    with open(lut_path, 'rb') as f:
        lookup_table = pickle.load(f)
    
    # Get images and convert to torch tensors
    invivo_image = torch.tensor(lookup_table['in_vivo']['Transformed'], 
                              dtype=torch.float32).to(device)
    exvivo_image = torch.tensor(lookup_table['exvivo_GCAMP']['Transformed'], 
                               dtype=torch.float32).to(device)
    
    # Extract matched coordinates
    in_vivo_coords = []
    ex_vivo_coords = []
    
    for ex_cell in lookup_table['exvivo_GCAMP']['cells']:
        if 'in_vivo_id' in ex_cell:
            in_vivo_cell = next(cell for cell in lookup_table['in_vivo']['cells'] 
                               if cell['id'] == ex_cell['in_vivo_id'])
            in_vivo_coords.append(in_vivo_cell['coordinates'])
            ex_vivo_coords.append(ex_cell['coordinates'])
    
    # Convert coordinates to torch tensors
    in_vivo_coords = torch.tensor(in_vivo_coords, dtype=torch.float32).to(device)
    ex_vivo_coords = torch.tensor(ex_vivo_coords, dtype=torch.float32).to(device)
    
    return in_vivo_coords, ex_vivo_coords, invivo_image, exvivo_image

In [4]:
def compute_rigid_transform(source_points, target_points):
    """Compute rigid transformation using GPU"""
    # Center the point sets
    source_centroid = torch.mean(source_points, dim=0)
    target_centroid = torch.mean(target_points, dim=0)
    
    centered_source = source_points - source_centroid
    centered_target = target_points - target_centroid
    
    # Compute optimal rotation
    H = centered_source.T @ centered_target
    U, _, V = torch.linalg.svd(H)
    R = V.T @ U.T
    
    # Ensure right-handed coordinate system
    if torch.linalg.det(R) < 0:
        V[2, :] *= -1
        R = V.T @ U.T
    
    # Compute scale
    scale = torch.sum(centered_target * (R @ centered_source.T).T) / torch.sum(centered_source * centered_source)
    
    # Compute translation
    t = target_centroid - scale * (R @ source_centroid)
    
    return R, t, scale


In [5]:
def apply_rigid_transform(image, target_image, R, t, scale=1.0):
    """Apply rigid transformation using GPU"""
    # Convert inputs to torch tensors if they aren't already
    if not torch.is_tensor(image):
        image = torch.tensor(image, dtype=torch.float32).to(R.device)
    if not torch.is_tensor(target_image):
        target_image = torch.tensor(target_image, dtype=torch.float32).to(R.device)
    
    # Create affine grid
    batch_size = 1
    depth, height, width = target_image.shape
    
    # Create 4x4 transformation matrix
    transform_matrix = torch.eye(4, device=R.device)
    transform_matrix[:3, :3] = scale * R
    transform_matrix[:3, 3] = t
    
    # Create normalized coordinate grid
    grid_d, grid_h, grid_w = torch.meshgrid(
        torch.linspace(-1, 1, depth, device=R.device),
        torch.linspace(-1, 1, height, device=R.device),
        torch.linspace(-1, 1, width, device=R.device)
    )
    
    grid = torch.stack([grid_w, grid_h, grid_d, torch.ones_like(grid_w)], dim=-1)
    grid = grid.reshape(-1, 4).T
    
    # Apply transformation
    transformed_grid = transform_matrix @ grid
    transformed_grid = transformed_grid[:3].T.reshape(depth, height, width, 3)
    
    # Normalize coordinates back to [-1, 1]
    transformed_grid[..., 0] = transformed_grid[..., 0] * 2 / (width - 1) - 1
    transformed_grid[..., 1] = transformed_grid[..., 1] * 2 / (height - 1) - 1
    transformed_grid[..., 2] = transformed_grid[..., 2] * 2 / (depth - 1) - 1
    
    # Use grid_sample for interpolation
    transformed = torch.nn.functional.grid_sample(
        image.unsqueeze(0).unsqueeze(0),
        transformed_grid.unsqueeze(0),
        mode='bilinear',
        padding_mode='zeros',
        align_corners=True
    )
    
    return transformed.squeeze()

def learn_and_apply_deformable(source_img, target_img, vec_ds=3, device='cuda'):
    """GPU-accelerated non-rigid registration"""
    ncc_list = []
    vec_field_smooth_list = []
    current_img = source_img.clone()
    
    # Convert to torch tensors if not already
    if not torch.is_tensor(source_img):
        source_img = torch.tensor(source_img, dtype=torch.float32).to(device)
    if not torch.is_tensor(target_img):
        target_img = torch.tensor(target_img, dtype=torch.float32).to(device)
    
    for i in range(30):
        if i % 5 == 0:
            print(f'iteration {i}...')
        
        # Move to CPU for phase correlation
        current_cpu = current_img.cpu().numpy()
        target_cpu = target_img.cpu().numpy()
        
        # Compute displacement field
        shifts = torch.tensor(
            phase_cross_correlation(current_cpu, target_cpu, 
                                  upsample_factor=vec_ds)[0],
            device=device
        )
        
        # Smooth displacement field
        vec_field_smooth = torch.tensor(
            gaussian_filter(shifts.cpu().numpy(), sigma=2),
            device=device
        )
        vec_field_smooth_list.append(vec_field_smooth)
        
        # Apply displacement
        for d in range(3):
            current_img = torch.roll(current_img, 
                                   shifts=int(vec_field_smooth[d].item()), 
                                   dims=d)
        
        # Compute correlation
        ncc = torch.corrcoef(current_img.flatten(), 
                            target_img.flatten())[0,1]
        ncc_list.append(ncc.item())
    
    return current_img, source_img, ncc_list, vec_field_smooth_list
    return current_img, source_img, ncc_list, vec_field_smooth_list

In [6]:
def main(lut_path):
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Load data and process
    in_vivo_coords, ex_vivo_coords, invivo_image, exvivo_image = load_and_extract_matches(lut_path, device)
    
    # Compute and apply transformations
    R, t, scale = compute_rigid_transform(in_vivo_coords, ex_vivo_coords)
    rigid_transformed = apply_rigid_transform(invivo_image, exvivo_image, R, t, scale)
    
    # Non-rigid registration
    final_image, _, ncc_list, _ = learn_and_apply_deformable(rigid_transformed, exvivo_image, device=device)
    
    # Move results back to CPU for visualization
    final_image = final_image.cpu().numpy()
    R = R.cpu().numpy()
    t = t.cpu().numpy()
    
    return R, t, scale, final_image

In [9]:
if __name__ == "__main__":
    lut_path = "/scratch/jl10897/Automatic_Registration/LUT_multimodal_487_Region2.pkl"
    R, t, scale, registered_image = main(lut_path)

Using device: cuda


FileNotFoundError: [Errno 2] No such file or directory: 'LUT_multimodal_487_Region2.pkl'