# Results for Generalization using KL Divergence

In [1]:
import pandas as pd
import numpy as np
import os

# KL Divergence as suggested by the chatGPT usin gentropy
from scipy.stats import entropy

In [2]:
def concat(path, output_files):
    df_dict = {}
    for i, file in enumerate(output_files):
        df_dict[f'df_{i+1}'] = pd.read_csv(path+file)
        df_dict[f'df_{i+1}'].columns = df_dict[f'df_{i+1}'].columns.astype(int)
        
    return df_dict

In [3]:
# Calculate KL divergence
def kl_divergence(p, q, epsilon=1e-9):
    return np.sum(p * np.log2(p / q))

# ST

In [5]:
#Get probability distribution files for weak generalization - Standard set transformer
wg_pt_path1 = './strong_generalization/multi_class_probability_files/st/ModelNet40/Weak_Generalization/64_16_128_3_0.05/'
wg_pt_output_files1 = [file for file in os.listdir(wg_pt_path1) if file.startswith('output_file_') and file.endswith('.csv')]

#Get probability distribution files for strong generalization - Standard set transformer
sg_pt_path1 = './strong_generalization/multi_class_probability_files/st/ModelNet40/Strong_Generalization/64_16_128_3_0.05_bs32_ss256_vss128/'
sg_pt_output_files1 = [file for file in os.listdir(sg_pt_path1) if file.startswith('output_file_') and file.endswith('.csv')]

#Get probability distribution files for weak generalization - contrastive pre-trained model
wg_pt_path2 = './strong_generalization/multi_class_probability_files/st/ModelNet40/Weak_Generalization/masked_encoder_160_512_150_200_64_16_128_3_0.05_True_False_pd1024_ed64_lr0.001_sndl_ST/'
wg_pt_output_files2 = [file for file in os.listdir(wg_pt_path2) if file.startswith('output_file_') and file.endswith('.csv')]

#Get probability distribution files for strong generalization - contrastive pre-trained model
sg_pt_path2 = './strong_generalization/multi_class_probability_files/st/ModelNet40/Strong_Generalization/masked_encoder_160_512_150_200_64_16_128_3_0.05_True_False_pd1024_ed64_lr0.001_sndl_ST_pd128_bs32_ss256_vss128/'
sg_pt_output_files2 = [file for file in os.listdir(sg_pt_path2) if file.startswith('output_file_') and file.endswith('.csv')]

In [6]:
wg_pt_df_dict1 = concat(wg_pt_path1, wg_pt_output_files1)
sg_pt_df_dict1 = concat(sg_pt_path1, sg_pt_output_files1)
wg_pt_df_dict2 = concat(wg_pt_path2, wg_pt_output_files2)
sg_pt_df_dict2 = concat(sg_pt_path2, sg_pt_output_files2)

In [7]:
# Create a 3D matrix from the DataFrames
wg_pt_data_matrix1 = np.array([df1.values for df1 in wg_pt_df_dict1.values()])
sg_pt_data_matrix1 = np.array([df1.values for df1 in sg_pt_df_dict1.values()])
wg_pt_data_matrix2 = np.array([df1.values for df1 in wg_pt_df_dict2.values()])
sg_pt_data_matrix2 = np.array([df1.values for df1 in sg_pt_df_dict2.values()])

In [8]:
print(wg_pt_data_matrix1.shape)
print(sg_pt_data_matrix1.shape)
print(wg_pt_data_matrix2.shape)
print(sg_pt_data_matrix2.shape)

(10, 40, 40)
(10, 40, 40)
(10, 40, 40)
(10, 40, 40)


In [9]:
# Calculate the mean along the first axis (axis=0)
# To combine all individual run
wg_pt_combined_matrix1 = np.mean(wg_pt_data_matrix1, axis=0)
sg_pt_combined_matrix1 = np.mean(sg_pt_data_matrix1, axis=0)
wg_pt_combined_matrix2 = np.mean(wg_pt_data_matrix2, axis=0)
sg_pt_combined_matrix2 = np.mean(sg_pt_data_matrix2, axis=0)

In [10]:
# replacing diagonal values with small number for normalization
np.fill_diagonal(wg_pt_combined_matrix1, 1e-9)
np.fill_diagonal(wg_pt_combined_matrix2, 1e-9)

In [11]:
# Get sum for normalization
wg_pt_matrix_sum1 = pd.DataFrame(wg_pt_combined_matrix1).sum().sum()
print(wg_pt_matrix_sum1)
sg_pt_matrix_sum1 = pd.DataFrame(sg_pt_combined_matrix1).sum().sum()
print(sg_pt_matrix_sum1)

wg_pt_matrix_sum2 = pd.DataFrame(wg_pt_combined_matrix2).sum().sum()
print(wg_pt_matrix_sum2)
sg_pt_matrix_sum2 = pd.DataFrame(sg_pt_combined_matrix2).sum().sum()
print(sg_pt_matrix_sum2)


9.042474178867774
40.00000005002191
8.636569670430092
40.00000013114292


In [12]:
# Matrix Normalization
wg_pt_mean_matrix_partial1 = wg_pt_combined_matrix1/wg_pt_matrix_sum1
sg_pt_mean_matrix_partial1 = sg_pt_combined_matrix1/sg_pt_matrix_sum1

wg_pt_mean_matrix_partial2 = wg_pt_combined_matrix2/wg_pt_matrix_sum2
sg_pt_mean_matrix_partial2 = sg_pt_combined_matrix2/sg_pt_matrix_sum2

In [13]:
# Flatten for KL Divergence
wg_pt_mean_matrix_partial_flatten1 = wg_pt_mean_matrix_partial1.flatten()
sg_pt_mean_matrix_partial_flatten1 = sg_pt_mean_matrix_partial1.flatten()

wg_pt_mean_matrix_partial_flatten2 = wg_pt_mean_matrix_partial2.flatten()
sg_pt_mean_matrix_partial_flatten2 = sg_pt_mean_matrix_partial2.flatten()

In [14]:
# Make sure there is no 0, as KL divergence will be NAN
wg_pt_mean_matrix_partial_flatten1[wg_pt_mean_matrix_partial_flatten1 < 1e-15] = 1e-12
sg_pt_mean_matrix_partial_flatten1[sg_pt_mean_matrix_partial_flatten1 < 1e-15] = 1e-12

wg_pt_mean_matrix_partial_flatten2[wg_pt_mean_matrix_partial_flatten2 < 1e-15] = 1e-12
sg_pt_mean_matrix_partial_flatten2[sg_pt_mean_matrix_partial_flatten2 < 1e-15] = 1e-12

In [15]:
# Calculate KL divergence - Standard Set Transformer
kl_divergence_value_naive = np.sum([kl_divergence(p_row, q_row) for p_row, q_row in zip(wg_pt_mean_matrix_partial_flatten1, sg_pt_mean_matrix_partial_flatten1)])
print("KL Divergence Naive:", kl_divergence_value_naive)

# Calculate KL divergence - Contrastive Pre-trained Model
kl_divergence_value2_cont = np.sum([kl_divergence(p_row, q_row) for p_row, q_row in zip(wg_pt_mean_matrix_partial_flatten2, sg_pt_mean_matrix_partial_flatten2)])
print("KL Divergence Contrastive:", kl_divergence_value2_cont)

KL Divergence Naive: 0.8018927660501509
KL Divergence Contrastive: 0.8093742643299559


In [16]:
# Calculate KL divergence between the two distributions
kl_divergence_array1 = entropy(wg_pt_mean_matrix_partial_flatten1, sg_pt_mean_matrix_partial_flatten1, base=2)

kl_divergence_value1 = np.mean(kl_divergence_array1)  # or np.sum(kl_divergence_array)
print("Overall KL Divergence of Naive:", kl_divergence_value1)

kl_divergence_array2 = entropy(wg_pt_mean_matrix_partial_flatten2, sg_pt_mean_matrix_partial_flatten2, base=2)

kl_divergence_value2 = np.mean(kl_divergence_array2)  # or np.sum(kl_divergence_array)
print("Overall KL Divergence of Contrastive:", kl_divergence_value2)

Overall KL Divergence of Naive: 0.8018927640495719
Overall KL Divergence of Contrastive: 0.8093742622572059


# PCT

In [17]:
#Get probability distribution files for weak generalization - Standard set transformer
wg_pt_path_pct1 = './strong_generalization/multi_class_probability_files/pct/ModelNet40/Weak_Generalization/ed256_0.2/'
wg_pt_output_files_pct1 = [file for file in os.listdir(wg_pt_path_pct1) if file.startswith('output_file_') and file.endswith('.csv')]

#Get probability distribution files for strong generalization - Standard set transformer
sg_pt_path_pct1 = './strong_generalization/multi_class_probability_files/pct/ModelNet40/Strong_Generalization/ed256_0.2_bs32_ss256_vss128/'
sg_pt_output_files_pct1 = [file for file in os.listdir(sg_pt_path_pct1) if file.startswith('output_file_') and file.endswith('.csv')]

#Get probability distribution files for weak generalization - contrastive pre-trained model
wg_pt_path_pct2 = './strong_generalization/multi_class_probability_files/pct/ModelNet40/Weak_Generalization/masked_encoder_160_512_150_200_pd1024_ed256_lr0.001_sndl_PCT_loss/'
wg_pt_output_files_pct2 = [file for file in os.listdir(wg_pt_path_pct2) if file.startswith('output_file_') and file.endswith('.csv')]

#Get probability distribution files for strong generalization - contrastive pre-trained model
sg_pt_path_pct2 = './strong_generalization/multi_class_probability_files/pct/ModelNet40/Strong_Generalization/masked_encoder_160_512_150_200_pd1024_ed256_lr0.001_sndl_PCT_loss_pd1024_bs32_ss256_vss128/'
sg_pt_output_files_pct2 = [file for file in os.listdir(sg_pt_path_pct2) if file.startswith('output_file_') and file.endswith('.csv')]

In [18]:
wg_pt_df_dict_pct1 = concat(wg_pt_path_pct1, wg_pt_output_files_pct1)
sg_pt_df_dict_pct1 = concat(sg_pt_path_pct1, sg_pt_output_files_pct1)
wg_pt_df_dict_pct2 = concat(wg_pt_path_pct2, wg_pt_output_files_pct2)
sg_pt_df_dict_pct2 = concat(sg_pt_path_pct2, sg_pt_output_files_pct2)

In [19]:
# Create a 3D matrix from the DataFrames
wg_pt_data_matrix_pct1 = np.array([df1.values for df1 in wg_pt_df_dict_pct1.values()])
sg_pt_data_matrix_pct1 = np.array([df1.values for df1 in sg_pt_df_dict_pct1.values()])
wg_pt_data_matrix_pct2 = np.array([df1.values for df1 in wg_pt_df_dict_pct2.values()])
sg_pt_data_matrix_pct2 = np.array([df1.values for df1 in sg_pt_df_dict_pct2.values()])

In [20]:
print(wg_pt_data_matrix_pct1.shape)
print(sg_pt_data_matrix_pct1.shape)
print(wg_pt_data_matrix_pct2.shape)
print(sg_pt_data_matrix_pct2.shape)

(10, 40, 40)
(10, 40, 40)
(10, 40, 40)
(10, 40, 40)


In [21]:
# Calculate the mean along the first axis (axis=0)
# To combine all individual run
wg_pt_combined_matrix_pct1 = np.mean(wg_pt_data_matrix_pct1, axis=0)
sg_pt_combined_matrix_pct1 = np.mean(sg_pt_data_matrix_pct1, axis=0)
wg_pt_combined_matrix_pct2 = np.mean(wg_pt_data_matrix_pct2, axis=0)
sg_pt_combined_matrix_pct2 = np.mean(sg_pt_data_matrix_pct2, axis=0)

In [22]:
wg_pt_combined_matrix_pct1.shape

(40, 40)

In [23]:
# replacing diagonal values with small number for normalization
np.fill_diagonal(wg_pt_combined_matrix_pct1, 1e-9)
np.fill_diagonal(wg_pt_combined_matrix_pct2, 1e-9)

In [24]:
# Get sum for normalization
wg_pt_matrix_sum_pct1 = pd.DataFrame(wg_pt_combined_matrix_pct1).sum().sum()
print(wg_pt_matrix_sum_pct1)
sg_pt_matrix_sum_pct1 = pd.DataFrame(sg_pt_combined_matrix_pct1).sum().sum()
print(sg_pt_matrix_sum_pct1)

wg_pt_matrix_sum_pct2 = pd.DataFrame(wg_pt_combined_matrix_pct2).sum().sum()
print(wg_pt_matrix_sum_pct2)
sg_pt_matrix_sum_pct2 = pd.DataFrame(sg_pt_combined_matrix_pct2).sum().sum()
print(sg_pt_matrix_sum_pct2)


7.4093553364625535
40.00000001229662
8.322018361739964
39.99999999075611


In [25]:
# Matrix Normalization
wg_pt_mean_matrix_partial_pct1 = wg_pt_combined_matrix_pct1/wg_pt_matrix_sum_pct1
sg_pt_mean_matrix_partial_pct1 = sg_pt_combined_matrix_pct1/sg_pt_matrix_sum_pct1

wg_pt_mean_matrix_partial_pct2 = wg_pt_combined_matrix_pct2/wg_pt_matrix_sum_pct2
sg_pt_mean_matrix_partial_pct2 = sg_pt_combined_matrix_pct2/sg_pt_matrix_sum_pct2

In [26]:
# Flatten for KL Divergence
wg_pt_mean_matrix_partial_flatten_pct1 = wg_pt_mean_matrix_partial_pct1.flatten()
sg_pt_mean_matrix_partial_flatten_pct1 = sg_pt_mean_matrix_partial_pct1.flatten()

wg_pt_mean_matrix_partial_flatten_pct2 = wg_pt_mean_matrix_partial_pct2.flatten()
sg_pt_mean_matrix_partial_flatten_pct2 = sg_pt_mean_matrix_partial_pct2.flatten()

In [27]:
# Make sure there is no 0, as KL divergence will be NAN
wg_pt_mean_matrix_partial_flatten_pct1[wg_pt_mean_matrix_partial_flatten_pct1 < 1e-15] = 1e-12
sg_pt_mean_matrix_partial_flatten_pct1[sg_pt_mean_matrix_partial_flatten_pct1 < 1e-15] = 1e-12

wg_pt_mean_matrix_partial_flatten_pct2[wg_pt_mean_matrix_partial_flatten_pct2 < 1e-15] = 1e-12
sg_pt_mean_matrix_partial_flatten_pct2[sg_pt_mean_matrix_partial_flatten_pct2 < 1e-15] = 1e-12

In [28]:
# Calculate KL divergence - Standard Set Transformer
kl_divergence_value_pct1 = np.sum([kl_divergence(p_row, q_row) for p_row, q_row in zip(wg_pt_mean_matrix_partial_flatten_pct1, sg_pt_mean_matrix_partial_flatten_pct1)])
print("KL Divergence Naive PCT:", kl_divergence_value_pct1)

# Calculate KL divergence - Contrastive Pre-trained Model
kl_divergence_value_pct2 = np.sum([kl_divergence(p_row, q_row) for p_row, q_row in zip(wg_pt_mean_matrix_partial_flatten_pct2, sg_pt_mean_matrix_partial_flatten_pct2)])
print("KL Divergence Contrastive PCT:", kl_divergence_value_pct2)

KL Divergence Naive PCT: 1.477473745756666
KL Divergence Contrastive PCT: 1.283625666232929


In [29]:
# Calculate KL divergence between the two distributions
kl_divergence_array_pct1 = entropy(wg_pt_mean_matrix_partial_flatten_pct1, sg_pt_mean_matrix_partial_flatten_pct1, base=2)

kl_divergence_value_pct1 = np.mean(kl_divergence_array_pct1)  # or np.sum(kl_divergence_array)
print("Overall KL Divergence of Naive:", kl_divergence_value_pct1)

kl_divergence_array_pct2 = entropy(wg_pt_mean_matrix_partial_flatten_pct2, sg_pt_mean_matrix_partial_flatten_pct2, base=2)

kl_divergence_value_pct2 = np.mean(kl_divergence_array_pct2)  # or np.sum(kl_divergence_array)
print("Overall KL Divergence of Contrastive:", kl_divergence_value_pct2)

Overall KL Divergence of Naive: 1.4774737427043938
Overall KL Divergence of Contrastive: 1.2836256635970325
