In [4]:
import os
import pandas as pd
import nibabel as nib
import numpy as np

import warnings
warnings.filterwarnings("ignore")

from skimage.measure import block_reduce
import numpy as np
from concurrent.futures import ProcessPoolExecutor
import time
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split
import tensorly as tl

#Debugging import
import importlib
var = 'TensorDecisionTreeRegressorO' #the published version of code
package = importlib.import_module(var)
for name, value in package.__dict__.items():
    if not name.startswith("__"):
        globals()[name] = value

from TensorDecisionTreeRegressorP import *

import os
import nibabel as nib
import numpy as np
import matplotlib as plt
import pandas as pd
from sklearn.model_selection import train_test_split


# File path to the CSV file
csv_file = '/Users/zc56/Documents/CommenDesktop/RICE/MyProject/Bayes_Tensor_Tree/3D-images/ADNIData.csv'
df = pd.read_csv(csv_file)

# Remove rows where ADAS11_bl is missing (NaN)
df_cleaned = df.dropna(subset=['ADAS11_bl'])

# Extract the 'ADAS11_bl' column as the y variable
y_variable = df_cleaned['ADAS11_bl'].values

# Split the dataframe based on the DX_bl column values
cn_group = df_cleaned[df_cleaned['DX_bl'] == 'CN']
ad_group = df_cleaned[df_cleaned['DX_bl'] == 'AD']
lmci_group = df_cleaned[df_cleaned['DX_bl'] == 'LMCI']

# Display the counts for each group after removing NA
print(f"CN group size: {cn_group.shape[0]}")
print(f"AD group size: {ad_group.shape[0]}")
print(f"LMCI group size: {lmci_group.shape[0]}")

# Directory containing the 3D images
directory = '/Users/zc56/Documents/CommenDesktop/RICE/MyProject/Bayes_Tensor_Tree/3D-images/3D-Images/bl'

# Initialize dictionaries to hold the images and y values for each group
cn_images, ad_images, lmci_images = [], [], []
cn_y, ad_y, lmci_y = [], [], []

# Function to load the images based on PTID matching and append y values
def load_images_and_y(group, image_list, y_list):
    for _, row in group.iterrows():
        ptid = row['PTID']
        # Find the corresponding file based on PTID
        filename = f'{ptid}.nii.gz'
        file_path = os.path.join(directory, filename)
        
        if os.path.exists(file_path):
            # Load the NIfTI file
            img = nib.load(file_path)
            data = img.get_fdata()
            
            # Append the 3D image data and y value to the respective lists
            image_list.append(data)
            y_list.append(row['ADAS11_bl'])
        else:
            print(f"File {filename} not found.")

# Load images and y values for each group
load_images_and_y(cn_group, cn_images, cn_y)
load_images_and_y(ad_group, ad_images, ad_y)
load_images_and_y(lmci_group, lmci_images, lmci_y)

# Convert lists of 3D images and y values to NumPy arrays
if cn_images:
    cn_tensor = np.stack(cn_images, axis=0)
    cn_y = np.array(cn_y)
    print(f"CN 4D tensor shape: {cn_tensor.shape}")
    print(f"CN y shape: {cn_y.shape}")
else:
    print("No CN images loaded.")

if ad_images:
    ad_tensor = np.stack(ad_images, axis=0)
    ad_y = np.array(ad_y)
    print(f"AD 4D tensor shape: {ad_tensor.shape}")
    print(f"AD y shape: {ad_y.shape}")
else:
    print("No AD images loaded.")

if lmci_images:
    lmci_tensor = np.stack(lmci_images, axis=0)
    lmci_y = np.array(lmci_y)
    print(f"LMCI 4D tensor shape: {lmci_tensor.shape}")
    print(f"LMCI y shape: {lmci_y.shape}")
else:
    print("No LMCI images loaded.")

CN group size: 229
AD group size: 187
LMCI group size: 401
CN 4D tensor shape: (229, 48, 48, 48)
CN y shape: (229,)
AD 4D tensor shape: (187, 48, 48, 48)
AD y shape: (187,)
LMCI 4D tensor shape: (401, 48, 48, 48)
LMCI y shape: (401,)


In [5]:
X_train, X_test, y_train, y_test = train_test_split(lmci_tensor, lmci_y, test_size=0.2, random_state=42)


print(X_train.shape,y_train.shape)
model  =  TensorDecisionTreeRegressor(max_depth=2, min_samples_split=12,split_method='variance_LS', split_rank=4, CP_reg_rank=12, Tucker_reg_rank=12, n_mode=3)
model.use_mean_as_threshold  =  False
model.sample_rate  =  .1
# X_coarsen_shape = (1,4,4,4)
# X_coarsen_func = np.max
# X_train_c = block_reduce(X_train,block_size=X_coarsen_shape, func=X_coarsen_func)
# X_test_c = block_reduce(X_test,block_size=X_coarsen_shape, func=X_coarsen_func)
#middle_z = X_train_c.shape[2] // 2
#X_train_c = X_train_c[:,:,:,middle_z:middle_z+2]
#X_test_c = X_test_c[:,:,:,middle_z:middle_z+2]
# X_train_c = X_train_c + np.random.random(X_train_c.shape) * 1e-3
# X_test_c = X_test_c + np.random.random(X_test_c.shape) * 1e-3
# print(X_train_c.shape,y_train.shape,X_test_c.shape)
model.fit(X_train,y_train)
#model.prune(X_val=None, y_val=None, model_type='mean', alpha=0.0000000000001)


predictions = model.predict(X_train,regression_method='mean')
print(f"mean train RSE: ", np.mean((predictions-y_train)**2)/np.var(y_train))
predictions = model.predict(X_train,regression_method='cp')
print(f"CP train RSE: ", np.mean((predictions-y_train)**2)/np.var(y_train))
predictions = model.predict(X_train, regression_method='tucker')
print(f"Tucker train RSE: ", np.mean((predictions-y_train)**2)/np.var(y_train)) 

predictions = model.predict(X_test,regression_method='mean')
print(f"mean test RSE: ", np.mean((predictions-y_test)**2)/np.var(y_test))
predictions = model.predict(X_test,regression_method='cp')
print(f"CP test RSE: ", np.mean((predictions-y_test)**2)/np.var(y_test))
predictions = model.predict(X_test,regression_method='tucker')
print(f"Tucker test RSE: ", np.mean((predictions-y_test)**2)/np.var(y_test))
model.print_tree()

(320, 48, 48, 48) (320,)
mean train RSE:  0.7491463249333379
CP train RSE:  5.34756042120831e-05
Tucker train RSE:  3.595751795835852e-05
mean test RSE:  1.0532402614147889
CP test RSE:  1.3587628162383019
Tucker test RSE:  1.6842076825374153
  if X[:, 28 , 20 , 15 ] <=  2.3704092502593994
   if X[:, 36 , 39 , 16 ] <=  7.901856422424316
     has  0  child nodes, and  138  samples.
   else: # if X[:, 36 , 39 , 16 ] >  7.901856422424316
     has  0  child nodes, and  4  samples.
  else: # if X[:, 28 , 20 , 15 ] >  2.3704092502593994
   if X[:, 12 , 38 , 11 ] <=  2.285515069961548
     has  0  child nodes, and  134  samples.
   else: # if X[:, 12 , 38 , 11 ] >  2.285515069961548
     has  0  child nodes, and  44  samples.
