In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch

In [None]:
layer_sizes = [4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 32, 64, 128]

seed = 33
ntest = 2000
activations = [torch.load(f"trained_activations/mlp_xor_activations_seed{seed}_nhidden{i}_ntest{ntest}.pt", map_location=torch.device('cpu')) for i in layer_sizes]

In [None]:
# Centered representations
internal_reps = [activation["relu"] - activation["relu"].mean(dim=0) for activation in activations]


## Geometric similarity (Centered Kernel Alignment)

In [None]:
# Compute linear kernel matrices
Ks = [(1/internal_reps[i].shape[1]) * internal_reps[i] @ internal_reps[i].T for i in range(len(internal_reps))]

# Compute distance matrix using CKA with linear kernel
dist_mat = np.zeros((len(internal_reps),len(internal_reps)))

for i in range(len(internal_reps)):
    for j in range(len(internal_reps)):
        Ki = Ks[i]
        Kj = Ks[j]
        dist_mat[i,j] = np.arccos( np.linalg.trace(Ki@Kj) / (np.linalg.norm(Ki, ord='fro')*np.linalg.norm(Kj, ord='fro')) )
        # dist_mat[i,j] = np.linalg.trace(resultDict["Z_pred"][i].T @ resultDict["Z_pred"][j])
        # dist_mat[i,j] = np.linalg.norm(resultDict["Z_pred"][i]/np.linalg.norm(resultDict["Z_pred"][i], ord='fro') - resultDict["Z_pred"][j]/np.linalg.norm(resultDict["Z_pred"][j], ord='fro'), ord='fro')
        # dist_mat[i,j] = np.linalg.norm(Ki - Kj, ord='fro')
        # dist_mat[i,j] = np.arccos(np.linalg.trace(plot_list[i].T @ plot_list[j])/(np.linalg.norm(plot_list[i], ord='fro')*np.linalg.norm(plot_list[j], ord='fro')))
    print(str(i))
dist_mat = np.nan_to_num(dist_mat)

                                                                              

In [None]:
plt.imshow(dist_mat)
plt.colorbar()
# plt.clim(0, np.pi/2)
plt.xticks(ticks=np.arange(len(layer_sizes)), labels=layer_sizes)
plt.yticks(ticks=np.arange(len(layer_sizes)), labels=layer_sizes)
plt.xlabel("Number of hidden units")
plt.ylabel("Number of hidden units")
plt.title("Representation distance matrix (CKA linear kernel)")
# plt.savefig("mlp_xor_rep_distance_matrix_linear_kernel.png")

In [None]:
# MDS-PCA embedding for visualization

from sklearn.manifold import MDS
from sklearn.decomposition import PCA

embedding = MDS(n_components=100, metric= True, eps = 0.00001, normalized_stress='auto', dissimilarity='precomputed')
Z = embedding.fit_transform(np.abs(np.real(dist_mat)))
print("Embedding stress: ", embedding.stress_)

pca = PCA(n_components=3)
pcs1 = pca.fit_transform(Z)
print("Explained variance ratio sum: ", np.sum(pca.explained_variance_ratio_))

In [None]:
# Interactive 3D Scatter plot of representations in the low-d space using Plotly

import plotly
import plotly.graph_objs as go

# Configure Plotly to be rendered inline in the notebook.
plotly.offline.init_notebook_mode()

# Configure the trace.
trace1 = go.Scatter3d(
    x=pcs1[:,0], 
    y=pcs1[:,1],  
    z=pcs1[:,2],  
    mode='markers+text',
    marker={
        'size': 10,
        'opacity': 0.8
    }, text=[str(size) for size in layer_sizes], textposition='top center'
)

# Configure the layout
layout = go.Layout(
    margin={'l': 0, 'r': 0, 'b': 0, 't': 0}
)

data = [trace1]
# Create the figure
plot_figure = go.Figure(data=data, layout=layout)

# Render the plot
plotly.offline.iplot(plot_figure)

## Decoding-based comparison

In [None]:
# Attempt to decode the irrelvant bit from the hidden representations

import sys
sys.path.append('..')
import dsutils
from sklearn.preprocessing import StandardScaler
np.random.seed(41) 

eval_on_test_set = True
test_set_size = 400

# Make sure test_indices doesn't exceed the size of your data
max_samples = internal_reps[0].shape[0]
test_set_size = min(test_set_size, max_samples)
test_indices = np.random.choice(np.arange(max_samples), test_set_size, replace=False)

param_grid = {
    'a': [1],
    # 'a' : list(np.logspace(-5, 4, num=10, endpoint=True, base=10.0, dtype=None, axis=0)),
    'b' : list(np.logspace(-6, 6, num=30, endpoint=True, base=10.0, dtype=None, axis=0)),
    'gamma': [1]
    # 'gamma' : list(np.logspace(-6, 6, num=15, endpoint=True, base=10.0, dtype=None, axis=0))
}

# # Create all combinations
# param_combos = list(product(*param_grid.values()))

kernel = 'linear'


resultDict = {}
resultDict["Z_pred"] = []
resultDict["R2"] = []
resultDict["best_params"] = []
resultDict["Z"] = []
resultDict["layer_names"] =[str(size) for size in layer_sizes]
resultDict["kernel"] = kernel

if kernel == 'linear':
    resultDict["reg_weights"] = []

for i in range(len(internal_reps)):
    
    # Set decoding target 
    Zfull = activations[i]["test_set"]["X_test"][:,2].cpu().numpy() - np.mean(activations[i]["test_set"]["X_test"][:,2].cpu().numpy())
    # Zfull = activations[i]["test_set"]["y_test"].cpu().numpy() - np.mean(activations[i]["test_set"]["y_test"].cpu().numpy())  # Centered target
    
    # Split data into train and test
    if eval_on_test_set:
        Z_test = Zfull[test_indices]
        X_full = internal_reps[i]
        X_test = X_full[test_indices]
        X = np.delete(X_full, test_indices, axis=0)
        Z = np.delete(Zfull, test_indices, axis=0)
    else:
        Z_test = Zfull
        Z = Zfull
        X_full = internal_reps[i]
        X = X_full
        X_test = X_full

    # Fit scaler on training data only and transform both train and test
    scaler = StandardScaler()
    X = scaler.fit_transform(X)  # Fit and transform training data
    X_test = scaler.transform(X_test)  # Transform test data using training statistics

    M = X.shape[0]

    # print("Condition number of X^T X:", np.linalg.cond(X.T @ X))
    
    best_params, best_score, X_train, Z_train, all_params_coefs = dsutils.cross_val_score_custom(dsutils.genKernelRegression,
                                                     X, Z, param_grid, loss_fn=dsutils.mse_loss, cv=5, kernel=kernel)
    
    print("Best Params:", best_params)
    print("Best CV Loss:", best_score)

    probe = dsutils.genKernelRegression(center_columns=True, kernel=kernel, a=best_params['a'], 
                                      b=best_params['b'], gamma=best_params['gamma'], fit_intercept=False)
    probe.fit(X, Z)
    
    # Predict and evaluate
    Z_pred = probe.predict(X_test)
    R2 = probe.score(X_test, Z_test)
    
    # Evaluate
    print("R^2 score:", R2)

    resultDict["Z_pred"].append(Z_pred)
    resultDict["R2"].append(R2)
    resultDict["best_params"].append(best_params)
    resultDict["Z"].append(Zfull)
    if kernel == 'linear':
        reg_weights = probe.Xtrain.T @ probe.coef_
        resultDict["reg_weights"].append(reg_weights)


    print(str(i) +  '/' + str(len(internal_reps)))




In [None]:
pred_accuracy = []

for i in range(len(internal_reps)):

    bit_value = (activations[i]["test_set"]["X_test"][:,2] > 0.5).float()
    test_bits = bit_value[test_indices].cpu().numpy()
    predicted_bit = (resultDict["Z_pred"][i] > 0.0).astype(float)
    from sklearn.metrics import accuracy_score, confusion_matrix
    accuracy = accuracy_score(test_bits, predicted_bit)
    conf_matrix = confusion_matrix(test_bits, predicted_bit)
    pred_accuracy.append(accuracy)  
    print("Accuracy:", accuracy)
    print("Confusion Matrix:\n", conf_matrix)

In [None]:
# Plot Accuracy vs hidden layer size
hidden_layer_sizes = [4,5,6,7,8,9,10,11,12,13,14,15,16,32,64,128]
plt.figure(figsize=(8,6))
plt.plot(hidden_layer_sizes, pred_accuracy, marker='o')
plt.xscale('log', base=2)
plt.xlabel('Hidden Layer Size')
plt.ylabel('Accuracy')
plt.title('Accuracy vs Hidden Layer Size for XOR MLPs')
plt.grid(True)
plt.show()

In [None]:
plot_list = resultDict["Z_pred"]

dist_mat = np.zeros((len(plot_list),len(plot_list)))

for i in range(len(plot_list)):
    for j in range(len(plot_list)):
        # dist_mat[i,j] = np.linalg.trace(resultDict["Z_pred"][i].T @ resultDict["Z_pred"][j])
        # dist_mat[i,j] = np.linalg.norm(resultDict["Z_pred"][i]/np.linalg.norm(resultDict["Z_pred"][i], ord='fro') - resultDict["Z_pred"][j]/np.linalg.norm(resultDict["Z_pred"][j], ord='fro'), ord='fro')
        dist_mat[i,j] = np.linalg.norm(plot_list[i] - plot_list[j])/np.linalg.norm(resultDict["Z"][0][test_indices])
        # dist_mat[i,j] = np.arccos(np.inner(plot_list[i], plot_list[j]) / (np.linalg.norm(plot_list[i])*np.linalg.norm(plot_list[j])))

dist_mat = np.nan_to_num(dist_mat)

                                                                              

In [None]:
plt.imshow(dist_mat)
plt.colorbar()
plt.clim(0, 1)
plt.xticks(ticks=np.arange(len(layer_sizes)), labels=layer_sizes)
plt.yticks(ticks=np.arange(len(layer_sizes)), labels=layer_sizes)
plt.xlabel("Number of hidden units")
plt.ylabel("Number of hidden units")
plt.title("Representation distance matrix (linear kernel decoding distance)")
# plt.savefig("mlp_xor_rep_distance_matrix_linear_kernel.png")

In [None]:
# MDS-PCA Embedding

from sklearn.manifold import MDS
from sklearn.decomposition import PCA

embedding = MDS(n_components=100, metric= True, eps = 0.00001, normalized_stress='auto', dissimilarity='precomputed')
Z = embedding.fit_transform(np.abs(np.real(dist_mat)))
print("Embedding stress:", embedding.stress_)

pca = PCA(n_components=3)
pcs1 = pca.fit_transform(Z)
print("Explained variance ratio sum:", np.sum(pca.explained_variance_ratio_))

In [None]:
import plotly
import plotly.graph_objs as go

# Configure Plotly to be rendered inline in the notebook.
plotly.offline.init_notebook_mode()

# Nlayers2 = len(repDict[model_names[1]])

# Configure the trace.
trace1 = go.Scatter3d(
    x=pcs1[:,0],  # <-- Put your data instead
    y=pcs1[:,1],  # <-- Put your data instead
    z=pcs1[:,2],  # <-- Put your data instead
    mode='markers+text',
    marker={
        'size': 10,
        'opacity': 0.8
    }, text=['4','5','6','7','8','9','10','11', '12', '13', '14', '15', '16','32','64', '128'], textposition='top center'
)


# Configure the layout.
layout = go.Layout(
    margin={'l': 0, 'r': 0, 'b': 0, 't': 0}
)

data = [trace1]


plot_figure = go.Figure(data=data, layout=layout)

# Render the plot.
plotly.offline.iplot(plot_figure)

In [None]:
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
from matplotlib.animation import FuncAnimation

gammas = [resultDict["best_params"][i]["gamma"] for i in range(len(resultDict["best_params"]))]
bs = [resultDict["best_params"][i]["b"] for i in range(len(resultDict["best_params"]))]

# Create the figure and axes object
fig = plt.figure(figsize=(9, 9))
ax = fig.add_subplot(111)

num_points = len(pcs1)
x = pcs1[:,0]
y = pcs1[:,1]

layer_names = resultDict["layer_names"]

for i, txt in enumerate(layer_names):
    ax.text(pcs1[i,0], pcs1[i,1],txt,fontsize=12,alpha=0.7)

colors = np.array(resultDict["R2"])
# colors = np.array(np.log(gammas))
# colors = np.array(np.log(bs))

scatter = plt.scatter(x,y, c=colors, cmap='viridis',s=90)
plt.colorbar(scatter, label='RÂ² Score')
plt.xlabel('MDS-PCA Dimension 1')
plt.ylabel('MDS-PCA Dimension 2')
plt.title('2D MDS-PCA Embedding of distance matrix')