In [None]:
import torch
from generative.inferers import DiffusionInferer
from generative.networks.nets import DiffusionModelUNet
from generative.networks.schedulers import DDPMScheduler
from torch.cuda.amp import GradScaler, autocast
import matplotlib.pyplot as plt
import cv2
from monai.utils import first, set_determinism
from monai.transforms import Compose, LoadImage, ToTensor, ScaleIntensity, CenterSpatialCrop, Resize, EnsureChannelFirst, RandAffined, SaveImage, Rotate90
from monai.data import CacheDataset, DataLoader, Dataset, ArrayDataset
import os
from torch.utils.data import ConcatDataset,random_split
from monai.config import print_config
import nibabel as nib
import numpy as np
print_config()

In [None]:
set_determinism(42)

### Loading trained model 

In [None]:
model = DiffusionModelUNet(
    spatial_dims=2,
    in_channels=1,
    out_channels=1,
    num_channels=(128, 256, 256), #256, 256, 512
    attention_levels=(False, True, True),
    num_res_blocks=1,
    num_head_channels=256,
)
device = torch.device("cuda")
#modelname = "Models/bs16_Epoch124_of_2503nov" #74 / 124 / 174
#modelname = "Models/bs16_Epoch124_of_2503nov"
#modelname = "Models/bs16_Epoch149_of_2008nov_timestep500"
#modelname = "Models/bs8_Epoch149_of_2008nov"
modelname = "Models/bs16_Epoch149_of_2503nov"
pre_trained_model = torch.load(modelname) #,map_location=torch.device('cpu'))
model.load_state_dict(pre_trained_model, strict = False) 
model.to(device)


scheduler = DDPMScheduler(num_train_timesteps=1000)#1000
inferer = DiffusionInferer(scheduler)

In [None]:
from Metrics import *

In [None]:
model.eval()
noise = torch.randn((1, 1, 128, 128))
noise = noise.to(device)
scheduler.set_timesteps(num_inference_steps=1000)
with autocast(enabled=True):
    image, intermediates = inferer.sample(
        input_noise=noise, diffusion_model=model, scheduler=scheduler, save_intermediates=True, intermediate_steps=200
    )

chain = torch.cat(intermediates, dim=-1)

plt.style.use("default")
plt.imshow(chain[0, 0].cpu(), vmin=0, vmax=1, cmap="gray")
plt.tight_layout()
plt.axis("off")
plt.figure(figsize=(30,10))
plt.show()

### Loading dataset

In [None]:
transform = Compose(
    [LoadImage(image_only = True),
     EnsureChannelFirst(),
     ToTensor(),
     #ScaleIntensity(minv = 0.0, maxv = 1.0),
     #CenterSpatialCrop(roi_size=(256,256,-1)),
     Resize(spatial_size=(128, 128, -1)),
     #ScaleIntensity(minv = 0.0, maxv = 1.0),
     #Rotate90(k=3, spatial_axes=(0, 1), lazy=False)
     ])

class NiFTIDataset(Dataset):
    def __init__(self, data_dir, transform = None):
        self.data_dir = data_dir
        self.data = os.listdir(data_dir)
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        nifti_file = os.path.join(self.data_dir, self.data[index])
        if self.transform is not None:
            nifti_file = self.transform(nifti_file)
        return nifti_file

image_transforms = Compose(
    [
     #Resize(spatial_size=(128, 128,-1)),
     ScaleIntensity(minv = 0.0, maxv = 1.0),
     Rotate90(k=3, spatial_axes=(0, 1), lazy=False),
     ])

def extract_slices(nifti_dataset): 
    total_dataset = Dataset([])
    for i in range(len(nifti_dataset)):#Skrev på Dataset om nifti_dataset
        image_stack = Dataset(nifti_dataset).__getitem__(index = i)
        for j in range(image_stack.shape[3]):
            image_stack[:,:,:,j] = image_transforms(image_stack[:,:,:,j])
        images = Dataset([image_stack[:,:,:,k] for k in range(3, image_stack.shape[3] - 3)])
        total_dataset = ConcatDataset([total_dataset, images])
    
    return total_dataset

nifti_dataset = NiFTIDataset(data_dir= "T2_images", transform = transform)
print(torch.amax(nifti_dataset.__getitem__(index = 0)[0,:,:,0]))

In [None]:
nifti_dataset = extract_slices(nifti_dataset)

In [None]:
print((nifti_dataset.__getitem__(0).shape))
plt.imshow(nifti_dataset.__getitem__(0)[0], cmap = "bone")
plt.colorbar()
print(torch.amax(nifti_dataset.__getitem__(6)))

### Saving preprocessed dataset to folder

In [None]:
real_images_to_save = extract_slices(nifti_dataset)

In [None]:
print(len(real_images_to_save))
print(real_images_to_save.__getitem__(0).shape)
i = 0
for real_image in real_images_to_save:
    print(torch.amax(real_image[0]))
    real_image = real_image.numpy()
    print(np.amax(real_image[0]))
    #nifti_image = nib.Nifti1Image(real_image[0],np.eye(4))
    #nib.save(nifti_image, "Real_images/Real_training_data/nifti_file_" + str(i) + ".nii")
    #i+=1

In [None]:
plt.imshow(real_images_to_save.__getitem__(0)[0], cmap = "bone")
plt.colorbar()
print(np.amax(real_images_to_save.__getitem__(0)[0]))

In [None]:
def hold_out(train_ratio, nifti_dataset):
    train_patiens = int(train_ratio * len(nifti_dataset))
    val_patiens = len(nifti_dataset) - train_patiens
    
    train_nifti_dataset, val_nifti_dataset = random_split(nifti_dataset, [train_patiens, val_patiens])
    
    train_dataset = extract_slices(train_nifti_dataset)
    val_dataset = extract_slices(val_nifti_dataset)
    
    return train_dataset, val_dataset

In [None]:
train, val = hold_out(0.8, nifti_dataset)

In [None]:
print(train.__getitem__(0).shape)
print(len(val))

### Sampling images

In [None]:
noise = torch.randn((40, 1, 128, 128)) #Hvis input noise er på formen (n_synthetic_images, 1, 64, 64) så genereres det n_synthetic images
noise = noise.to(device)
scheduler.set_timesteps(num_inference_steps=1000)

images = inferer.sample(input_noise=noise, diffusion_model=model, scheduler=scheduler)
print(len(images))
plt.figure()
plt.imshow(images[0, 0].cpu(), vmin=0, vmax=1, cmap="bone")
plt.colorbar()
plt.axis("off")
plt.show()


for i in range(len(images)):
    print(images[i, 0].cpu().shape, images[i, 0].cpu().type)
    numpy_arr = images[i, 0].detach().cpu().numpy()
    #numpy_arr_scaled = 255 * numpy_arr
    plt.imshow(numpy_arr, vmin = 0, vmax = 1, cmap = "bone")
    #plt.imshow(numpy_arr_scaled, cmap = "bone")
    #cv2.imwrite("Genererte_" + str(i) + ".png", numpy_arr_scaled)
    nifti_image = nib.Nifti1Image(numpy_arr,np.eye(4))
    #nib.save(nifti_image, "Synthetic_images/bs16_125epochs_3nov/nifti_file_" + str(i) + ".nii") #125
    #nib.save(nifti_image, "Synthetic_images/bs8_150epochs_8nov/nifti_file_" + str(i) + ".nii")
    nib.save(nifti_image, "Synthetic_images/bs16_150epochs_22_nov_larger_dataset/nifti_file_" + str(i+400) + ".nii")
    #nib.save(nifti_image, "Synthetic_images/bs16_150epochs_timestep500/nifti_file_" + str(i) + ".nii")

In [None]:
import numpy as np
f_transform = np.fft.fft2(numpy_arr_scaled)
print(f_transform.shape)
f_transform_shifted = np.fft.fftshift(f_transform)
power_spectrum = np.abs(f_transform_shifted) ** 2
plt.imshow(np.log1p(power_spectrum), cmap='gray')
plt.title('Fourier transform of synthetic image')
plt.show()

In [None]:
plt.hist(numpy_arr_scaled.flatten(), bins=256)
plt.title("Histogram for synthetic image")
plt.show()

In [None]:
plt.figure()
plt.imshow(images[1, 0].cpu(), vmin=0, vmax=1, cmap="bone")
plt.figure()
plt.imshow(images[2, 0].cpu(), vmin=0, vmax=1, cmap="bone")
plt.figure()
#plt.imshow(images[3, 0].cpu(), vmin=0, vmax=1, cmap="bone")


In [None]:
_, real_images = random_split(val, [len(val) - 4, 4])

In [None]:
from monai.data import Dataset

In [None]:
radnet = torch.hub.load("Warvito/radimagenet-models", model="radimagenet_resnet50", verbose=True)
radnet.to(device)
radnet.eval()

In [None]:
print((Dataset(real_images).__getitem__(0).shape))
print(real_images.type)

In [None]:
import numpy as np
print(images[0].shape)
image = np.repeat(images[0], 3, axis=0)
print(image.shape)
scaled_image = image*255
print(scaled_image.dtype)
#int_img = scaled_image.type(torch.uint8)
#print(int_img.dtype)
batched_image = torch.unsqueeze(scaled_image, axis=0)
#batched_image = batched_image.type(torch.uint8)
print("final shape", batched_image.shape)
print(batched_image.dtype)

real_image = images[1]
print(real_image.shape)
print(real_image.type)
import numpy as np
a = np.repeat(real_image, 3, axis=0)
print(image.shape)
b = a*255
#print(b.dtype)
#c = b.type(torch.uint8)
#print(c.dtype)
d = torch.unsqueeze(b, axis=0)
#d = d.type(torch.uint8)
print(d.shape)
print(d.dtype)

In [None]:
'''The function calculate_FID assumes that the input is a RGB image on the form
[]

'''
fid = calculate_FID(batched_image, d, device)

In [None]:
import numpy as np
from scipy.linalg import sqrtm

# Define functions to calculate mean and covariance of feature embeddings

def calculate_fid(model, real_images, generated_images):
    real_feature_activations = model(real_images).detach().cpu().numpy()
    generated_feature_activations = model(generated_images).detach().cpu().numpy()

    # Calculate mean and covariance for real and generated feature embeddings
    mu_real = np.mean(real_feature_activations, axis=0)
    mu_generated = np.mean(generated_feature_activations, axis=0)
    cov_real = np.cov(real_feature_activations, rowvar=False)
    cov_generated = np.cov(generated_feature_activations, rowvar=False)

    # Calculate the FID score
    cov_sqrt = sqrtm(cov_real.dot(cov_generated))
    if np.iscomplexobj(cov_sqrt):
        cov_sqrt = cov_sqrt.real
    fid = np.sum((mu_real - mu_generated) ** 2) + np.trace(cov_real + cov_generated - 2 * cov_sqrt)

    return fid

# Load a pre-trained InceptionV3 model
inception_model = torch.hub.load('pytorch/vision', 'inception_v3', pretrained=True)

# Calculate FID score
fid_score = calculate_fid(inception_model, real_images, images)
print(f'FID Score: {fid_score}')


In [None]:
import numpy as np
image = np.repeat(images[0], 3, axis=0)
scaled_image = image*255

import torch
_ = torch.manual_seed(123)
from torchmetrics.image.fid import FrechetInceptionDistance #ÆÆÆÆÆ funker ikke på greyscale
fid = FrechetInceptionDistance(feature=64)
# generate two slightly overlapping image intensity distributions
imgs_dist1 = torch.randint(0, 200, (100, 3, 299, 299), dtype=torch.uint8)
imgs_dist2 = torch.randint(100, 255, (100, 3, 299, 299), dtype=torch.uint8)
print(imgs_dist2.shape)
print(imgs_dist2.type)



fid.update(, real=True)
fid.update(imgs_dist2, real=False)
fid.compute()

In [None]:
import numpy as np
print(images[0].shape)
image = np.repeat(images[0], 3, axis=0)
print(image.shape)
scaled_image = image*255
print(scaled_image.dtype)
int_img = scaled_image.type(torch.uint8)
print(int_img.dtype)
batched_image = torch.unsqueeze(int_img, axis=0)
batched_image = batched_image.type(torch.uint8)
print("final shape", batched_image.shape)
print(batched_image.dtype)

real_image = images[1]
print(real_image.shape)
print(real_image.type)
import numpy as np
a = np.repeat(real_image, 3, axis=0)
print(image.shape)
b = a*255
print(b.dtype)
c = b.type(torch.uint8)
print(c.dtype)
d = torch.unsqueeze(b, axis=0)
d = d.type(torch.uint8)
print(d.shape)
print(d.dtype)

In [None]:
plt.imshow(batched_image[0, 0], vmin = 0, vmax = 255)
plt.colorbar()

In [None]:
real = d.numpy()
print(real.dtype)
real = torch.from_numpy(real)
print(real.dtype)

In [None]:
import torch
_ = torch.manual_seed(123)
from torchmetrics.image.fid import FrechetInceptionDistance #ÆÆÆÆÆ funker ikke på greyscale
fid = FrechetInceptionDistance(feature=64)

fid.update([batched_image, real], real=True)
fid.update(real, real=False)
fid.compute()

In [None]:
print(val[0].shape)