In [1]:
import os
import glob
import pickle
import numpy as np
import pandas as pd
from scipy.stats import spearmanr

pd.set_option('display.max_rows', 500)
pd.set_option('display.max_columns', 50)
pd.set_option('display.width', 1000)

# Custom imports
from werdich_cfr.models.Modeltrainer_Inc2 import VideoTrainer
from werdich_cfr.tfutils.tfutils import use_gpu_devices

%load_ext autoreload
%autoreload 2

physical_devices, device_list = use_gpu_devices(gpu_device_string='0,1,2,3')

AVAILABLE GPUs:
PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')
PhysicalDevice(name='/physical_device:GPU:1', device_type='GPU')
PhysicalDevice(name='/physical_device:GPU:2', device_type='GPU')
PhysicalDevice(name='/physical_device:GPU:3', device_type='GPU')
TRAIN DEVICE LIST:
/GPU:0
/GPU:1
/GPU:2
/GPU:3


In [2]:
# checkpoint files from model_dir
def get_checkpoint_file_list(model_dir, epoch_list):
    
    checkpoint_file_list = sorted(glob.glob(os.path.join(model_dir, '*_chkpt_*.h5')))
    checkpoint_file_list_xt = [file.split('.')[0] for file in checkpoint_file_list]
    checkpoint_file_cut = [file.rsplit('_', maxsplit=1)[0] for file in checkpoint_file_list_xt][0]
    checkpoint_epoch_list = [int(os.path.basename(file).rsplit('_')[-1]) \
                             for file in checkpoint_file_list_xt]
    mag=len(str(max(checkpoint_epoch_list)))

    # Select only those epochs that we want
    epoch_list = sorted(list(set(checkpoint_epoch_list).intersection(set(epoch_list))))
    print(f'Found checkpoints for epochs: {epoch_list}')
    epoch_checkpoint_file_list = [checkpoint_file_cut+'_'+str(epoch).zfill(mag)+'.h5' for epoch in epoch_list]
    
    return epoch_checkpoint_file_list

In [3]:
data_root = os.path.normpath('/mnt/obi0/andreas/data/cfr')
log_dir = os.path.join(data_root, 'log')
model_dir_list = glob.glob(os.path.join(log_dir, '*/'))
print(*model_dir_list, sep = '\n')

/mnt/obi0/andreas/data/cfr/log/global_a4c_gpu2_global_cfr_calc/
/mnt/obi0/andreas/data/cfr/log/nondefect_a4c_dgx-1_unaffected_cfr/
/mnt/obi0/andreas/data/cfr/log/global_a4c_dgx-1_rest_global_mbf/
/mnt/obi0/andreas/data/cfr/log/nondefect_a4c_dgx-1_stress_mbf_unaff/
/mnt/obi0/andreas/data/cfr/log/global_a4c_gpu2_rest_global_mbf/
/mnt/obi0/andreas/data/cfr/log/global_a4c_gpu2_fc128aug_rest_global_mbf/
/mnt/obi0/andreas/data/cfr/log/nondefect_a4c_dgx-1_rest_mbf_unaff/
/mnt/obi0/andreas/data/cfr/log/global_a4c_gpu2_fc128_rest_global_mbf/
/mnt/obi0/andreas/data/cfr/log/global_a4c_gpu2_stress_global_mbf/
/mnt/obi0/andreas/data/cfr/log/global_a4c_dgx-1_global_cfr_calc/
/mnt/obi0/andreas/data/cfr/log/global_a4c_dgx-1_stress_global_mbf/


In [4]:
m = 0
model_dir = os.path.normpath(model_dir_list[0])

epoch_list = [50, 100, 150]

# This can be in a function

model_dict_name = os.path.basename(model_dir)+'_model_dict.pkl'
model_dict_file = os.path.join(model_dir, model_dict_name)
train_dict_name = model_dict_name.replace('_model_dict.pkl', '_train_dict.pkl')
train_dict_file = os.path.join(model_dir, train_dict_name)
with open(model_dict_file, 'rb') as fl:
    model_dict = pickle.load(fl)
with open(train_dict_file, 'rb') as fl:
    train_dict = pickle.load(fl)

#test files
train_file = train_dict['train_file_list'][0]
test_basename = os.path.basename(train_file).rsplit('_', maxsplit=1)[0].replace('train', 'test')
tfr_dir = os.path.dirname(train_file)
test_tfr_file_list = sorted(glob.glob(os.path.join(tfr_dir, test_basename+'*.tfrecords')))
test_parquet_file_list = [file.replace('.tfrecords', '.parquet') for file in test_tfr_file_list]
test_df = pd.concat([pd.read_parquet(file) for file in test_parquet_file_list])
print('Test files:')
print(*test_tfr_file_list, sep='\n')

# feature_dict
feature_dict_file = glob.glob(os.path.join(tfr_dir, '*.pkl'))[0]
with open(feature_dict_file, 'rb') as fl:
    feature_dict = pickle.load(fl)
    
# instantiate model trainer class
VT = VideoTrainer(log_dir=None, model_dict=model_dict, train_dict=train_dict, feature_dict=feature_dict)
    
# checkpoint_files
checkpoint_file_list = get_checkpoint_file_list(model_dir, epoch_list)

Test files:
/mnt/obi0/andreas/data/cfr/tfr_200519/global/cfr_global_a4c_test_200519_0.tfrecords
/mnt/obi0/andreas/data/cfr/tfr_200519/global/cfr_global_a4c_test_200519_1.tfrecords
/mnt/obi0/andreas/data/cfr/tfr_200519/global/cfr_global_a4c_test_200519_2.tfrecords
/mnt/obi0/andreas/data/cfr/tfr_200519/global/cfr_global_a4c_test_200519_3.tfrecords
/mnt/obi0/andreas/data/cfr/tfr_200519/global/cfr_global_a4c_test_200519_4.tfrecords
/mnt/obi0/andreas/data/cfr/tfr_200519/global/cfr_global_a4c_test_200519_5.tfrecords
/mnt/obi0/andreas/data/cfr/tfr_200519/global/cfr_global_a4c_test_200519_6.tfrecords
/mnt/obi0/andreas/data/cfr/tfr_200519/global/cfr_global_a4c_test_200519_7.tfrecords
Found checkpoints for epochs: [50, 100, 150]


In [10]:
# predictions
pred_df = VT.predict_on_test(test_tfr_file_list[0], checkpoint_file_list[0], batch_size=12)

Extracting true labels from testset.
Samples: 123, steps: 11


In [11]:
# re_shape the output of the predictions


Unnamed: 0,global_cfr_calc,global_a4c_gpu2_global_cfr_calc_chkpt_050
0,2.384252,2.044266
1,0.742226,1.910853
2,0.994403,1.942179
3,1.07772,1.986941
4,0.888172,1.118363


In [9]:
model_dict['model_output']

'global_cfr_calc'