In [1]:
from monai import data, transforms
import glob
import numpy as np
import os
import re
import natsort
import SimpleITK as sitk

def get_loader():
    train_real = natsort.natsorted(glob.glob(f'/workspace/PD_SSL_ZOO/3_RECONSTRUCTION/DATA/*.nii.gz'))[:] #ALL -> 2125 or 2130

    print("Train [Total]  number = ", len(train_real))

    files_tr = [img_tr for img_tr in zip(train_real)]

    tr_transforms = transforms.Compose(
        [
            transforms.LoadImage(image_only=True),
            transforms.EnsureChannelFirst(),
            transforms.Orientation(axcodes="LPS"),
            transforms.EnsureType(),
            transforms.ToTensor(track_meta=False)
        ]
    )

    # new_dataset -> Cachenew_dataset
    train_ds = data.Dataset(data = files_tr, transform = tr_transforms)

    train_loader = data.DataLoader(
        train_ds,
        batch_size=1,
        shuffle=False,
        num_workers=2,
        pin_memory=False
        # persistent_workers=True,
    )

    print("loader is ver (train)")

    loader = train_loader

    return loader, train_real
loader = get_loader()

Train [Total]  number =  30
loader is ver (train)


In [None]:
import torch
    
from generative.networks.nets.diffusion_model_aniso_unet_AE_official import DiffusionModelUNet_aniso_AE, DiffusionModelEncoder_ansio
class DIF_oriAE(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.unet =  DiffusionModelUNet_aniso_AE(spatial_dims=3,
                                                in_channels=8, #wavelet mode
                                                out_channels=8, #wavelet mode
                                                num_channels=[128,128,256,256,512],
                                                attention_levels=[False,False,False,False,True],
                                                num_head_channels=[0,0,0,0,64],
                                                norm_num_groups=32,
                                                use_flash_attention=True,
                                                iso_conv_down=(False, True, True, True, None),
                                                iso_conv_up=(True, True, True, False, None),
                                                num_res_blocks=2)


        self.semantic_encoder = DiffusionModelEncoder_ansio(spatial_dims=3,
                                                            in_channels=8,
                                                            out_channels=8,
                                                            num_channels=[128,256,256,512],
                                                            attention_levels=[False,False,False,False],
                                                            num_head_channels=[0,0,0,0],
                                                            norm_num_groups=32,
                                                            iso_conv_down=(False, True, True, True),
                                                            num_res_blocks=(2,2,2,2))

def filter_ema_keys(checkpoint):
    ema_model_state_dict = {key.replace('ema_model.', ''): value 
                            for key, value in checkpoint.items() 
                            if 'online_model' not in key}
    del ema_model_state_dict['initted']
    del ema_model_state_dict['step']
    
    return ema_model_state_dict
    
def save_image(pred, path):
    pred_img = pred.cpu().detach().numpy().transpose(0,4,3,2,1).squeeze()
    save_pred = sitk.GetImageFromArray(pred_img)
    sitk.WriteImage(save_pred, path)
    
model = DIF_oriAE()

In [None]:
from generative.inferers import DiffusionInferer_ae
from generative.networks.schedulers import DDPMScheduler
import matplotlib.pyplot as plt
import torch
import ptwt
import pywt

ckpt_path = '/workspace/PD_SSL_ZOO/2_DOWNSTREAM/WEIGHTS/1_HWDAE.pt'
ckpt = torch.load(ckpt_path, map_location='cpu')
checkpoint = ckpt['ema']
new_ckpt = filter_ema_keys(checkpoint)
model.load_state_dict(new_ckpt)
print(f"ckpt_path : {ckpt_path}")

scheduler = DDPMScheduler(num_train_timesteps=1000, 
                          schedule="linear_beta", 
                          beta_start=0.0005, 
                          beta_end=0.0195)

inferer = DiffusionInferer_ae(scheduler)

model.to('cuda')
loader, train_list = get_loader()

for idx, batch_data in enumerate(loader):
    model.eval()
    with torch.no_grad():
        images = batch_data.to('cuda')
        coeffs3 = ptwt.wavedec3(images, pywt.Wavelet('haar'), level=1, mode='zero')
        images = torch.cat((coeffs3[0], 
                        coeffs3[1]['aad'], 
                        coeffs3[1]['ada'], 
                        coeffs3[1]['add'], 
                        coeffs3[1]['daa'], 
                        coeffs3[1]['dad'], 
                        coeffs3[1]['dda'], 
                        coeffs3[1]['ddd']), dim=1)
        
        timesteps = torch.randint(0, inferer.scheduler.num_train_timesteps, (images.shape[0],), device=images.device).long()
        noise = torch.randn_like(images).to('cuda')

        latent = model.semantic_encoder(images)

    _, _, H, W, D = images.shape
    # Sampling image during training
    image = torch.randn((1, 8, H, W, D))
    image = image.to("cuda")
    scheduler.set_timesteps(num_inference_steps=1000)
    image_pred = inferer.sample(input_noise=image, 
                                diffusion_model=model.unet, 
                                scheduler=scheduler, 
                                save_intermediates=False,
                                cond=latent)
        
    coeffs3[0] = image_pred[:,0:1,:,:,:]
    coeffs3[1]['aad'] = image_pred[:,1:2,:,:,:]
    coeffs3[1]['ada'] = image_pred[:,2:3,:,:,:]
    coeffs3[1]['add'] = image_pred[:,3:4,:,:,:]
    coeffs3[1]['daa'] = image_pred[:,4:5,:,:,:]
    coeffs3[1]['dad'] = image_pred[:,5:6,:,:,:]
    coeffs3[1]['dda'] = image_pred[:,6:7,:,:,:]
    coeffs3[1]['ddd'] = image_pred[:,7:8,:,:,:]
    
    reconstruction_ema = ptwt.waverec3(coeffs3, pywt.Wavelet("haar"))       

    pred_path = train_list[idx].replace("DATA", "1_HWDAE")
    
    save_image(reconstruction_ema, pred_path)