In [29]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
!pip install SimpleITK

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting SimpleITK
  Downloading SimpleITK-2.2.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (52.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m52.7/52.7 MB[0m [31m20.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: SimpleITK
Successfully installed SimpleITK-2.2.1


In [3]:
from nibabel.testing import data_path
import SimpleITK as sitk
from SimpleITK import ResampleImageFilter
from google.colab import files
import os
import pandas as pd
import nibabel as nib 
import matplotlib.pyplot as plt
import numpy as np
data_folder = '/content/drive/MyDrive/ISMI_final_proj'
validation_pairs = '/content/drive/MyDrive/ISMI_final_proj/pairs_val_training.csv'

# Read the CSV file
pairs_df = pd.read_csv(validation_pairs)

# Iterate through the pairs and perform registration
registration_method = sitk.ImageRegistrationMethod()
deformation_fields = []

In [None]:

# Set interpolator
interpolator = sitk.sitkLinear
registration_method.SetInterpolator(interpolator)

# Set initial transform (if needed)
initial_transform = sitk.Transform()
registration_method.SetInitialTransform(initial_transform)

num_stages = 3

for _, pair in pairs_df.iterrows():
    fixed_id = pair[0]
    moving_id = pair[1]

    fixed_image_path = os.path.join(data_folder, f'L2R_Task4_HippocampusMRI_training/Training/img/hippocampus_{fixed_id}.nii.gz')
    moving_image_path = os.path.join(data_folder, f'L2R_Task4_HippocampusMRI_training/Training/img/hippocampus_{moving_id}.nii.gz')

    fixed_image = sitk.ReadImage(fixed_image_path)
    moving_image = sitk.ReadImage(moving_image_path)

    fixed_image = sitk.Cast(fixed_image, sitk.sitkFloat32)

    # Check the type and dimension of the images
    if fixed_image.GetPixelIDTypeAsString() != moving_image.GetPixelIDTypeAsString():
        moving_image = sitk.Cast(moving_image, fixed_image.GetPixelID()) #converts pixel type of moving image to match fixed img


    if fixed_image.GetSize() != moving_image.GetSize():
        raise ValueError("Dimension mismatch between fixed and moving images.")

    # Set initial transformation
    initial_transform = sitk.CenteredTransformInitializer(
        fixed_image, moving_image, sitk.Euler3DTransform(), sitk.CenteredTransformInitializerFilter.MOMENTS
    )
    registration_method.SetInitialTransform(initial_transform)

    # Set registration parameters
    registration_method.SetMetricAsMeanSquares()
    registration_method.SetOptimizerAsRegularStepGradientDescent(learningRate=0.1, minStep=1e-4, numberOfIterations=100)
    registration_method.SetInterpolator(sitk.sitkLinear)
    registration_method.SetShrinkFactorsPerLevel(shrinkFactors=[4, 2, 1])
    registration_method.SetSmoothingSigmasPerLevel(smoothingSigmas=[2, 1, 0])

    # Add observer to observe the registration process
    registration_method.AddCommand(sitk.sitkIterationEvent, lambda: print(f"Iteration: {registration_method.GetOptimizerIteration()}"))

    deformation_field = None
    moving_image_prev = moving_image

    #multistage registration
    for stage in range(num_stages):
        resampler = ResampleImageFilter()
        resampler.SetReferenceImage(fixed_image)

        if deformation_field is not None:
            # Create a transformation from the deformation field
            displacement_field = sitk.Cast(deformation_field, sitk.sitkVectorFloat64)
            transform = sitk.DisplacementFieldTransform(displacement_field)
            resampler.SetTransform(transform)
        resampler.SetInterpolator(sitk.sitkLinear)

        # Warp the moving image using the resampler
        warped_moving_image = resampler.Execute(moving_image_prev)

        # Perform registration with the warped moving image
        final_transform = registration_method.Execute(fixed_image, warped_moving_image)

        # Compute the deformation field for the current stage
        deformation_field = sitk.TransformToDisplacementField(final_transform, sitk.sitkVectorFloat64)

    deformation_fields.append(deformation_field)

    # Update the moving image for the next pair
    moving_image_prev = moving_image

# Save the deformation fields
#output_folder = os.path.join(data_folder, 'Displacement_Fields_warped')
#os.makedirs(output_folder, exist_ok=True)

#for i, deformation_field in enumerate(deformation_fields):
 #   output_path = os.path.join(output_folder, f"displacement_field_{i}.nii.gz")
  #  sitk.WriteImage(deformation_field, output_path)

In [None]:
import matplotlib.pyplot as plt
import nibabel as nib

output_folder = os.path.join(data_folder, 'Displacement_Fields_warped')
# Create the output folder for the warped images
os.makedirs(output_folder, exist_ok=True)

for _, pair in pairs_df.iterrows():
    moving_id = pair[1]
    moving_image_path = os.path.join(data_folder, f'L2R_Task4_HippocampusMRI_training/Training/img/hippocampus_{moving_id}.nii.gz')
    warped_image_path = os.path.join(output_folder, f"warped_image_{moving_id}.nii.gz")

    # Read the warped image
    warped_image = nib.load(warped_image_path)
    warped_image_data = warped_image.get_fdata()

    # Display the warped image
    plt.figure()
    plt.imshow(warped_image_data[:, :, warped_image_data.shape[2] // 2], cmap='gray')
    plt.title(f"Warped Image - ID: {moving_id}")
    plt.axis('off')
    plt.show()


In [None]:
#install packages
!pip install SimpleITK==1.2.0
!pip install torch==1.4.0+cu92 torchvision==0.5.0+cu92
!pip install tqdm==4.30
!pip install numpy==1.16.0
!pip install pandas==0.23.4
!pip install matplotlib==3.0.2
!pip install nibabel==2.3.3
!pip install threadpoolctl==2.0.0
!pip install scipy==1.4.1
!pip install evalutils==0.2.3
!pip install surface_distance==0.1


In [None]:
!pip install torch torchvision


In [42]:
import sys
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import argparse
sys.path.append('/content/drive/MyDrive/ISMI_final_proj')

from blocks import *
from model_loader import load_model

import os

blocks_path = os.path.join(os.getcwd(), 'blocks.py')
print(os.path.exists(blocks_path))

In [40]:
args = Namespace(
    arch='my_model',  # Unet
    model_abspath='/content/drive/MyDrive/ISMI_final_proj/Theoest model/Hippocampus_registration/Models/Baseline.pth.tar'  # Path to the pretrained model file
)
kwargs = {}


In [43]:
from my_model import Decoder, Encoder

# Define the parameters
in_channels = 4  # Number of input channels
out_channels = 1  # Number of output channels
pool_blocks = 5  # Number of pooling blocks

# Define the number of channels in each block (aligned with pool_blocks)
channels = [64, 128, 256, 512, 1024, 1024]  # Number of channels in each block

# Make sure the length of channels list is pool_blocks + 1
assert len(channels) == pool_blocks + 1, "Invalid number of channels"

last_activation = 'sigmoid'  # Activation function for the last layer
activation_type = 'leaky'  # Activation function for other layers
instance_norm = False  # Use instance normalization
batch_norm = False  # Use batch normalization
nb_Convs = [1, 1, 1, 1, 1]  # Number of convolution layers in each block

# Create the encoder and decoder
encoder = Encoder(pool_blocks, channels, activation_type, in_channels,
                  instance_norm, batch_norm, nb_Convs)
decoder = Decoder(pool_blocks, channels, out_channels, last_activation,
                  activation_type, instance_norm, batch_norm, nb_Convs)

# Create the U-Net model
model = nn.Sequential(encoder, decoder)








In [65]:


deformation_fields_path = '/content/drive/MyDrive/ISMI_final_proj/Displacement_Fields_warped'

# Create an empty list to store the deformation fields
deformation_fields = []

import os
import nibabel as nib

for filename in os.listdir(deformation_fields_path):
    if filename.startswith('displacement_field_') and filename.endswith('.nii.gz'):
        file_path = os.path.join(deformation_fields_path, filename)
        deformation_field = nib.load(file_path)
        # Perform further processing with the deformation field
        image_data = deformation_field.get_fdata()
        deformation_fields.append(image_data)  # Add the loaded deformation field to the list

# Verify the contents of the deformation_fields list
print("Number of deformation fields:", len(deformation_fields))


Number of deformation fields: 60


In [66]:
from torch.utils.data import DataLoader, Dataset

class DeformationFieldDataset(Dataset):
    def __init__(self, deformation_fields):
        self.deformation_fields = deformation_fields

    def __len__(self):
        return len(self.deformation_fields)

    def __getitem__(self, index):
        deformation_field = self.deformation_fields[index]
        
        # Return the deformation field as the sample
        return deformation_field


dataset = DeformationFieldDataset(deformation_fields)

# Define the batch size for training
batch_size = 32

# Create the dataloader
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)



In [46]:
# Creating the model (Unet)
model = nn.Sequential(encoder, decoder)

# Define the loss function
loss_fn = nn.MSELoss()

# Define the optimizer
optimizer = optim.Adam(model.parameters(), lr=0.001)



In [None]:
# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Set the model in training mode
model.train()

# Define the number of epochs
num_epochs = 10  #

# Iterate over your training data
for epoch in range(num_epochs):
    for batch_data in dataloader:
        # Move the batch data to the device (e.g., GPU)
        batch_data = batch_data.to(device)

        # Forward pass
        outputs = model(batch_data)

        # Compute the loss
        loss = loss_function(outputs, batch_data)  # Adjust the inputs based on your data and task

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Print the loss or other metrics if desired
        print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
              .format(epoch+1, num_epochs, step+1, total_steps, loss.item()))

# Save the trained model
torch.save(model.state_dict(), '/content/drive/MyDrive/ISMI_final_proj/my_model.py')
