In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from preprocessing import train_loader, train_dataset

Number of MRI images: 1744
Number of CT images: 1742
Shape of an MRI image array: (280, 212, 3)
Shape of a CT image array: (512, 512)
Shape of a resized MRI image array: (256, 256)
Shape of a resized CT image array: (256, 256)
Minimum pixel value in a normalized MRI image: 0.0
Maximum pixel value in a normalized MRI image: 0.8392156862745098
Minimum pixel value in a normalized CT image: 0.0
Maximum pixel value in a normalized CT image: 1.0
Number of paired data samples: 1742
Number of training samples: 1219
Number of validation samples: 261
Number of test samples: 262


In [2]:
# Define the architecture for the shared encoder
class SharedEncoder(nn.Module):
    def __init__(self, final_image_size, latent_dim):
        super(SharedEncoder, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=2, padding=1)
        self.conv3 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=2, padding=1)
        self.fc = nn.Linear(256 * (final_image_size // 8) * (final_image_size // 8), latent_dim)  # Adjust the linear layer input size

        self.final_image_size = final_image_size
        self.latent_dim = latent_dim

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = x.view(x.size(0), -1)
        encoded_features = self.fc(x)
        return encoded_features


In [3]:
class SharedDecoder(nn.Module):
    def __init__(self, final_image_size, latent_dim):
        super(SharedDecoder, self).__init__()
        self.final_image_size = final_image_size
        self.latent_dim = latent_dim

        # Calculate the number of output features from the Linear layer
        num_features = 256 * (final_image_size // 8) * (final_image_size // 8)

        # Adjust the Linear layer to match the encoder's output
        self.fc = nn.Linear(latent_dim*2, num_features)

        self.conv3 = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.conv2 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.conv1 = nn.ConvTranspose2d(in_channels=64, out_channels=1, kernel_size=3, stride=2, padding=1, output_padding=1)

    def forward(self, x):
        x = self.fc(x)
        x = x.view(x.size(0), 256, self.final_image_size // 8, self.final_image_size // 8)
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv2(x))
        decoded_image = torch.sigmoid(self.conv1(x))
        return decoded_image


In [4]:
# Define the cINN model
class ConditionalINN(nn.Module):
    def __init__(self, final_image_size, latent_dim):
        super(ConditionalINN, self).__init__()
        self.shared_encoder = SharedEncoder(final_image_size, latent_dim)
        self.shared_decoder = SharedDecoder(final_image_size, latent_dim)
        # Inside the forward method of ConditionalINN
    def forward(self, x, direction):
        if direction == 0:  # MRI to CT translation
            encoded_features = self.shared_encoder(x)
            noise = torch.randn(x.size(0), self.shared_encoder.latent_dim).to(x.device)

            # Check the shapes of encoded_features and noise
            combined_input = torch.cat((encoded_features, noise), dim=1)


            # Debug prints
            print("encoded_features shape:", encoded_features.shape)
            print("noise shape:", noise.shape)
            print("combined_input shape:", combined_input.shape)

            generated_image = self.shared_decoder(combined_input)



            return generated_image
        elif direction == 1:  # Add other translation directions if needed
            # Add code for other translation directions here
            pass

In [5]:
# Instantiate the cINN model
final_image_size = 256 # ...  # Set the image size after convolutions
latent_dim = 64 # ...  # Set the desired latent dimension
cinn = ConditionalINN(final_image_size, latent_dim)

In [6]:
# Define training loop
def train_epoch(model, dataloader, optimizer, criterion):
    model.train()
    for batch_mri, batch_ct in dataloader:
        optimizer.zero_grad()
        # Perform MRI-to-CT translation
        generated_ct = model(batch_mri, direction=0)
        # Calculate the loss (e.g., reconstruction loss)
        loss = criterion(generated_ct, batch_ct)

        # Backpropagation and optimization
        loss.backward()
        optimizer.step()

In [9]:
# Define training parameters
epochs = 10
learning_rate = 0.001
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(cinn.parameters(), lr=learning_rate)

In [10]:
# Training loop
for epoch in range(epochs):
    train_epoch(cinn, train_loader, optimizer, criterion)
    print(f"Epoch [{epoch+1}/{epochs}] completed")

encoded_features shape: torch.Size([32, 64])
noise shape: torch.Size([32, 64])
combined_input shape: torch.Size([32, 128])
encoded_features shape: torch.Size([32, 64])
noise shape: torch.Size([32, 64])
combined_input shape: torch.Size([32, 128])
encoded_features shape: torch.Size([32, 64])
noise shape: torch.Size([32, 64])
combined_input shape: torch.Size([32, 128])
encoded_features shape: torch.Size([32, 64])
noise shape: torch.Size([32, 64])
combined_input shape: torch.Size([32, 128])
encoded_features shape: torch.Size([32, 64])
noise shape: torch.Size([32, 64])
combined_input shape: torch.Size([32, 128])
encoded_features shape: torch.Size([32, 64])
noise shape: torch.Size([32, 64])
combined_input shape: torch.Size([32, 128])
encoded_features shape: torch.Size([32, 64])
noise shape: torch.Size([32, 64])
combined_input shape: torch.Size([32, 128])
encoded_features shape: torch.Size([32, 64])
noise shape: torch.Size([32, 64])
combined_input shape: torch.Size([32, 128])
encoded_features

In [11]:
# After training, you can use the cINN model to perform translation
input_mri = "dataset/images/trainB/mri12.jpg"  # Load or generate an input MRI image for inference
generated_ct = cinn(input_mri, direction=0)  # Perform MRI-to-CT translation

TypeError: conv2d() received an invalid combination of arguments - got (str, Parameter, Parameter, tuple, tuple, tuple, int), but expected one of:
 * (Tensor input, Tensor weight, Tensor bias, tuple of ints stride, tuple of ints padding, tuple of ints dilation, int groups)
      didn't match because some of the arguments have invalid types: (!str!, !Parameter!, !Parameter!, !tuple of (int, int)!, !tuple of (int, int)!, !tuple of (int, int)!, int)
 * (Tensor input, Tensor weight, Tensor bias, tuple of ints stride, str padding, tuple of ints dilation, int groups)
      didn't match because some of the arguments have invalid types: (!str!, !Parameter!, !Parameter!, !tuple of (int, int)!, !tuple of (int, int)!, !tuple of (int, int)!, int)


In [None]:
# You can use or save the 'generated_ct' image as needed
# save the generated image in the new output folder
import os
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from torchvision.utils import save_image

In [None]:
# create the output folder
output_folder = "output"
if not os.path.exists(output_folder):
    os.makedirs(output_folder)

# save the generated image
save_image(generated_ct, os.path.join(output_folder, "generated_ct.png"))