In [1]:
import SimpleITK as sitk
from skimage import filters

import pyvips
# from skimage.transform import resize, rescale
import PIL
from PIL import Image
import imutils
from scipy import ndimage
from scipy.ndimage import shift
from skimage.transform import rescale

import cv2

from glob import glob

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

import scanpy as sc

import gc
import sys
import time
import os
import pathlib
from pathlib import Path
import warnings
from scipy.signal import fftconvolve

import tifffile

import matplotlib as mpl
# import matplotlib.pyplot as plt
# import matplotlib.colors
mpl.rcParams['pdf.fonttype'] = 42

from scipy.signal import find_peaks
sc._settings.settings._vector_friendly=True

from IPython.display import clear_output

In [10]:
def check_core(cid):
    base_dir = '/common/lamt2/HPV/data/xenium/alignment_v2'

    img = Image.open(f'{base_dir}/aligned_cores/images/{cid}_aligned.png')
    plt.imshow(img)
    plt.show()

    do_alignment = input(f'Do core alignment: [y]/n or reload/quit') or 'y'
    while do_alignment == 'reload' or do_alignment == 'r':
        img = Image.open(f'{base_dir}/aligned_cores/images/{cid}_aligned.png')
        plt.imshow(img)
        plt.title(f'{cid}')
        plt.show()
        do_alignment = input(f'Core {cid} alignment: [y]/n or reload/quit') or 'y'
        
    # Close the plots
    plt.close()
    clear_output(wait=False)
    
    return do_alignment

In [3]:
# Function for simpleitk alignment iterations
def command_iteration(method):
    """ Callback invoked when the optimization has an iteration """
    if method.GetOptimizerIteration() == 0:
        print("Estimated Scales: ", method.GetOptimizerScales())
    if method.GetOptimizerIteration() % 10 == 0:
        print(
            f"{method.GetOptimizerIteration():3} "
            + f"= {method.GetMetricValue():7.5f} "
            + f": {method.GetOptimizerPosition()}"
        )

# Function to do image alignment and save transformation
def alignimgs(fixed, moving, savepth_md, snm='', verbose=True, init_translation = [0, 0]):
 
    R = sitk.ImageRegistrationMethod()
    R.SetMetricAsCorrelation()
    R.SetOptimizerAsRegularStepGradientDescent(learningRate=2.0,
                                                minStep=1e-4,
                                                numberOfIterations=400,
                                                gradientMagnitudeTolerance=1e-8)

    R.SetOptimizerScalesFromIndexShift()  

    tx = sitk.CenteredTransformInitializer(fixed, moving, sitk.Similarity2DTransform())
    tx.SetTranslation(init_translation)
    R.SetInitialTransform(tx) 
    R.SetInterpolator(sitk.sitkLinear)
    R.AddCommand( sitk.sitkIterationEvent, lambda: command_iteration(R) )

    outTx = R.Execute(fixed, moving)

    # if savename is given, save transformation in savepth_md directory
    if len(snm) != 0:
        savenm = f'{savepth_md}/tfm_{snm}.hdf'
        print (savenm)
        sitk.WriteTransform(outTx, savenm)
        print('saved: ', savenm)

    if verbose:
        print("-------")
        print(outTx)
        print("Optimizer stop condition: {0}".format(R.GetOptimizerStopConditionDescription()))
        print(" Iteration: {0}".format(R.GetOptimizerIteration()))
        print(" Metric value: {0}".format(R.GetMetricValue()))

    
    resampler = sitk.ResampleImageFilter()
    resampler.SetReferenceImage(fixed);
    resampler.SetInterpolator(sitk.sitkLinear)
    resampler.SetDefaultPixelValue(0)
    resampler.SetTransform(outTx)
    
    out = resampler.Execute(moving)
    simg1 = sitk.Cast(sitk.RescaleIntensity(fixed), sitk.sitkUInt8)
    simg2 = sitk.Cast(sitk.RescaleIntensity(out), sitk.sitkUInt8)
    cimg = sitk.Compose(simg1, simg2)

    nda = sitk.GetArrayFromImage(cimg)
    
    return nda


In [12]:
def doit_percore_afterprealign_cropped(cid, pth_img, savepth_md, savepth_img, verbose=True, do_threshold = True, do_cropping = True, do_offset = True):    

    # Read in images
    fixed = sitk.ReadImage(f'{pth_img}/{cid}_dapi.png', sitk.sitkFloat32)
    moving = sitk.ReadImage(f'{pth_img}/{cid}_he.png', sitk.sitkFloat32)

    # Invert images
    arrf = sitk.GetArrayViewFromImage(fixed)
    arrm = sitk.GetArrayViewFromImage(moving)
    
    arrf = 255 - (255/arrf.max())*arrf
    arrm = 255 - (255/arrm.max())*arrm

    # Do thresholding
    arrf_filtered = arrf.copy()
    arrm_filtered = arrm.copy()
    
    if do_threshold:
        #Plot the pixel intensity of the dapi and he
        fig, ax = plt.subplots(1, 2, figsize = (7, 3))
        bins = [5*x for x in range((255//5) + 1)]
        ax[0].hist(arrf.flatten(), bins = bins)
        ax[0].set_title("dapi pixel intensity")
        ax[0].set_yscale('log')
        ax[1].hist(arrm.flatten(), bins = bins)
        ax[1].set_title('he pixel intensity')
        ax[1].set_yscale('log')
        plt.show()
        
        # Plot the otsu thresholded images for reference
        thresh_f = filters.threshold_otsu(arrf)
        thresh_m = filters.threshold_otsu(arrm)
        saturate = False
        fig, ax = plt.subplots(1, 2, figsize = (8, 4))
        ax[0].imshow(arrf > thresh_f, cmap = "Reds")
        ax[0].set_title(f'dapi, threshold + saturation {thresh_f:.2f}')
        ax[1].imshow(arrm > thresh_m, cmap = "Blues")
        ax[1].set_title(f'he, threshold + saturation {thresh_m:.2f}')
        plt.show()
        
        pass_thresh = 'n'
        
        while pass_thresh != 'y':
            threshold = 0
            # Filter the images with custom threshold
            arrf_filtered = arrf.copy()
            arrm_filtered = arrm.copy()
            
            # Get input for thresholds (saturate = any value above threshold set to max)
            thresholds = input('Input thresholds: (fixed, moving, saturate)/[default]') or f'{thresh_f}, {thresh_m}, 0'
            saturate = bool(int(thresholds.split(',')[-1])) if len(thresholds.split(',')) > 2 else False
            thresh_f, thresh_m = [float(thresholds.split(',')[0]), float(thresholds.split(',')[1])]

            print(f'fixed threshold: {thresh_f}')
            arrf_filtered[arrf_filtered <= thresh_f] = 0
            if saturate:
                arrf_filtered[arrf_filtered > thresh_f] = 255

            print(f'moving threshold: {thresh_m}')
            arrm_filtered[arrm_filtered <= thresh_m] = 0
            if saturate:
                arrm_filtered[arrm_filtered > thresh_m] = 255

            # Plot the filtered images
            fig, ax = plt.subplots(1, 2, figsize = (10, 15))
            ax[0].imshow(arrf_filtered, cmap = 'Reds')
            ax[0].set_title('filtered dapi')
            ax[1].imshow(arrm_filtered, cmap = 'Blues')
            ax[1].set_title('filtered he')
            plt.show()

            pass_thresh = input('Passes threshold: [y]/n') or 'y'
    else:
        # Plot the filtered images
        fig, ax = plt.subplots(1, 2, figsize = (10, 15))
        ax[0].imshow(arrf_filtered, cmap = 'Reds')
        ax[0].set_title('dapi')
        ax[1].imshow(arrm_filtered, cmap = 'Blues')
        ax[1].set_title('he')
        plt.show()
        
    # Now to do the alignment
    aligned = 'n'
    while aligned != 'y':
        
        # Initial cropping coords
        x0 = 0
        x1 = arrf_filtered.shape[1]
        y0 = 0
        y1 = arrf_filtered.shape[0]
        if do_cropping:
            cropped = 'n'
            if cropped == 'exit':
                return
            while cropped != 'y':
                # Get cropping coordinates
                coords = input("Input cropping coords:") or f'{x0}, {x1}, {y0}, {y1}'
                x0, x1, y0, y1 = [int(a) for a in coords.split(',')]
                
                # Plot the cropped dapi/he using the raw image and filtered images
                fig, ax = plt.subplots(2, 2, figsize = (8, 8))
                ax[0][0].imshow(arrf[y0:y1, x0:x1], cmap = 'Reds')
                ax[0][0].set_title('raw dapi')
                ax[0][1].imshow(arrm[y0:y1, x0:x1], cmap = 'Blues')
                ax[0][1].set_title('raw he')
                ax[1][0].imshow(arrf_filtered[y0:y1, x0:x1], cmap = 'Reds')
                ax[1][0].set_title('filtered dapi')
                ax[1][1].imshow(arrm_filtered[y0:y1, x0:x1], cmap = 'Blues')
                ax[1][1].set_title('filtered he')
                fig.suptitle(f'{x0}, {x1}, {y0}, {y1}')
                plt.show()
                
                # Check cropping
                cropped = input('Cropping okay: [y]/n or quit/reload') or 'y'
                if cropped == 'quit' or cropped == 'q':
                    return
                while cropped == 'reload' or cropped == 'r':
                    fig, ax = plt.subplots(2, 2, figsize = (8, 8))
                    ax[0][0].imshow(arrf[y0:y1, x0:x1], cmap = 'Reds')
                    ax[0][0].set_title('raw dapi')
                    ax[0][1].imshow(arrm[y0:y1, x0:x1], cmap = 'Blues')
                    ax[0][1].set_title('raw he')
                    ax[1][0].imshow(arrf_filtered[y0:y1, x0:x1], cmap = 'Reds')
                    ax[1][0].set_title('filtered dapi')
                    ax[1][1].imshow(arrm_filtered[y0:y1, x0:x1], cmap = 'Blues')
                    ax[1][1].set_title('filtered he')
                    fig.suptitle(f'{x0}, {x1}, {y0}, {y1}')

                    plt.show()
        plt.close()

        print (f'Doing alignment with coords {x0}, {x1}, {y0}, {y1}')
        #crop the images
        fixed_filtered = sitk.GetImageFromArray(arrf_filtered)
        arrf_cropped = arrf_filtered[y0:y1, x0:x1]
        fixed_cropped = sitk.GetImageFromArray(arrf_cropped)

        moving_filtered = sitk.GetImageFromArray(arrm_filtered)
        arrm_cropped = arrm_filtered[y0:y1, x0:x1]
        moving_cropped = sitk.GetImageFromArray(arrm_cropped)

        # Pad the images to match original size
        pad_filter_f = sitk.ConstantPadImageFilter()
        pad_filter_f.SetPadLowerBound([x0, y0])
        pad_filter_f.SetPadUpperBound([arrf.shape[1] - x1, arrf.shape[0] - y1])
        pad_filter_f.SetConstant(0)

        pad_filter_m = sitk.ConstantPadImageFilter()
        pad_filter_m.SetPadLowerBound([x0, y0])
        pad_filter_m.SetPadUpperBound([arrm.shape[1] - x1, arrm.shape[0] - y1])
        pad_filter_m.SetConstant(0)

        moving_padded = pad_filter_m.Execute(moving_cropped)
        fixed_padded = pad_filter_f.Execute(fixed_cropped)     
        
        x_init, y_init = 0, 0
        if do_offset:
            # Large picture with gridlines to get initial offset by inspection
            f, ax = plt.subplots(1, 1, figsize = (35,35))
            ax.imshow(arrm_cropped, cmap = "Blues")
            ax.imshow(arrf_cropped, cmap = "Reds", alpha = 0.5)
            ax.set_xticks(np.arange(0, x1-x0, 100))
            ax.set_yticks(np.arange(0, y1-y0, 100))
            ax.grid(True, which = 'both')
            plt.title('Initial offset')
            plt.show()
            x_init = 0
            y_init = 0
            offset = input("Initial offset: (x, y) or reload") or '0.0, 0.0'
            while offset == 'reload' or offset == 'r':
                ax.imshow(arrm_cropped, cmap = "Blues")
                ax.imshow(arrf_cropped, cmap = "Reds", alpha = 0.5)
                ax.set_xticks(np.arange(0, x1-x0, 100))
                ax.set_yticks(np.arange(0, y1-y0, 100))
                ax.grid(True, which = 'both')
                plt.title('Initial offset')
                plt.show()
                
                offset = input("Initial offset: (x, y) or reload") or '0.0, 0.0'

            x_init, y_init = [float(a) for a in offset.split(',')]
            print (f'Initial alignment [{x_init}, {y_init}]')
        # Do the alignment
        nda = alignimgs(fixed_padded, moving_padded, savepth_md, cid.split('/')[-1], verbose=True, init_translation=[x_init, y_init])

        # Overlay cropped, filtered, images in a png
        f, ax = plt.subplots(nrows=1, ncols=1, figsize=(35,35))
        ax.imshow(nda[:,:,1], cmap='Blues')
        ax.imshow(nda[:,:,0], cmap='Reds', alpha=0.5)
        ax.set_axis_off()

        plt.show()
        
        aligned = input('Image aligned: [y]/n or quit/reload') or 'y'
        while aligned == 'reload' or aligned == 'r':
            f, ax = plt.subplots(nrows=1, ncols=1, figsize=(35,35))
            ax.imshow(nda[:,:,1], cmap='Blues')
            ax.imshow(nda[:,:,0], cmap='Reds', alpha=0.5)
            ax.set_axis_off()
            plt.show()
            aligned = input('Image aligned: [y]/n or quit/reload') or 'y'

        if aligned == 'quit' or aligned == 'q':
            return
        plt.close()
        
        del moving_padded, fixed_padded, moving_cropped, fixed_cropped, arrf_cropped, arrm_cropped
        gc.collect()
        
    # Do the alignment of the full images
    # Get transformation from cropped alignment
    outTx = sitk.ReadTransform(f'{savepth_md}/tfm_{cid.split('/')[-1]}.hdf')
    resampler = sitk.ResampleImageFilter()
    resampler.SetReferenceImage(fixed)
    resampler.SetInterpolator(sitk.sitkLinear)
    resampler.SetDefaultPixelValue(0)
    resampler.SetTransform(outTx)
    
    out = resampler.Execute(moving)
    simg1 = sitk.Cast(sitk.RescaleIntensity(fixed), sitk.sitkUInt8)
    simg2 = sitk.Cast(sitk.RescaleIntensity(out), sitk.sitkUInt8)
    cimg = sitk.Compose(simg1, simg2)
    nda = sitk.GetArrayFromImage(cimg)

    savenm_npy = f'{savepth_img}/{cid.split('/')[-1]}.npy'

    np.save(savenm_npy, nda)
    
    fig, ax = plt.subplots(figsize = (35, 35))
    ax.imshow(nda[:,:,1], 'Blues_r')
    ax.imshow(nda[:,:,0], cmap = 'Reds_r', alpha = 0.5)
    
    if do_cropping:
        ax.plot([x0, x0], [y0, y1], color = 'black', linestyle = '--', linewidth = 2)
        ax.plot([x1, x1], [y0, y1], color = 'black', linestyle = '--', linewidth = 2)
        ax.plot([x0, x1], [y0, y0], color = 'black', linestyle = '--', linewidth = 2)
        ax.plot([x0, x1], [y1, y1], color = 'black', linestyle = '--', linewidth = 2)
        ax.annotate('Alignment Region', xy = (x1, y0))

    plt.title(cid)
    plt.savefig(f'{savepth_img}/images/{cid.split('/')[-1]}_aligned.png')
    plt.show()

    del fixed, moving
    gc.collect()

In [None]:
pth_img = '/common/lamt2/HPV/data/xenium/alignment_v2/prealigned_cores'
savepth_md = '/common/lamt2/HPV/data/xenium/alignment_v2/transformations'
savepth_img = '/common/lamt2/HPV/data/xenium/alignment_v2/aligned_cores'

In [6]:
df_cores = pd.read_csv('/common/lamt2/HPV/data/xenium/alignment_v2/alignment_validation.csv', index_col = 0)

In [15]:
bad_cores = df_cores[df_cores.aligned == 0].index.values
bad_cores

array(['output-XETG00206__0060366__Region_1__20250305__223715___39',
       'output-XETG00206__0060366__Region_1__20250305__223715___15',
       'output-XETG00206__0060364__Region_1__20250305__223715___31',
       'output-XETG00206__0060488__Region_1__20250312__004017___55'],
      dtype=object)

In [None]:
for core in bad_cores:
    print (f'Aligning core {core}')
    do_alignment = check_core(core)
    if do_alignment == 'n':
        continue
    if do_alignment == 'quit' or do_alignment == 'q':
        break
    print (df_cores.loc[core])
    aligned = '0'
    while aligned == '0':
        doit_percore_afterprealign_cropped(core, pth_img, savepth_md, savepth_img, verbose=True, do_threshold = False, do_cropping = True, do_offset = True)
        aligned = input("Core aligned: [0]/1/2 or skip/quit") or '0'
        if aligned == 'skip' or aligned == 'quit':
            break
        else:
            df_cores.at[core, 'aligned'] = int(aligned)
    if aligned == 'quit':
        break
    clear_output(wait = False)
    df_cores.to_csv('/common/lamt2/HPV/data/xenium/alignment_v2/alignment_validation.csv')

In [44]:
df_cores.to_csv('/common/lamt2/HPV/data/xenium/alignment_v2/alignment_validation.csv')