In [None]:
import torch
import argparse
import torchio as tio
import numpy as np
import nibabel as nib
from box import Box
import os.path as path

import sys, os, json
sys.path.append(os.path.join(sys.path[0], '../'))
# from dataloader.patch_dataloader import patch_dataloader
from dataloader.data_utils_eval import load_MRI_from_paths_list
# from network.VNet_model import VNetModel
# from network.WatNet_model import WatNet3DModel, WatNet2DModel


In [2]:
def load_params(param_dir):
    with open(param_dir, 'r') as json_file:
        params = json.load(json_file)
    return Box(params)

In [3]:
param_path = '../../config/params_VNet_eval.json'
params = load_params(param_path)

In [4]:
print(params)

{'fp': {'subjects': '../../data/subjects.txt', 'filepaths': ['../../data/original'], 'postfixes': ['.nii.gz'], 'eval_modes': ['org', 'down2', 'down4'], 'ckpt_dir': '../../ckpts/tensorlog/VNet_final_ckpts', 'leave_out_subjects': []}, 'data': {'batch_size': 40, 'patch_size': 64, 'patch_overlap': 8}, 'training': {'num_workers': 32, 'max_queue_length': 1000, 'matmul_precision': 'medium', 'precision': '16-mixed'}, 'checkpoint': {'ckpt_dir': '../../ckpts/tensorlog'}}


In [5]:
empty_params = Box()

In [6]:
empty_params.faa = params.checkpoint

In [7]:
empty_params

Box({'faa': {'ckpt_dir': '../../ckpts/tensorlog'}})

# Loading args

In [8]:
ckpt_path = r'/home/dual4090/lab/github/synth7T-MICCAI/ckpts/tensorlog/VNet_final_ckpts/MAEloss_finalmodel/Synthetic_7T_MRI_VNet_weights.ckpt'

In [9]:
if not os.path.exists(ckpt_path):
    raise ValueError('Checkpoint path does not exist')

if not ckpt_path.endswith('.ckpt'):
    raise ValueError('Please provide a .ckpt file')

In [None]:
def get_filelist_from_path(arg_path):
    if not os.path.exists(arg_path):
        raise ValueError('Input path does not exist')

    if os.path.isdir(arg_path):
        # input is dir, get all .nii.gz files in input dir
        input_files = [os.path.join(arg_path, f) for f in os.listdir(arg_path) if f.endswith('.nii.gz')]
        if not input_files:
            raise ValueError('No .nii.gz files found in input directory')
        
    elif os.path.isfile(arg_path):
        # input is file, check if it is .nii.gz
        if not arg_path.endswith('.nii.gz'):
            raise ValueError('Input file is not a .nii.gz file')
        input_files = [arg_path]

    return input_files

In [None]:
def add_mask_to_subject(subjects_t1_3T, subjects_mask):
    subject_attribute_combined = {}

    for subject in subjects_t1_3T:
        subject_key = list(subject.keys())
        subject_key.remove('id')
        subject_key = subject_key[0]
        subject_attribute_combined[subject.id] = {'id': subject.id, 
                                                subject_key: subject[subject_key]}
        

    for subject in subjects_mask:
        subject_key = list(subject.keys())
        subject_key.remove('id')
        subject_key = subject_key[0]
        if subject.id in subject_attribute_combined:
            # If the subject exists, update the mask
            subject_attribute_combined[subject.id][subject_key] = subject[subject_key]
    
    
    subjects = []

    for subject_data in subject_attribute_combined.values():
        subjects.append(tio.Subject(subject_data))
    
    return subjects

In [68]:
input_3T_path = r'/home/dual4090/lab/github/3T-7T_registration/data/original/3T'
input_3T_files = get_filelist_from_path(input_3T_path)
subjects_t1_3T = load_MRI_from_paths_list(input_3T_files, 't1_3T')

In [69]:
input_mask_path = r'/home/dual4090/lab/github/3T-7T_registration/data/original/mask/'
input_mask_files = get_filelist_from_path(input_mask_path)
subjects_mask = load_MRI_from_paths_list(input_mask_files, 'mask')

In [65]:
subjects

[Subject(Keys: ('id', 't1_3T', 'mask'); images: 2),
 Subject(Keys: ('id', 't1_3T', 'mask'); images: 2),
 Subject(Keys: ('id', 't1_3T', 'mask'); images: 2),
 Subject(Keys: ('id', 't1_3T', 'mask'); images: 2),
 Subject(Keys: ('id', 't1_3T', 'mask'); images: 2),
 Subject(Keys: ('id', 't1_3T', 'mask'); images: 2),
 Subject(Keys: ('id', 't1_3T', 'mask'); images: 2),
 Subject(Keys: ('id', 't1_3T', 'mask'); images: 2),
 Subject(Keys: ('id', 't1_3T', 'mask'); images: 2),
 Subject(Keys: ('id', 't1_3T', 'mask'); images: 2),
 Subject(Keys: ('id', 't1_3T', 'mask'); images: 2),
 Subject(Keys: ('id', 't1_3T', 'mask'); images: 2),
 Subject(Keys: ('id', 't1_3T', 'mask'); images: 2),
 Subject(Keys: ('id', 't1_3T', 'mask'); images: 2),
 Subject(Keys: ('id', 't1_3T', 'mask'); images: 2),
 Subject(Keys: ('id', 't1_3T', 'mask'); images: 2),
 Subject(Keys: ('id', 't1_3T', 'mask'); images: 2),
 Subject(Keys: ('id', 't1_3T', 'mask'); images: 2)]

In [29]:
subjects_mask

[Subject(Keys: ('id', 'mask'); images: 1),
 Subject(Keys: ('id', 'mask'); images: 1),
 Subject(Keys: ('id', 'mask'); images: 1),
 Subject(Keys: ('id', 'mask'); images: 1),
 Subject(Keys: ('id', 'mask'); images: 1),
 Subject(Keys: ('id', 'mask'); images: 1),
 Subject(Keys: ('id', 'mask'); images: 1),
 Subject(Keys: ('id', 'mask'); images: 1),
 Subject(Keys: ('id', 'mask'); images: 1),
 Subject(Keys: ('id', 'mask'); images: 1),
 Subject(Keys: ('id', 'mask'); images: 1),
 Subject(Keys: ('id', 'mask'); images: 1),
 Subject(Keys: ('id', 'mask'); images: 1),
 Subject(Keys: ('id', 'mask'); images: 1),
 Subject(Keys: ('id', 'mask'); images: 1),
 Subject(Keys: ('id', 'mask'); images: 1),
 Subject(Keys: ('id', 'mask'); images: 1),
 Subject(Keys: ('id', 'mask'); images: 1)]

In [25]:
[s.id for s in subjects_t1]

['06',
 '04',
 '9016',
 '09',
 '07',
 '9015',
 '10',
 '03',
 '05',
 '08',
 '01',
 '02',
 '9029',
 '9047',
 '9027',
 '9042',
 '9014',
 '9044']

# Functions

In [39]:
path.exists('')

False

In [None]:
def check_filePaths(file_paths):
    # make sure all file paths in list is a nifti file

In [None]:
def load_MRIs(fps_input, fps_mask = None):
    
    if fps_mask is not None:
        flag_mask = True

    subjects = []
    for i in range(len(fns)):
        # normalize value between 0,1
        # outlier removal done during preprocessing (5%-95% percentile kept)
        rescale1 = tio.RescaleIntensity(out_min_max=(0, 1), percentiles=(0, 100))

        subject_dict = {}
        if flag3:
            subject_dict['t1_3T'] = rescale1(tio.ScalarImage(path.join(fp3, str(fns[i]) + postfix)))
        if flag_mask:
            subject_dict['mask'] = tio.LabelMap(path.join(fpm, str(fns[i]) + postfix))

### what happens if there is no mask?

        subject = tio.Subject(id=fns[i], **subject_dict)

        subjects.append(subject)
    # print('Dataset size:', len(subjects), 'subjects')
    return subjects


In [36]:
def load_all_from_paths(filePaths):
    #filePaths is a list of file paths
    subjects =[]
    for i in range(len(filePaths)):
        subjects.extend(load_MRIs(filePaths[i]))

    tio_dataset = tio.SubjectsDataset(subjects)
    print('total number of subjects: ', len(subjects))
    return tio_dataset