In [None]:
import os
import SimpleITK as sitk
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np


In [None]:
class MedicalDataset(Dataset):
    """Custom Dataset for loading CT and PET images.

    Args:
        Dataset (torch.utils.data.Dataset): Inherits from PyTorch Dataset class.

    Attributes:
        data (list): Contains tuples of CT and PET image tensors.
    """

    def __init__(self, patient_folders, target_size=(128, 128, 64)):
        """Initializes the MedicalDataset with CT and PET images.

        Args:
            patient_folders (list): List of patient folder paths.
            target_size (tuple): Desired size for resizing images.
        """
        self.data = []
        for folder in patient_folders:
            ct_folder = os.path.join(folder, 'ct')
            pet_folder = os.path.join(folder, 'pet')
            ct_images = [self.load_and_resize_image(os.path.join(ct_folder, img), target_size) for img in os.listdir(ct_folder)]
            pet_images = [self.load_and_resize_image(os.path.join(pet_folder, img), target_size) for img in os.listdir(pet_folder)]
            self.data.append((torch.tensor(ct_images, dtype=torch.float32), torch.tensor(pet_images, dtype=torch.float32)))

    def load_and_resize_image(self, filepath, target_size):
        """Loads and resizes a DICOM image.

        Args:
            filepath (str): Path to the DICOM file.
            target_size (tuple): Desired size for resizing.

        Returns:
            np.ndarray: Resized image array.
        """
        image = sitk.ReadImage(filepath)
        image_array = sitk.GetArrayFromImage(image)
        resized_image = sitk.GetArrayFromImage(sitk.Resample(image, target_size))
        return resized_image

    def __len__(self):
        """Returns the length of the dataset."""
        return len(self.data)

    def __getitem__(self, idx):
        """Retrieves an item from the dataset.

        Args:
            idx (int): Index of the desired item.

        Returns:
            tuple: CT and PET images tensors.
        """
        return self.data[idx]


In [None]:
class RegistrationNet(nn.Module):
    """Simple Neural Network Model for image registration.

    Args:
        nn.Module: Inherits from PyTorch's neural network module.

    Attributes:
        conv1 (nn.Conv3d): First convolutional layer.
        conv2 (nn.Conv3d): Second convolutional layer.
        fc1 (nn.Linear): First fully connected layer.
        fc2 (nn.Linear): Second fully connected layer.
    """

    def __init__(self):
        """Initializes the RegistrationNet model."""
        super(RegistrationNet, self).__init__()
        self.conv1 = nn.Conv3d(1, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv3d(16, 32, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(32 * 128 * 128 * 64 // 4, 128)  # Adjust based on your input size
        self.fc2 = nn.Linear(128, 6)  # 3 for translation and 3 for rotation

    def forward(self, x):
        """Defines the forward pass of the model.

        Args:
            x (torch.Tensor): Input tensor.

        Returns:
            torch.Tensor: Output tensor with registration parameters.
        """
        x = nn.functional.relu(self.conv1(x.unsqueeze(1)))  # Add channel dimension
        x = nn.functional.max_pool3d(x, kernel_size=2)
        x = nn.functional.relu(self.conv2(x))
        x = x.view(x.size(0), -1)  # Flatten the tensor
        x = nn.functional.relu(self.fc1(x))
        x = self.fc2(x)
        return x


In [None]:
def load_data(patient_dirs, test_size=0.2):
    """Loads and splits the dataset into training and testing sets.

    Args:
        patient_dirs (list): List of patient directories.
        test_size (float): Proportion of the dataset to include in the test split.

    Returns:
        DataLoader, DataLoader: Training and testing data loaders.
    """
    dataset = MedicalDataset(patient_dirs)
    dataset_size = len(dataset)
    indices = list(range(dataset_size))
    np.random.shuffle(indices)

    split = int(np.floor(test_size * dataset_size))
    test_indices = indices[:split]
    train_indices = indices[split:]

    train_data = DataLoader(dataset, batch_size=2, sampler=torch.utils.data.SubsetRandomSampler(train_indices))
    test_data = DataLoader(dataset, batch_size=2, sampler=torch.utils.data.SubsetRandomSampler(test_indices))

    return train_data, test_data

In [None]:
patient_folders = [f'/content/mahak/patient{i}' for i in range(1, 11)]
train_loader, test_loader = load_data(patient_folders)

# Initialize the model, loss function, and optimizer
model = RegistrationNet()
criterion = nn.MSELoss()  # Placeholder for actual loss calculation
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training the model
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    for ct_images, pet_images in train_loader:
        optimizer.zero_grad()
        outputs = model(ct_images)
        loss = criterion(outputs, pet_images.view(-1, 6))  # Adjust based on your actual output
        loss.backward()
        optimizer.step()
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

# Testing on a single patient
model.eval()
with torch.no_grad():
    for ct_images, pet_images in test_loader:
        registration_params = model(ct_images)
        print("Predicted Registration Parameters:", registration_params)

        # Transform the PET images for the test patient
        transformed_pet_images = []
        for i in range(len(pet_images)):
            pet_image = sitk.GetImageFromArray(pet_images[i].numpy())

            # Prepare the translation and rotation as a single list
            translation = registration_params[i][:3].numpy().tolist()  # tx, ty, tz
            rotation = registration_params[i][3:].numpy().tolist()      # rx, ry, rz

            # Combine translation and rotation into a single list
            transform_parameters = translation + rotation

            # Create the transform with a single list
            transform = sitk.Euler3DTransform(transform_parameters)

            resampled_pet = sitk.Resample(
                pet_image,
                ct_images[i].numpy(),
                transform,
                sitk.sitkLinear,
                0.0,
                pet_image.GetPixelID()
            )
            transformed_pet_images.append(sitk.GetArrayFromImage(resampled_pet))

        # Stack transformed images to create a single tensor
        transformed_pet_images = torch.tensor(transformed_pet_images, dtype=torch.float32)

        # Here you can save or visualize the transformed PET images
        # Example: Save the first transformed image
        output_path = '/content/transformed_pet_image.dcm'
        sitk.WriteImage(sitk.GetImageFromArray(transformed_pet_images[0].numpy()), output_path)

print("Registration completed for test patient.")