## Training a classification model using the TMC data

### Dependencies

In [1]:
import glob
import numpy as np
import os
import collections
import sys

import pandas as pd
import pickle
import matplotlib
import matplotlib.pyplot as plt

import cv2
import imageio
import csv
from tifffile import TiffFile, imsave, imread, imwrite

import argparse
import yaml

%load_ext autoreload


### Constants

In [2]:
root_path = '/project/ahoover/mhealth/zeyut/tmc/TMC AI Files'
processed_data_path = '/project/ahoover/mhealth/zeyut/tmc/TMC AI Files/ProcessedData'
raw_data_path = '/project/ahoover/mhealth/zeyut/tmc/TMC AI Files/Data'
gt_path = '/project/ahoover/mhealth/zeyut/tmc/TMC AI Files/ProcessedData/gt.pkl'
result_path = '/project/ahoover/mhealth/zeyut/tmc/results/'
# result_path = './results/'

exclude_ids = ['H3', 'H14', 'H20', 'J1']

Label2Numbers = {'Healthy': 0, 'Peri-OA': 1, 'OA': 2}
Numbers2Labels = {0: 'Healthy', 1: 'Peri-OA', 2: 'OA'}

gesture_list = ['Key Pinch', 'Stat Abd', 'Stat Add', 'Stat Ext', 'Stat Flex']

kinematic_channels = ['helical_angle',
             'helical_translation',
             'volar-dorsal_angle',
             'volar-dorsal_translation',
             'radial-ulnar_angle',
             'radial-ulnar_translation',
             'inferior-superior_angle',
             'inferior-superior_translation']
target_cycle_lens = {'Key Pinch': 11500,
                     'Stat Abd': 9000,
                     'Stat Add': 9200,
                     'Stat Ext': 9200,
                     'Stat Flex': 11500}
target_cycle_len = 12000
downsample_rate = 10
cv_configs = {'random_seed': 42,
              'num_splits': 5}

kinematic_model_configs = {'gt_type': 'kinematic',
                            'num_epochs': 10,
                           'batch_size': 8,
                           'beta': 0.9999,
                           'num_workers': 4,
                           'max_step': 50,
                           'learning_rate': 0.001,
                           'patience': 3,
                           'constrastive_margin': 0.2,
                          'top_k': 3}
image_model_configs = {'gt_type': 'bony',
                        'num_epochs': 20,
                        'batch_size': 4,
                        'max_step': 30,
                        'num_workers': 4,
                        'patience': 3,
                        'learning_rate': 0.001,
                        'seed_path': "/project/ahoover/mhealth/zeyut/tmc/pre_trained/swin_unetr_btcv_segmentation/models/model.pt"
                        }

configs = {'root_path': root_path,
           'raw_data_path': raw_data_path,
           'processed_data_path': processed_data_path,
           'gt_path': gt_path,
           'results_path': result_path,
           'Label2Numbers': Label2Numbers,
           'gesture_list': gesture_list,
           'kinematic_channels': kinematic_channels,
           'target_cycle_len': target_cycle_len,
           'downsample_rate': 10,
           'cv_configs': cv_configs,
           'kinematic_model_configs': kinematic_model_configs,
          'image_model_configs': image_model_configs}

with open(os.path.join(root_path, 'configs.yaml'), 'w') as file:
    yaml.dump(configs, file)

In [3]:
font = {'weight' : 'normal','size': 21}
matplotlib.rc('font', **font)
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42

matplotlib.rcParams['axes.spines.top'] = True
matplotlib.rcParams['axes.spines.left'] = True
matplotlib.rcParams['axes.spines.right'] = True
matplotlib.rcParams['axes.spines.bottom'] = True

### Load data

#### Ground truth files

In [63]:
grades = collections.defaultdict(dict)
ages = {}

# # Load new scores
# grade_df = pd.read_excel(os.path.join(root_path, 'Supplements/Hand Groups.xlsx'))
# id_list = (grade_df.iloc[:, 0].tolist())
# for hid in id_list:
#     group = grade_df.iloc[grade_df.index[grade_df.iloc[:, 0] == hid][0], 1]
#     grades['new'][hid] = Label2Numbers[group]
    
# Load new scores
grade_df = pd.read_excel(os.path.join(root_path, 'Supplements/summary_classification_per_cohort.xlsx'))
id_list = (grade_df.iloc[:, 0].tolist())
for hid in exclude_ids:
    try:
        id_list.remove(hid)
    except:
        pass
for hid in id_list:
    group = grade_df.iloc[grade_df.index[grade_df.iloc[:, 0] == hid][0], 2]
    grades['group'][hid] = Label2Numbers[group]
    kinematic_class = grade_df.iloc[grade_df.index[grade_df.iloc[:, 0] == hid][0], 4]
    grades['kinematic'][hid] = kinematic_class
    bony_class = grade_df.iloc[grade_df.index[grade_df.iloc[:, 0] == hid][0], 6]
    grades['bony'][hid] = bony_class
    ligament_class = grade_df.iloc[grade_df.index[grade_df.iloc[:, 0] == hid][0], 8]
    grades['ligament'][hid] = ligament_class
        
# Load old OA scores
old_grades = {}
grade_df = pd.read_excel(os.path.join(root_path, 'Supplements/all_results_clean_for_stats.xlsx'))
for hid in id_list:
    try:
        grades['old'][hid] = grade_df.iloc[grade_df.index[grade_df.iloc[:, 0] == hid][0], 2]
        ages[hid] = grade_df.iloc[grade_df.index[grade_df.iloc[:, 0] == hid][0], 1]
    except:
        print(f'Cannot find old grade for {hid}')
        del grades['new'][hid]
    

In [64]:
print(f'class size')
class_sizes_old = [collections.Counter(list(grades['old'].values())).get(i, 0) for i in range(5)]
class_sizes_new = [collections.Counter(list(grades['group'].values())).get(i, 0) for i in range(3)]
print(f"old_grades: {class_sizes_old}")
print(f"new_grades: {class_sizes_new}")

class size
old_grades: [7, 12, 12, 7, 2]
new_grades: [11, 10, 19]


In [65]:
if not os.path.exists(gt_path):
    with open(gt_path, 'wb') as file:
        pickle.dump(grades, file)

#### Processed kinematic data
* Calculated from motion and rigid body data 
* Processing script: kinematic_multiprocessing.py (Edited on 11/2024 by Daniel Gordon)
* subject H7, H8, H16, J2, J4, J8, J9 do not have kinematic data

In [8]:
raw_pkl_path = os.path.join(processed_data_path, 'kinematic_raw.pkl')
max_, min_ = 0, 10000
if os.path.exists(raw_pkl_path):
    with open(raw_pkl_path, 'rb') as file:
        kinematic_raw = pickle.load(file)
else:
    # Load raw data
    kinematic_raw = collections.defaultdict(dict)
    for hid in id_list:
        cur_path = os.path.join(root_path, f'ProcessingScripts/Results/{hid} Kinematic Results.xlsx')
        try:
            kinematic_sheets = pd.read_excel(cur_path, sheet_name=None)
            for sheet_name in kinematic_sheets.keys():
                gesture = sheet_name.split(hid)[1].strip()
                raw = kinematic_sheets[sheet_name].values
                kinematic_raw[hid][gesture] = kinematic_sheets[sheet_name].values
                max_ = max(max_, np.max(kinematic_raw[hid][gesture]))
                min_ = min(min_, np.min(kinematic_raw[hid][gesture]))
        except:
            print(f'Can not open {hid} kinematic result file')
    print(f'max value: {max_}, min value: {min_}')
    # Save data to pkl file to reduce loading time
    with open(raw_pkl_path, 'wb') as file:
        pickle.dump(kinematic_raw, file)

In [9]:
# Load start/end timestamps for each cycle 
cycles = collections.defaultdict(dict)
all_cycle_sheets = pd.read_excel(os.path.join(root_path, f'Supplements/Kinematic Cycles.xlsx'), sheet_name=None)
for hid in kinematic_raw:
    try:
        cycle_df = all_cycle_sheets[hid]
        for _, row in cycle_df.iterrows():
            gesture = row.iloc[0]
            cur_cycles = []
            for i, value in enumerate(row[1:]):  # skipping index
                if not pd.isna(value):
                    start_idx, end_idx = value.split(':')
                    cur_cycles.append([int(start_idx.strip()), int(end_idx.strip())])
            if cur_cycles:
                cycles[hid][gesture] = cur_cycles
    except:
        print(f'Can not find cycle info for {hid}')

# Remove those subjects that do not have at least one gesture 
for hid in list(cycles.keys()):
    if len(cycles[hid]) < 5:
        print(len(cycles[hid]))
        # del cycles[hid]


4
2
4
4
1


In [10]:
# Count maximum cycle length for each gesture
print('Max cycle length')
for hid in cycles:
    # print(hid)
    all_lens = []
    for gesture in cycles[hid]:
        lens = [end-start for start, end in cycles[hid][gesture]]
        # print(f"{gesture} max: {np.max(lens)} avg: {np.mean(lens):.02f} cycle len: {lens}")    
    
        all_lens.append(max([end-start for start, end in cycles[hid][gesture]]))
    print(f'max cycle length per gesture: {all_lens}')

Max cycle length
max cycle length per gesture: [1157, 695, 740, 671, 793]
max cycle length per gesture: [8679, 6921, 6629, 6921, 7372]
max cycle length per gesture: [11332, 8885, 8280, 8391, 10320]
max cycle length per gesture: [7113, 8317, 8011, 7816, 9264]
max cycle length per gesture: [11432, 8312, 8306, 8189, 10547]
max cycle length per gesture: [9520, 9115, 9157, 7509, 9398]
max cycle length per gesture: [8518, 8607, 8000, 8640, 9246]
max cycle length per gesture: [10721, 8120, 8234, 8117, 11429]
max cycle length per gesture: [28, 46, 17, 33]
max cycle length per gesture: [9494, 7157, 7144, 8943, 7901]
max cycle length per gesture: [9480, 8368]
max cycle length per gesture: [469, 388, 406, 382, 414]
max cycle length per gesture: [382, 319, 301, 349, 219]
max cycle length per gesture: [270, 181, 233, 324, 255]
max cycle length per gesture: [362, 273, 380, 469, 283]
max cycle length per gesture: [456, 338, 247, 432, 336]
max cycle length per gesture: [7962, 7133, 7173, 7803]
max cyc

In [11]:
# Slice data within each cycle and save to files
kinematic_data = collections.defaultdict(dict)
for hid in kinematic_raw:
    if hid not in cycles:
        continue
    for gesture in kinematic_raw[hid]:
        if gesture not in cycles[hid]:
            continue
        data = kinematic_raw[hid][gesture]
        kinematic_data[hid][gesture] = []
        for start_idx, end_idx in cycles[hid][gesture]:
            kinematic_data[hid][gesture].append(data[start_idx: end_idx])

if not os.path.exists(os.path.join(processed_data_path, 'kinematic_data.pkl')):
    with open(os.path.join(processed_data_path, 'kinematic_data.pkl'), 'wb') as file:
        pickle.dump(kinematic_data, file)



#### Static data 
* Include kinematic indices from motion tracking data and mechanical/dimensional measurements
* Select sample codes which have data available on all types

In [66]:
import collections
static_data = collections.defaultdict(list)

'''
Data drieved from dynamic measurement
    File: all_results_clean_for_stats.xlsx 
    Use sheet 'bony' and 'kinm', sheet 'mech' has many empty cells.

'''
static_df = pd.read_excel(os.path.join(root_path, f'Supplements/All_Data_New_Gold.xlsx'))

static_df.dropna(thresh=150, inplace=True) # Drop rows with more than 150 cells being nan
static_df.dropna(thresh=10, inplace=True, axis=1) # Drop columns with more than 10 cells being nan
sid_list = static_df.iloc[:, 0].values
gt_col = map(lambda x: Label2Numbers[x], static_df.iloc[:, 7].values)
static_df = static_df.drop(columns=static_df.columns[:10]) # Leave numeric data only
static_df.fillna(static_df.mean(), inplace=True) # Fill empty values with column average
for hid, data in zip(sid_list, static_df.values):
    static_data[hid].append(data)

In [67]:
if not os.path.exists(os.path.join(processed_data_path, 'static_data.pkl')):
    with open(os.path.join(processed_data_path, 'static_data.pkl'), 'wb') as file:
        pickle.dump(static_data, file)

#### Image data

In [69]:
image_hid = []
for hid in id_list:
    image_list = glob.glob(os.path.join(root_path, f"Data/{hid}/BonyGeometry/DICOMs/CT*/*.IMA"))
    if len(image_list) > 0:
        image_hid.append(hid)

print(f'#image samples: {len(image_hid)}')

#image samples: 37


In [54]:
image_paths = sorted(glob.glob(os.path.join(root_path, f"Data/{hid}/BonyGeometry/DICOMs/CT*/*.IMA")))  

In [60]:
valid_sc = []

for sc in id_list:
    image_list = glob.glob(os.path.join(root_path, f"Data/{hid}/BonyGeometry/DICOMs/CT*/*.IMA"))
    if len(image_list) > 0:
        valid_sc.append(sc)

In [61]:
len(valid_sc)

40

In [24]:
# Save image data to pkl file to reduce loading time
# Currently, masks are not included 
max_ = 0
min_ = float('inf')
for hid in id_list:
    image_pkl_path = os.path.join(processed_data_path, f'images/{hid}.pkl')
    if not os.path.exists(image_pkl_path):
        image_paths = sorted(glob.glob(os.path.join(root_path, f"Data/{hid}/BonyGeometry/DICOMs/CT*/*.IMA")))  
        mask_img = imread(os.path.join(root_path, f"CT Labelling/{hid} Labels.tif"))
        os.makedirs(os.path.join(processed_data_path,'images'), exist_ok=True)
        cur_vol = np.zeros((512, 512, 512))
        for i, slice_file in enumerate(image_paths):
            if i >= 512:
                break  # Avoid reading more slices than the target depth
            image = imageio.v2.imread(slice_file).astype(np.float32) 
            if image.shape != (512,512):
                print(hid, image.shape)
            cur_vol[i, :, :] = image
            max_ = max(max_, np.max(cur_vol))
            min_ = min(min_, np.min(cur_vol))
        with open(image_pkl_path, 'wb') as file:
            pickle.dump(cur_vol, file)

In [20]:
# Load segmentation masks
for hid in id_list:
    label_path = os.path.join(root_path, f"CT Labelling/{hid} Labels.tif")
    if os.path.exists(label_path):
        label_img = imread(label_path)
    else:
        print(f'cannot find labels for {hid}')
    # break

cannot find labels for H8
cannot find labels for H17
cannot find labels for H21
cannot find labels for H35
cannot find labels for H36
cannot find labels for J4
