In [52]:
import torch
import numpy as np
from torchvision import transforms
from funcs_transmorph import *

import pydicom as dicom
import matplotlib.pylab as plt
import numpy as np
import os
import skimage as ski
from skimage.transform import warp, AffineTransform, pyramid_expand, pyramid_reduce
import cv2
import scipy
from natsort import natsorted

from skimage.registration import phase_cross_correlation
from scipy import ndimage as scp
from tqdm import tqdm

import time
import pickle

from scipy.optimize import minimize as minz



In [53]:
def BrightestCenterSquareCrop(image):
    H, W = image.shape
    crop_size = min(H, W)

    # Find brightest point
    flat_idx = np.argmax(image)
    y, x = divmod(flat_idx, W)

    half = crop_size // 2
    top = max(0, y - half)
    left = max(0, x - half)

    # Ensure bounds
    if top + crop_size > H:
        top = H - crop_size
    if left + crop_size > W:
        left = W - crop_size

    cropped = image[top:top+crop_size, left:left+crop_size]
    return cropped


In [96]:
# === Define transform (same as training) ===
preprocess = transforms.Compose([
    transforms.ToTensor(),
    # BrightestCenterSquareCrop(),
    # transforms.Resize((64, 64)),
])

# === Load model on CPU ===
model_path = '/Users/akapatil/Documents/Transmorph_2D_translation/model_transmorph_batch32_ncc_nonnormalized_shiftrange5.pt' # adjust path as needed
model = torch.load(model_path, map_location='cpu')
model.eval()

# === Load Spatial Transformer (same size as training output) ===
warper = SpatialTransformer(size=(64, 64))  # Match your training size
warper.to('cpu')

# === Inference function ===
def infer(static_np, moving_np):
    # Ensure float32 numpy arrays
    static_np = static_np.astype(np.float32)
    moving_np = moving_np.astype(np.float32)

    # Preprocess
    static = preprocess(static_np)
    moving = preprocess(moving_np)

    # Add batch and channel dim: (1, 1, H, W)
    static = static.unsqueeze(0)
    moving = moving.unsqueeze(0)

    # Concat and infer
    with torch.no_grad():
        input_pair = torch.cat([static, moving], dim=1).double()  # shape: (1, 2, H, W)
        moved_img, pred_translation = model(input_pair)
        warped = warper(moving.double(), pred_translation)

    # Remove batch + channel dim
    warped_np = warped.squeeze().numpy()
    return warped_np, pred_translation.squeeze().numpy()


In [97]:
def load_data_dcm(scan_num,crop=False):
    path = f'{scan_num}/'
    pic_paths = []
    for i in os.listdir(path):
        if i.endswith('.dcm') or  i.endswith('.DCM') or i.endswith('.PNG'):
            pic_paths.append(i)
    pic_paths = np.array(natsorted(pic_paths))
    fst = dicom.dcmread(path+pic_paths[0]).pixel_array

    data = np.empty((len(pic_paths),fst.shape[0],fst.shape[1]))
    for i,j in enumerate(pic_paths):
        data[i] = dicom.dcmread(path+j).pixel_array
    data = data[:,100:-100,:].astype(np.float32)
    return data


In [98]:
data = load_data_dcm('scan20')

In [119]:
static = np.zeros((64,64))
moving = np.zeros((64,64))
static[20:30,30:40] = 4.870
moving[21:31,32:42] = 4.909

In [120]:
infer(static,moving)

(array([[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]]),
 array([2.89223887, 0.84435774]))