<a href="https://colab.research.google.com/github/zubaerimran/SSWL-IDN/blob/main/SSWL-IDN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Setup

In [None]:
  from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
!pip install dicom_numpy
!pip install lpips



In [None]:
import os, re
from glob import glob
import numpy as np
import matplotlib.pyplot as plt
import dicom_numpy as dcm2np
import pydicom as dicom
import scipy.io as sio
import pdb

from skimage.transform import resize
from sklearn.model_selection import train_test_split

# Data Preparation

In [None]:
#Returns numpy array and image affine from a list of dicoms
def extract_voxel_data(list_of_dicom_files):
    datasets = [dicom.read_file(f) for f in list_of_dicom_files]
    try:
        ndarray, afn = dcm2np.combine_slices(datasets)
    except:
        print(len(datasets))
        pass
        #dicom_numpy.DicomImportException as e:
        # invalid DICOM data
        #raise
    for idx in range(ndarray.shape[-1]):
        ndarray[:,:,idx] = np.transpose(ndarray[:,:,idx], (1,0)) #transpose to align the axes [from (y, x) to (x, y)]
    return ndarray, afn

In [None]:
#Simulate low-dose
'''
Given: quarter dose and full dose CT images, required Dose level.
Return: Images at the Dose level
'''
def simulate_ld(I_qd, I_fd, Dose=1.):
    if Dose==1:
        return I_fd
    elif Dose==0.25:
        return I_qd
    else:
        a = np.sqrt(((1/Dose)-1)/3)
        #print(a)

        I_noise = I_qd - I_fd # extract the noise array by just subtracting the images

        return I_fd+(a*I_noise) # take the full dose and add the noise multiplied by a coefficient, which is a function inputting the requested dose


#Window-leveling
'''
input (before w/l image, window width, window center, newmax, newmin)
'''
def window_leveling(x, w, c, ymin=0, ymax=1.):
    sh = x.shape

    y = np.zeros(sh) #window-leveled image
    
    #print(x[x >= -110 & x < 189])
    #if (x <= c - 0.5 - (w-1) /2), then y = ymin
    mask1 = ( x <= (c - 0.5 - (w-1) /2) )
    y[mask1] = ymin #Update
    
    #else if (x > c - 0.5 + (w-1) /2), then y = ymax
    mask2 = ( x > (c - 0.5 + (w-1) /2) )
    y[mask2] = ymax #Update

    #else y = ((x - (c - 0.5)) / (w-1) + 0.5) * (ymax- ymin) + ymin
    mask3 = ( x > (c - 0.5 - (w-1) /2)) & (x <= (c - 0.5 + (w-1) /2) ) 
    np.putmask(y, mask3, ((x - (c - 0.5)) / (w-1) + 0.5) * (ymax- ymin) + ymin) #Update
    
    return y #return window-leveled image


#For each CT slice, do w/l and resizing if required
def prep_slices(volume, scale=512):
    resized_data = []
    
    for i in range(volume.shape[0]):
        img = window_leveling(volume[i], 300, 40) #width=300, center=40
        #print('Img max-min: ', np.max(img), np.min(img))
        
        #img = resize(img, [scale, scale]) #uncomment for resizing
        assert img.shape == (scale, scale)
        try:
            assert np.max(img) == 1.
            assert np.min(img) == 0.
        except:
            print('Wrong!')
        resized_data.append(img)
    
    resized_data = np.reshape(resized_data, [-1, scale, scale])

    return resized_data

#For each CT slice, do w/l and resizing if required
def prep_slices_no_window_leveling(volume, scale=512):
    resized_data = []
    
    for i in range(volume.shape[0]):
        # img = window_leveling(volume[i], 300, 40) #width=300, center=40
        img = volume[i]
        print('Img max-min: ', np.max(img), np.min(img))

        # pdb.set_trace()

        #img = resize(img, [scale, scale]) #uncomment for resizing
        assert img.shape == (scale, scale)
        try:
            assert np.max(img) == 1.
            assert np.min(img) == 0.
        except:
            print('Wrong!')
        resized_data.append(img)
    
    resized_data = np.reshape(resized_data, [-1, scale, scale])

    return resized_data

# Load Data

In [None]:
ld_train = sio.loadmat("/content/drive/MyDrive/Research/SSL_Ayaan/all_ct_data_5dose.mat")['ld_train']
fd_train = sio.loadmat("/content/drive/MyDrive/Research/SSL_Ayaan/all_ct_data_5dose.mat")['fd_train']

ld_test = sio.loadmat("/content/drive/MyDrive/Research/SSL_Ayaan/all_ct_data_5dose.mat")['ld_test']
fd_test = sio.loadmat("/content/drive/MyDrive/Research/SSL_Ayaan/all_ct_data_5dose.mat")['fd_test']

ld_train_nw = sio.loadmat("/content/drive/MyDrive/Research/SSL_Ayaan/all_ct_data_nowindow_5dose.mat")['ld_train_nw']
fd_train_nw = sio.loadmat("/content/drive/MyDrive/Research/SSL_Ayaan/all_ct_data_nowindow_5dose.mat")['fd_train_nw']

ld_test_nw = sio.loadmat("/content/drive/MyDrive/Research/SSL_Ayaan/all_ct_data_nowindow_5dose.mat")['ld_test_nw']
fd_test_nw = sio.loadmat("/content/drive/MyDrive/Research/SSL_Ayaan/all_ct_data_nowindow_5dose.mat")['fd_test_nw']

# normalize non-window-leveled scans
ld_train_nw = (ld_train_nw - np.min(ld_train_nw))/np.ptp(ld_train_nw)
fd_train_nw = (fd_train_nw - np.min(fd_train_nw))/np.ptp(fd_train_nw)
ld_test_nw = (ld_test_nw - np.min(ld_test_nw))/np.ptp(ld_test_nw)
fd_test_nw = (fd_test_nw - np.min(fd_test_nw))/np.ptp(fd_test_nw)

print(ld_train.shape, ld_train.shape)
print(ld_test.shape, fd_test.shape)

# Data Loading

In [None]:
import torch
from torch import nn,optim
from torch.optim import lr_scheduler
import torch.nn.functional as F
from torchsummary import summary
import lpips

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from skimage.util import random_noise
from sklearn.preprocessing import normalize
import skimage.metrics as skmetrics

import time
from collections import defaultdict
from itertools import cycle
import gc
from math import exp
from torch.autograd import Variable
import torch.nn.init as init

In [None]:
class CT_Dataset(Dataset):
    def __init__(self, inputs, targets, transform=None):

        self.input_ = inputs
        self.target_ = targets
        self.transform = transform

    def __len__(self):
        return len(self.target_)

    def __getitem__(self, idx):
        input_img, target_img = self.input_[idx], self.target_[idx]

        if self.transform:
            input_img = self.transform(input_img)
            target_img = self.transform(target_img)
        
        
        return (input_img, target_img)

# Define Model

In [None]:
class SSWL_IDN(nn.Module):
    def __init__(self, img_ch = 1, out_ch=96, latent_dim = 256):
        super(VAE_RED_CNN, self).__init__()
        self.conv1 = nn.Conv2d(img_ch, out_ch, kernel_size=5, stride=1, padding=0)
        self.conv2 = nn.Conv2d(out_ch, out_ch, kernel_size=5, stride=1, padding=0)
        self.conv3 = nn.Conv2d(out_ch, out_ch, kernel_size=5, stride=1, padding=0)
        self.conv4 = nn.Conv2d(out_ch, out_ch, kernel_size=5, stride=1, padding=0)
        self.conv5 = nn.Conv2d(out_ch, out_ch, kernel_size=5, stride=1, padding=0)

        # Build Bottleneck
        self.fc_mu = nn.Linear(96 * 4 * 4, latent_dim)
        self.fc_var = nn.Linear(96 * 4 * 4, latent_dim)

        # Build Decoder

        self.decoder_input = nn.Linear(latent_dim, 96 * 4 * 4)

        self.avgpool = nn.AdaptiveAvgPool2d((4,4))
        self.upsample = nn.Upsample(scale_factor= (256 - 4 * 5)/4, mode='bilinear', align_corners=True)

        self.tconv1 = nn.ConvTranspose2d(out_ch, out_ch, kernel_size=5, stride=1, padding=0)
        self.tconv2 = nn.ConvTranspose2d(out_ch, out_ch, kernel_size=5, stride=1, padding=0)
        self.tconv3 = nn.ConvTranspose2d(out_ch, out_ch, kernel_size=5, stride=1, padding=0)
        self.tconv4 = nn.ConvTranspose2d(out_ch, out_ch, kernel_size=5, stride=1, padding=0)
        self.tconv5 = nn.ConvTranspose2d(out_ch, img_ch, kernel_size=5, stride=1, padding=0)

        self.relu = nn.ReLU()

    def forward(self, x):
        # encoder
        residual_1 = x
        out = self.relu(self.conv1(x))
        out = self.relu(self.conv2(out))
        residual_2 = out
        out = self.relu(self.conv3(out))
        out = self.relu(self.conv4(out))
        residual_3 = out
        out = self.relu(self.conv5(out))

        result = self.avgpool(out)
        result = torch.flatten(result, start_dim=1)
        # Split the result into mu and var components of the latent Gaussian distribution
        mu = self.fc_mu(result)
        logvar = self.fc_var(result)

        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        z = eps * std + mu

        result = self.decoder_input(z)
        result = result.view(-1, 96, 4, 4)
        result = self.upsample(result)

        # decoder
        out = self.tconv1(result)
        out += residual_3
        out = self.tconv2(self.relu(out))
        out = self.tconv3(self.relu(out))
        out += residual_2
        out = self.tconv4(self.relu(out))
        out = self.tconv5(self.relu(out))
        out += residual_1
        out = self.relu(out)

        return out, mu, logvar

# Training

## Loops

In [None]:
def compute_measure(x, y, pred, data_range = None):

    original_psnr = compute_PSNR(y, x, data_range)
    original_ssim = compute_SSIM(y, x, data_range)
    original_mse = compute_MSE(y, x)
    original_rmse = compute_RMSE(y, x)
    original_nrmse = compute_NRMSE(y, x)

    pred_psnr = compute_PSNR(y, pred, data_range)
    pred_ssim = compute_SSIM(y, pred, data_range)
    pred_mse = compute_MSE(y, pred)
    pred_rmse = compute_RMSE(y, pred)
    pred_nrmse = compute_NRMSE(y, pred)

    return (original_psnr, original_ssim, original_mse, original_rmse, original_nrmse), (pred_psnr, pred_ssim, pred_mse, pred_rmse, pred_nrmse)

def compute_MSE(img1, img2):
    return skmetrics.mean_squared_error(img1, img2)


def compute_RMSE(img1, img2):
    return np.sqrt(compute_MSE(img1, img2))

def compute_NRMSE(img1, img2):
    return skmetrics.normalized_root_mse(img1, img2)

def compute_PSNR(img1, img2, data_range = None):
    return skmetrics.peak_signal_noise_ratio(img1, img2)


def compute_SSIM(img1, img2, data_range = None, window_size=11, channel=1, size_average=True):

    img1 = np.reshape(img1, [-1, 256, 256])
    img2 = np.reshape(img2, [-1, 256, 256])


    total_ssim = 0

    for i in range(len(img1)):
      try:
        total_ssim += skmetrics.structural_similarity(img1[i], img2[i])
      except:
        pdb.set_trace()
    
    return total_ssim / len(img1)
    

def save_fig(x, y, pred, fig_name, original_result, pred_result, save_path = "/content/"):
    # x, y, pred = x.numpy(), y.numpy(), pred.numpy()
    f, ax = plt.subplots(1, 3, figsize=(30, 10))
    ax[0].imshow(np.squeeze(x[0]), cmap=plt.cm.gray)
    ax[0].set_title('Quarter-dose', fontsize=30)
    ax[0].set_xlabel("PSNR: {:.4f}\nSSIM: {:.4f}\nMSE: {:.4f}\nRMSE: {:.4f}\nNRMSE: {:.4f}".format(original_result[0],
                                                                        original_result[1],
                                                                        original_result[2],
                                                                        original_result[3],
                                                                        original_result[4]), fontsize=20)
    # Predictions === \nPSNR avg: {:.4f} \nSSIM avg: {:.4f} \nMSE avg: {:.4f} \nRMSE avg: {:.4f} \nNRMSE avg: {:.4f}
    ax[1].imshow(np.squeeze(pred[0]), cmap=plt.cm.gray)
    ax[1].set_title('Result', fontsize=30)
    ax[1].set_xlabel("PSNR: {:.4f}\nSSIM: {:.4f}\nMSE: {:.4f}\nRMSE: {:.4f}\nNRMSE: {:.4f}".format(pred_result[0],
                                                                        pred_result[1],
                                                                        pred_result[2],
                                                                        pred_result[3],
                                                                        pred_result[4]), fontsize=20)
    ax[2].imshow(np.squeeze(y[0]), cmap=plt.cm.gray)
    ax[2].set_title('Full-dose', fontsize=30)

    f.savefig(os.path.join(save_path, 'result_{}.png'.format(fig_name)))
    plt.close()

In [None]:
def hybrid_loss(out, targets, criterion, percep_loss, metrics, weight = 0.6):

  out_3d = out.repeat(1, 3, 1, 1)
  targets_3d = targets.repeat(1, 3, 1, 1)
  perceptual = percep_loss.forward(out_3d, targets_3d)
  perceptual = torch.mean(perceptual)

  l_loss = criterion(out, targets)

  loss = l_loss + weight * perceptual

  # metrics['loss'] += loss.data.cpu().numpy() * targets.size(0)
  
  return loss

def vae_loss(out, targets, mu, logvar, criterion, percep_loss, metrics, hybrid = False, kld_weight = 1.0):

  # pdb.set_trace()
  
  kld_loss = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
  
  if hybrid:
    mse_loss = hybrid_loss(out, targets, criterion, percep_loss, metrics)
  else:
    mse_loss = criterion(out, targets)
  
  loss = mse_loss + kld_weight * kld_loss

  metrics['loss'] += loss.data.cpu().numpy() * targets.size(0)
  
  return loss

def calc_loss(out, targets, criterion, percep_loss, metrics, perceptual = False):

  if perceptual:
      out_3d = out.repeat(1, 3, 1, 1)
      targets_3d = targets.repeat(1, 3, 1, 1)
      loss = percep_loss.forward(out_3d, targets_3d)
      # loss = percep_loss.forward(out, targets)
      loss = torch.mean(loss)
      # pdb.set_trace()
  else:
    loss = criterion(out, targets)

  metrics['loss'] += loss.data.cpu().numpy() * targets.size(0)
  
  return loss

def print_metrics(metrics, epoch_samples, phase):
  outputs = []
  for k in metrics.keys():
      outputs.append("{}: {:4f}".format(k, metrics[k] / epoch_samples))

  print("{}: {}".format(phase, ", ".join(outputs)))

def train(model, dataloader, optimizer, criterion, percep_loss, scheduler, checkpoint_path, print_iters = 10, epochs = 15, hybrid = False, perceptual = False, vae = False):
  best_loss = 1e10
  train_losses = []
  total_iters = 0
  start_time = time.time()

  for epoch in range(1, epochs):

    # get CUDA space
    gc.collect()
    torch.cuda.empty_cache()

    for phase in ['train', 'test']:
      
      gc.collect()
      torch.cuda.empty_cache()

      if phase == 'train':

        gc.collect()
        torch.cuda.empty_cache()

        model.train()

        metrics = defaultdict(float)
        epoch_samples = 0.0

      elif phase == 'test':
        
        gc.collect()
        torch.cuda.empty_cache()
        
        model.eval()
      
      if phase == 'train':

        for i, (data) in enumerate(dataloader[phase]):

          inputs, targets = data
          inputs, targets = inputs.to(device=device, dtype=torch.float), targets.to(device=device, dtype=torch.float)

          if vae:
            out, mu, logvar = model(inputs)
          else:
            out = model(inputs)

          model.zero_grad()
          optimizer.zero_grad()

          if vae:
            loss = vae_loss(out, targets, mu, logvar, criterion, percep_loss, metrics, hybrid = hybrid)
          elif hybrid:
            loss = hybrid_loss(out, targets, criterion, percep_loss, metrics)
          else:
            loss = calc_loss(out, targets, criterion, percep_loss, metrics, perceptual = perceptual)
          
          loss.backward()
          optimizer.step()

          epoch_loss = loss.item()
          train_losses.append(epoch_loss)

          # print
          if i % print_iters == 0:
              print("STEP [{}], EPOCH [{}/{}], ITER [{}/{}] \nLOSS: {:.8f}, TIME: {:.1f}s".format(i, epoch, 
                                                                                                  epochs, i+1, 
                                                                                                  len(dataloader[phase]), loss.item(), 
                                                                                                time.time() - start_time))
          epoch_samples += len(inputs)

      else:
        print("Testing")
        # add testing code later if needed
        # test(model)

    print_metrics(metrics, epoch_samples, phase)
    epoch_loss = metrics['loss'] / epoch_samples
    
    if phase == 'train':
      scheduler.step()
      for param_group in optimizer.param_groups:
        print("LR", param_group['lr'])

    # save the model weights
    if phase == 'test':
        if epoch_loss < best_loss:
          print(f"saving best model to {checkpoint_path}")
          best_loss = epoch_loss
          torch.save(model.state_dict(), checkpoint_path)
        
    
  # load best model weights
  model.load_state_dict(torch.load(checkpoint_path))
  return model

In [None]:
def test(model, test_loader, mat_path, vae = False):
      # load
      model.eval()

      # compute PSNR, SSIM, RMSE
      ori_psnr_avg, ori_ssim_avg, ori_mse_avg, ori_rmse_avg, ori_nrmse_avg = 0, 0, 0, 0, 0
      pred_psnr_avg, pred_ssim_avg, pred_mse_avg, pred_rmse_avg, pred_nrmse_avg= 0, 0, 0, 0, 0

      with torch.no_grad():
          for i, (x,y) in enumerate(test_loader):

              x, y = x.to(device=device, dtype=torch.float), y.to(device=device, dtype=torch.float)

              shape_ = x.shape[-1]

              if vae:
                pred, _, _ = model(x)
              else:
                pred = model(x)

              x = x.cpu().detach().numpy()
              y = y.cpu().detach().numpy()
              pred = pred.cpu().detach().numpy()

              # x = np.squeeze(x)
              # y = np.squeeze(y)
              # pred = np.squeeze(pred)

              original_result, pred_result = compute_measure(x, y, pred)
              
              ori_psnr_avg += original_result[0]
              ori_ssim_avg += original_result[1]
              ori_mse_avg += original_result[2]
              ori_rmse_avg += original_result[3]
              ori_nrmse_avg += original_result[4]

              pred_psnr_avg += pred_result[0]
              pred_ssim_avg += pred_result[1]
              pred_mse_avg += pred_result[2]
              pred_rmse_avg += pred_result[3]
              pred_nrmse_avg += pred_result[4]

              # def save_fig(self, x, y, pred, fig_name, original_result, pred_result):


              save_fig(x, y, pred, i, original_result, pred_result)

          print('\n')
          print('Original === \nPSNR avg: {:.4f} \nSSIM avg: {:.4f} \nMSE avg: {:.4f} \nRMSE avg: {:.4f} \nNRMSE avg: {:.4f}'.format(ori_psnr_avg/len(test_loader), 
                                                                                          ori_ssim_avg/len(test_loader), 
                                                                                          ori_mse_avg/len(test_loader),
                                                                                          ori_rmse_avg/len(test_loader),
                                                                                          ori_nrmse_avg/len(test_loader)
                                                                                          ))
          print('\n')
          print('Predictions === \nPSNR avg: {:.4f} \nSSIM avg: {:.4f} \nMSE avg: {:.4f} \nRMSE avg: {:.4f} \nNRMSE avg: {:.4f}'.format(pred_psnr_avg/len(test_loader), 
                                                                                                pred_ssim_avg/len(test_loader), 
                                                                                                pred_mse_avg/len(test_loader),
                                                                                                pred_rmse_avg/len(test_loader),
                                                                                                pred_nrmse_avg/len(test_loader)
                                                                                                ))
          
          sio.savemat(mat_path, {'psnr': pred_psnr_avg/len(test_loader),
                                 'ssim': pred_ssim_avg/len(test_loader),
                                 'mse': pred_mse_avg/len(test_loader),
                                 'rmse': pred_rmse_avg/len(test_loader),
                                 'nrsme': pred_nrmse_avg/len(test_loader)})

## Load Data

In [None]:
transform = transforms.Compose([
  transforms.ToTensor(),
])

nw_train_dataset = CT_Dataset(inputs = ld_train, targets = ld_train_nw, transform = transform)
nw_test_dataset = CT_Dataset(inputs = ld_test, targets = ld_test_nw, transform = transform)

nw_train_loader = DataLoader(dataset = nw_train_dataset, batch_size = 10, shuffle = True, num_workers=0)
nw_test_loader = DataLoader(dataset = nw_test_dataset, batch_size = 10, shuffle = True, num_workers=0)

nw_dataloader = {
    'train': nw_train_loader,
    'test': nw_test_loader
}

print(len(nw_dataloader['train']), len(nw_dataloader['test']))

In [None]:
transform = transforms.Compose([
  transforms.ToTensor(),
])

# n = 250 
# indices = np.random.choice(ld_train.shape[0], n, replace=False)  

# ld_train = ld_train[indices]
# fd_train = fd_train[indices]

train_dataset = CT_Dataset(inputs = ld_train, targets = fd_train, transform = transform)
test_dataset = CT_Dataset(inputs = ld_test, targets = fd_test, transform = transform)

train_loader = DataLoader(dataset = train_dataset, batch_size = 10, shuffle = True, num_workers= 0)
test_loader = DataLoader(dataset = test_dataset, batch_size = 10, shuffle = False, num_workers= 0)

dataloader = {
    'train': train_loader,
    'test': test_loader
}

print(len(dataloader['train']), len(dataloader['test']))

## Start Training

Uncomment the following code for training

In [None]:
# gc.collect()
# torch.cuda.empty_cache()

# #pretext_downstream
# save_path = "/content/save_path"
# mat_path = save_path + ".mat"
# checkpoint_path = save_path + ".pth"

# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# model = SSWL_IDN().to(device)
# vae_bool = True

# lr = 1e-5
# epochs = 15
# criterion = nn.MSELoss()
# loss_fn = lpips.LPIPS(net='vgg')
# loss_fn.cuda()

# optimizer_ = optimizer = optim.Adam(model.parameters(), lr)
# scheduler = lr_scheduler.StepLR(optimizer, step_size=8, gamma=0.1)

# summary(model, input_size=(1, 256, 256))

In [None]:
# # SSWL Training
# model = train(model, nw_dataloader, optimizer, criterion, loss_fn, scheduler, checkpoint_path, print_iters = 1, epochs = epochs, perceptual = False, hybrid = True, vae = vae_bool)
# # Denoising Training
# model = train(model, dataloader, optimizer, criterion, loss_fn, scheduler, checkpoint_path, print_iters = 1, epochs = epochs, perceptual = False, hybrid = True, vae = vae_bool)

# Evaluate

To test the code, load the provided pth file after instantiating the model. Then, run the test code and receive a prediction.

In [None]:
model.load_state_dict(torch.load(checkpoint_path))

In [None]:
# visualize predictions

model.eval() 

pred_masks = []
inputs_arr = []
targets_arr = []

with torch.no_grad():
  print("starting validation")
  for inputs, targets in dataloader['test']:
    gc.collect()
    torch.cuda.empty_cache()

    inputs = inputs.to(device=device, dtype=torch.float)
    targets = targets.to(device=device, dtype=torch.float)
    if vae_bool:
      pred, _, _ = model(inputs)
    else:
      pred = model(inputs)
    pred = pred.data.cpu().numpy()
    inputs_np = inputs.data.cpu().numpy()
    targets_np = targets.data.cpu().numpy()
    for i in range (len(pred)):
      inputs_arr.append(inputs_np[i])
      targets_arr.append(targets_np[i])
      pred_masks.append(pred[i])

inputs_arr = np.reshape(inputs_arr, [-1, 256, 256, 1])
targets_arr = np.reshape(targets_arr, [-1, 256, 256, 1])
pred_masks = np.reshape(pred_masks, [-1, 256, 256, 1])

In [None]:
test(model, dataloader['test'], mat_path, vae = vae_bool)