# Results for Generalization using KL Divergence for Binned Age from 3D body scan

In [2]:
import pandas as pd
import numpy as np
import os

In [7]:
#Get probability distribution files for weak generalization - Standard set transformer
wg_pt_path1 = './multi_class_probability_files/pct/BinnedAge/Weak_Generalization/ed256_0.05_bs5_ss8000_ep250/'
wg_pt_output_files1 = [file for file in os.listdir(wg_pt_path1) if file.startswith('output_file_') and file.endswith('.csv')]

#Get probability distribution files for strong generalization - Standard set transformer
sg_pt_path1 = './multi_class_probability_files/pct/BinnedAge/Strong_Generalization/ed256_0.05_bs5_ss8000/'
sg_pt_output_files1 = [file for file in os.listdir(sg_pt_path1) if file.startswith('output_file_') and file.endswith('.csv')]

#Get probability distribution files for weak generalization - contrastive pre-trained model
wg_pt_path2 = './multi_class_probability_files/pct/BinnedAge/Weak_Generalization/unmasked_encoder_32_1024_150_200_pd1024_ed256_lr0.0001_dl_PCT_loss_pd1024_bs5_ss8000_ep250/'
wg_pt_output_files2 = [file for file in os.listdir(wg_pt_path2) if file.startswith('output_file_') and file.endswith('.csv')]

#Get probability distribution files for strong generalization - contrastive pre-trained model
sg_pt_path2 = './multi_class_probability_files/pct/BinnedAge/Strong_Generalization/unmasked_encoder_32_1024_150_200_pd1024_ed256_lr0.0001_dl_PCT_loss_pd1024_bs5_ss8000/'
sg_pt_output_files2 = [file for file in os.listdir(sg_pt_path2) if file.startswith('output_file_') and file.endswith('.csv')]

In [8]:
def concat(path, output_files):
    df_dict = {}
    for i, file in enumerate(output_files):
        df_dict[f'df_{i+1}'] = pd.read_csv(path+file)
        df_dict[f'df_{i+1}'].columns = df_dict[f'df_{i+1}'].columns.astype(int)
        
    return df_dict
    
wg_pt_df_dict1 = concat(wg_pt_path1, wg_pt_output_files1)
sg_pt_df_dict1 = concat(sg_pt_path1, sg_pt_output_files1)
wg_pt_df_dict2 = concat(wg_pt_path2, wg_pt_output_files2)
sg_pt_df_dict2 = concat(sg_pt_path2, sg_pt_output_files2)

In [9]:
# Create a 3D matrix from the DataFrames
wg_pt_data_matrix1 = np.array([df1.values for df1 in wg_pt_df_dict1.values()])
sg_pt_data_matrix1 = np.array([df1.values for df1 in sg_pt_df_dict1.values()])
wg_pt_data_matrix2 = np.array([df1.values for df1 in wg_pt_df_dict2.values()])
sg_pt_data_matrix2 = np.array([df1.values for df1 in sg_pt_df_dict2.values()])

In [10]:
print(wg_pt_data_matrix1.shape)
print(sg_pt_data_matrix1.shape)
print(wg_pt_data_matrix2.shape)
print(sg_pt_data_matrix2.shape)

(10, 4, 4)
(10, 4, 4)
(10, 4, 4)
(10, 4, 4)


In [11]:
# Calculate the mean along the first axis (axis=0)
# To combine all individual run
wg_pt_combined_matrix1 = np.mean(wg_pt_data_matrix1, axis=0)
sg_pt_combined_matrix1 = np.mean(sg_pt_data_matrix1, axis=0)
wg_pt_combined_matrix2 = np.mean(wg_pt_data_matrix2, axis=0)
sg_pt_combined_matrix2 = np.mean(sg_pt_data_matrix2, axis=0)

In [12]:
# replacing diagonal values with small number for normalization
np.fill_diagonal(wg_pt_combined_matrix1, 1e-9)
np.fill_diagonal(wg_pt_combined_matrix2, 1e-9)

In [13]:
# Get sum for normalization
wg_pt_matrix_sum1 = pd.DataFrame(wg_pt_combined_matrix1).sum().sum()
print(wg_pt_matrix_sum1)
sg_pt_matrix_sum1 = pd.DataFrame(sg_pt_combined_matrix1).sum().sum()
print(sg_pt_matrix_sum1)

wg_pt_matrix_sum2 = pd.DataFrame(wg_pt_combined_matrix2).sum().sum()
print(wg_pt_matrix_sum2)
sg_pt_matrix_sum2 = pd.DataFrame(sg_pt_combined_matrix2).sum().sum()
print(sg_pt_matrix_sum2)


0.8447227986883679
3.999999998473946
4e-09
4.000000008641688


In [14]:
# Matrix Normalization
wg_pt_mean_matrix_partial1 = wg_pt_combined_matrix1/wg_pt_matrix_sum1
sg_pt_mean_matrix_partial1 = sg_pt_combined_matrix1/sg_pt_matrix_sum1

wg_pt_mean_matrix_partial2 = wg_pt_combined_matrix2/wg_pt_matrix_sum2
sg_pt_mean_matrix_partial2 = sg_pt_combined_matrix2/sg_pt_matrix_sum2

In [15]:
# Flatten for KL Divergence
wg_pt_mean_matrix_partial_flatten1 = wg_pt_mean_matrix_partial1.flatten()
sg_pt_mean_matrix_partial_flatten1 = sg_pt_mean_matrix_partial1.flatten()

wg_pt_mean_matrix_partial_flatten2 = wg_pt_mean_matrix_partial2.flatten()
sg_pt_mean_matrix_partial_flatten2 = sg_pt_mean_matrix_partial2.flatten()

In [16]:
# Make sure there is no 0, as KL divergence will be NAN
wg_pt_mean_matrix_partial_flatten1[wg_pt_mean_matrix_partial_flatten1 < 1e-15] = 1e-12
sg_pt_mean_matrix_partial_flatten1[sg_pt_mean_matrix_partial_flatten1 < 1e-15] = 1e-12

wg_pt_mean_matrix_partial_flatten2[wg_pt_mean_matrix_partial_flatten2 < 1e-15] = 1e-12
sg_pt_mean_matrix_partial_flatten2[sg_pt_mean_matrix_partial_flatten2 < 1e-15] = 1e-12

In [17]:
# Calculate KL divergence
def kl_divergence(p, q, epsilon=1e-9):
    return np.sum(p * np.log2(p / q))

In [18]:
# Calculate KL divergence - Standard Set Transformer
kl_divergence_value1 = np.sum([kl_divergence(p_row, q_row) for p_row, q_row in zip(wg_pt_mean_matrix_partial_flatten1, sg_pt_mean_matrix_partial_flatten1)])
print("KL Divergence:", kl_divergence_value1)

# Calculate KL divergence - Contrastive Pre-trained Model
kl_divergence_value2 = np.sum([kl_divergence(p_row, q_row) for p_row, q_row in zip(wg_pt_mean_matrix_partial_flatten2, sg_pt_mean_matrix_partial_flatten2)])
print("KL Divergence:", kl_divergence_value2)

KL Divergence: 0.8491089167624167
KL Divergence: 37.863137138219194


In [19]:
# KL Divergence as suggested by the chatGPT usin gentropy
from scipy.stats import entropy

# Calculate KL divergence between the two distributions
kl_divergence_array1 = entropy(wg_pt_mean_matrix_partial_flatten1, sg_pt_mean_matrix_partial_flatten1, base=2)

kl_divergence_value1 = np.mean(kl_divergence_array1)  # or np.sum(kl_divergence_array)
print("Overall KL Divergence of Naive:", kl_divergence_value1)

kl_divergence_array2 = entropy(wg_pt_mean_matrix_partial_flatten2, sg_pt_mean_matrix_partial_flatten2, base=2)

kl_divergence_value2 = np.mean(kl_divergence_array2)  # or np.sum(kl_divergence_array)
print("Overall KL Divergence of Contrastive:", kl_divergence_value2)

Overall KL Divergence of Naive: 0.8491089167544367
Overall KL Divergence of Contrastive: 37.863137137753306


# ST

In [20]:
#Get probability distribution files for weak generalization - Standard set transformer
wg_pt_path_pct1 = './multi_class_probability_files/st/binned_age/Weak_Generalization/64_16_128_3_0.05/'
wg_pt_output_files_pct1 = [file for file in os.listdir(wg_pt_path_pct1) if file.startswith('output_file_') and file.endswith('.csv')]

#Get probability distribution files for strong generalization - Standard set transformer
sg_pt_path_pct1 = './multi_class_probability_files/st/binned_age/Strong_Generalization/64_16_128_3_0.05/'
sg_pt_output_files_pct1 = [file for file in os.listdir(sg_pt_path_pct1) if file.startswith('output_file_') and file.endswith('.csv')]

#Get probability distribution files for weak generalization - contrastive pre-trained model
wg_pt_path_pct2 = './multi_class_probability_files/st/binned_age/Weak_Generalization/masked_encoder_64_16_128_3_0.05/'
wg_pt_output_files_pct2 = [file for file in os.listdir(wg_pt_path_pct2) if file.startswith('output_file_') and file.endswith('.csv')]

#Get probability distribution files for strong generalization - contrastive pre-trained model
sg_pt_path_pct2 = './multi_class_probability_files/st/binned_age/Strong_Generalization/masked_encoder_64_16_128_3_0.05/'
sg_pt_output_files_pct2 = [file for file in os.listdir(sg_pt_path_pct2) if file.startswith('output_file_') and file.endswith('.csv')]

In [21]:
wg_pt_df_dict_pct1 = concat(wg_pt_path_pct1, wg_pt_output_files_pct1)
sg_pt_df_dict_pct1 = concat(sg_pt_path_pct1, sg_pt_output_files_pct1)
wg_pt_df_dict_pct2 = concat(wg_pt_path_pct2, wg_pt_output_files_pct2)
sg_pt_df_dict_pct2 = concat(sg_pt_path_pct2, sg_pt_output_files_pct2)

In [22]:
# Create a 3D matrix from the DataFrames
wg_pt_data_matrix_pct1 = np.array([df1.values for df1 in wg_pt_df_dict_pct1.values()])
sg_pt_data_matrix_pct1 = np.array([df1.values for df1 in sg_pt_df_dict_pct1.values()])
wg_pt_data_matrix_pct2 = np.array([df1.values for df1 in wg_pt_df_dict_pct2.values()])
sg_pt_data_matrix_pct2 = np.array([df1.values for df1 in sg_pt_df_dict_pct2.values()])

In [23]:
print(wg_pt_data_matrix_pct1.shape)
print(sg_pt_data_matrix_pct1.shape)
print(wg_pt_data_matrix_pct2.shape)
print(sg_pt_data_matrix_pct2.shape)

(10, 4, 4)
(10, 4, 4)
(10, 4, 4)
(10, 4, 4)


In [24]:
# Calculate the mean along the first axis (axis=0)
# To combine all individual run
wg_pt_combined_matrix_pct1 = np.mean(wg_pt_data_matrix_pct1, axis=0)
sg_pt_combined_matrix_pct1 = np.mean(sg_pt_data_matrix_pct1, axis=0)
wg_pt_combined_matrix_pct2 = np.mean(wg_pt_data_matrix_pct2, axis=0)
sg_pt_combined_matrix_pct2 = np.mean(sg_pt_data_matrix_pct2, axis=0)

In [25]:
wg_pt_combined_matrix_pct1.shape

(4, 4)

In [26]:
# replacing diagonal values with small number for normalization
np.fill_diagonal(wg_pt_combined_matrix_pct1, 1e-9)
np.fill_diagonal(wg_pt_combined_matrix_pct2, 1e-9)

In [27]:
# Get sum for normalization
wg_pt_matrix_sum_pct1 = pd.DataFrame(wg_pt_combined_matrix_pct1).sum().sum()
print(wg_pt_matrix_sum_pct1)
sg_pt_matrix_sum_pct1 = pd.DataFrame(sg_pt_combined_matrix_pct1).sum().sum()
print(sg_pt_matrix_sum_pct1)

wg_pt_matrix_sum_pct2 = pd.DataFrame(wg_pt_combined_matrix_pct2).sum().sum()
print(wg_pt_matrix_sum_pct2)
sg_pt_matrix_sum_pct2 = pd.DataFrame(sg_pt_combined_matrix_pct2).sum().sum()
print(sg_pt_matrix_sum_pct2)


1.2005171422324274
3.9999999875646353
0.025810461601702223
4.000000006680706


In [28]:
# Matrix Normalization
wg_pt_mean_matrix_partial_pct1 = wg_pt_combined_matrix_pct1/wg_pt_matrix_sum_pct1
sg_pt_mean_matrix_partial_pct1 = sg_pt_combined_matrix_pct1/sg_pt_matrix_sum_pct1

wg_pt_mean_matrix_partial_pct2 = wg_pt_combined_matrix_pct2/wg_pt_matrix_sum_pct2
sg_pt_mean_matrix_partial_pct2 = sg_pt_combined_matrix_pct2/sg_pt_matrix_sum_pct2

In [29]:
# Flatten for KL Divergence
wg_pt_mean_matrix_partial_flatten_pct1 = wg_pt_mean_matrix_partial_pct1.flatten()
sg_pt_mean_matrix_partial_flatten_pct1 = sg_pt_mean_matrix_partial_pct1.flatten()

wg_pt_mean_matrix_partial_flatten_pct2 = wg_pt_mean_matrix_partial_pct2.flatten()
sg_pt_mean_matrix_partial_flatten_pct2 = sg_pt_mean_matrix_partial_pct2.flatten()

In [30]:
# Make sure there is no 0, as KL divergence will be NAN
wg_pt_mean_matrix_partial_flatten_pct1[wg_pt_mean_matrix_partial_flatten_pct1 < 1e-15] = 1e-12
sg_pt_mean_matrix_partial_flatten_pct1[sg_pt_mean_matrix_partial_flatten_pct1 < 1e-15] = 1e-12

wg_pt_mean_matrix_partial_flatten_pct2[wg_pt_mean_matrix_partial_flatten_pct2 < 1e-15] = 1e-12
sg_pt_mean_matrix_partial_flatten_pct2[sg_pt_mean_matrix_partial_flatten_pct2 < 1e-15] = 1e-12

In [31]:
# Calculate KL divergence - Standard Set Transformer
kl_divergence_value_pct1 = np.sum([kl_divergence(p_row, q_row) for p_row, q_row in zip(wg_pt_mean_matrix_partial_flatten_pct1, sg_pt_mean_matrix_partial_flatten_pct1)])
print("KL Divergence Naive PCT:", kl_divergence_value_pct1)

# Calculate KL divergence - Contrastive Pre-trained Model
kl_divergence_value_pct2 = np.sum([kl_divergence(p_row, q_row) for p_row, q_row in zip(wg_pt_mean_matrix_partial_flatten_pct2, sg_pt_mean_matrix_partial_flatten_pct2)])
print("KL Divergence Contrastive PCT:", kl_divergence_value_pct2)

KL Divergence Naive PCT: 0.32651569000050235
KL Divergence Contrastive PCT: 2.169921924516986


In [138]:
# Calculate KL divergence between the two distributions
kl_divergence_array_pct1 = entropy(wg_pt_mean_matrix_partial_flatten_pct1, sg_pt_mean_matrix_partial_flatten_pct1, base=2)

kl_divergence_value_pct1 = np.mean(kl_divergence_array_pct1)  # or np.sum(kl_divergence_array)
print("Overall KL Divergence of Naive:", kl_divergence_value_pct1)

kl_divergence_array_pct2 = entropy(wg_pt_mean_matrix_partial_flatten_pct2, sg_pt_mean_matrix_partial_flatten_pct2, base=2)

kl_divergence_value_pct2 = np.mean(kl_divergence_array_pct2)  # or np.sum(kl_divergence_array)
print("Overall KL Divergence of Contrastive:", kl_divergence_value_pct2)

Overall KL Divergence of Naive: 0.32651569000273467
Overall KL Divergence of Contrastive: 2.169921924490244
