# EC500 - Group 6 - Denoising CT Images 
**Avantika Kothandaraman, Caiwei Zhang, Long Chen**

## Section-1: Installing the necessary packages

In [None]:
# !pip install pynrrd
# !pip install SimpleITK
# !python -c "import monai" || pip install -q "monai"
# !python -c "import matplotlib" || pip install -q matplotlib
# %matplotlib inline

In [None]:
# !pip install monai --upgrade

In [None]:
# !pip install patchify

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
from torchvision.transforms import v2
import torch.optim as optim
import os
import nibabel as nib
import nrrd
from torchvision.datasets import ImageFolder
import SimpleITK as sitk
import matplotlib.pyplot as plt
import random
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from scipy import ndimage
import tempfile
import shutil
import glob

In [None]:
from monai.utils import first, set_determinism
from monai.transforms import (
    AsDiscrete,
    Activations,
    AsDiscreted,
    EnsureChannelFirstd,
    EnsureChannelFirst,
    Compose,
    AsChannelLastd,
    CropForegroundd,
    ScaleIntensityd,
    LoadImaged,
    Orientationd,
    Spacingd,
    Invertd,
    RandSpatialCropSamplesd,
    RandSpatialCropSamples,
    RandSpatialCropd,
    ScaleIntensityRanged,
    ScaleIntensityRange,
    RandRotated,
    RandFlipd,
    RandZoomd,
    RandScaleIntensityd, 
    RandShiftIntensityd,
    #AddChannel,
    ToTensord,
    NormalizeIntensityd
)
from monai.handlers.utils import from_engine 
from monai.networks.nets import UNet
from monai.networks.layers import Norm
from monai.metrics import DiceMetric
from monai.losses import DiceLoss
from monai.inferers import sliding_window_inference 
from monai.inferers import SlidingWindowInferer
from monai.data import CacheDataset, DataLoader, Dataset, decollate_batch, pad_list_data_collate
from monai.config import print_config
from monai.apps import download_and_extract
from patchify import patchify
import math


In [None]:
torch.cuda.empty_cache()
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

In [None]:
import gc

gc.collect()

torch.cuda.empty_cache()

## Section-2: Initial data inspection and experimentation

In [None]:
# initializing directory
data_dir = "./scans"

In [None]:
# looping through for inspection
count = 0
dims = []
sizes = []
shapes = []

for filename in os.listdir(data_dir):
    if filename.endswith('.nrrd'):
        count += 1
        img, header = nrrd.read(os.path.join(data_dir,filename))
        dims.append(img.ndim)
        sizes.append(img.size)
        shapes.append(img.shape)
        
dims_check = all(dim == dims[0] for dim in dims)
size_check = all(size == sizes[0] for size in sizes)
shape_check = all(shape == shapes[0] for shape in shapes)

if dims_check and size_check and shape_check:
    print('Dimensions, shapes and sizes are uniform')
else:
    print('Dimensions, shapes and sizes are NOT uniform')
    
print('The total number of images in the dataset is {}'.format(count))

In [None]:
list(shapes)

In [None]:
def convert_to_2d(img_volume, axis=1):
    return np.max(img_volume, axis=axis)

trans = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((512,512)),
    #transforms.Normalize(mean=[0.5], std=[0.5])
])

In [None]:
volume, header = nrrd.read(os.path.join(data_dir,'L506_signal.nrrd'))
volume_2d = convert_to_2d(volume)
volume_2d = trans(volume_2d)
print(volume.shape, volume_2d.squeeze().shape, volume_2d.type)
plt.imshow(volume_2d.squeeze())
plt.show()

In [None]:
volume_2d = volume_2d.numpy()
volume_2d.dtype

In [None]:
def patches(image):
    demo_dict = []
    image = image.squeeze()
    patches = patchify(image.numpy(), (64,64), step=64)
    #patches = patchify(image, (64,64), step=64)   
    for i in range(patches.shape[0]):
        for j in range(patches.shape[1]):
            single_patch_img = patches[i,j,:,:]
            demo_dict.append(single_patch_img)
    return demo_dict

## Set deterministic seed for reproducibility

In [None]:
set_determinism(seed=0)

## Section-3: Creating a custom dataset and making transforms for augmentation


In [None]:
class CustomData(torch.utils.data.Dataset):
    
    def __init__(self, root_dir):
        self.root_dir = root_dir
        self.data = []
        image_sizes = []
        
        for file in os.listdir(self.root_dir):
            if file.endswith('signal.nrrd'):
                image_id = file.split('_')[0]
                
                # reading in the images
                signal_nrrd, _ = nrrd.read(os.path.join(self.root_dir, file))
                noise_nrrd, _ = nrrd.read(os.path.join(self.root_dir, f"{image_id}_noise.nrrd"))
                std_nrrd, _ = nrrd.read(os.path.join(self.root_dir, f"{image_id}_std.nrrd"))

                # converting to 2D Axial
                signal = np.max(signal_nrrd, axis=1)
                noise = np.max(noise_nrrd, axis=1)
                std = np.max(std_nrrd, axis=1)
                
                # converting to tensor and resizing to 512,512 for uniformity
                trans = transforms.Compose([transforms.ToTensor(), 
                                           transforms.Resize((512,512))])
                                           #transforms.Normalize(mean=[0.5], std=[0.5]])
                signal = trans(signal)
                noise = trans(noise)
                std = trans(std)
                
                # generating input image fromm signal and noise
                k = random.uniform(0,5)
                ct_generated = signal + (k*noise)
                
                # generating patches
                ct_patches = patches(ct_generated)
                std_patches = patches(std)
                #print(len(ct_patches), len(std_patches))
                
                # storing the new dataset in a dictionary
                for i in range(len(ct_patches)):
                    self.data.append({'ct_generated' : ct_patches[i], 'std_map' : std_patches[i]})
                
                
    def data_info(self, idx):
        item = self.data[idx]
        ct_gen = item['ct_generated']
        std_ma = item['std_map']

        # Print the index of the data item
        print(f"Data item {idx}:")

        # Print the shape of the ct_generated tensor
        print(f"ct_generated shape: {ct_gen.shape}")

        # Print the shape of the std_map tensor
        print(f"std_map shape: {std_ma.shape}")

        print()
            
            
    def plot_ct(self, idx):
        item = self.data[idx]
        ct_generated = item['ct_generated']
        std_map = item['std_map']
        
        # Create a figure
        plt.figure(figsize=(5,5))
        
        # Plot axial view of ct_generated
        plt.subplot(1, 2, 1)
        plt.imshow(ct_generated, cmap='gray')
        plt.title('ct_generated Patch-0')
        plt.axis('off')
        
        # Plot axial view of std_map
        plt.subplot(1, 2, 2)
        plt.imshow(std_map, cmap='hot')
        plt.title('std_map Patch-0')
        plt.axis('off')
        
        # Show the plots
        plt.show()

            
    def __len__(self):
        return len(self.data)
    
    
    def __getitem__(self, idx):
        item = self.data[idx]
        print("Type of item:", type(item))
        print("Available keys:", item.keys())
        print("Type of image data:", type(item['ct_generated']))  
        print("Shape of image data:", item['ct_generated'].shape)
        ct_generated = torch.from_numpy(item['ct_generated']).unsqueeze(0).float()  
        std_map = torch.from_numpy(item['std_map']).unsqueeze(0).float() 
        #return item
        print("Shape of ct_generated:", ct_generated.shape)
        print("Shape of std_map:", std_map.shape)
        return {'ct_generated': ct_generated, 'std_map': std_map}
    
    def patches(image):
        demo_dict = []
        image = image.squeeze()
        patches = patchify(image.numpy(), (64,64), step=64)

        for i in range(patches.shape[0]):
            for j in range(patches.shape[1]):
                single_patch_img = patches[i,j,:,:]
                demo_dict.append(single_patch_img)
        return demo_dict

In [None]:
custom_dataset = CustomData(root_dir = data_dir)


In [None]:
len(custom_dataset)

In [None]:
custom_dataset.data_info(5)

In [None]:
custom_dataset.plot_ct(500)

In [None]:
train_files, remaining_files = train_test_split(custom_dataset, test_size=0.2, random_state=42)
val_files, test_files = train_test_split(remaining_files, test_size=0.5, random_state=42)

print(len(train_files), len(val_files), len(test_files))

In [None]:
# transforms for data augmentation and refining

# train_transforms = transforms.Compose([
#     ScaleIntensityd(keys = ['ct_generated', 'std_map']), 
#     RandRotated(keys = ['ct_generated', 'std_map'], range_x=(-np.pi/12, np.pi/12), prob = 0.5, keep_size = True), 
#     RandFlipd(keys = ['ct_generated', 'std_map'], spatial_axis = 0, prob = 0.5), 
#     RandZoomd(keys = ['ct_generated', 'std_map'], zoom = (0.9,1.1), prob = 0.5) 
# ])

# val_transforms = transforms.Compose([
#     ScaleIntensityd(keys = ['ct_generated', 'std_map']) 
# ])

In [None]:
# transforms for data augmentation and refining

# train_transforms = v2.Compose([v2.RandomHorizontalFlip(p=0.5),
#                             v2.RandomVerticalFlip(p=0.5),
#                             v2.RandomRotation(30),
#                             v2.RandomAffine(degrees=30, translate=(0.1, 0.1), scale=(0.8, 1.2), shear=10),
#                             v2.ToDtype(torch.float32, scale=True),
#                             #v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
#                             v2.Normalize(mean=[0.456], std=[0.224])
# ])

# # val_transforms = v2.Compose([v2.RandomHorizontalFlip(p=0.5),
# #                              v2.RandomVerticalFlip(p=0.5),
# #                             v2.RandomRotation(30),
# #                             v2.RandomAffine(degrees=30, translate=(0.1, 0.1), scale=(0.8, 1.2), shear=10),
# #                             v2.ToDtype(torch.float32, scale=True),
# #                             #v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
# #                             v2.Normalize(mean=[0.456], std=[0.224])
# # ])

# val_transforms = v2.Compose([
#     v2.Resize((64,64)),  
#     v2.ToTensor(), 
#     v2.Normalize(mean=[0.456], std=[0.224]) 
# ])

## Section-4: Data Loader

In [None]:
## Complete your code here

train_ds = CacheDataset(data=train_files, transform=None, cache_rate=1.0)



# for data in train_loader:
#     ct_generated = data['ct_generated']
#     std_map = data['std_map']

#     print("Batch shape of ct_gen:", ct_generated.shape)
#     print("Batch shape of std_map:", std_map.shape)
#     break

train_loader = DataLoader(train_ds, batch_size=1, shuffle=True, num_workers=1, collate_fn=pad_list_data_collate)
for data in train_loader:
    ct_generated = data['ct_generated']
    std_map = data['std_map']
    # signal = data['signal']
    print(ct_generated.shape)
    print(std_map.shape)
    # print(signal.shape)
    break
val_ds = CacheDataset(data=val_files, transform=None, cache_rate=1.0)
val_loader = DataLoader(val_ds, batch_size=1, num_workers=1, collate_fn=pad_list_data_collate)

## Section-5: UNet Training and Testing

In [None]:
# Unet with Monte Carlo Dropout

import torch
import torch.nn as nn
import torch.nn.functional as F

class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels, mid_channels=None, dropout_prob=0.1):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.conv_op = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_prob),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_prob)
        )

    def forward(self, x):
        return self.conv_op(x)

class Down(nn.Module):
    def __init__(self, in_channels, out_channels, dropout_prob):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels, dropout_prob=dropout_prob)
        )

    def forward(self, x):
        return self.maxpool_conv(x)

class Up(nn.Module):
    def __init__(self, in_channels, out_channels, bilinear=True, dropout_prob=0.1):
        super().__init__()
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
        self.conv = DoubleConv(in_channels // 2 + out_channels, out_channels, dropout_prob=dropout_prob)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]
        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)

class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

class UNet(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=False, dropout_prob=0.1):
        super().__init__()
        self.inc = DoubleConv(n_channels, 64, dropout_prob=dropout_prob)
        self.down1 = Down(64, 128, dropout_prob=dropout_prob)
        self.down2 = Down(128, 256, dropout_prob=dropout_prob)
        self.down3 = Down(256, 512, dropout_prob=dropout_prob)
        factor = 2 if bilinear else 1
        self.down4 = Down(512, 1024 // factor, dropout_prob=dropout_prob)
        self.up1 = Up(1024 // factor, 512 // factor, bilinear, dropout_prob=dropout_prob)
        self.up2 = Up(512 // factor, 256 // factor, bilinear, dropout_prob=dropout_prob)
        self.up3 = Up(256 // factor, 128 // factor, bilinear, dropout_prob=dropout_prob)
        self.up4 = Up(128, 64, bilinear, dropout_prob=dropout_prob)
        self.outc = OutConv(64, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits

    def mc_dropout_forward(self, x, n_samples=30):
        self.train() 
        mc_samples = torch.stack([self.forward(x) for _ in range(n_samples)])
        mc_mean = mc_samples.mean(dim=0)
        mc_std = mc_samples.std(dim=0)
        return mc_mean, mc_std


In [None]:
#Unet with  Bayesian network
'''
Perform multiple forward passes on the input data. After each propagation, model output is collected.
Calculate the mean and standard deviation of the output, which represents the central tendency and variation of the predictions.
Calculate the loss and KL divergence of each propagation, which measures the information loss of the model parameter distribution. Adding the loss and KL divergence gives ELBO, which I think is a trade-off between the model fitting the data and keeping the parameter distribution simple.
Finally, the average loss (as an estimate of ELBO) and average KL divergence over all propagations are returned, along with the mean and standard deviation of the predictions.
'''
import torch
import torch.nn as nn
import torch.nn.functional as F
from blitz.modules import BayesianConv2d
from blitz.utils import variational_estimator

@variational_estimator
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.conv_op = nn.Sequential(
            BayesianConv2d(in_channels, mid_channels, kernel_size=(3,3), padding=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            BayesianConv2d(mid_channels, out_channels, kernel_size=(3,3), padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv_op(x)

class Down(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)

class Up(nn.Module):
    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
        self.conv = DoubleConv(in_channels // 2 + out_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]
        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)

class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = BayesianConv2d(in_channels, out_channels, kernel_size=(1, 1))

    def forward(self, x):
        return self.conv(x)

class UNet(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=False):
        super(UNet, self).__init__()
        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        factor = 2 if bilinear else 1
        self.down4 = Down(512, 1024 // factor)
        self.up1 = Up(1024, 512 // factor, bilinear)
        self.up2 = Up(512, 256 // factor, bilinear)
        self.up3 = Up(256, 128 // factor, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits
    
    def sample_elbo(self, inputs, labels, criterion, sample_nbr=10):
        total_loss = 0
        kl_divergence = 0
        predictions = []

        for _ in range(sample_nbr):
            output = self.forward(inputs)
            predictions.append(output)

            loss = criterion(output, labels)
            kl = sum([layer.kl_divergence for layer in self.modules() if hasattr(layer, 'kl_divergence')])
            kl_divergence += kl

            elbo = loss + kl
            total_loss += elbo

        mean_preds = torch.stack(predictions).mean(0)
        std_preds = torch.stack(predictions).std(0)
        return {'pred': mean_preds, 'std': std_preds, 'loss': total_loss / sample_nbr}, kl_divergence / sample_nbr


In [None]:
# initializing the model and optimizer

device = "cuda" if torch.cuda.is_available() else "cpu"

model = UNet(n_channels = 1, n_classes = 1).to(device)
optimizer = optim.Adam(model.parameters(), lr = 0.001)
model_saving_path = './best'


In [None]:
# LOSS Function - Average Relative Error
def average_relative_error(output, target):
    # Avoid division by zero
    nonzero_mask = target != 0
    return torch.mean(torch.abs((output[nonzero_mask] - target[nonzero_mask]) / target[nonzero_mask]))

import matplotlib.pyplot as plt

def plot_losses(train_losses, val_losses):
    plt.figure(figsize=(10, 5))
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.title('Training and Validation Losses')
    plt.xlabel('Epoch')
    plt.ylabel('Average Relative Error Loss')
    plt.legend()
    plt.grid(True)
    plt.show()

# test_output_2d = torch.rand(64, 64)
# test_target_2d = torch.rand(64, 64)
# error_2d_single = average_relative_error(test_output_2d, test_target_2d)
# print(type(error_2d_single))

In [None]:

def plot_comparison(ori, std, pred, uncertainty, epoch, save_path):
    fig, axs = plt.subplots(1, 4, figsize=(16, 4))
    axs[0].imshow(ori, cmap='gray')
    axs[0].set_title('Original Image')
    axs[0].axis('off')

    axs[1].imshow(std, cmap='hot')
    axs[1].set_title('STD Image')
    axs[1].axis('off')

    axs[2].imshow(pred, cmap='hot')
    axs[2].set_title('Predicted STD')
    axs[2].axis('off')
    
    axs[3].imshow(uncertainty, cmap='hot')
    axs[3].set_title('Uncertainty')
    axs[3].axis('off')
    
    plt.show()
    if save_path:
        formatted_path = save_path.format(epoch=epoch)  
        fig.savefig(formatted_path, bbox_inches='tight', dpi=150)
        plt.close(fig)
        




In [None]:
# MonteCarlo training and validation

train_losses = []
val_losses = []
num_epochs = 50
best_loss = 1.0
import cv2

for epoch in range(num_epochs):

    running_loss = 0.0
    running_val_loss = 0.0
    model.train()  # Ensure the model is in training mode
 
    for idx, images in enumerate(tqdm(train_loader)):
        img = images['ct_generated'].float().to(device)
        std_map = images['std_map'].float().to(device)
        
        
        optimizer.zero_grad()
        
        y_pred = model(img)
        
        loss = average_relative_error(y_pred, std_map)
        if loss < best_loss:
                best_loss = loss
                torch.save(model.state_dict(), 'best_model.pth')
        #loss = criterion(y_pred, std_map)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * img.size(0)

    epoch_loss = running_loss / len(train_loader.dataset)
    train_losses.append(epoch_loss)

    # Validation phase
    model.eval()  # Set model to evaluation mode
    with torch.no_grad():
        for idx, image in enumerate(tqdm(val_loader)):
            img = image['ct_generated'].float().to(device)
            std_map = image['std_map'].float().to(device)
            mc_mean, mc_std = model.mc_dropout_forward(img, n_samples=50)
            # mc_outputs = model.forward_with_mc_dropout(img, n_passes=30)
            # mc_mean = mc_outputs.mean(0)
            # mc_std = mc_outputs.std(0)
            
            # y_pred = model(img)

            val_loss = average_relative_error(mc_mean, std_map)

            running_val_loss += val_loss.item() * img.size(0)

            if idx % 100 == 0:
                pred_mean_np = mc_mean.squeeze().cpu().numpy()
                pred_std_np = mc_std.squeeze().cpu().numpy()
                img_np = img.squeeze().cpu().numpy()
                std_np = std_map.squeeze().cpu().numpy()
                plot_comparison(img_np, std_np, pred_mean_np, pred_std_np, epoch=epoch + 1, save_path="nunet_results/epoch_{epoch}.png")


        epoch_val_loss = running_val_loss / len(val_loader.dataset)
        val_losses.append(epoch_val_loss)

    print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {epoch_loss:.4f}, Validation Loss: {epoch_val_loss:.4f}")
plot_losses(train_losses, val_losses)

In [None]:
#Bayes training and validation
train_losses = []
val_losses = []
num_epochs = 50
best_loss = 1.0

for epoch in range(num_epochs):
    running_loss = 0.0
    running_val_loss = 0.0
    model.train()

    for idx, images in enumerate(tqdm(train_loader)):
        img = images['ct_generated'].float().to(device)
        std_map = images['std_map'].float().to(device)

        optimizer.zero_grad()

        y_pred = model(img)
        loss = average_relative_error(y_pred, std_map)
        if loss < best_loss:
            best_loss = loss
            torch.save(model.state_dict(), 'best_model.pth')
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * img.size(0)

    epoch_loss = running_loss / len(train_loader.dataset)
    train_losses.append(epoch_loss)

    model.eval()
    with torch.no_grad():
        for idx, image in enumerate(tqdm(val_loader)):
            img = image['ct_generated'].float().to(device)
            std_map = image['std_map'].float().to(device)

            outputs, kl_divergence = model.sample_elbo(inputs=img, labels=std_map, criterion=average_relative_error, sample_nbr=10)
            val_loss = outputs['loss']  

            running_val_loss += val_loss.item() * img.size(0)

            if idx % 100 == 0:
                pred_mean_np = outputs['pred'].mean(0).squeeze().cpu().numpy()
                pred_std_np = outputs['pred'].std(0).squeeze().cpu().numpy()
                img_np = img.squeeze().cpu().numpy()
                std_np = std_map.squeeze().cpu().numpy()
                plot_comparison(img_np, std_np, pred_mean_np, pred_std_np, epoch=epoch + 1, save_path="nunet_results/epoch_{epoch}.png")

        epoch_val_loss = running_val_loss / len(val_loader.dataset)
        val_losses.append(epoch_val_loss)

    print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {epoch_loss:.4f}, Validation Loss: {epoch_val_loss:.4f}")

plot_losses(train_losses, val_losses)


In [None]:
plot_losses(train_losses, val_losses)

In [None]:

noise_nrrd, _ = nrrd.read(os.path.join(data_dir, f"L506_noise.nrrd"))

noise = np.max(noise_nrrd, axis=1)

trans = transforms.Compose([transforms.ToTensor(), 
                            transforms.Resize((512,512))])

noise = trans(noise)
k = random.uniform(0,5)


device = "cuda" if torch.cuda.is_available() else "cpu"
model = UNet(n_channels=1, n_classes=1).to(device)
model_path = 'best_bayes.pth'
model.load_state_dict(torch.load(model_path))

volume_2d = volume_2d + (k*noise)
volume_2d = volume_2d.squeeze().float()
volume_2d = volume_2d.unsqueeze(0).unsqueeze(0)
volume_2d = volume_2d.to(device)

model.eval()
with torch.no_grad():
    output = model(volume_2d)

output_cpu = output.cpu().numpy().squeeze()  

stdvolume, stdheader = nrrd.read(os.path.join(data_dir,'L506_std.nrrd'))
stdvolume_2d = convert_to_2d(stdvolume)  
stdvolume_2d = trans(stdvolume_2d) 

if isinstance(stdvolume_2d, torch.Tensor):
    stdvolume_2d = stdvolume_2d.cpu().numpy()

fig, axs = plt.subplots(1, 3, figsize=(12, 4))
axs[0].imshow(volume_2d.cpu().squeeze().numpy(), cmap='gray')
axs[0].set_title('Original Image')
axs[0].axis('off')

axs[1].imshow(stdvolume_2d.squeeze(), cmap='hot')
axs[1].set_title('STD Image')
axs[1].axis('off')

axs[2].imshow(output_cpu, cmap='hot')
axs[2].set_title('Predicted STD')
axs[2].axis('off')

plt.show()

    


## Section-6: RATUNet Training and Testing