In [3]:
import os
import torch
import numpy as np
import SimpleITK as sitk
import ipyvolume as ipv
from torch.utils.data import Dataset, DataLoader
from scipy.ndimage import map_coordinates
from voxelmorph.pytorch.networks import VxmDense
from torch.optim import Adam

class CTScanDataset(Dataset):
    def __init__(self, directory_path):
        self.files = [os.path.join(directory_path, f) for f in os.listdir(directory_path) if f.endswith('.dcm')]

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        image = sitk.ReadImage(self.files[idx])
        image_array = sitk.GetArrayFromImage(image)
        image_array = (image_array - image_array.min()) / (image_array.max() - image_array.min())
        return image_array

class CTScanPairDataset(Dataset):
    def __init__(self, pairs):
        self.pairs = pairs

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        fixed_image, moving_image = self.pairs[idx]
        fixed_image = preprocess(fixed_image)
        moving_image = preprocess(moving_image)
        return fixed_image, moving_image

def load_data_for_training(directory_path):
    dataset = CTScanDataset(directory_path)
    pairs = [(dataset[i], dataset[i + 1]) for i in range(len(dataset) - 1)]
    return CTScanPairDataset(pairs)

def preprocess(image_array):
    image_array = image_array[np.newaxis, ...]
    image_tensor = torch.from_numpy(image_array)
    return image_tensor

def apply_displacement_field(moving_image_array, displacement_field):
    coords = np.mgrid[0:moving_image_array.shape[0], 0:moving_image_array.shape[1], 0:moving_image_array.shape[2]]
    coords += displacement_field
    warped_moving_image_array = map_coordinates(moving_image_array, coords, order=3)
    return warped_moving_image_array

def display_images(image_array):
    ipv.figure()
    ipv.volshow(image_array, level=[0.25, 0.75], opacity=0.03, level_width=0.1, data_min=0, data_max=1)
    ipv.show()

# Paths to the DICOM directories
image_directory_path = r"C:\Users\HP\Documents\GitHub\3D-3D_Image_Registration\SE000003"

# Load the DICOM images
dataset = load_data_for_training(image_directory_path)

# Initialize the model and optimizer
model = VxmDense(inshape=(None, None, None, 1), nb_unet_features=[32, 32, 32, 32])
optimizer = Adam(model.parameters())

# Train the model
num_epochs = 10
for epoch in range(num_epochs):
    print(f'Starting epoch {epoch + 1}')
    for i, (fixed_image, moving_image) in enumerate(dataset):
        optimizer.zero_grad()
        y_pred, _ = model([fixed_image.float().unsqueeze(0), moving_image.float().unsqueeze(0)])
        loss = torch.nn.MSELoss()(y_pred, fixed_image.float().unsqueeze(0))
        loss.backward()
        optimizer.step()

# Save the model
model_save_path = r"C:\Users\HP\Documents\GitHub\3D-3D_Image_Registration\SE000003\model.pth"
torch.save(model.state_dict(), model_save_path)

# Load the model and apply it to a pair of images for visualization
model.load_state_dict(torch.load(model_save_path))

# Choose a pair of images to register and visualize
fixed_image_tensor, moving_image_tensor = dataset[0]  # Change index if necessary

# Compute the displacement field
displacement_field, _ = model([fixed_image_tensor.float().unsqueeze(0), moving_image_tensor.float().unsqueeze(0)])

# Convert displacement field back to numpy
displacement_field = displacement_field.detach().numpy()

# Apply displacement field to moving image
warped_moving_image_array = apply_displacement_field(moving_image_tensor.numpy(), displacement_field)

# Display the images after transformation
display_images(fixed_image_tensor.numpy().squeeze())
display_images(warped_moving_image_array.squeeze())


ModuleNotFoundError: No module named 'voxelmorph.pytorch'

In [4]:
!pip install voxelmorph.pytorch

ERROR: Could not find a version that satisfies the requirement voxelmorph.pytorch (from versions: none)
ERROR: No matching distribution found for voxelmorph.pytorch

[notice] A new release of pip is available: 23.0 -> 23.1.2
[notice] To update, run: C:\Users\HP\AppData\Local\Microsoft\WindowsApps\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\python.exe -m pip install --upgrade pip
