In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from preprocessing import train_loader, val_loader

Number of MRI images: 1744
Number of CT images: 1742
Shape of an MRI image array: (280, 212, 3)
Shape of a CT image array: (512, 512)
Shape of a resized MRI image array: (256, 256)
Shape of a resized CT image array: (256, 256)
Minimum pixel value in a normalized MRI image: 0.0
Maximum pixel value in a normalized MRI image: 0.8392156862745098
Minimum pixel value in a normalized CT image: 0.0
Maximum pixel value in a normalized CT image: 1.0
Number of paired data samples: 1742
Number of training samples: 1219
Number of validation samples: 261
Number of test samples: 262


In [2]:
# Define the architecture for the shared encoder
class SharedEncoder(nn.Module):
    def __init__(self, final_image_size, latent_dim):
        super(SharedEncoder, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=2, padding=1)
        self.conv3 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=2, padding=1)
        self.fc = nn.Linear(256 * (final_image_size // 8) * (final_image_size // 8), latent_dim)  # Adjust the linear layer input size

        self.final_image_size = final_image_size
        self.latent_dim = latent_dim

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = x.view(x.size(0), -1)
        encoded_features = self.fc(x)
        return encoded_features


In [3]:
class SharedDecoder(nn.Module):
    def __init__(self, final_image_size, latent_dim):
        super(SharedDecoder, self).__init__()
        self.final_image_size = final_image_size
        self.latent_dim = latent_dim

        # Calculate the number of output features from the Linear layer
        num_features = 256 * (final_image_size // 8) * (final_image_size // 8)

        # Adjust the Linear layer to match the encoder's output
        self.fc = nn.Linear(latent_dim*2, num_features)

        self.conv3 = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.conv2 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.conv1 = nn.ConvTranspose2d(in_channels=64, out_channels=1, kernel_size=3, stride=2, padding=1, output_padding=1)

    def forward(self, x):
        x = self.fc(x)
        x = x.view(x.size(0), 256, self.final_image_size // 8, self.final_image_size // 8)
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv2(x))
        decoded_image = torch.sigmoid(self.conv1(x))
        return decoded_image


In [4]:
# Define the cINN model
class ConditionalINN(nn.Module):
    def __init__(self, final_image_size, latent_dim):
        super(ConditionalINN, self).__init__()
        self.shared_encoder = SharedEncoder(final_image_size, latent_dim)
        self.shared_decoder = SharedDecoder(final_image_size, latent_dim)
        # Inside the forward method of ConditionalINN
    def forward(self, x, direction):
        if direction == 0:  # MRI to CT translation
            encoded_features = self.shared_encoder(x)
            noise = torch.randn(x.size(0), self.shared_encoder.latent_dim).to(x.device)

            # Check the shapes of encoded_features and noise
            combined_input = torch.cat((encoded_features, noise), dim=1)


            # # Debug prints
            # print("encoded_features shape:", encoded_features.shape)
            # print("noise shape:", noise.shape)
            # print("combined_input shape:", combined_input.shape)

            generated_image = self.shared_decoder(combined_input)



            return generated_image
        elif direction == 1:  # Add other translation directions if needed
            # Add code for other translation directions here
            pass

In [5]:
# Instantiate the cINN model
final_image_size = 256 # ...  # Set the image size after convolutions
latent_dim = 64 # ...  # Set the desired latent dimension
cinn = ConditionalINN(final_image_size, latent_dim)

In [6]:
# Define a function to calculate PSNR
def psnr(img1, img2):
    mse = torch.mean((img1 - img2) ** 2)
    max_pixel = 1.0  # Assuming pixel values are in [0, 1] range
    psnr_value = 20 * torch.log10(max_pixel / torch.sqrt(mse))
    return psnr_value.item()

In [7]:
# Define training loop
'''def train_epoch(model, dataloader, optimizer, criterion):
    model.train()
    for batch_mri, batch_ct in dataloader:
        optimizer.zero_grad()
        # Perform MRI-to-CT translation
        generated_ct = model(batch_mri, direction=0)
        # Calculate the loss (e.g., reconstruction loss)
        loss = criterion(generated_ct, batch_ct)

        # Backpropagation and optimization
        loss.backward()
        optimizer.step()'''

'def train_epoch(model, dataloader, optimizer, criterion):\n    model.train()\n    for batch_mri, batch_ct in dataloader:\n        optimizer.zero_grad()\n        # Perform MRI-to-CT translation\n        generated_ct = model(batch_mri, direction=0)\n        # Calculate the loss (e.g., reconstruction loss)\n        loss = criterion(generated_ct, batch_ct)\n\n        # Backpropagation and optimization\n        loss.backward()\n        optimizer.step()'

In [8]:
# Define training parameters
epochs = 20
learning_rate = 0.001
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(cinn.parameters(), lr=learning_rate)

In [9]:
model=cinn
for epoch in range(epochs):
    model.train()
    total_loss = 0.0
    total_psnr = 0.0
    num_batches = len(train_loader)

    for batch_mri, batch_ct in train_loader:
        optimizer.zero_grad()
        # Perform MRI-to-CT translation
        generated_ct = model(batch_mri, direction=0)
        # Calculate the loss (e.g., reconstruction loss)
        loss = criterion(generated_ct, batch_ct)
        total_loss += loss.item()

        # Calculate and accumulate PSNR
        psnr_value = psnr(generated_ct, batch_ct)
        total_psnr += psnr_value

        # Backpropagation and optimization
        loss.backward()
        optimizer.step()

    # Calculate average loss and PSNR for the training dataset
    avg_loss = total_loss / num_batches
    avg_psnr = total_psnr / num_batches

    print(f"Epoch [{epoch+1}/{epochs}] - Training Loss: {avg_loss:.4f}, Training PSNR: {avg_psnr:.4f}")

    # Validation loop (similar to training loop)
    model.eval()
    total_val_loss = 0.0
    total_val_psnr = 0.0
    num_val_batches = len(val_loader)

    for val_batch_mri, val_batch_ct in val_loader:
        with torch.no_grad():  # Disable gradient computation for validation
            # Perform MRI-to-CT translation
            val_generated_ct = model(val_batch_mri, direction=0)
            # Calculate the validation loss
            val_loss = criterion(val_generated_ct, val_batch_ct)
            total_val_loss += val_loss.item()

            # Calculate and accumulate PSNR for validation
            val_psnr_value = psnr(val_generated_ct, val_batch_ct)
            total_val_psnr += val_psnr_value

    # Calculate average validation loss and PSNR
    avg_val_loss = total_val_loss / num_val_batches
    avg_val_psnr = total_val_psnr / num_val_batches

    print(f"Epoch [{epoch+1}/{epochs}] - Validation Loss: {avg_val_loss:.4f}, Validation PSNR: {avg_val_psnr:.4f}")


Epoch [1/1] - Training Loss: 0.0761, Training PSNR: 11.5130
Epoch [1/1] - Validation Loss: 0.0624, Validation PSNR: 12.0570


In [10]:
from PIL import Image
from torchvision import transforms

In [11]:

input_mri = Image.open("dataset/images/trainB/mri12.jpg")
target_size = (256, 256)
input_mri=input_mri.resize(target_size).convert('L')

# Convert the input image to a PyTorch tensor
input_mri = transforms.ToTensor()(input_mri)
input_mri = input_mri.unsqueeze(0)

# Perform MRI-to-CT translation
generated_ct = cinn(input_mri, direction=0)

In [12]:
# You can use or save the 'generated_ct' image as needed
# save the generated image in the new output folder
import os
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from torchvision.utils import save_image

In [13]:
# create the output folder
output_folder = "output"
if not os.path.exists(output_folder):
    os.makedirs(output_folder)

# save the generated image
save_image(generated_ct, os.path.join(output_folder, "generated_ct.png"))