In [1]:
import os

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import SimpleITK as sitk
from scipy.ndimage import zoom
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import torch
import torch.nn as nn

from dataset import MRIDataset, get_loader
from models import C3D
from train import epoch_iter, add_metrics

In [None]:
basepath = "/data1/TBM/ttest_24FEB2023/data"
csvpath = '/data1/TBM/data_for_AI/subjects_info/final_TBM_subjects_info.csv'
modality = "T2s"
use_file = "R2S.nii"

In [2]:
df_data = pd.read_csv(csvpath)

filenames = []
labels_data = []

mean_data = []
var_data = []

for name, label in zip(df_data.label_id, df_data['label']):

    if label.lower() == 'mci':
        category = 1
        fdt_paths_path = os.path.join(basepath, 'MCI',name, modality, use_file)
    elif label.lower() == 'normal':
        category = 0
        fdt_paths_path = os.path.join(basepath, 'Normal',name, modality, use_file)
    elif label.lower() == 'mmd':
        category = 2
        fdt_paths_path = os.path.join(basepath, 'AD',name, modality, use_file)
    else:
        raise ValueError(f"No label name {label}")
        
#     img = sitk.ReadImage(fdt_paths_path)
#     img_array = sitk.GetArrayFromImage( img)

#     if img_array.shape[0] != 28 or img_array.shape[1] != 256 or img_array.shape[2] != 256:
#         print(img_array.shape)
#         continue
    try:
        img = sitk.ReadImage(fdt_paths_path)
        img_array = sitk.GetArrayFromImage( img)
        
        if img_array.shape[0] != 28 or img_array.shape[1] != 256 or img_array.shape[2] != 256:
            print(img_array.shape)
            continue
    except:
        print(name)
        continue
    labels_data.append(category)
    filenames.append(fdt_paths_path)
    
    mean_img_array = img_array[img_array>0].mean()
    var_img_array = img_array[img_array>0].var()
    
    mean_data.append(mean_img_array)
    var_data.append(var_img_array)
    
assert len(labels_data) == len(filenames)

mean_data = np.mean(mean_data)
std_data = np.sqrt(np.mean(var_data))

train_loader, val_loader = get_loader(filenames, labels_data, mean_data, std_data, batch_size = 8)


In [None]:
imgs  = []
labels_img = []
for img_train, label in train_loader:
    imgs.extend(img_train.detach().cpu().tolist())
    labels_img.extend(label.detach().cpu().tolist())
labels_img = np.array(labels_img) 
imgs = np.array(imgs)

In [None]:
imgs_normal = imgs[labels_img == 0]
# imgs_normal = imgs_normal.mean(axis = 0)
imgs_normal = imgs_normal[imgs_normal> -0.63]
imgs_normal = imgs_normal.reshape((-1))
imgs_d = imgs[labels_img == 1]
# imgs_d = imgs_d.mean(axis = 0)
imgs_d = imgs_d[imgs_d> -0.63]
imgs_d = imgs_d.reshape((-1))

In [None]:
_ = plt.hist(imgs_d[:4000000], bins=1000, density=True)
plt.ylim((0,1))

In [None]:
_ = plt.hist(imgs_normal[:4000000], bins=1000, density=True)
plt.ylim((0,1))