In [1]:
import os
import sys; sys.path.insert(0, os.path.abspath("../"))
from pathlib import Path
this_path = Path().resolve()
import numpy as np
import torch
import SimpleITK as sitk
import pandas as pd
import nibabel as nib
import pytorch_lightning as pl
import torchio as tio
import matplotlib.pyplot as plt
import yaml

from dataset.hmri_dataset import HMRIControlsDataModule, HMRIPDDataModule
from models.pl_model import Model_AE
from utils.utils import save_nifti_from_array


In [6]:
def reconstruct(data, model, save_img=False, out_dir=None, type='pd'):
    patches, locations, sampler, subject, subj_id = data
    input_imgs = patches.to(model.device)
    aggregator = tio.data.GridAggregator(sampler)  
    with torch.no_grad():
        x_hat = model(input_imgs)
    aggregator.add_batch(x_hat, locations)
    reconstructed = aggregator.get_output_tensor()

    # Compute reconstruction error
    subject = subject['image'][tio.DATA]
    diff = [torch.pow(subject[i] - reconstructed[i], 2) for i in range(subject.shape[0])]
    rerror = torch.sqrt(torch.sum(torch.stack(diff), dim=0))
    rerror = rerror.cpu().numpy()
    
    if out_dir is None:
        out_dir = Path('/home/alejandrocu/Documents/parkinson_classification/reconstructions') / Path(ckpt_path).parent.parent.parent.name
        out_dir.mkdir(parents=True, exist_ok=True)
        
    if save_img:
        save_nifti_from_array(subj_id=subj_id,
                              arr=reconstructed[0].cpu().numpy(),
                              path=out_dir / f'{type}_{subj_id}_recon.nii.gz')
        save_nifti_from_array(subj_id=subj_id,
                              arr=rerror,
                              path=out_dir / f'{type}_{subj_id}_re_error.nii.gz')
        save_nifti_from_array(subj_id=subj_id,
                              arr=subject[0].cpu().numpy(),
                              path=out_dir / f'{type}_{subj_id}_original.nii.gz')
    
    return rerror

In [3]:
hc_idx = 1
pd_idx = 1
ckpt_path = Path('/home/alejandrocu/Documents/parkinson_classification/p2_hmri_outs/aehmri-da00_bz32_mse_adam_lr0.001-v4/version_0/checkpoints/epoch=187-val_loss=0.0054-val_mse=0.0054.ckpt')

In [4]:
# read config file
exp_dir = ckpt_path.parent.parent.parent
with open(exp_dir /'config_dump.yml', 'r') as f:
    cfg = list(yaml.load_all(f, yaml.SafeLoader))[0]

# create dataset
root_dir = Path('/mnt/scratch/7TPD/mpm_run_acu/bids/derivatives/hMRI')
md_df = pd.read_csv(this_path.parent/'bids_3t.csv')
md_df_hc = md_df[md_df['group'] == 0]
md_df_pd = md_df[md_df['group'] == 1]
data = HMRIControlsDataModule(md_df=md_df_hc,
                                root_dir=root_dir, 
                                **cfg['dataset'])
data.prepare_data()
data.setup()
hc_patches, hc_locations, hc_sampler, hc_subject = data.get_grid(subj=hc_idx)
hc_subj_id = data.md_df_train.iloc[hc_idx]['id']

hc_data = [hc_patches, hc_locations, hc_sampler, hc_subject, hc_subj_id]

data_pd = HMRIPDDataModule(md_df=md_df_pd,
                            root_dir=root_dir,  
                            **cfg['dataset'])
data_pd.prepare_data()
data_pd.setup()
pd_patches, pd_locations, pd_sampler, pd_subject = data_pd.get_grid(subj=pd_idx)

pd_subj_id = data_pd.md_df.iloc[pd_idx]['id']

pd_data = [pd_patches, pd_locations, pd_sampler, pd_subject, pd_subj_id]
# create model
model = Model_AE.load_from_checkpoint(ckpt_path, net='autoencoder', **cfg['model'])
model.eval()

A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  self.md_df.drop(self.md_df[self.md_df.id == drop_id].index, inplace=True)


Drop subjects ['sub-058', 'sub-016', 'sub-025']


Model_AE(
  (criterion): MSELoss()
  (train_acc): BinaryAccuracy()
  (val_acc): BinaryAccuracy()
  (train_auroc): BinaryAUROC()
  (val_auroc): BinaryAUROC()
  (train_f1): BinaryF1Score()
  (val_f1): BinaryF1Score()
  (train_mse): MeanSquaredError()
  (val_mse): MeanSquaredError()
  (net): AutoEncoder(
    (encode): Sequential(
      (encode_0): Convolution(
        (conv): Conv3d(1, 32, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), bias=False)
        (adn): ADN(
          (N): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (A): PReLU(num_parameters=1)
        )
      )
      (encode_1): Convolution(
        (conv): Conv3d(32, 64, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), bias=False)
        (adn): ADN(
          (N): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (A): PReLU(num_parameters=1)
        )
      )
      (encode_2): Convolution(
        (conv): Conv3d(64, 12

In [9]:
loss = 'l1'
cfg['exp_name'] = f"aehmri-da00_bz32_{loss}_adam_lr0.001-v4"
cfg['exp_name']

'aehmri-da00_bz32_l1_adam_lr0.001-v4'

In [5]:
_ = reconstruct(pd_data, model, save_img=True)
_ = reconstruct(hc_data, model, save_img=True)