In [None]:
!pip install pydicom torchvision tqdm scikit-image

Collecting pydicom
  Downloading pydicom-2.4.4-py3-none-any.whl.metadata (7.8 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch==2.3.1->torchvision)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch==2.3.1->torchvision)
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch==2.3.1->torchvision)
  Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch==2.3.1->torchvision)
  Using cached nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.1.3.1 (from torch==2.3.1->torchvision)
  Using cached nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.0.2.54 (from torch==2.3.1->torchvi

In [None]:
import os
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import pydicom
import cv2
import numpy as np
import torch
from glob import glob
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import torchvision.transforms as transforms


In [None]:
def load_dicom_image(filepath):
    dicom = pydicom.dcmread(filepath)
    img = dicom.pixel_array
    img = img[:, :, 0] if len(img.shape) == 3 else img
    img_resized = cv2.resize(img, (256, 256))
    img_normalized = img_resized / img_resized.max()
    return img_normalized.astype(np.float32)

def load_dicom_mask(filepath):
    dicom = pydicom.dcmread(filepath)
    mask = dicom.pixel_array
    mask = mask[:, :, 0] if len(mask.shape) == 3 else mask
    mask_resized = cv2.resize(mask, (256, 256))
    mask_binary = (mask_resized > 50).astype(np.float32)
    return mask_binary


In [None]:
class MRCTDataset(Dataset):
    def __init__(self, image_paths, mask_paths, transform=None):
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = load_dicom_image(self.image_paths[idx])
        mask = load_dicom_mask(self.mask_paths[idx])
        if self.transform:
            image = self.transform(image)
            mask = torch.tensor(mask, dtype=torch.float32).unsqueeze(0)  # Add channel dimension
        return image, mask


In [None]:
# Paths to datasets
root_dir = '/content/drive/MyDrive/z4wc364g79-1/JUH_MR-CT_dataset'
mr_image_paths = sorted(glob(os.path.join(root_dir, 'MR/image_MR/*.dcm')))
mr_mask_paths = sorted(glob(os.path.join(root_dir, 'MR/mask_MR/*.dcm')))
ct_image_paths = sorted(glob(os.path.join(root_dir, 'CT/image_CT/*.dcm')))
ct_mask_paths = sorted(glob(os.path.join(root_dir, 'CT/mask_CT/*.dcm')))

# Split dataset into training and validation sets
train_mr_image_paths, val_mr_image_paths, train_mr_mask_paths, val_mr_mask_paths = train_test_split(
    mr_image_paths, mr_mask_paths, test_size=0.2, random_state=42
)
train_ct_image_paths, val_ct_image_paths, train_ct_mask_paths, val_ct_mask_paths = train_test_split(
    ct_image_paths, ct_mask_paths, test_size=0.2, random_state=42
)


In [None]:
# Define augmentation transformations
augmentation_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.RandomRotation(30),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

# Parameters
num_augmentations_mri = 40  # For around 2500 images
num_augmentations_ct = 40  # For around 2500 images

# Custom dataset class to apply augmentations multiple times
class AugmentedDataset(Dataset):
    def __init__(self, image_paths, mask_paths, transform, num_augmentations):
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.transform = transform
        self.num_augmentations = num_augmentations

    def __len__(self):
        return len(self.image_paths) * self.num_augmentations

    def __getitem__(self, idx):
        original_idx = idx % len(self.image_paths)
        image = load_dicom_image(self.image_paths[original_idx])
        mask = load_dicom_mask(self.mask_paths[original_idx])

        if self.transform:
            augmented = self.transform(image)
            mask = torch.tensor(mask, dtype=torch.float32).unsqueeze(0)  # Add channel dimension
        return augmented, mask

# Create augmented datasets
augmented_mri_dataset = AugmentedDataset(train_mr_image_paths, train_mr_mask_paths, augmentation_transforms, num_augmentations_mri)
augmented_ct_dataset = AugmentedDataset(train_ct_image_paths, train_ct_mask_paths, augmentation_transforms, num_augmentations_ct)

# Create DataLoaders
batch_size = 8
augmented_mri_dataloader = DataLoader(augmented_mri_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
augmented_ct_dataloader = DataLoader(augmented_ct_dataset, batch_size=batch_size, shuffle=True, num_workers=4)

# Check dataset sizes
print(f"Number of augmented MRI images: {len(augmented_mri_dataset)}")
print(f"Number of augmented CT images: {len(augmented_ct_dataset)}")


Number of augmented MRI images: 2880
Number of augmented CT images: 2800


In [None]:
# Create validation datasets without augmentation
simple_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

val_mr_dataset = MRCTDataset(val_mr_image_paths, val_mr_mask_paths, transform=simple_transform)
val_ct_dataset = MRCTDataset(val_ct_image_paths, val_ct_mask_paths, transform=simple_transform)

val_mr_dataloader = DataLoader(val_mr_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
val_ct_dataloader = DataLoader(val_ct_dataset, batch_size=batch_size, shuffle=False, num_workers=2)


In [None]:
def check_consistency(image_paths, mask_paths, dataset_name):
    print(f"Number of {dataset_name} images: {len(image_paths)}")
    print(f"Number of {dataset_name} masks: {len(mask_paths)}")
    assert len(image_paths) == len(mask_paths), f"Mismatch in number of images and masks for {dataset_name}"

# Check consistency of original datasets
check_consistency(mr_image_paths, mr_mask_paths, "MRI")
check_consistency(ct_image_paths, ct_mask_paths, "CT")

# Check consistency of training datasets
check_consistency(train_mr_image_paths, train_mr_mask_paths, "Training MRI")
check_consistency(train_ct_image_paths, train_ct_mask_paths, "Training CT")

# Check consistency of validation datasets
check_consistency(val_mr_image_paths, val_mr_mask_paths, "Validation MRI")
check_consistency(val_ct_image_paths, val_ct_mask_paths, "Validation CT")


Number of MRI images: 90
Number of MRI masks: 90
Number of CT images: 88
Number of CT masks: 88
Number of Training MRI images: 72
Number of Training MRI masks: 72
Number of Training CT images: 70
Number of Training CT masks: 70
Number of Validation MRI images: 18
Number of Validation MRI masks: 18
Number of Validation CT images: 18
Number of Validation CT masks: 18


In [None]:
def check_image_sizes(dataloader, dataset_name):
    images, masks = next(iter(dataloader))
    print(f"Shape of {dataset_name} images: {images.shape}")
    print(f"Shape of {dataset_name} masks: {masks.shape}")

# Check image sizes for augmented datasets
check_image_sizes(augmented_mri_dataloader, "Augmented MRI")
check_image_sizes(augmented_ct_dataloader, "Augmented CT")

# Check image sizes for validation datasets
check_image_sizes(val_mr_dataloader, "Validation MRI")
check_image_sizes(val_ct_dataloader, "Validation CT")


Shape of Augmented MRI images: torch.Size([8, 1, 256, 256])
Shape of Augmented MRI masks: torch.Size([8, 1, 256, 256])
Shape of Augmented CT images: torch.Size([8, 1, 256, 256])
Shape of Augmented CT masks: torch.Size([8, 1, 256, 256])
Shape of Validation MRI images: torch.Size([8, 1, 256, 256])
Shape of Validation MRI masks: torch.Size([8, 1, 256, 256])
Shape of Validation CT images: torch.Size([8, 1, 256, 256])
Shape of Validation CT masks: torch.Size([8, 1, 256, 256])


Step 1: Defining the CycleGAN architetcure for Image Synthesis.    

Implementation of ResNet-based Generator and PatchGAN Discriminator.

Ablation Study - II

In [None]:
import torch
import torch.nn as nn
import torchvision.models as models

# ResNet-based generator
class ResNetGenerator(nn.Module):
    def __init__(self, input_nc, output_nc):
        super(ResNetGenerator, self).__init__()
        model = models.resnet18(pretrained=True)
        self.input_nc = input_nc
        self.output_nc = output_nc

        # Modify the first convolution layer to accept 1-channel input
        if input_nc != 3:
            self.model = model
            self.model.conv1 = nn.Conv2d(input_nc, 64, kernel_size=7, stride=2, padding=3, bias=False)
        else:
            self.model = model

        # Extract layers except the last fully connected layer
        self.resnet_layers = nn.Sequential(*list(self.model.children())[:-2])

        # Additional layers to upsample to the original image size
        self.upsample_layers = nn.Sequential(
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(True),
            nn.ConvTranspose2d(32, output_nc, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.resnet_layers(x)
        x = self.upsample_layers(x)
        return x

# PatchGAN discriminator
class PatchGANDiscriminator(nn.Module):
    def __init__(self, input_nc):
        super(PatchGANDiscriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(input_nc, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1)
        )

    def forward(self, x):
        return self.model(x)


In [None]:
import torch
import torch.nn as nn
import torchvision.models as models
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# VGG-based feature extractor
class VGGFeatureExtractor(nn.Module):
    def __init__(self, layers):
        super(VGGFeatureExtractor, self).__init__()
        vgg = models.vgg19(pretrained=True).features
        self.layers = layers
        self.vgg = nn.ModuleList([vgg[i] for i in layers])

    def forward(self, x):
        features = []
        for i, layer in enumerate(self.vgg):
            x = layer(x)
            if i in self.layers:
                features.append(x)
        return features

# Instantiate the feature extractor
feature_extractor = VGGFeatureExtractor([3, 8, 17]).to(device)





In [None]:
!pip install torchmetrics



Ablation Study - III (Incorporating SSIM Loss in the Training Loop)

Ablation Study - II (Training Loop without SSIM Loss)

In [None]:
!pip install torchmetrics




In [None]:
import torch
import torchmetrics.image as tmi

# Define SSIM and PSNR metrics
ssim_metric = tmi.StructuralSimilarityIndexMeasure(data_range=1.0).to(device)
psnr_metric = tmi.PeakSignalNoiseRatio(data_range=1.0).to(device)

# Function to compute SSIM and PSNR using torchmetrics
def compute_ssim_psnr(real_img, fake_img):
    ssim_score = ssim_metric(fake_img, real_img)
    psnr_score = psnr_metric(fake_img, real_img)
    return ssim_score.item(), psnr_score.item()

In [None]:
def evaluate_model(generator, dataloader, device):
    ssim_scores = []
    psnr_scores = []

    generator.eval()

    with torch.no_grad():
        for real_img, _ in dataloader:
            real_img = real_img.to(device)
            fake_img = generator(real_img)

            # Compute SSIM and PSNR
            ssim_score, psnr_score = compute_ssim_psnr(real_img, fake_img)

            ssim_scores.append(ssim_score)
            psnr_scores.append(psnr_score)

    # Compute average SSIM and PSNR
    avg_ssim = np.mean(ssim_scores)
    avg_psnr = np.mean(psnr_scores)

    return avg_ssim, avg_psnr

Result - Ablation Study - II

In [None]:
# Define validation dataloaders
val_mr_dataloader = DataLoader(val_mr_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
val_ct_dataloader = DataLoader(val_ct_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

# Evaluate MRI to CT generator
avg_ssim_mri_to_ct, avg_psnr_mri_to_ct = evaluate_model(G_MRI_to_CT, val_mr_dataloader, device)
print(f"MRI to CT - SSIM: {avg_ssim_mri_to_ct}, PSNR: {avg_psnr_mri_to_ct}")

# Evaluate CT to MRI generator
avg_ssim_ct_to_mri, avg_psnr_ct_to_mri = evaluate_model(G_CT_to_MRI, val_ct_dataloader, device)
print(f"CT to MRI - SSIM: {avg_ssim_ct_to_mri}, PSNR: {avg_psnr_ct_to_mri}")

MRI to CT - SSIM: 0.38917426268259686, PSNR: 5.718809286753337
CT to MRI - SSIM: 0.5082509318987528, PSNR: 8.227468490600586


In [None]:
import torch
import torchmetrics.image as tmi

# Define SSIM and PSNR metrics
ssim_metric = tmi.StructuralSimilarityIndexMeasure(data_range=1.0).to(device)
psnr_metric = tmi.PeakSignalNoiseRatio(data_range=1.0).to(device)

# Function to compute SSIM and PSNR using torchmetrics
def compute_ssim_psnr(real_img, fake_img):
    ssim_score = ssim_metric(fake_img, real_img)
    psnr_score = psnr_metric(fake_img, real_img)
    return ssim_score.item(), psnr_score.item()


In [None]:
def evaluate_model(generator, dataloader, device):
    ssim_scores = []
    psnr_scores = []

    generator.eval()

    with torch.no_grad():
        for real_img, _ in dataloader:
            real_img = real_img.to(device)
            fake_img = generator(real_img)

            # Compute SSIM and PSNR
            ssim_score, psnr_score = compute_ssim_psnr(real_img, fake_img)

            ssim_scores.append(ssim_score)
            psnr_scores.append(psnr_score)

    # Compute average SSIM and PSNR
    avg_ssim = np.mean(ssim_scores)
    avg_psnr = np.mean(psnr_scores)

    return avg_ssim, avg_psnr


Result - Ablation Study - III

In [None]:
# Define validation dataloaders
val_mr_dataloader = DataLoader(val_mr_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
val_ct_dataloader = DataLoader(val_ct_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

# Evaluate MRI to CT generator
avg_ssim_mri_to_ct, avg_psnr_mri_to_ct = evaluate_model(G_MRI_to_CT, val_mr_dataloader, device)
print(f"MRI to CT - SSIM: {avg_ssim_mri_to_ct}, PSNR: {avg_psnr_mri_to_ct}")

# Evaluate CT to MRI generator
avg_ssim_ct_to_mri, avg_psnr_ct_to_mri = evaluate_model(G_CT_to_MRI, val_ct_dataloader, device)
print(f"CT to MRI - SSIM: {avg_ssim_ct_to_mri}, PSNR: {avg_psnr_ct_to_mri}")


MRI to CT - SSIM: 0.4133886396884918, PSNR: 5.906961441040039
CT to MRI - SSIM: 0.5284612476825714, PSNR: 8.28885793685913
