In [None]:
### load library
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from copy import deepcopy # Add Deepcopy for args
import densenet3d

import numpy as np
import pandas as pd
from nilearn import plotting
import nibabel as nib
import seaborn as sns # visualization
import matplotlib.pyplot as plt # graph
import sklearn

import os
import glob
import sys
import argparse
import time
from tqdm.auto import tqdm # process bar
import random

import monai
from monai.data import CSVSaver, ImageDataset, DistributedWeightedRandomSampler
from monai.transforms import AddChannel, Compose, RandRotate90, Resize, ScaleIntensity, Flip, ToTensor
from monai.utils import set_determinism
from monai.apps import CrossValidation
import imageio
from collections import OrderedDict

%matplotlib inline

In [None]:
parser = argparse.ArgumentParser()
args = parser.parse_args("")

args.val_size = 0.1
args.test_size = 0.1
args.resize = (80,80,80)
args.target = 'sex'


In [None]:

orig_image_dir = '/share/master_ssd/3DCNN/data/2.UKB/1.sMRI_fs_cropped'
os.chdir(orig_image_dir)
orig_image_files = glob.glob('*.nii.gz')
orig_image_files = sorted(orig_image_files)

female_cam_upper_75_dir = '/share/scratch/connectome/dhkdgmlghks/UKB_interpretation/sex/OcclusionSensitivity/female_upper_0.75'
female_cam_upper_75_files = glob.glob(os.path.join(female_cam_upper_75_dir,'*.npy'))
female_cam_upper_75_files = sorted(female_cam_upper_75_files)

male_cam_upper_75_dir = '/share/scratch/connectome/dhkdgmlghks/UKB_interpretation/sex/OcclusionSensitivity/male_upper_0.75'
male_cam_upper_75_files = glob.glob(os.path.join(male_cam_upper_75_dir,'*.npy'))
male_cam_upper_75_files = sorted(male_cam_upper_75_files)


In [1]:
col_list = [args.target] + ['eid']

subject_data = pd.read_csv('/share/master_ssd/3DCNN/data/2.UKB/2.demo_qc/UKB_phenotype.csv')
subject_data = subject_data.loc[:,col_list]
subject_data = subject_data.sort_values(by='eid')
subject_data = subject_data.dropna(axis = 0) 
subject_data = subject_data.reset_index(drop=True) # removing subject have NA values in sex


imageFiles_labels = []
    
    
subj= []
if type(subject_data['eid'][0]) == np.str_ or type(subject_data['eid'][0]) == str:
    for i in range(len(orig_image_files)):
        subj.append(str(orig_image_files[i][:-12]))
elif type(subject_data['eid'][0]) == np.int_ or type(subject_data['eid'][0]) == int:
    for i in range(len(orig_image_files)):
        subj.append(int(orig_image_files[i][:-12]))
    
image_list = pd.DataFrame({'eid':subj, 'image_files': orig_image_files})
subject_data = pd.merge(subject_data, image_list, how='inner', on='eid')

col_list = col_list + ['image_files']
    
for i in tqdm(range(len(subject_data))):
    imageFile_label = {}
    for j, col in enumerate(col_list):
        imageFile_label[col] = subject_data[col][i]
    imageFiles_labels.append(imageFile_label)

NameError: name 'pd' is not defined

In [None]:
def partition_dataset(imageFiles_labels,args):
    #random.shuffle(imageFiles_labels)

    images = []
    labels = []
     

    for imageFile_label in imageFiles_labels:
        image = imageFile_label['image_files']
        label = imageFile_label[args.target]

        images.append(image)
        labels.append(label)

    resize = tuple(args.resize)
    train_transform = Compose([ScaleIntensity(),
                               AddChannel(),
                               Resize(resize),
                              ToTensor()])

    val_transform = Compose([ScaleIntensity(),
                               AddChannel(),
                               Resize(resize),
                              ToTensor()])

    test_transform = Compose([ScaleIntensity(),
                               AddChannel(),
                               Resize(resize),
                              ToTensor()])

    # number of total / train,val, test
    num_total = len(images)
    num_train = int(num_total*(1 - args.val_size - args.test_size))
    num_val = int(num_total*args.val_size)
    num_test = int(num_total*args.test_size)

    # image and label information of train
    images_train = images[:num_train]
    labels_train = labels[:num_train]

    # image and label information of valid
    images_val = images[num_train:num_train+num_val]
    labels_val = labels[num_train:num_train+num_val]

    # image and label information of test
    images_test = images[num_train+num_val:]
    labels_test = labels[num_train+num_val:]

    train_set = ImageDataset(image_files=images_train,labels=labels_train,transform=train_transform)
    val_set = ImageDataset(image_files=images_val,labels=labels_val,transform=val_transform)
    test_set = ImageDataset(image_files=images_test,labels=labels_test,transform=test_transform)

    partition = {}
    partition['train'] = train_set
    partition['val'] = val_set
    partition['test'] = test_set

    
    return partition

In [None]:
partition = partition_dataset(imageFiles_labels,args)

In [None]:
os.chdir(orig_image_dir)

testloader = torch.utils.data.DataLoader(partition['test'],
                                         batch_size=1,
                                         shuffle=True,
                                         num_workers=2)

count = 0
for i, data in enumerate(testloader,0):
    image, label = data 
    count += 1 
    
    if count == 1:
        break

## Making Mean heat map 

In [None]:
img_size = 80

### female upper 0.75

In [None]:
mean_heatmap_female = np.zeros((img_size, img_size, img_size))

for img_dir in tqdm(female_cam_upper_75_files): 
    img = np.load(img_dir)
    mean_heatmap_female += img 


mean_heatmap_female = mean_heatmap_female / len(female_cam_upper_75_files)



In [None]:
# mid-sagittal image 
slice_orig = image.cpu().squeeze().numpy()[40,:,:]
slice_heat = mean_heatmap_female[40,:,:]

plt.figure()
plt.imshow(np.rot90(slice_orig),interpolation='nearest',cmap=plt.cm.gray)
plt.imshow(np.rot90(slice_heat),interpolation='bilinear',cmap='jet',alpha=0.5)

# mid-scoronal image
slice_orig = image.cpu().squeeze().numpy()[:,40,:]
slice_heat = mean_heatmap_female[:,40,:]

plt.figure()
plt.imshow(np.rot90(slice_orig),interpolation='nearest',cmap=plt.cm.gray)
plt.imshow(np.rot90(slice_heat),interpolation='bilinear',cmap='jet',alpha=0.5)

# mid-horizontal image
slice_orig = image.cpu().squeeze().numpy()[:,:,40]
slice_heat = mean_heatmap_female[:,:,40]

plt.figure()
plt.imshow(np.rot90(slice_orig),interpolation='nearest',cmap=plt.cm.gray)
plt.imshow(np.rot90(slice_heat),interpolation='bilinear',cmap='jet',alpha=0.5)




### male upper 0.75

In [None]:
mean_heatmap_male = np.zeros((img_size, img_size, img_size))

for img_dir in tqdm(male_cam_upper_75_files): 
    img = np.load(img_dir)
    mean_heatmap_male += img 


mean_heatmap_male = mean_heatmap_male / len(female_cam_upper_75_files)



In [None]:
# mid-sagittal image 
slice_orig = image.cpu().squeeze().numpy()[40,:,:]
slice_heat = mean_heatmap_male[40,:,:]

plt.figure()
plt.imshow(np.rot90(slice_orig),interpolation='nearest',cmap=plt.cm.gray)
plt.imshow(np.rot90(slice_heat),interpolation='bilinear',cmap='jet',alpha=0.5)

# mid-scoronal image
slice_orig = image.cpu().squeeze().numpy()[:,40,:]
slice_heat = mean_heatmap_male[:,40,:]

plt.figure()
plt.imshow(np.rot90(slice_orig),interpolation='nearest',cmap=plt.cm.gray)
plt.imshow(np.rot90(slice_heat),interpolation='bilinear',cmap='jet',alpha=0.5)

# mid-horizontal image
slice_orig = image.cpu().squeeze().numpy()[:,:,40]
slice_heat = mean_heatmap_male[:,:,40]

plt.figure()
plt.imshow(np.rot90(slice_orig),interpolation='nearest',cmap=plt.cm.gray)
plt.imshow(np.rot90(slice_heat),interpolation='bilinear',cmap='jet',alpha=0.5)



