In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
from PIL import Image 
import SimpleITK as sitk
import matplotlib.pyplot as plt
import numpy as np
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr
import random
import nibabel as nib
from torch.utils.tensorboard import SummaryWriter
import glob as glob

class VolumeDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.volume_files = [
            os.path.join(dirpath, f)
            for dirpath, _, filenames in os.walk(root_dir)
            if "20" in dirpath and "80" not in dirpath
            for f in filenames
            if f.endswith('.mhd')
        ]
    def __len__(self):
        return len(self.volume_files)

    def __getitem__(self, idx):
        volume_path = self.volume_files[idx]
        if volume_path.endswith(('.nii', '.nii.gz')):
            volume = nib.load(volume_path).get_fdata()
        elif volume_path.endswith('.mhd'):
            volume = sitk.GetArrayFromImage(sitk.ReadImage(volume_path))

        volume = torch.from_numpy(volume).float().unsqueeze(0)
        volume = (volume - volume.min()) / (volume.max() - volume.min())
        high_res_volume = F.interpolate(volume.unsqueeze(0), size=(128, 128, 128), mode='trilinear', align_corners=False).squeeze(0)
        low_res_volume = F.interpolate(high_res_volume.unsqueeze(0), size=(80, 80, 80), mode='trilinear', align_corners=False).squeeze(0)

        if self.transform:
            low_res_volume = self.transform(low_res_volume)

        return low_res_volume, high_res_volume
def random_crop(volume, crop_size):
    w, h, d = volume.shape[1:]
    th, tw, td = crop_size
    
    if w == th and h == tw and d == td:
        return volume

    x1 = torch.randint(0, w - th + 1, (1,))
    y1 = torch.randint(0, h - tw + 1, (1,))
    z1 = torch.randint(0, d - td + 1, (1,))

    cropped_volume = volume[:, x1:x1+th, y1:y1+tw, z1:z1+td]

    # Pad if necessary to maintain original size
    pad_left = max(0, -x1)
    pad_right = max(0, x1 + th - w)
    pad_top = max(0, -y1)
    pad_bottom = max(0, y1 + tw - h)
    pad_front = max(0, -z1)
    pad_back = max(0, z1 + td - d)

    cropped_volume = torch.nn.functional.pad(cropped_volume, (pad_left, pad_right, pad_top, pad_bottom, pad_front, pad_back))

    return cropped_volume

root_dir = "raw"
dataset = VolumeDataset(root_dir=root_dir)
train_size = int(0.8 * len(dataset))
val_size = int(0.1 * len(dataset))
test_size = len(dataset) - train_size - val_size
print(f"Train size: {train_size}, Validation size: {val_size}, Test size: {test_size}")
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False)
num_train_images = len(train_dataset)
num_val_images = len(val_dataset)
num_test_images = len(test_dataset)

print(f"Number of images in Training set: {num_train_images}")
print(f"Number of images in Validation set: {num_val_images}")
print(f"Number of images in Test set: {num_test_images}")
num_train_batches = len(train_loader)
num_val_batches = len(val_loader)
num_test_batches = len(test_loader)

print(f"Number of batches in Training set: {num_train_batches}") 
print(f"Number of batches in Validation set: {num_val_batches}")
print(f"Number of batches in Test set: {num_test_batches}")


class FSRCNN_3D(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv_1 = nn.Conv3d(in_channels=1, out_channels=56, kernel_size=5, padding=2)
        self.conv_2 = nn.Conv3d(in_channels=56, out_channels=12, kernel_size=1, padding=0)
        self.conv_3 = nn.Conv3d(in_channels=12, out_channels=12, kernel_size=3, padding=1)
        self.conv_4 = nn.Conv3d(in_channels=12, out_channels=12, kernel_size=3, padding=1)
        self.conv_5 = nn.Conv3d(in_channels=12, out_channels=12, kernel_size=3, padding=1)
        self.conv_6 = nn.Conv3d(in_channels=12, out_channels=12, kernel_size=3, padding=1)
        self.conv_7 = nn.Conv3d(in_channels=12, out_channels=56, kernel_size=1, padding=0)
        self.de_conv_1 = nn.ConvTranspose3d(in_channels=56, out_channels=1, kernel_size=9, stride=3, padding=3, output_padding=0)
        self.dropout = nn.Dropout(p=0.5)

    def forward(self, x):
        output = self.conv_1(x)
        output = F.relu(output)
        output = self.conv_2(output)
        output = F.relu(output)
        output = self.conv_3(output)
        output = F.relu(output)
        output = self.conv_4(output)
        output = F.relu(output)
        output = self.conv_5(output)
        output = F.relu(output)
        output = self.conv_6(output)
        output = F.relu(output)
        output = self.conv_7(output)
        output = self.dropout(output)
        output = self.de_conv_1(output)
        output = F.relu(output)
        return output


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

model = FSRCNN_3D()
model= nn.DataParallel(model)
model.to(device)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)


def calculate_metrics(output, target):
    output_np = output.squeeze().cpu().numpy()  
    target_np = target.squeeze().cpu().numpy()
    ssim_val = np.mean([ssim(o, t, data_range=t.max() - t.min(), channel_axis=-1, win_size=5)
                        for o, t in zip(output_np, target_np)])
    psnr_val = np.mean([psnr(t, o, data_range=t.max() - t.min()) for o, t in zip(output_np, target_np)])
    mse_val = np.mean((output_np - target_np) ** 2)
    
    return ssim_val, psnr_val, mse_val

def visualize_images(original, downsampled, output):
    plt.figure(figsize=(12, 4))

    depth = original.shape[1]
    mid_slice = min(depth // 2, original.shape[1] - 1) 
    
    plt.subplot(1, 3, 1)
    plt.imshow(original[:, mid_slice, :, :].cpu().permute(1, 2, 0).numpy(), cmap="gray")
    plt.title('Original Image')
    plt.axis('off')
    
    plt.subplot(1, 3, 2)
    depth_upsampled = downsampled.shape[1]
    mid_slice_downsmapled = min(depth_upsampled // 2, depth_upsampled - 1)  # Adjusted for upsampled depth
    plt.imshow(downsampled[:, mid_slice_downsmapled, :, :].cpu().permute(1, 2, 0).numpy(), cmap="gray")
    plt.title('Low Resolution Image')
    plt.axis('off')

    plt.subplot(1, 3, 3)
    depth_output = output.shape[1]
    mid_slice_output = min(depth_output // 2, depth_output - 1)  # Adjusted for output depth
    plt.imshow(output[:, mid_slice_output, :, :].cpu().permute(1, 2, 0).numpy(), cmap="gray")
    plt.title('Model Output')
    plt.axis('off')
    
    plt.show()


Train size: 6, Validation size: 0, Test size: 2
Number of images in Training set: 6
Number of images in Validation set: 0
Number of images in Test set: 2
Number of batches in Training set: 2
Number of batches in Validation set: 0
Number of batches in Test set: 1
cuda


In [2]:
def test_model():
    model.load_state_dict(torch.load('model3/SRCNN3D_epoch_100.pth'))  
    model.eval()
    ssim_total, psnr_total, mse_total = 0, 0, 0  
    num_test_images = min(20, len(test_loader)) 

    with torch.no_grad():
        for i, (low_res_inputs, high_res_targets) in enumerate(test_loader):
            if i >= 20: 
                break
            low_res_inputs = low_res_inputs.to(device)
            high_res_targets = high_res_targets.to(device)
            outputs = model(low_res_inputs)
            outputs_resized = F.interpolate(outputs, size=high_res_targets.shape[2:], mode='trilinear', align_corners=False)

            ssim_val, psnr_val, mse_val = calculate_metrics(outputs_resized, high_res_targets)
            ssim_total += ssim_val
            psnr_total += psnr_val
            mse_total += mse_val

            if i < 20:  # You can change this number to visualize more/less images
                visualize_images(high_res_targets[0], low_res_inputs[0], outputs_resized[0])


    avg_ssim = ssim_total / num_test_images
    avg_psnr = psnr_total / num_test_images
    avg_mse = mse_total / num_test_images

    print(f'Test Results - SSIM: {avg_ssim:.4f}, PSNR: {avg_psnr:.4f}, MSE: {avg_mse:.4f}')

test_model()

  model.load_state_dict(torch.load('model3/SRCNN3D_epoch_100.pth'))


RuntimeError: Error(s) in loading state_dict for DataParallel:
	Missing key(s) in state_dict: "module.conv_1.weight", "module.conv_1.bias", "module.conv_2.weight", "module.conv_2.bias", "module.conv_3.weight", "module.conv_3.bias", "module.conv_4.weight", "module.conv_4.bias", "module.conv_5.weight", "module.conv_5.bias", "module.conv_6.weight", "module.conv_6.bias", "module.conv_7.weight", "module.conv_7.bias", "module.de_conv_1.weight", "module.de_conv_1.bias". 
	Unexpected key(s) in state_dict: "module.conv1.weight", "module.conv1.bias", "module.conv2.weight", "module.conv2.bias", "module.conv3.weight", "module.conv3.bias". 

In [None]:
def test_model():
    scale_factor = 0.5
    model.load_state_dict(torch.load('model3/SRCNN3D_epoch_100.pth'))  
    model.eval()  

    with torch.no_grad():

        for i, data in enumerate(test_loader):
            if i >= 20:
                break
            inputs = data.to(device)  # Move data to the appropriate device
            inputs_downsampled = F.interpolate(inputs, scale_factor=scale_factor , mode='trilinear', align_corners=False)
            inputs_upsampled = F.interpolate(inputs_downsampled, size=inputs.shape[2:], mode='trilinear', align_corners=False)
            outputs = model(inputs_upsampled)
            outputs_resized = F.interpolate(outputs, size=inputs.shape[2:], mode='trilinear', align_corners=False)

            if i < 10:  # You can change this number to visualize more/less images
                visualize_images(inputs[0], inputs_upsampled[0], outputs_resized[0])
test_model()

  model.load_state_dict(torch.load('model3/SRCNN3D_epoch_100.pth'))


RuntimeError: Error(s) in loading state_dict for DataParallel:
	Missing key(s) in state_dict: "module.conv_1.weight", "module.conv_1.bias", "module.conv_2.weight", "module.conv_2.bias", "module.conv_3.weight", "module.conv_3.bias", "module.conv_4.weight", "module.conv_4.bias", "module.conv_5.weight", "module.conv_5.bias", "module.conv_6.weight", "module.conv_6.bias", "module.conv_7.weight", "module.conv_7.bias", "module.de_conv_1.weight", "module.de_conv_1.bias". 
	Unexpected key(s) in state_dict: "module.conv1.weight", "module.conv1.bias", "module.conv2.weight", "module.conv2.bias", "module.conv3.weight", "module.conv3.bias". 