In [None]:
from generative.networks.nets.diffusion_model_aniso_unet_AE_official import DiffusionModelUNet_aniso_AE, DiffusionModelEncoder_ansio
import torch
import os

os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0"

from monai import data, transforms
import torch
import os
import json
import ptwt
import pywt
import random

class DIF_oriAE(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.unet =  DiffusionModelUNet_aniso_AE(spatial_dims=3,
                                                in_channels=8,
                                                out_channels=8,
                                                num_channels=[128,128,256,256,512],
                                                attention_levels=[False,False,False,False,True],
                                                num_head_channels=[0,0,0,0,64],
                                                norm_num_groups=32,
                                                use_flash_attention=True,
                                                iso_conv_down=(False, True, True, True, None),
                                                iso_conv_up=(True, True, True, False, None),
                                                num_res_blocks=2)

        self.semantic_encoder = DiffusionModelEncoder_ansio(spatial_dims=3,
                                                            in_channels=8,
                                                            out_channels=8,
                                                            num_channels=[128,256,256,512],
                                                            attention_levels=[False,False,False,False],
                                                            num_head_channels=[0,0,0,0],
                                                            norm_num_groups=32,
                                                            iso_conv_down=(False, True, True, True),
                                                            num_res_blocks=(2,2,2,2))

model = DIF_oriAE()

def filter_ema_keys(checkpoint):
    ema_model_state_dict = {key.replace('ema_model.', ''): value 
                            for key, value in checkpoint.items() 
                            if 'online_model' not in key}
    del ema_model_state_dict['initted']
    del ema_model_state_dict['step']
    
    return ema_model_state_dict

ckpt_path = '/workspace/PD_SSL_ZOO/2_DOWNSTREAM/WEIGHTS/1_HWDAE.pt'

print(f"ckpt_path : {ckpt_path}")
ckpt = torch.load(ckpt_path, map_location='cpu')
new_ckpt = filter_ema_keys(ckpt['ema'])
model.load_state_dict(new_ckpt, strict=False)

model = model.to('cuda')

def mean_gen_latent(tsne_dataset_gen):
    sum_tensors = [torch.zeros_like(tsne_dataset_gen[0][i]) for i in range(len(tsne_dataset_gen[0]))]
    
    for tensor_list in tsne_dataset_gen:
        for i, tensor in enumerate(tensor_list):
            sum_tensors[i] += tensor
    
    mean_tensors = [sum_tensor / len(tsne_dataset_gen) for sum_tensor in sum_tensors]

    return mean_tensors


In [3]:
# For Mean Vectors
##########################################################################################
##########################################################################################
'''
From This Cell, I try to get MEAN VECTORS of
1. NC_LESS
2. LESS_MORE
3. LESS_NC
4. MORE_LESS
'''
##########################################################################################
##########################################################################################

def get_loader_for_direc_vector():
    
    train_list = f"/workspace/Ablation/ABLATION_PD/INTERPRETATION/JSON/NP/NC_03.json"
    num_class = 1
    
    train_idx = 0
    files_tr_nc = []
    
    with open(train_list, 'r') as train_file:
        train_files = json.load(train_file)
        
    for file_name, label in train_files['train'].items():
        label = torch.nn.functional.one_hot(torch.as_tensor(label), num_classes=num_class)
        files_tr_nc.append({"image_train": file_name, "label_train": label})
        train_idx += 1
        
    files_tr_nc_1 = random.choices(files_tr_nc, k=100)
    files_tr_nc_2 = random.choices(files_tr_nc, k=100)
    files_tr_nc_3 = random.choices(files_tr_nc, k=100)
    files_tr_nc_4 = random.choices(files_tr_nc, k=100)
    files_tr_nc_5 = random.choices(files_tr_nc, k=100)
    
    ##########################################################################################
    ##########################################################################################
    
    train_list = f"/workspace/Ablation/ABLATION_PD/INTERPRETATION/JSON/REG/ONSET_mean.json"
    num_class = 25
    
    with open(train_list, 'r') as train_file:
        train_files = json.load(train_file)
    
    files_tr_less = []
    files_tr_more = []

    for file_name, label in train_files['train'].items():
        if label < 5:
            label = torch.nn.functional.one_hot(torch.as_tensor(label), num_classes=num_class)
            files_tr_less.append({"image_train": file_name, "label_train": label})
            train_idx += 1
        else:
            label = torch.nn.functional.one_hot(torch.as_tensor(label), num_classes=num_class)
            files_tr_more.append({"image_train": file_name, "label_train": label})
            train_idx += 1

    files_tr_less_1 = random.choices(files_tr_less, k=100)
    files_tr_less_2 = random.choices(files_tr_less, k=100)
    files_tr_less_3 = random.choices(files_tr_less, k=100)
    files_tr_less_4 = random.choices(files_tr_less, k=100)
    files_tr_less_5 = random.choices(files_tr_less, k=100)
    
    files_tr_more_1 = random.choices(files_tr_more, k=100)
    files_tr_more_2 = random.choices(files_tr_more, k=100)
    files_tr_more_3 = random.choices(files_tr_more, k=100)
    files_tr_more_4 = random.choices(files_tr_more, k=100)
    files_tr_more_5 = random.choices(files_tr_more, k=100)
    
    print("Train [Total]  number = ", train_idx)

    tr_transforms = transforms.Compose(
        [
            transforms.LoadImaged(keys=["image_train"]),
            transforms.EnsureChannelFirstd(keys=["image_train"]),
            transforms.Orientationd(keys=["image_train"], axcodes="LPS"),
            transforms.ScaleIntensityRanged(keys=["image_train"], a_min=0.0, a_max=22.0, b_min=0.0, b_max=1.0, clip=True),
            transforms.EnsureTyped(keys=["image_train", "label_train"]),
            transforms.ToTensord(keys=["image_train", "label_train"], track_meta=False)
        ]
    )

    # new_dataset -> Cachenew_dataset
    train_ds_nc_1 = data.Dataset(data = files_tr_nc_1, transform = tr_transforms)
    train_ds_nc_2 = data.Dataset(data = files_tr_nc_2, transform = tr_transforms)
    train_ds_nc_3 = data.Dataset(data = files_tr_nc_3, transform = tr_transforms)
    train_ds_nc_4 = data.Dataset(data = files_tr_nc_4, transform = tr_transforms)
    train_ds_nc_5 = data.Dataset(data = files_tr_nc_5, transform = tr_transforms)
    
    files_tr_less_1 = data.Dataset(data = files_tr_less_1, transform = tr_transforms)
    files_tr_less_2 = data.Dataset(data = files_tr_less_2, transform = tr_transforms)
    files_tr_less_3 = data.Dataset(data = files_tr_less_3, transform = tr_transforms)
    files_tr_less_4 = data.Dataset(data = files_tr_less_4, transform = tr_transforms)
    files_tr_less_5 = data.Dataset(data = files_tr_less_5, transform = tr_transforms)
    
    files_tr_more_1 = data.Dataset(data = files_tr_more_1, transform = tr_transforms)
    files_tr_more_2 = data.Dataset(data = files_tr_more_2, transform = tr_transforms)
    files_tr_more_3 = data.Dataset(data = files_tr_more_3, transform = tr_transforms)
    files_tr_more_4 = data.Dataset(data = files_tr_more_4, transform = tr_transforms)
    files_tr_more_5 = data.Dataset(data = files_tr_more_5, transform = tr_transforms)

    nc_loader_1 = data.DataLoader(train_ds_nc_1, batch_size=1, shuffle=False, num_workers=8, pin_memory=True)
    nc_loader_2 = data.DataLoader(train_ds_nc_2, batch_size=1, shuffle=False, num_workers=8, pin_memory=True)
    nc_loader_3 = data.DataLoader(train_ds_nc_3, batch_size=1, shuffle=False, num_workers=8, pin_memory=True)
    nc_loader_4 = data.DataLoader(train_ds_nc_4, batch_size=1, shuffle=False, num_workers=8, pin_memory=True)
    nc_loader_5 = data.DataLoader(train_ds_nc_5, batch_size=1, shuffle=False, num_workers=8, pin_memory=True)
    
    less_loader_1 = data.DataLoader(files_tr_less_1, batch_size=1, shuffle=False, num_workers=8, pin_memory=True)
    less_loader_2 = data.DataLoader(files_tr_less_2, batch_size=1, shuffle=False, num_workers=8, pin_memory=True)
    less_loader_3 = data.DataLoader(files_tr_less_3, batch_size=1, shuffle=False, num_workers=8, pin_memory=True)
    less_loader_4 = data.DataLoader(files_tr_less_4, batch_size=1, shuffle=False, num_workers=8, pin_memory=True)
    less_loader_5 = data.DataLoader(files_tr_less_5, batch_size=1, shuffle=False, num_workers=8, pin_memory=True)
    
    more_loader_1 = data.DataLoader(files_tr_more_1, batch_size=1, shuffle=False, num_workers=8, pin_memory=True)
    more_loader_2 = data.DataLoader(files_tr_more_2, batch_size=1, shuffle=False, num_workers=8, pin_memory=True)
    more_loader_3 = data.DataLoader(files_tr_more_3, batch_size=1, shuffle=False, num_workers=8, pin_memory=True)
    more_loader_4 = data.DataLoader(files_tr_more_4, batch_size=1, shuffle=False, num_workers=8, pin_memory=True)
    more_loader_5 = data.DataLoader(files_tr_more_5, batch_size=1, shuffle=False, num_workers=8, pin_memory=True)
    
    print("loader is ver(train, val)")

    return [[nc_loader_1, nc_loader_2, nc_loader_3, nc_loader_4, nc_loader_5], [less_loader_1, less_loader_2, less_loader_3, less_loader_4, less_loader_5], [more_loader_1, more_loader_2, more_loader_3, more_loader_4, more_loader_5]]
    

In [None]:
loader_list = get_loader_for_direc_vector()

nc_mean_gen_list = []
less_mean_gen_list = []
more_mean_gen_list = []

for i in range(5):
    print("NC")
    dataset_nc = []
    
    for idx, batch_data in enumerate(loader_list[0][i]):
        model.eval()
        with torch.no_grad():
            images = batch_data['image_train'].to('cuda')
            coeffs3 = ptwt.wavedec3(images, pywt.Wavelet('haar'), level=1, mode='zero')
            images = torch.cat((coeffs3[0], 
                                coeffs3[1]['aad'], 
                                coeffs3[1]['ada'], 
                                coeffs3[1]['add'], 
                                coeffs3[1]['daa'], 
                                coeffs3[1]['dad'], 
                                coeffs3[1]['dda'], 
                                coeffs3[1]['ddd']), dim=1)

            latent = model.semantic_encoder(images)
            dataset_nc.append(latent)
            
    print(f"{i}th len : {len(loader_list[0][i])}")
    
    print("LESS")
    dataset_less = []
    
    for idx, batch_data in enumerate(loader_list[1][i]):
        model.eval()
        with torch.no_grad():
            images = batch_data['image_train'].to('cuda')
            coeffs3 = ptwt.wavedec3(images, pywt.Wavelet('haar'), level=1, mode='zero')
            images = torch.cat((coeffs3[0], 
                                coeffs3[1]['aad'], 
                                coeffs3[1]['ada'], 
                                coeffs3[1]['add'], 
                                coeffs3[1]['daa'], 
                                coeffs3[1]['dad'], 
                                coeffs3[1]['dda'], 
                                coeffs3[1]['ddd']), dim=1)
            
            latent = model.semantic_encoder(images)
            dataset_less.append(latent)
            
    print(f"{i}th len : {len(loader_list[1][i])}")
    
    print("MORE")
    dataset_more = []
    
    for idx, batch_data in enumerate(loader_list[2][i]):
        model.eval()
        with torch.no_grad():
            images = batch_data['image_train'].to('cuda')
            coeffs3 = ptwt.wavedec3(images, pywt.Wavelet('haar'), level=1, mode='zero')
            images = torch.cat((coeffs3[0], 
                                coeffs3[1]['aad'], 
                                coeffs3[1]['ada'], 
                                coeffs3[1]['add'], 
                                coeffs3[1]['daa'], 
                                coeffs3[1]['dad'], 
                                coeffs3[1]['dda'], 
                                coeffs3[1]['ddd']), dim=1)
            
            latent = model.semantic_encoder(images)
            dataset_more.append(latent)
                
    print(f"{i}th len : {len(loader_list[2][i])}")

    nc_mean_gen = mean_gen_latent(dataset_nc)
    less_mean_gen = mean_gen_latent(dataset_less)
    more_mean_gen = mean_gen_latent(dataset_more)
    
    nc_mean_gen_list.append(nc_mean_gen)
    less_mean_gen_list.append(less_mean_gen)
    more_mean_gen_list.append(more_mean_gen)

In [5]:
import torch.nn.functional as F

def a_direction(init_vector, fin_vector):
    direction_vectors = []
    
    for init_vec, fin_vec in zip(init_vector, fin_vector):
        direction_vector = fin_vec - init_vec
        
        # Normalize the direction vector using F.normalize
        normalized_direction_vector = F.normalize(direction_vector, p=2, dim=1)  # Assuming you want to normalize across the last dimension
        
        direction_vectors.append(normalized_direction_vector)
    return direction_vectors

In [6]:
#NC_LESS
nc_less_dir_list = []
nc_less_dir_list.append(a_direction(nc_mean_gen_list[0], less_mean_gen_list[0]))
nc_less_dir_list.append(a_direction(nc_mean_gen_list[1], less_mean_gen_list[1]))
nc_less_dir_list.append(a_direction(nc_mean_gen_list[2], less_mean_gen_list[2]))
nc_less_dir_list.append(a_direction(nc_mean_gen_list[3], less_mean_gen_list[3]))
nc_less_dir_list.append(a_direction(nc_mean_gen_list[4], less_mean_gen_list[4]))

nc_less_final_direction = []
nc_less_final_direction.append((nc_less_dir_list[0][0][:,:] + nc_less_dir_list[1][0][:,:] + nc_less_dir_list[2][0][:,:] + nc_less_dir_list[3][0][:,:] + nc_less_dir_list[4][0][:,:])/5)
nc_less_final_direction.append((nc_less_dir_list[0][1][:,:] + nc_less_dir_list[1][1][:,:] + nc_less_dir_list[2][1][:,:] + nc_less_dir_list[3][1][:,:] + nc_less_dir_list[4][1][:,:])/5)
nc_less_final_direction.append((nc_less_dir_list[0][2][:,:] + nc_less_dir_list[1][2][:,:] + nc_less_dir_list[2][2][:,:] + nc_less_dir_list[3][2][:,:] + nc_less_dir_list[4][2][:,:])/5)
nc_less_final_direction.append((nc_less_dir_list[0][3][:,:] + nc_less_dir_list[1][3][:,:] + nc_less_dir_list[2][3][:,:] + nc_less_dir_list[3][3][:,:] + nc_less_dir_list[4][3][:,:])/5)
nc_less_final_direction.append((nc_less_dir_list[0][4][:,:] + nc_less_dir_list[1][4][:,:] + nc_less_dir_list[2][4][:,:] + nc_less_dir_list[3][4][:,:] + nc_less_dir_list[4][4][:,:])/5)

#LESS_MORE
less_more_dir_list = []
less_more_dir_list.append(a_direction(less_mean_gen_list[0], more_mean_gen_list[0]))
less_more_dir_list.append(a_direction(less_mean_gen_list[1], more_mean_gen_list[1]))
less_more_dir_list.append(a_direction(less_mean_gen_list[2], more_mean_gen_list[2]))
less_more_dir_list.append(a_direction(less_mean_gen_list[3], more_mean_gen_list[3]))
less_more_dir_list.append(a_direction(less_mean_gen_list[4], more_mean_gen_list[4]))

less_more_final_direction = []
less_more_final_direction.append((less_more_dir_list[0][0][:,:] + less_more_dir_list[1][0][:,:] + less_more_dir_list[2][0][:,:] + less_more_dir_list[3][0][:,:] + less_more_dir_list[4][0][:,:])/5)
less_more_final_direction.append((less_more_dir_list[0][1][:,:] + less_more_dir_list[1][1][:,:] + less_more_dir_list[2][1][:,:] + less_more_dir_list[3][1][:,:] + less_more_dir_list[4][1][:,:])/5)
less_more_final_direction.append((less_more_dir_list[0][2][:,:] + less_more_dir_list[1][2][:,:] + less_more_dir_list[2][2][:,:] + less_more_dir_list[3][2][:,:] + less_more_dir_list[4][2][:,:])/5)
less_more_final_direction.append((less_more_dir_list[0][3][:,:] + less_more_dir_list[1][3][:,:] + less_more_dir_list[2][3][:,:] + less_more_dir_list[3][3][:,:] + less_more_dir_list[4][3][:,:])/5)
less_more_final_direction.append((less_more_dir_list[0][4][:,:] + less_more_dir_list[1][4][:,:] + less_more_dir_list[2][4][:,:] + less_more_dir_list[3][4][:,:] + less_more_dir_list[4][4][:,:])/5)

#LESS_NC
less_nc_dir_list = []
less_nc_dir_list.append(a_direction(less_mean_gen_list[0], nc_mean_gen_list[0]))
less_nc_dir_list.append(a_direction(less_mean_gen_list[1], nc_mean_gen_list[1]))
less_nc_dir_list.append(a_direction(less_mean_gen_list[2], nc_mean_gen_list[2]))
less_nc_dir_list.append(a_direction(less_mean_gen_list[3], nc_mean_gen_list[3]))
less_nc_dir_list.append(a_direction(less_mean_gen_list[4], nc_mean_gen_list[4]))

less_nc_final_direction = []
less_nc_final_direction.append((less_nc_dir_list[0][0][:,:] + less_nc_dir_list[1][0][:,:] + less_nc_dir_list[2][0][:,:] + less_nc_dir_list[3][0][:,:] + less_nc_dir_list[4][0][:,:])/5)
less_nc_final_direction.append((less_nc_dir_list[0][1][:,:] + less_nc_dir_list[1][1][:,:] + less_nc_dir_list[2][1][:,:] + less_nc_dir_list[3][1][:,:] + less_nc_dir_list[4][1][:,:])/5)
less_nc_final_direction.append((less_nc_dir_list[0][2][:,:] + less_nc_dir_list[1][2][:,:] + less_nc_dir_list[2][2][:,:] + less_nc_dir_list[3][2][:,:] + less_nc_dir_list[4][2][:,:])/5)
less_nc_final_direction.append((less_nc_dir_list[0][3][:,:] + less_nc_dir_list[1][3][:,:] + less_nc_dir_list[2][3][:,:] + less_nc_dir_list[3][3][:,:] + less_nc_dir_list[4][3][:,:])/5)
less_nc_final_direction.append((less_nc_dir_list[0][4][:,:] + less_nc_dir_list[1][4][:,:] + less_nc_dir_list[2][4][:,:] + less_nc_dir_list[3][4][:,:] + less_nc_dir_list[4][4][:,:])/5)

#MORE_LESS
more_less_dir_list = []
more_less_dir_list.append(a_direction(more_mean_gen_list[0], less_mean_gen_list[0]))
more_less_dir_list.append(a_direction(more_mean_gen_list[1], less_mean_gen_list[1]))
more_less_dir_list.append(a_direction(more_mean_gen_list[2], less_mean_gen_list[2]))
more_less_dir_list.append(a_direction(more_mean_gen_list[3], less_mean_gen_list[3]))
more_less_dir_list.append(a_direction(more_mean_gen_list[4], less_mean_gen_list[4]))

more_less_final_direction = []
more_less_final_direction.append((more_less_dir_list[0][0][:,:] + more_less_dir_list[1][0][:,:] + more_less_dir_list[2][0][:,:] + more_less_dir_list[3][0][:,:] + more_less_dir_list[4][0][:,:])/5)
more_less_final_direction.append((more_less_dir_list[0][1][:,:] + more_less_dir_list[1][1][:,:] + more_less_dir_list[2][1][:,:] + more_less_dir_list[3][1][:,:] + more_less_dir_list[4][1][:,:])/5)
more_less_final_direction.append((more_less_dir_list[0][2][:,:] + more_less_dir_list[1][2][:,:] + more_less_dir_list[2][2][:,:] + more_less_dir_list[3][2][:,:] + more_less_dir_list[4][2][:,:])/5)
more_less_final_direction.append((more_less_dir_list[0][3][:,:] + more_less_dir_list[1][3][:,:] + more_less_dir_list[2][3][:,:] + more_less_dir_list[3][3][:,:] + more_less_dir_list[4][3][:,:])/5)
more_less_final_direction.append((more_less_dir_list[0][4][:,:] + more_less_dir_list[1][4][:,:] + more_less_dir_list[2][4][:,:] + more_less_dir_list[3][4][:,:] + more_less_dir_list[4][4][:,:])/5)

In [7]:
import natsort
import glob
import SimpleITK as sitk

def get_loader_for_manipulation():

    NC_list = natsort.natsorted(glob.glob("/workspace/Ablation/ABLATION_PD/All_The_tSNE/MANIPULATION/NC/*.nii.gz"))
    LESS_list = natsort.natsorted(glob.glob("/workspace/Ablation/ABLATION_PD/All_The_tSNE/MANIPULATION/LESS/*.nii.gz"))
    MORE_list = natsort.natsorted(glob.glob("/workspace/Ablation/ABLATION_PD/All_The_tSNE/MANIPULATION/MORE/*.nii.gz"))
    
    NC_patient_list = []
    LESS_patient_list = []
    MORE_patient_list = []
    
    for i in range(len(NC_list)):
        NC_patient_list.append(NC_list[i].split('/')[-1].split("_centered")[0])
        LESS_patient_list.append(LESS_list[i].split('/')[-1].split("_centered")[0])
        MORE_patient_list.append(MORE_list[i].split('/')[-1].split("_centered")[0])
    
    files_tr_nc = []
    files_tr_less = []
    files_tr_more = []

    train_idx_nc = 0
    train_idx_less = 0
    train_idx_more = 0

    for file_name in NC_list:
        files_tr_nc.append({"image_train": file_name})
        train_idx_nc += 1
        
    for file_name in LESS_list:
        files_tr_less.append({"image_train": file_name})
        train_idx_less += 1
        
    for file_name in MORE_list:
        files_tr_more.append({"image_train": file_name})
        train_idx_more += 1
    
    print(f"Train [Total]  number = {train_idx_nc}, {train_idx_less}, {train_idx_more}")

    tr_transforms = transforms.Compose(
        [
            transforms.LoadImaged(keys=["image_train"]),
            transforms.EnsureChannelFirstd(keys=["image_train"]),
            transforms.Orientationd(keys=["image_train"], axcodes="LPS"),
            transforms.ScaleIntensityRanged(keys=["image_train"], a_min=0.0, a_max=22.0, b_min=0.0, b_max=1.0, clip=True),
            transforms.EnsureTyped(keys=["image_train"]),
            transforms.ToTensord(keys=["image_train"], track_meta=False)
        ]
    )

    # new_dataset -> Cachenew_dataset
    train_ds_nc = data.Dataset(data = files_tr_nc, transform = tr_transforms)
    files_tr_less = data.Dataset(data = files_tr_less, transform = tr_transforms)
    files_tr_more = data.Dataset(data = files_tr_more, transform = tr_transforms)

    nc_loader = data.DataLoader(train_ds_nc, batch_size=1, shuffle=False, num_workers=8, pin_memory=True)
    less_loader = data.DataLoader(files_tr_less, batch_size=1, shuffle=False, num_workers=8, pin_memory=True)
    more_loader = data.DataLoader(files_tr_more, batch_size=1, shuffle=False, num_workers=8, pin_memory=True)
    
    print("loader is ver(train, val)")

    return [nc_loader, less_loader, more_loader], [NC_patient_list, LESS_patient_list, MORE_patient_list]
    

In [None]:
loader_list_init, [NC_ID, LESS_ID, MORE_ID] = get_loader_for_manipulation()

print("NC")
tsne_dataset_nc_init = []

for idx, batch_data in enumerate(loader_list_init[0]):
    model.eval()
    with torch.no_grad():
        images = batch_data['image_train'].to('cuda')
        coeffs3 = ptwt.wavedec3(images, pywt.Wavelet('haar'), level=1, mode='zero')
        images = torch.cat((coeffs3[0], 
                            coeffs3[1]['aad'], 
                            coeffs3[1]['ada'], 
                            coeffs3[1]['add'], 
                            coeffs3[1]['daa'], 
                            coeffs3[1]['dad'], 
                            coeffs3[1]['dda'], 
                            coeffs3[1]['ddd']), dim=1)
        
        latent = model.semantic_encoder(images)
        tsne_dataset_nc_init.append(latent)
  
print("LESS")

tsne_dataset_less_init = []

for idx, batch_data in enumerate(loader_list_init[1]):
    model.eval()
    with torch.no_grad():
        images = batch_data['image_train'].to('cuda')
        coeffs3 = ptwt.wavedec3(images, pywt.Wavelet('haar'), level=1, mode='zero')
        images = torch.cat((coeffs3[0], 
                            coeffs3[1]['aad'], 
                            coeffs3[1]['ada'], 
                            coeffs3[1]['add'], 
                            coeffs3[1]['daa'], 
                            coeffs3[1]['dad'], 
                            coeffs3[1]['dda'], 
                            coeffs3[1]['ddd']), dim=1)
        
        latent = model.semantic_encoder(images)
        tsne_dataset_less_init.append(latent)
        
print("MORE")

tsne_dataset_more_init = []

for idx, batch_data in enumerate(loader_list_init[2]):
    model.eval()
    with torch.no_grad():
        images = batch_data['image_train'].to('cuda')
        coeffs3 = ptwt.wavedec3(images, pywt.Wavelet('haar'), level=1, mode='zero')
        images = torch.cat((coeffs3[0], 
                            coeffs3[1]['aad'], 
                            coeffs3[1]['ada'], 
                            coeffs3[1]['add'], 
                            coeffs3[1]['daa'], 
                            coeffs3[1]['dad'], 
                            coeffs3[1]['dda'], 
                            coeffs3[1]['ddd']), dim=1)
        
        latent = model.semantic_encoder(images)
        tsne_dataset_more_init.append(latent)

In [None]:
import pywt
import ptwt
import SimpleITK as sitk
from generative.networks.schedulers import DDPMScheduler
from generative.inferers import DiffusionInferer_ae

scheduler = DDPMScheduler(num_train_timesteps=1000, 
                          schedule="linear_beta", 
                          beta_start=0.0005, 
                          beta_end=0.0195)

inferer = DiffusionInferer_ae(scheduler)

for img_idx in range(5):
    # NC to LESS
    for factor in range(5):
        factor = factor*0.5
        print(f"factor is {factor}")
        manipulated_vectors = [nc_vec + direction_vector * factor
                            for nc_vec, direction_vector in zip(tsne_dataset_nc_init[img_idx], nc_less_final_direction)]
        latent = manipulated_vectors
        model.eval()
        with torch.no_grad():
            _, _, H, W, D = images.shape
            image = torch.randn((1, 8, 96, 96, 48))
            image = image.to("cuda")
            scheduler.set_timesteps(num_inference_steps=1000)
            image_pred = inferer.sample(input_noise=image, 
                                        diffusion_model=model.unet, 
                                        scheduler=scheduler, 
                                        save_intermediates=False, 
                                        cond=latent)
            rand_image = torch.randn((1, 1, 192, 192, 96))
            coeffs3 = list(ptwt.wavedec3(rand_image, pywt.Wavelet('haar'), level=1, mode='zero'))

            coeffs3[0] = image_pred[:,0:1,:,:,:]
            coeffs3[1]['aad'] = image_pred[:,1:2,:,:,:]
            coeffs3[1]['ada'] = image_pred[:,2:3,:,:,:]
            coeffs3[1]['add'] = image_pred[:,3:4,:,:,:]
            coeffs3[1]['daa'] = image_pred[:,4:5,:,:,:]
            coeffs3[1]['dad'] = image_pred[:,5:6,:,:,:]
            coeffs3[1]['dda'] = image_pred[:,6:7,:,:,:]
            coeffs3[1]['ddd'] = image_pred[:,7:8,:,:,:]
            coeffs3 = tuple(coeffs3)
        
            reconstruction = ptwt.waverec3(coeffs3, pywt.Wavelet("haar"))
            
            pred_img = reconstruction.cpu().detach().numpy().transpose(0,4,3,2,1).squeeze()
            save_pred = sitk.GetImageFromArray(pred_img)
            img_id = NC_ID[img_idx]
            sitk.WriteImage(save_pred, f"/workspace/PD_SSL_ZOO/4_LATENT_MANIPULATION/nc_less_{img_id}_{factor}_pred.nii.gz")
            
    #LESS to MORE
    for factor in range(5):
        factor = factor*0.5
        print(f"factor is {factor}")
        manipulated_vectors = [nc_vec + direction_vector * factor
                            for nc_vec, direction_vector in zip(tsne_dataset_less_init[img_idx], less_more_final_direction)]
        latent = manipulated_vectors
        model.eval()
        with torch.no_grad():
            _, _, H, W, D = images.shape
            image = torch.randn((1, 8, 96, 96, 48))
            image = image.to("cuda")
            scheduler.set_timesteps(num_inference_steps=1000)
            image_pred = inferer.sample(input_noise=image, 
                                        diffusion_model=model.unet, 
                                        scheduler=scheduler, 
                                        save_intermediates=False, 
                                        cond=latent)
            rand_image = torch.randn((1, 1, 192, 192, 96))
            coeffs3 = list(ptwt.wavedec3(rand_image, pywt.Wavelet('haar'), level=1, mode='zero'))
            
            coeffs3[0] = image_pred[:,0:1,:,:,:]
            coeffs3[1]['aad'] = image_pred[:,1:2,:,:,:]
            coeffs3[1]['ada'] = image_pred[:,2:3,:,:,:]
            coeffs3[1]['add'] = image_pred[:,3:4,:,:,:]
            coeffs3[1]['daa'] = image_pred[:,4:5,:,:,:]
            coeffs3[1]['dad'] = image_pred[:,5:6,:,:,:]
            coeffs3[1]['dda'] = image_pred[:,6:7,:,:,:]
            coeffs3[1]['ddd'] = image_pred[:,7:8,:,:,:]
            coeffs3 = tuple(coeffs3)
        
            reconstruction = ptwt.waverec3(coeffs3, pywt.Wavelet("haar"))
            
            pred_img = reconstruction.cpu().detach().numpy().transpose(0,4,3,2,1).squeeze()
            save_pred = sitk.GetImageFromArray(pred_img)
            img_id = LESS_ID[img_idx]
            sitk.WriteImage(save_pred, f"/workspace/PD_SSL_ZOO/4_LATENT_MANIPULATION/less_more_{img_id}_{factor}_pred.nii.gz")

    # #LESS to NC
    for factor in range(5):
        factor = factor*0.5
        print(f"factor is {factor}")
        manipulated_vectors = [nc_vec + direction_vector * factor
                            for nc_vec, direction_vector in zip(tsne_dataset_less_init[img_idx], less_nc_final_direction)]
        latent = manipulated_vectors
        model.eval()
        with torch.no_grad():
            _, _, H, W, D = images.shape
            image = torch.randn((1, 8, 96, 96, 48))
            image = image.to("cuda")
            scheduler.set_timesteps(num_inference_steps=1000)
            image_pred = inferer.sample(input_noise=image, 
                                        diffusion_model=model.unet, 
                                        scheduler=scheduler, 
                                        save_intermediates=False, 
                                        cond=latent)
            rand_image = torch.randn((1, 1, 192, 192, 96))
            coeffs3 = list(ptwt.wavedec3(rand_image, pywt.Wavelet('haar'), level=1, mode='zero'))
            
            coeffs3[0] = image_pred[:,0:1,:,:,:]
            coeffs3[1]['aad'] = image_pred[:,1:2,:,:,:]
            coeffs3[1]['ada'] = image_pred[:,2:3,:,:,:]
            coeffs3[1]['add'] = image_pred[:,3:4,:,:,:]
            coeffs3[1]['daa'] = image_pred[:,4:5,:,:,:]
            coeffs3[1]['dad'] = image_pred[:,5:6,:,:,:]
            coeffs3[1]['dda'] = image_pred[:,6:7,:,:,:]
            coeffs3[1]['ddd'] = image_pred[:,7:8,:,:,:]
            coeffs3 = tuple(coeffs3)
        
            reconstruction = ptwt.waverec3(coeffs3, pywt.Wavelet("haar"))
            
            pred_img = reconstruction.cpu().detach().numpy().transpose(0,4,3,2,1).squeeze()
            save_pred = sitk.GetImageFromArray(pred_img)
            img_id = LESS_ID[img_idx]
            sitk.WriteImage(save_pred, f"/workspace/PD_SSL_ZOO/4_LATENT_MANIPULATION/less_nc_{img_id}_{factor}_pred.nii.gz")

    #MORE to LESS
    for factor in range(5):
        factor = factor*0.5
        print(f"factor is {factor}")
        manipulated_vectors = [nc_vec + direction_vector * factor
                            for nc_vec, direction_vector in zip(tsne_dataset_more_init[img_idx], more_less_final_direction)]
        latent = manipulated_vectors
        model.eval()
        with torch.no_grad():
            _, _, H, W, D = images.shape
            image = torch.randn((1, 8, 96, 96, 48))
            image = image.to("cuda")
            scheduler.set_timesteps(num_inference_steps=1000)
            image_pred = inferer.sample(input_noise=image, 
                                        diffusion_model=model.unet, 
                                        scheduler=scheduler, 
                                        save_intermediates=False, 
                                        cond=latent)
            rand_image = torch.randn((1, 1, 192, 192, 96))
            coeffs3 = list(ptwt.wavedec3(rand_image, pywt.Wavelet('haar'), level=1, mode='zero'))
            
            coeffs3[0] = image_pred[:,0:1,:,:,:]
            coeffs3[1]['aad'] = image_pred[:,1:2,:,:,:]
            coeffs3[1]['ada'] = image_pred[:,2:3,:,:,:]
            coeffs3[1]['add'] = image_pred[:,3:4,:,:,:]
            coeffs3[1]['daa'] = image_pred[:,4:5,:,:,:]
            coeffs3[1]['dad'] = image_pred[:,5:6,:,:,:]
            coeffs3[1]['dda'] = image_pred[:,6:7,:,:,:]
            coeffs3[1]['ddd'] = image_pred[:,7:8,:,:,:]
            coeffs3 = tuple(coeffs3)
        
            reconstruction = ptwt.waverec3(coeffs3, pywt.Wavelet("haar"))
            
            pred_img = reconstruction.cpu().detach().numpy().transpose(0,4,3,2,1).squeeze()
            save_pred = sitk.GetImageFromArray(pred_img)
            img_id = MORE_ID[img_idx]
            sitk.WriteImage(save_pred, f"/workspace/PD_SSL_ZOO/4_LATENT_MANIPULATION/more_less_{img_id}_{factor}_pred.nii.gz")