In [1]:
import pandas as pd 
import torch

from easydict import EasyDict
import os

In [2]:
from monai.transforms import (
                            Compose,
                            OneOf,
                            
                            AsDiscreted,

                            LoadImaged,
                            EnsureTyped,
                            ScaleIntensityRanged,
                            
                            Orientationd,
                            CropForegroundd, 
                            RandCropByPosNegLabeld,
                            RandSpatialCropd,

                            ## nnUNet v1 impl Aug
                            RandFlipd,
                            RandRotated,
                            Rand3DElasticd,
                            RandScaleIntensityd,
                            RandAdjustContrastd, # gamma correction
                            
                            Resized,
                            Spacingd,
                            RandGaussianNoised,
                            RandGaussianSmoothd,
                            RandShiftIntensityd,
                            
                            AsDiscrete,

                            # RandAdjustContrastd,
                            # RandGaussianSharpend,

                            # RandCoarseDropoutd,
                            # RandCoarseShuffled
                        )

import torch

In [3]:
all_key = ['img','seg']

full_transform = Compose([LoadImaged(keys=all_key, image_only=True, ensure_channel_first=True),
         EnsureTyped(keys=all_key, device=None, track_meta=False),
        #  Orientationd(keys=all_key, axcodes="RAS"),
         CropForegroundd(keys=all_key, source_key='seg', allow_smaller=False, select_fn=lambda x: x>0),
        #  AsDiscreted(keys='seg', to_onehot=3),
         Resized(keys='seg', spatial_size=(144, 128, 144), mode="nearest",),
         Resized(keys='img', spatial_size=(144, 128, 144), mode="trilinear",),
         ScaleIntensityRanged(keys='img',
                                a_max=255, a_min=0,
                                b_max=1, b_min=0, clip=True),
                ])

In [4]:
df = pd.read_csv('/root/snsb/data/transposed_20231113.csv')
print(df.shape)
df.head(5)

(167, 27)


Unnamed: 0,DID,Age,"Sex(M=1, F=2)",Education,Diagnosis,Diagnosis_Sub,CDR,GDS,MMSE_Reg,MMSE_Time,...,Diabete,Hyperlipidemia,Alchol,Smoking,SNSB_Attention,SNSB_Language,SNSB_Visuospatial,SNSB_Memory,SNSB_Frontal,FS Failed
0,DUIH_0001,71.0,1.0,6.0,AD,normal pressure hydrocephalus,2.0,5.0,3,1,...,1.0,0.0,0.0,0.0,,,,,,0
1,DUIH_0002,90.0,1.0,6.0,AD,,2.0,5.0,3,2,...,0.0,0.0,0.0,0.0,,,,,,0
2,DUIH_0003,83.0,2.0,0.5,Epilepsy,Atrial fibrillation,0.5,3.0,3,5,...,0.0,0.0,0.0,0.0,,,,,,0
3,DUIH_0004,81.0,2.0,6.0,Parkinsonism,Dementia,3.0,6.0,3,0,...,1.0,0.0,0.0,0.0,,,,,,1
4,DUIH_0005,89.0,2.0,12.0,Seizure,AD,1.0,3.0,3,4,...,0.0,1.0,0.0,0.0,,,,,,0


In [5]:
df = df[df['FS Failed'] == 0]
df.shape

(164, 27)

In [6]:
list(df['CDR'].isna())

[False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 True,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 

In [8]:
src_dir = '/root/snsb/data/mri'
target_dir = '/root/snsb/data/tensor'

cdr_onehot = AsDiscrete(to_onehot=5)
gds_onehot = AsDiscrete(to_onehot=7)

for idx in range(df.shape[0]):
    did = df['DID'].iloc[idx]
    cdr = df['CDR'].iloc[idx]
    gds = df['GDS'].iloc[idx]

    
    if 'DUIH_0049' == did:
        continue # since there is no cdr, gds
    print(did, cdr, gds)
    
    path_d = {
            'img': f'{src_dir}/{did}/T1.nii.gz',
            'seg': f'{src_dir}/{did}/aseg.nii.gz'
            }
    
    target_path = f'{target_dir}/{did}'
    
    try:
        d = full_transform(path_d)
    except Exception as ex:
        print(path_d)
    
    if not os.path.isdir(target_path):
        os.mkdir(target_path)
    
    if cdr > 0:
        if cdr == 0.5:
            cdr = 1 
        else:
            cdr += 1 
            
    gds -= 1 
        
    print(cdr)
    
    cdr = cdr_onehot(cdr)
    gds = gds_onehot(gds)
    cdr = torch.tensor(cdr, dtype=torch.long)
    gds = torch.tensor(gds, dtype=torch.long)
    
    d['CDR'] = cdr
    d['GDS'] = gds
        
    
    torch.save(d, f'{target_path}/data.pt')
    

DUIH_0001 2.0 5.0
3.0
DUIH_0002 2.0 5.0


  cdr = torch.tensor(cdr, dtype=torch.long)
  gds = torch.tensor(gds, dtype=torch.long)


3.0
DUIH_0003 0.5 3.0
1
DUIH_0005 1.0 3.0
2.0
DUIH_0006 0.5 3.0
1
DUIH_0007 1.0 4.0
2.0
DUIH_0008 3.0 6.0
4.0
DUIH_0009 1.0 4.0
2.0
DUIH_0010 0.5 3.0
1
DUIH_0011 1.0 4.0
2.0
DUIH_0012 2.0 5.0
3.0
DUIH_0013 1.0 4.0
2.0
DUIH_0014 0.5 4.0
1
DUIH_0015 1.0 4.0
2.0
DUIH_0016 2.0 5.0
3.0
DUIH_0017 3.0 6.0
4.0
DUIH_0020 0.5 4.0
1
DUIH_0021 2.0 5.0
3.0
DUIH_0022 0.0 2.0
0.0
DUIH_0023 0.5 3.0
1
DUIH_0024 1.0 3.0
2.0
DUIH_0025 1.0 4.0
2.0
DUIH_0026 1.0 5.0
2.0
DUIH_0027 2.0 5.0
3.0
DUIH_0028 1.0 5.0
2.0
DUIH_0029 1.0 4.0
2.0
DUIH_0030 0.5 3.0
1
DUIH_0031 0.5 2.0
1
DUIH_0032 1.0 3.0
{'img': '/root/snsb/data/mri/DUIH_0032/T1.nii.gz', 'seg': '/root/snsb/data/mri/DUIH_0032/aseg.nii.gz'}
2.0
DUIH_0033 1.0 4.0
2.0
DUIH_0034 1.0 5.0
2.0
DUIH_0035 0.0 2.0
0.0
DUIH_0036 0.5 3.0
1
DUIH_0037 1.0 4.0
2.0
DUIH_0038 1.0 4.0
2.0
DUIH_0039 0.5 3.0
1
DUIH_0040 0.5 3.0
1
DUIH_0041 2.0 5.0
3.0
DUIH_0042 0.5 3.0
1
DUIH_0043 1.0 5.0
2.0
DUIH_0044 0.5 3.0
1
DUIH_0045 0.5 3.0
1
DUIH_0046 0.5 3.0
1
DUIH_0047 0.5 3.0
1
D

In [42]:
ret = torch.load('/root/snsb/data/tensor/DUIH_0001/data.pt')
ret

{'img': metatensor([[[[0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           ...,
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.]],
 
          [[0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           ...,
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.]],
 
          [[0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           ...,
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.]],
 
          ...,
 
          [[0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           .