In [1]:
from monai import data, transforms
import glob
import numpy as np
import os
import re
import natsort
import SimpleITK as sitk

def get_loader():
    train_real = natsort.natsorted(glob.glob(f'/workspace/PD_SSL_ZOO/3_RECONSTRUCTION/DATA/*.nii.gz'))[:] #ALL -> 2125 or 2130

    print("Train [Total]  number = ", len(train_real))

    files_tr = [img_tr for img_tr in zip(train_real)]

    tr_transforms = transforms.Compose(
        [
            transforms.LoadImage(image_only=True),
            transforms.EnsureChannelFirst(),
            transforms.Orientation(axcodes="RAI"),
            transforms.EnsureType(),
            transforms.ToTensor(track_meta=False)
        ]
    )

    # new_dataset -> Cachenew_dataset
    train_ds = data.Dataset(data = files_tr, transform = tr_transforms)

    train_loader = data.DataLoader(
        train_ds,
        batch_size=1,
        shuffle=False,
        num_workers=2,
        pin_memory=False
        # persistent_workers=True,
    )

    print("loader is ver (train)")

    loader = train_loader

    return loader, train_real


In [None]:
import torch
from generative.networks.nets.diffusion_model_aniso_unet_AE_no_wavelet import DiffusionModelUNet_aniso_AE_no_wavelet, DiffusionModelEncoder_ansio_no_wavelet

class HDAE(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.unet =  DiffusionModelUNet_aniso_AE_no_wavelet(spatial_dims=3,
                                                            in_channels=1, 
                                                            out_channels=1, 
                                                            num_channels=[8,32,64,128,256,512],
                                                            attention_levels=[False,False,False,False,True,True],
                                                            num_head_channels=[0,0,0,0,16,32],
                                                            norm_num_groups=8,
                                                            use_flash_attention=True,
                                                            iso_conv_down=(False, True, True, True, True, None),
                                                            iso_conv_up=(True, True, True, True, False, None),
                                                            num_res_blocks=2,)


        self.semantic_encoder = DiffusionModelEncoder_ansio_no_wavelet(spatial_dims=3,
                                                                        in_channels=1,
                                                                        out_channels=1,
                                                                        num_channels=[16,32,128,256,512],
                                                                        attention_levels=(False,False,False,False,False),
                                                                        num_head_channels=[0,0,0,0,0],
                                                                        norm_num_groups=16,
                                                                        iso_conv_down=(False, True, True, True, True),
                                                                        resblock_updown=False,
                                                                        num_res_blocks=(2,2,2,2,2))

def filter_ema_keys(checkpoint):
    ema_model_state_dict = {key.replace('ema_model.', ''): value 
                            for key, value in checkpoint.items() 
                            if 'online_model' not in key}
    del ema_model_state_dict['initted']
    del ema_model_state_dict['step']
    
    return ema_model_state_dict

def save_image(pred, path):
    pred_img = pred.cpu().detach().numpy().transpose(0,4,3,2,1).squeeze()
    pred_img = np.flipud(pred_img)
    pred_img = np.fliplr(pred_img)
    pred_img = np.flip(pred_img, axis=2)

    save_pred = sitk.GetImageFromArray(pred_img)
    sitk.WriteImage(save_pred, path)
    
model = HDAE()

In [16]:
def save_image(pred, path):
    pred_img = pred.cpu().detach().numpy().transpose(0,4,3,2,1).squeeze()
    pred_img = np.flipud(pred_img)
    pred_img = np.fliplr(pred_img)
    save_pred = sitk.GetImageFromArray(pred_img)
    sitk.WriteImage(save_pred, path)

In [None]:
import torch
from generative.networks.schedulers import DDPMScheduler
from generative.inferers import DiffusionInferer_ae
import pywt
import ptwt

a_path = '/workspace/PD_SSL_ZOO/2_DOWNSTREAM/WEIGHTS/6_HDAE.pt'
ckpt = torch.load(a_path, map_location='cpu')
checkpoint = ckpt['ema']
new_ckpt = filter_ema_keys(checkpoint)
model.load_state_dict(new_ckpt)
print(f"ckpt_path : {a_path}")

scheduler = DDPMScheduler(prediction_type="epsilon",
                          num_train_timesteps=1000, 
                          clip_sample=True,
                          schedule="scaled_linear_beta")

inferer = DiffusionInferer_ae(scheduler)

model.to('cuda')
loader, train_list = get_loader()

for idx, batch in enumerate(loader):
    model.eval()
    with torch.no_grad():
        images = batch.to('cuda')
        timesteps = torch.randint(0, inferer.scheduler.num_train_timesteps, (images.shape[0],)).to('cuda').long()
        latent = model.semantic_encoder(images)

    _, _, H, W, D = images.shape
    # Sampling image during training
    image = torch.randn((1, 1, H, W, D))
    image = image.to("cuda")
    scheduler.set_timesteps(num_inference_steps=1000)
    image_pred = inferer.sample(input_noise=image, 
                                diffusion_model=model.unet, 
                                scheduler=scheduler, 
                                save_intermediates=False,
                                cond=latent)
        
    pred_path = train_list[idx].replace("DATA", "6_HDAE")
    
    save_image(image_pred, pred_path)
    